diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index de4fded6ae6e66995aa9f1687a9d598017416f7a..3dad41a88c8212b7445c32f241d887306d3c19ad 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -41,7 +41,7 @@ TensorFlow coding style. #### General guidelines and philosophy for contribution * Include unit tests when you contribute new features, as they help to - a) prove that your code works correctly, b) guard against future breaking + a) prove that your code works correctly, and b) guard against future breaking changes to lower the maintenance cost. * Bug fixes also generally require unit tests, because the presence of bugs usually indicates insufficient test coverage. @@ -51,7 +51,7 @@ TensorFlow coding style. non-backward-compatible API changes without a major release. Reviewers of your pull request will comment on any API compatibility issues. * When you contribute a new feature to TensorFlow, the maintenance burden is (by - default) transferred to the TensorFlow team. This means that benefit of + default) transferred to the TensorFlow team. This means that benefit of the contribution must be compared against the cost of maintaining the feature. * Full new features (e.g., a new op implementing a cutting-edge algorithm) typically will live in @@ -68,8 +68,8 @@ Include a license at the top of new files. * [Java license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/Graph.java#L1) * [Go license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/operation.go#L1) * [Bash license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/ci_build/ci_sanity.sh#L2) -* [HTML license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/dist/index.html#L2) -* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/components/tf_backend/backend.ts#L1) +* [HTML license example](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/components/tf_backend/tf-backend.html#L2) +* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/components/tf_backend/backend.ts#L1) Bazel BUILD files also need to include a license section, e.g., [BUILD example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/BUILD#L61). @@ -163,7 +163,7 @@ There are two ways to run TensorFlow unit tests. bazel test ${flags} //tensorflow/python/... ``` -2. Using [Docker](www.docker.com) and TensorFlow's CI scripts. +2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts. ```bash # Install Docker first, then this will build and run cpu tests diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 1a401997c649518766acb2ebb0dea1c128bd0ba4..2f3df7cda9cec29ed0c2266629022f0a22b37df9 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -4,7 +4,7 @@ https://stackoverflow.com/questions/tagged/tensorflow If you open a GitHub issue, here is our policy: -1. It must be a bug or a feature request. +1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead). 2. The form below must be filled out. 3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues). diff --git a/README.md b/README.md index 0c93813e584d4e41fe80d50e047069b2dad8311a..ef5bdc66ef03131318e1dde627e0224cca9137fd 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,10 @@ ----------------- -| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | -|-----------------|---------------------|------------------|-------------------|---------------| -| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | + +| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | +|-----------------|---------------------|------------------|-------------------|---------------|---------------| +| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while @@ -21,20 +22,6 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. -**If you want to contribute to TensorFlow, be sure to review the [contribution -guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's -[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to -uphold this code.** - -**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for -tracking requests and bugs. So please see -[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions -and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** - -The TensorFlow project strives to abide by generally accepted best practices in open-source software development: - -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) - ## Installation *See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* @@ -75,6 +62,22 @@ $ python >>> sess.close() ``` +## Contribution guidelines + +**If you want to contribute to TensorFlow, be sure to review the [contribution +guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's +[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to +uphold this code.** + +**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for +tracking requests and bugs. So please see +[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions +and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** + +The TensorFlow project strives to abide by generally accepted best practices in open-source software development: + +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) + ## For more information * [TensorFlow Website](https://www.tensorflow.org) diff --git a/RELEASE.md b/RELEASE.md index fdf10407fda21444f1d0ee6cf20650d2659b146f..6f54dee58f75c29a16545ba25de12fe059baf1eb 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,9 +1,98 @@ +# Release 1.6.0 + +## Breaking Changes +* Prebuilt binaries are now built against CUDA 9.0 and cuDNN 7. +* Prebuilt binaries will use AVX instructions. This may break TF on older CPUs. + +## Major Features And Improvements +* New Optimizer internal API for non-slot variables. Descendants of AdamOptimizer that access _beta[12]_power will need to be updated. +* `tf.estimator.{FinalExporter,LatestExporter}` now export stripped SavedModels. This improves forward compatibility of the SavedModel. +* FFT support added to XLA CPU/GPU. + +## Bug Fixes and Other Changes +* Documentation updates: + * Added a second version of Getting Started, which is aimed at ML +newcomers. + * Clarified documentation on `resize_images.align_corners` parameter. + * Additional documentation for TPUs. +* Google Cloud Storage (GCS): + * Add client-side throttle. + * Add a `FlushCaches()` method to the FileSystem interface, with an implementation for GcsFileSystem. +* Other: + * Add `tf.contrib.distributions.Kumaraswamy`. + * `RetryingFileSystem::FlushCaches()` calls the base FileSystem's `FlushCaches()`. + * Add `auto_correlation` to distributions. + * Add `tf.contrib.distributions.Autoregressive`. + * Add SeparableConv1D layer. + * Add convolutional Flipout layers. + * When both inputs of `tf.matmul` are bfloat16, it returns bfloat16, instead of float32. + * Added `tf.contrib.image.connected_components`. + * Add `tf.contrib.framework.CriticalSection` that allows atomic variable access. + * Output variance over trees predictions for classifications tasks. + * For `pt` and `eval` commands, allow writing tensor values to filesystem as numpy files. + * gRPC: Propagate truncated errors (instead of returning gRPC internal error). + * Augment `parallel_interleave` to support 2 kinds of prefetching. + * Improved XLA support for C64-related ops log, pow, atan2, tanh. + * Add probabilistic convolutional layers. + +## API Changes +* Introducing `prepare_variance` boolean with default setting to False for backward compatibility. +* Move `layers_dense_variational_impl.py` to `layers_dense_variational.py`. + +## Known Bugs +* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or + `CUDA_ILLEGAL_ADDRESS` failures. + + Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9 + and CUDA 9.1 sometimes does not properly compute the carry bit when + decomposing 64-bit address calculations with large offsets (e.g. `load [x + + large_constant]`) into 32-bit arithmetic in SASS. + + As a result, these versions of `ptxas` miscompile most XLA programs which use + more than 4GB of temp memory. This results in garbage results and/or + `CUDA_ERROR_ILLEGAL_ADDRESS` failures. + + A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a + fix for CUDA 9.0.x. Until the fix is available, the only workaround is to + [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x + or disable XLA:GPU. + + TensorFlow will print a warning if you use XLA:GPU with a known-bad version of + CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Ag Ramesh, Aiden Scandella, Akimasa Kimura, Alex Rothberg, Allen Goodman, +amilioto, Andrei Costinescu, Andrei Nigmatulin, Anjum Sayed, Anthony Platanios, +Anush Elangovan, Armando Fandango, Ashish Kumar Ram, Ashwini Shukla, Ben, Bhavani Subramanian, +Brett Koonce, Carl Thomé, cclauss, Cesc, Changming Sun, Christoph Boeddeker, Clayne Robison, +Clemens Schulz, Clint (Woonhyuk Baek), codrut3, Cole Gerdemann, Colin Raffel, Daniel Trebbien, +Daniel Ylitalo, Daniel Zhang, Daniyar, Darjan Salaj, Dave Maclachlan, David Norman, Dong--Jian, +dongsamb, dssgsra, Edward H, eladweiss, elilienstein, Eric Lilienstein, error.d, Eunji Jeong, fanlu, +Florian Courtial, fo40225, Fred, Gregg Helt, Guozhong Zhuang, Hanchen Li, hsm207, hyunyoung2, +ImSheridan, Ishant Mrinal Haloi, Jacky Ko, Jay Young, Jean Flaherty, Jerome, JerrikEph, Jesse +Kinkead, jfaath, Jian Lin, jinghuangintel, Jiongyan Zhang, Joel Hestness, Joel Shor, Johnny Chan, +Julian Niedermeier, Julian Wolff, JxKing, K-W-W, Karl Lessard, Kasper Marstal, Keiji Ariyama, +Koan-Sin Tan, Loki Der Quaeler, Loo Rong Jie, Luke Schaefer, Lynn Jackson, ManHyuk, Matt Basta, +Matt Smith, Matthew Schulkind, Michael, michaelkhan3, Miguel Piedrafita, Mikalai Drabovich, +Mike Knapp, mjwen, mktozk, Mohamed Aly, Mohammad Ashraf Bhuiyan, Myungjoo Ham, Naman Bhalla, +Namrata-Ibm, Nathan Luehr, nathansilberman, Netzeband, Niranjan Hasabnis, Omar Aflak, Ozge +Yalcinkaya, Parth P Panchal, patrickzzy, Patryk Chrabaszcz, Paul Van Eck, Paweł Kapica, Peng Yu, +Philip Yang, Pierre Blondeau, Po-Hsien Chu, powderluv, Puyu Wang, Rajendra Arora, Rasmus, Renat +Idrisov, resec, Robin Richtsfeld, Ronald Eddy Jr, Sahil Singh, Sam Matzek, Sami Kama, sandipmgiri, +Santiago Castro, Sayed Hadi Hashemi, Scott Tseng, Sergii Khomenko, Shahid, Shengpeng Liu, Shreyash +Sharma, Shrinidhi Kl, Simone Cirillo, simsicon, Stanislav Levental, starsblinking, Stephen Lumenta, +Steven Hickson, Su Tang, Taehoon Lee, Takuya Wakisaka, Ted Chang, Ted Ying, Tijmen Verhulsdonck, +Timofey Kondrashov, vade, vaibhav, Valentin Khrulkov, vchigrin, Victor Costan, Viraj Navkal, +Vivek Rane, wagonhelm, Yan Facai (颜发才), Yanbo Liang, Yaroslav Bulatov, yegord, Yong Tang, +Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田传武 + # Release 1.5.0 ## Breaking Changes -* Prebuilt binaries are now built against CUDA 9 and cuDNN 7. -* Our Linux binaries are built using ubuntu 16 containers, potentially - introducing glibc incompatibility issues with ubuntu 14. +* Prebuilt binaries are now built against CUDA 9.0 and cuDNN 7. * Starting from 1.6 release, our prebuilt binaries will use AVX instructions. This may break TF on older CPUs. @@ -12,7 +101,7 @@ preview version is now available. * [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite) dev preview is now available. -* CUDA 9 and cuDNN 7 support. +* CUDA 9.0 and cuDNN 7 support. * Accelerated Linear Algebra (XLA): * Add `complex64` support to XLA compiler. * `bfloat` support is now added to XLA infrastructure. @@ -125,6 +214,27 @@ * Minor refactor: move stats files from `stochastic` to `common` and remove `stochastic`. +## Known Bugs +* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or + `CUDA_ILLEGAL_ADDRESS` failures. + + Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9 + and CUDA 9.1 sometimes does not properly compute the carry bit when + decomposing 64-bit address calculations with large offsets (e.g. `load [x + + large_constant]`) into 32-bit arithmetic in SASS. + + As a result, these versions of `ptxas` miscompile most XLA programs which use + more than 4GB of temp memory. This results in garbage results and/or + `CUDA_ERROR_ILLEGAL_ADDRESS` failures. + + A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a + fix for CUDA 9.0.x. Until the fix is available, the only workaround is to + [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x + or disable XLA:GPU. + + TensorFlow will print a warning if you use XLA:GPU with a known-bad version of + CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122. + ## Thanks to our Contributors This release contains contributions from many people at Google, as well as: @@ -523,7 +633,7 @@ answered questions, and were part of inspiring discussions. * Fixed LIBXSMM integration. * Make decode_jpeg/decode_png/decode_gif handle all formats, since users frequently try to decode an image as the wrong type. * Improve implicit broadcasting lowering. -* Improving stability of GCS/Bigquery clients by a faster retrying of stale transmissions. +* Improving stability of GCS/BigQuery clients by a faster retrying of stale transmissions. * Remove OpKernelConstruction::op_def() as part of minimizing proto dependencies. * VectorLaplaceDiag distribution added. * Android demo no longer requires libtensorflow_demo.so to run (libtensorflow_inference.so still required) diff --git a/configure b/configure index 9c21d2b03a27714f05094667691e74c16fa89f35..66b66ba54ed68a9aa0ee556f84f68c3a83a495ab 100755 --- a/configure +++ b/configure @@ -8,7 +8,8 @@ if [ -z "$PYTHON_BIN_PATH" ]; then fi # Set all env variables -"$PYTHON_BIN_PATH" configure.py +CONFIGURE_DIR=$(dirname "$0") +"$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" echo "Configuration finished" diff --git a/configure.py b/configure.py index cf16ef483763733cc12c838ea92b144c6493f0b1..b5436dba20ad1aadeffe8057c0a85709914f603e 100644 --- a/configure.py +++ b/configure.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import errno import os import platform @@ -32,10 +33,6 @@ except ImportError: from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), - '.tf_configure.bazelrc') -_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'WORKSPACE') _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' @@ -43,6 +40,7 @@ _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) +_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu' _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' @@ -50,6 +48,11 @@ _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 +_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__)) +_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' +_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) +_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE') + class UserInputError(Exception): pass @@ -118,22 +121,6 @@ def sed_in_place(filename, old, new): f.write(newdata) -def remove_line_with(filename, token): - """Remove lines that contain token from file. - - Args: - filename: string for filename. - token: string token to check if to remove a line from file or not. - """ - with open(filename, 'r') as f: - filedata = f.read() - - with open(filename, 'w') as f: - for line in filedata.strip().split('\n'): - if token not in line: - f.write(line + '\n') - - def write_to_bazelrc(line): with open(_TF_BAZELRC, 'a') as f: f.write(line + '\n') @@ -244,25 +231,26 @@ def setup_python(environ_cp): environ_cp['PYTHON_BIN_PATH'] = python_bin_path # Write tools/python_bin_path.sh - with open('tools/python_bin_path.sh', 'w') as f: + with open(os.path.join( + _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) -def reset_tf_configure_bazelrc(): +def reset_tf_configure_bazelrc(workspace_path): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() + bazelrc_path = os.path.join(workspace_path, '.bazelrc') - home = os.path.expanduser('~') - if not os.path.exists('.bazelrc'): - if os.path.exists(os.path.join(home, '.bazelrc')): - with open('.bazelrc', 'a') as f: - f.write('import %s/.bazelrc\n' % home.replace('\\', '/')) - else: - open('.bazelrc', 'w').close() - - remove_line_with('.bazelrc', 'tf_configure') - with open('.bazelrc', 'a') as f: - f.write('import %workspace%/.tf_configure.bazelrc\n') + data = [] + if os.path.exists(bazelrc_path): + with open(bazelrc_path, 'r') as f: + data = f.read().splitlines() + with open(bazelrc_path, 'w') as f: + for l in data: + if _TF_BAZELRC_FILENAME in l: + continue + f.write('%s\n' % l) + f.write('import %s\n' % _TF_BAZELRC) def cleanup_makefile(): @@ -270,7 +258,8 @@ def cleanup_makefile(): These files could interfere with Bazel parsing. """ - makefile_download_dir = 'tensorflow/contrib/makefile/downloads' + makefile_download_dir = os.path.join( + _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads') if os.path.isdir(makefile_download_dir): for root, _, filenames in os.walk(makefile_download_dir): for f in filenames: @@ -297,7 +286,7 @@ def get_var(environ_cp, System". enabled_by_default: boolean for default behavior. question: optional string for how to ask for user input. - yes_reply: optionanl string for reply when feature is enabled. + yes_reply: optional string for reply when feature is enabled. no_reply: optional string for reply when feature is disabled. Returns: @@ -410,7 +399,7 @@ def set_action_env_var(environ_cp, System". enabled_by_default: boolean for default behavior. question: optional string for how to ask for user input. - yes_reply: optionanl string for reply when feature is enabled. + yes_reply: optional string for reply when feature is enabled. no_reply: optional string for reply when feature is disabled. """ var = int( @@ -444,7 +433,7 @@ def convert_version_to_int(version): def check_bazel_version(min_version): - """Check installed bezel version is at least min_version. + """Check installed bazel version is at least min_version. Args: min_version: string for minimum bazel version. @@ -501,7 +490,8 @@ def set_cc_opt_flags(environ_cp): for opt in cc_opt_flags.split(): write_to_bazelrc('build:opt --copt=%s' % opt) # It should be safe on the same build host. - write_to_bazelrc('build:opt --host_copt=-march=native') + if not is_ppc64le(): + write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') # TODO(mikecase): Remove these default defines once we are able to get # TF Lite targets building without them. @@ -826,6 +816,28 @@ def set_gcc_host_compiler_path(environ_cp): write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) +def reformat_version_sequence(version_str, sequence_count): + """Reformat the version string to have the given number of sequences. + + For example: + Given (7, 2) -> 7.0 + (7.0.1, 2) -> 7.0 + (5, 1) -> 5 + (5.0.3.2, 1) -> 5 + + Args: + version_str: String, the version string. + sequence_count: int, an integer. + Returns: + string, reformatted version string. + """ + v = version_str.split('.') + if len(v) < sequence_count: + v = v + (['0'] * (sequence_count - len(v))) + + return '.'.join(v[:sequence_count]) + + def set_tf_cuda_version(environ_cp): """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION.""" ask_cuda_version = ( @@ -836,6 +848,7 @@ def set_tf_cuda_version(environ_cp): # Configure the Cuda SDK version to use. tf_cuda_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION) + tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2) # Find out where the CUDA toolkit is installed default_cuda_path = _DEFAULT_CUDA_PATH @@ -892,6 +905,7 @@ def set_tf_cudnn_version(environ_cp): tf_cudnn_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, _DEFAULT_CUDNN_VERSION) + tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1) default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH') ask_cudnn_path = (r'Please specify the location where cuDNN %s library is ' @@ -959,6 +973,128 @@ def set_tf_cudnn_version(environ_cp): write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) +def set_tf_tensorrt_install_path(environ_cp): + """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION. + + Adapted from code contributed by Sami Kama (https://github.com/samikama). + + Args: + environ_cp: copy of the os.environ. + + Raises: + ValueError: if this method was called under non-Linux platform. + UserInputError: if user has provided invalid input multiple times. + """ + if not is_linux(): + raise ValueError('Currently TensorRT is only supported on Linux platform.') + + # Ask user whether to add TensorRT support. + if str(int(get_var( + environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1': + return + + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): + ask_tensorrt_path = (r'Please specify the location where TensorRT is ' + 'installed. [Default is %s]:') % ( + _DEFAULT_TENSORRT_PATH_LINUX) + trt_install_path = get_from_env_or_user_or_default( + environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path, + _DEFAULT_TENSORRT_PATH_LINUX) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + trt_install_path = os.path.realpath( + os.path.expanduser(trt_install_path)) + + def find_libs(search_path): + """Search for libnvinfer.so in "search_path".""" + fl = set() + if os.path.exists(search_path) and os.path.isdir(search_path): + fl.update([os.path.realpath(os.path.join(search_path, x)) + for x in os.listdir(search_path) if 'libnvinfer.so' in x]) + return fl + + possible_files = find_libs(trt_install_path) + possible_files.update(find_libs(os.path.join(trt_install_path, 'lib'))) + possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64'))) + + def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver): + """Check the compatibility between tensorrt and cudnn/cudart libraries.""" + ldd_bin = which('ldd') or '/usr/bin/ldd' + ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep) + cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') + cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') + cudnn = None + cudart = None + for line in ldd_out: + if 'libcudnn.so' in line: + cudnn = cudnn_pattern.search(line) + elif 'libcudart.so' in line: + cudart = cuda_pattern.search(line) + if cudnn and len(cudnn.group(1)): + cudnn = convert_version_to_int(cudnn.group(1)) + if cudart and len(cudart.group(1)): + cudart = convert_version_to_int(cudart.group(1)) + return (cudnn == cudnn_ver) and (cudart == cuda_ver) + + cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION']) + cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) + nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$') + highest_ver = [0, None, None] + + for lib_file in possible_files: + if is_compatible(lib_file, cuda_ver, cudnn_ver): + ver_str = nvinfer_pattern.search(lib_file).group(1) + ver = convert_version_to_int(ver_str) if len(ver_str) else 0 + if ver > highest_ver[0]: + highest_ver = [ver, ver_str, lib_file] + if highest_ver[1] is not None: + trt_install_path = os.path.dirname(highest_ver[2]) + tf_tensorrt_version = highest_ver[1] + break + + # Try another alternative from ldconfig. + ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' + ldconfig_output = run_shell([ldconfig_bin, '-p']) + search_result = re.search( + '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output) + if search_result: + libnvinfer_path_from_ldconfig = search_result.group(2) + if os.path.exists(libnvinfer_path_from_ldconfig): + if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver): + trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) + tf_tensorrt_version = search_result.group(1) + break + + # Reset and Retry + if len(possible_files): + print('TensorRT libraries found in one the following directories', + 'are not compatible with selected cuda and cudnn installations') + print(trt_install_path) + print(os.path.join(trt_install_path, 'lib')) + print(os.path.join(trt_install_path, 'lib64')) + if search_result: + print(libnvinfer_path_from_ldconfig) + else: + print('Invalid path to TensorRT. None of the following files can be found:') + print(trt_install_path) + print(os.path.join(trt_install_path, 'lib')) + print(os.path.join(trt_install_path, 'lib64')) + if search_result: + print(libnvinfer_path_from_ldconfig) + + else: + raise UserInputError('Invalid TF_TENSORRT setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) + + # Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION + environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path + write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path) + environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version + write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version) + + def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1081,7 +1217,7 @@ def set_host_c_compiler(environ_cp): environ_cp, var_name='HOST_C_COMPILER', var_default=default_c_host_compiler, - ask_for_var=('Please specify which C compiler should be used as the host' + ask_for_var=('Please specify which C compiler should be used as the host ' 'C compiler.'), check_success=os.path.exists, error_msg='Invalid C compiler path. %s cannot be found.', @@ -1225,13 +1361,20 @@ def config_info_line(name, help_text): def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", + type=str, + default=_TF_WORKSPACE_ROOT, + help="The absolute path to your active Bazel workspace.") + args = parser.parse_args() + # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. environ_cp = dict(os.environ) check_bazel_version('0.5.4') - reset_tf_configure_bazelrc() + reset_tf_configure_bazelrc(args.workspace) cleanup_makefile() setup_python(environ_cp) @@ -1240,13 +1383,16 @@ def main(): environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' + environ_cp['TF_NEED_KAFKA'] = '0' environ_cp['TF_NEED_OPENCL_SYCL'] = '0' environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' + environ_cp['TF_NEED_TENSORRT'] = '0' if is_macos(): environ_cp['TF_NEED_JEMALLOC'] = '0' + environ_cp['TF_NEED_TENSORRT'] = '0' set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', 'with_jemalloc', True) @@ -1256,6 +1402,8 @@ def main(): 'with_hdfs_support', True, 'hdfs') set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', 'with_s3_support', True, 's3') + set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', + 'with_kafka_support', False, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', @@ -1278,7 +1426,13 @@ def main(): 'TF_CUDA_CONFIG_REPO' not in environ_cp): set_tf_cuda_version(environ_cp) set_tf_cudnn_version(environ_cp) + if is_linux(): + set_tf_tensorrt_install_path(environ_cp) set_tf_cuda_compute_capabilities(environ_cp) + if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( + 'LD_LIBRARY_PATH') != '1': + write_action_env_to_bazelrc('LD_LIBRARY_PATH', + environ_cp.get('LD_LIBRARY_PATH')) set_tf_cuda_clang(environ_cp) if environ_cp.get('TF_CUDA_CLANG') == '1': diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 63849943e4bdef132a9fdaead3d57811e24e686b..dc995d231d3e591771f801e28024a76610cdba26 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -211,6 +211,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_kafka_support", + define_values = {"with_kafka_support": "true"}, + visibility = ["//visibility:public"], +) + # Crosses between platforms and file system libraries not supported on those # platforms due to limitations in nested select() statements. config_setting( @@ -370,6 +376,14 @@ config_setting( visibility = ["//visibility:public"], ) +# TODO(laigd): consider removing this option and make TensorRT enabled +# automatically when CUDA is enabled. +config_setting( + name = "with_tensorrt_support", + values = {"define": "with_tensorrt_support=true"}, + visibility = ["//visibility:public"], +) + package_group( name = "internal", packages = [ @@ -469,6 +483,7 @@ filegroup( "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/examples:all_files", "//tensorflow/contrib/factorization/kernels:all_files", + "//tensorflow/contrib/feature_column:all_files", "//tensorflow/contrib/ffmpeg:all_files", "//tensorflow/contrib/ffmpeg/default:all_files", "//tensorflow/contrib/framework:all_files", @@ -528,7 +543,6 @@ filegroup( "//tensorflow/contrib/model_pruning:all_files", "//tensorflow/contrib/model_pruning/examples/cifar10:all_files", "//tensorflow/contrib/nccl:all_files", - "//tensorflow/contrib/ndlstm:all_files", "//tensorflow/contrib/nearest_neighbor:all_files", "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", @@ -536,8 +550,10 @@ filegroup( "//tensorflow/contrib/predictor:all_files", "//tensorflow/contrib/py2tf:all_files", "//tensorflow/contrib/py2tf/converters:all_files", + "//tensorflow/contrib/py2tf/impl:all_files", "//tensorflow/contrib/py2tf/pyct:all_files", "//tensorflow/contrib/py2tf/pyct/static_analysis:all_files", + "//tensorflow/contrib/py2tf/utils:all_files", "//tensorflow/contrib/quantize:all_files", "//tensorflow/contrib/receptive_field:all_files", "//tensorflow/contrib/reduce_slice_ops:all_files", @@ -566,6 +582,7 @@ filegroup( "//tensorflow/contrib/tensor_forest/proto:all_files", "//tensorflow/contrib/tensorboard:all_files", "//tensorflow/contrib/tensorboard/db:all_files", + "//tensorflow/contrib/tensorrt:all_files", "//tensorflow/contrib/testing:all_files", "//tensorflow/contrib/text:all_files", "//tensorflow/contrib/tfprof:all_files", diff --git a/tensorflow/SECURITY.md b/tensorflow/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..fea24b273920885ba8a1ae96aafbf7710df46e1f --- /dev/null +++ b/tensorflow/SECURITY.md @@ -0,0 +1,239 @@ +# Using TensorFlow Securely + +This document discusses how to safely deal with untrusted programs (models or +model parameters), and input data. Below, we also provide guidelines on how to +report vulnerabilities in TensorFlow. + +## TensorFlow models are programs + +TensorFlow's runtime system interprets and executes programs. What machine +learning practitioners term +[**models**](https://developers.google.com/machine-learning/glossary/#model) are +expressed as programs that TensorFlow executes. TensorFlow programs are encoded +as computation +[**graphs**](https://developers.google.com/machine-learning/glossary/#graph). +The model's parameters are often stored separately in **checkpoints**. + +At runtime, TensorFlow executes the computation graph using the parameters +provided. Note that the behavior of the computation graph may change +depending on the parameters provided. TensorFlow itself is not a sandbox. When +executing the computation graph, TensorFlow may read and write files, send and +receive data over the network, and even spawn additional processes. All these +tasks are performed with the permissions of the TensorFlow process. Allowing +for this flexibility makes for a powerful machine learning platform, +but it has implications for security. + +The computation graph may also accept **inputs**. Those inputs are the +data you supply to TensorFlow to train a model, or to use a model to run +inference on the data. + +**TensorFlow models are programs, and need to be treated as such from a security +perspective.** + +## Running untrusted models + +As a general rule: **Always** execute untrusted models inside a sandbox (e.g., +[nsjail](https://github.com/google/nsjail)). + +There are several ways in which a model could become untrusted. Obviously, if an +untrusted party supplies TensorFlow kernels, arbitrary code may be executed. +The same is true if the untrusted party provides Python code, such as the +Python code that generates TensorFlow graphs. + +Even if the untrusted party only supplies the serialized computation +graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the +set of computation primitives available to TensorFlow is powerful enough that +you should assume that the TensorFlow process effectively executes arbitrary +code. One common solution is to whitelist only a few safe Ops. While this is +possible in theory, we still recommend you sandbox the execution. + +It depends on the computation graph whether a user provided checkpoint is safe. +It is easily possible to create computation graphs in which malicious +checkpoints can trigger unsafe behavior. For example, consider a graph that +contains a `tf.cond` depending on the value of a `tf.Variable`. One branch of +the `tf.cond` is harmless, but the other is unsafe. Since the `tf.Variable` is +stored in the checkpoint, whoever provides the checkpoint now has the ability to +trigger unsafe behavior, even though the graph is not under their control. + +In other words, graphs can contain vulnerabilities of their own. To allow users +to provide checkpoints to a model you run on their behalf (e.g., in order to +compare model quality for a fixed model architecture), you must carefully audit +your model, and we recommend you run the TensorFlow process in a sandbox. + +## Accepting untrusted Inputs + +It is possible to write models that are secure in a sense that they can safely +process untrusted inputs assuming there are no bugs. There are two main reasons +to not rely on this: first, it is easy to write models which must not be exposed +to untrusted inputs, and second, there are bugs in any software system of +sufficient complexity. Letting users control inputs could allow them to trigger +bugs either in TensorFlow or in dependent libraries. + +In general, it is good practice to isolate parts of any system which is exposed +to untrusted (e.g., user-provided) inputs in a sandbox. + +A useful analogy to how any TensorFlow graph is executed is any interpreted +programming language, such as Python. While it is possible to write secure +Python code which can be exposed to user supplied inputs (by, e.g., carefully +quoting and sanitizing input strings, size-checking input blobs, etc.), it is +very easy to write Python programs which are insecure. Even secure Python code +could be rendered insecure by a bug in the Python interpreter, or in a bug in a +Python library used (e.g., +[this one](https://www.cvedetails.com/cve/CVE-2017-12852/)). + +## Running a TensorFlow server + +TensorFlow is a platform for distributed computing, and as such there is a +TensorFlow server (`tf.train.Server`). **The TensorFlow server is meant for +internal communication only. It is not built for use in an untrusted network.** + +For performance reasons, the default TensorFlow server does not include any +authorization protocol and sends messages unencrypted. It accepts connections +from anywhere, and executes the graphs it is sent without performing any checks. +Therefore, if you run a `tf.train.Server` in your network, anybody with +access to the network can execute what you should consider arbitrary code with +the privileges of the process running the `tf.train.Server`. + +When running distributed TensorFlow, you must isolate the network in which the +cluster lives. Cloud providers provide instructions for setting up isolated +networks, which are sometimes branded as "virtual private cloud." Refer to the +instructions for +[GCP](https://cloud.google.com/compute/docs/networks-and-firewalls) and +[AWS](https://aws.amazon.com/vpc/)) for details. + +Note that `tf.train.Server` is different from the server created by +`tensorflow/serving` (the default binary for which is called `ModelServer`). +By default, `ModelServer` also has no built-in mechanism for authentication. +Connecting it to an untrusted network allows anyone on this network to run the +graphs known to the `ModelServer`. This means that an attacker may run +graphs using untrusted inputs as described above, but they would not be able to +execute arbitrary graphs. It is possible to safely expose a `ModelServer` +directly to an untrusted network, **but only if the graphs it is configured to +use have been carefully audited to be safe**. + +Similar to best practices for other servers, we recommend running any +`ModelServer` with appropriate privileges (i.e., using a separate user with +reduced permisisons). In the spirit of defense in depth, we recommend +authenticating requests to any TensorFlow server connected to an untrusted +network, as well as sandboxing the server to minimize the adverse effects of +any breach. + +## Vulnerabilities in TensorFlow + +TensorFlow is a large and complex system. It also depends on a large set of +third party libraries (e.g., `numpy`, `libjpeg-turbo`, PNG parsers, `protobuf`). +It is possible that TensorFlow or its dependent libraries contain +vulnerabilities that would allow triggering unexpected or dangerous behavior +with specially crafted inputs. + +### What is a vulnerability? + +Given TensorFlow's flexibility, it is possible to specify computation graphs +which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models +can perform arbitrary computations means that they may read and write files, +communicate via the network, produce deadlocks and infinite loops, or run out +of memory. It is only when these behaviors are outside the specifications of the +operations involved that such behavior is a vulnerability. + +A `FileWriter` writing a file is not unexpected behavior and therefore is not a +vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution +**is** a vulnerability. + +This is more subtle from a system perspective. For example, it is easy to cause +a TensorFlow process to try to allocate more memory than available by specifying +a computation graph containing an ill-considered `tf.tile` operation. TensorFlow +should exit cleanly in this case (it would raise an exception in Python, or +return an error `Status` in C++). However, if the surrounding system is not +expecting the possibility, such behavior could be used in a denial of service +attack (or worse). Because TensorFlow behaves correctly, this is not a +vulnerability in TensorFlow (although it would be a vulnerability of this +hypothetical system). + +As a general rule, it is incorrect behavior for Tensorflow to access memory it +does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to +such behaviors constitute a vulnerability. + +One of the most critical parts of any system is input handling. If malicious +input can trigger side effects or incorrect behavior, this is a bug, and likely +a vulnerability. + +### Reporting vulnerabilities + +Please email reports about any security related issues you find to +`security@tensorflow.org`. This mail is delivered to a small security team. Your +email will be acknowledged within one business day, and you'll receive a more +detailed response to your email within 7 days indicating the next steps in +handling your report. For critical problems, you may encrypt your report (see +below). + +Please use a descriptive subject line for your report email. After the initial +reply to your report, the security team will endeavor to keep you informed of +the progress being made towards a fix and announcement. + +If you believe that an existing (public) issue is security-related, please send +an email to `security@tensorflow.org`. The email should include the issue ID and +a short description of why it should be handled according to this security +policy. + +Once an issue is reported, TensorFlow uses the following disclosure process: + +* When a report is received, we confirm the issue and determine its severity. +* If we know of specific third-party services or software based on TensorFlow + that require mitigation before publication, those projects will be notified. +* An advisory is prepared (but not published) which details the problem and + steps for mitigation. +* Wherever possible, fixes are prepared for the last minor release of the two + latest major releases, as well as the master branch. We will attempt to + commit these fixes as soon as possible, and as close together as + possible. +* Patch releases are published for all fixed released versions, a + notification is sent to discuss@tensorflow.org, and the advisory is published. + +Past security advisories are listed below. We credit reporters for identifying +security issues, although we keep your name confidential if you request it. + +#### Encryption key for `security@tensorflow.org` + +If your disclosure is extremely sensitive, you may choose to encrypt your +report using the key below. Please only use this for critical security +reports. + +``` +-----BEGIN PGP PUBLIC KEY BLOCK----- + +mQENBFpqdzwBCADTeAHLNEe9Vm77AxhmGP+CdjlY84O6DouOCDSq00zFYdIU/7aI +LjYwhEmDEvLnRCYeFGdIHVtW9YrVktqYE9HXVQC7nULU6U6cvkQbwHCdrjaDaylP +aJUXkNrrxibhx9YYdy465CfusAaZ0aM+T9DpcZg98SmsSml/HAiiY4mbg/yNVdPs +SEp/Ui4zdIBNNs6at2gGZrd4qWhdM0MqGJlehqdeUKRICE/mdedXwsWLM8AfEA0e +OeTVhZ+EtYCypiF4fVl/NsqJ/zhBJpCx/1FBI1Uf/lu2TE4eOS1FgmIqb2j4T+jY +e+4C8kGB405PAC0n50YpOrOs6k7fiQDjYmbNABEBAAG0LVRlbnNvckZsb3cgU2Vj +dXJpdHkgPHNlY3VyaXR5QHRlbnNvcmZsb3cub3JnPokBTgQTAQgAOBYhBEkvXzHm +gOJBnwP4Wxnef3wVoM2yBQJaanc8AhsDBQsJCAcCBhUKCQgLAgQWAgMBAh4BAheA +AAoJEBnef3wVoM2yNlkIAICqetv33MD9W6mPAXH3eon+KJoeHQHYOuwWfYkUF6CC +o+X2dlPqBSqMG3bFuTrrcwjr9w1V8HkNuzzOJvCm1CJVKaxMzPuXhBq5+DeT67+a +T/wK1L2R1bF0gs7Pp40W3np8iAFEh8sgqtxXvLGJLGDZ1Lnfdprg3HciqaVAiTum +HBFwszszZZ1wAnKJs5KVteFN7GSSng3qBcj0E0ql2nPGEqCVh+6RG/TU5C8gEsEf +3DX768M4okmFDKTzLNBm+l08kkBFt+P43rNK8dyC4PXk7yJa93SmS/dlK6DZ16Yw +2FS1StiZSVqygTW59rM5XNwdhKVXy2mf/RtNSr84gSi5AQ0EWmp3PAEIALInfBLR +N6fAUGPFj+K3za3PeD0fWDijlC9f4Ety/icwWPkOBdYVBn0atzI21thPRbfuUxfe +zr76xNNrtRRlbDSAChA1J5T86EflowcQor8dNC6fS+oHFCGeUjfEAm16P6mGTo0p +osdG2XnnTHOOEFbEUeWOwR/zT0QRaGGknoy2pc4doWcJptqJIdTl1K8xyBieik/b +nSoClqQdZJa4XA3H9G+F4NmoZGEguC5GGb2P9NHYAJ3MLHBHywZip8g9oojIwda+ +OCLL4UPEZ89cl0EyhXM0nIAmGn3Chdjfu3ebF0SeuToGN8E1goUs3qSE77ZdzIsR +BzZSDFrgmZH+uP0AEQEAAYkBNgQYAQgAIBYhBEkvXzHmgOJBnwP4Wxnef3wVoM2y +BQJaanc8AhsMAAoJEBnef3wVoM2yX4wIALcYZbQhSEzCsTl56UHofze6C3QuFQIH +J4MIKrkTfwiHlCujv7GASGU2Vtis5YEyOoMidUVLlwnebE388MmaJYRm0fhYq6lP +A3vnOCcczy1tbo846bRdv012zdUA+wY+mOITdOoUjAhYulUR0kiA2UdLSfYzbWwy +7Obq96Jb/cPRxk8jKUu2rqC/KDrkFDtAtjdIHh6nbbQhFuaRuWntISZgpIJxd8Bt +Gwi0imUVd9m9wZGuTbDGi6YTNk0GPpX5OMF5hjtM/objzTihSw9UN+65Y/oSQM81 +v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= +=CDME +-----END PGP PUBLIC KEY BLOCK----- +``` + +### Known vulnerabilities + +| Type | Versions affected | Reported by | Additional Information | +|-------------------|:-----------------:|--------------------|-----------------------------| +| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | + diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index 083634bd7964b0c12e10a1f3c71be5eab597a6c4..78ad6aec19f3bbbfcb389012ac1577573b3e4901 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -21,7 +21,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import -from tensorflow.python import * +from tensorflow.python import * # pylint: disable=redefined-builtin # pylint: enable=wildcard-import from tensorflow.python.util.lazy_loader import LazyLoader diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index c46cb32aa46af474c889095564d46c5f2399c3ad..5dfb743681255d8c03e91ea43fd441d94fdee59d 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -6,17 +6,12 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", "tf_cc_test", + "tf_cuda_cc_test", "tf_copts", "tf_cuda_library", "tf_custom_op_library", ) -# For platform specific build config -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_kernel_tests_linkstatic", -) - # ----------------------------------------------------------------------------- # Public targets @@ -33,7 +28,11 @@ filegroup( "*.cc", "*.h", ], - exclude = ["*test*"], + exclude = [ + "c_api_experimental.cc", + "c_api_experimental.h", + "*test*", + ], ), visibility = ["//visibility:public"], ) @@ -91,9 +90,33 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], + }) + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + ], + "//conditions:default": [], }), ) +tf_cuda_library( + name = "c_api_experimental", + srcs = [ + "c_api_experimental.cc", + ], + hdrs = [ + "c_api_experimental.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":c_api", + ":c_api_internal", + "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/core:protos_all_cc", + ], +) + exports_files( [ "version_script.lds", @@ -135,15 +158,21 @@ tf_cuda_library( testonly = 1, srcs = ["c_test_util.cc"], hdrs = ["c_test_util.h"], + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + ], deps = [ ":c_api", + ":c_api_experimental", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", "//tensorflow/core:test", ], ) -tf_cc_test( +tf_cuda_cc_test( name = "c_api_test", size = "small", srcs = ["c_api_test.cc"], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 3c7f041b39f01d9b8b187079b00e0c5ad99a38cc..85f1d1639b4d09f2de77d326481a86ec246270d0 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -64,6 +64,7 @@ using tensorflow::AllocationDescription; using tensorflow::DataType; using tensorflow::Graph; using tensorflow::GraphDef; +using tensorflow::mutex_lock; using tensorflow::NameRangeMap; using tensorflow::NameRangesForNode; using tensorflow::NewSession; @@ -77,6 +78,7 @@ using tensorflow::RunMetadata; using tensorflow::RunOptions; using tensorflow::Session; using tensorflow::Status; +using tensorflow::string; using tensorflow::Tensor; using tensorflow::TensorBuffer; using tensorflow::TensorId; @@ -87,8 +89,6 @@ using tensorflow::error::Code; using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; using tensorflow::gtl::ArraySlice; -using tensorflow::mutex_lock; -using tensorflow::string; using tensorflow::strings::StrCat; extern "C" { @@ -109,6 +109,10 @@ TF_Status* TF_NewStatus() { return new TF_Status; } void TF_DeleteStatus(TF_Status* s) { delete s; } void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) { + if (code == TF_OK) { + s->status = Status::OK(); + return; + } s->status = Status(static_cast(code), tensorflow::StringPiece(msg)); } @@ -195,11 +199,11 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, reinterpret_cast(data) % EIGEN_MAX_ALIGN_BYTES != 0) { // TF_STRING and TF_RESOURCE tensors have a different representation in // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste - // (any alignement requirements will be taken care of by TF_TensorToTensor + // (any alignment requirements will be taken care of by TF_TensorToTensor // and TF_TensorFromTensor). // - // Other types have the same represntation, so copy only if it is safe to do - // so. + // Other types have the same representation, so copy only if it is safe to + // do so. buf->data_ = allocate_tensor("TF_NewTensor", len); std::memcpy(buf->data_, data, len); buf->deallocator_ = deallocate_buffer; @@ -211,7 +215,13 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, buf->deallocator_ = deallocator; buf->deallocator_arg_ = deallocator_arg; } - return new TF_Tensor{dtype, TensorShape(dimvec), buf}; + TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf}; + size_t elem_size = TF_DataTypeSize(dtype); + if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { + delete ret; + return nullptr; + } + return ret; } TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { @@ -2144,7 +2154,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, opts.return_tensors.push_back(ToTensorId(nodes_to_return[i])); } - // TOOD(skyewm): change to OutputTensor + // TODO(skyewm): change to OutputTensor tensorflow::ImportGraphDefResults results; TF_RETURN_IF_ERROR( ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index d2e45341bf1b9ee4579f84064550ce26041dd04a..ad592ef70961ef427bfe9fd322a82bd64df7f9f1 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -226,6 +226,10 @@ typedef struct TF_Tensor TF_Tensor; // (*deallocator)(data, len, deallocator_arg) // Clients must provide a custom deallocator function so they can pass in // memory managed by something like numpy. +// +// May return NULL (and invoke the deallocator) if the provided data buffer +// (data, len) is inconsistent with a tensor of the given TF_DataType +// and the shape specified by (dima, num_dims). TF_CAPI_EXPORT extern TF_Tensor* TF_NewTensor( TF_DataType, const int64_t* dims, int num_dims, void* data, size_t len, void (*deallocator)(void* data, size_t len, void* arg), @@ -1283,11 +1287,12 @@ TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); typedef struct TF_Session TF_Session; -// Return a new execution session with the associated graph, or NULL on error. +// Return a new execution session with the associated graph, or NULL on +// error. Does not take ownership of any input parameters. // -// *graph must be a valid graph (not deleted or nullptr). This function will -// prevent the graph from being deleted until TF_DeleteSession() is called. -// Does not take ownership of opts. +// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be +// kept alive for the lifetime of the returned TF_Session. New nodes can still +// be added to `graph` after this call. TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opts, TF_Status* status); diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc new file mode 100644 index 0000000000000000000000000000000000000000..be7f85a5bb06dce84579b109d506ded049042b50 --- /dev/null +++ b/tensorflow/c/c_api_experimental.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api_experimental.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/protobuf/config.pb.h" + +void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { + tensorflow::ConfigProto& config = options->options.config; + auto* optimizer_options = + config.mutable_graph_options()->mutable_optimizer_options(); + if (enable) { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); + + // These XLA flags are needed to trigger XLA properly from C (more generally + // non-Python) clients. If this API is called again with `enable` set to + // false, it is safe to keep these flag values as is. + tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = + tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + flags->tf_xla_cpu_global_jit = true; + flags->tf_xla_min_cluster_size = 1; + } else { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); + } +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h new file mode 100644 index 0000000000000000000000000000000000000000..5a7b007e40aa199889b2d00b2bde5976c19e2966 --- /dev/null +++ b/tensorflow/c/c_api_experimental.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_C_C_API_EXPERIMENTAL_H_ + +#include +#include + +#include "tensorflow/c/c_api.h" + +// -------------------------------------------------------------------------- +// Experimental C API for TensorFlow. +// +// The API here is subject to changes in the future. + +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes.$a +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(COMPILER_MSVC) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // COMPILER_MSVC +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif + +// When `enable` is true, set +// tensorflow.ConfigProto.OptimizerOptions.global_jit_level to ON_1, and also +// set XLA flag values to prepare for XLA compilation. Otherwise set +// global_jit_level to OFF. +// +// This API is syntax sugar over TF_SetConfig(), and is used by clients that +// cannot read/write the tensorflow.ConfigProto proto. +TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, + unsigned char enable); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 46271e0514f473099848a8573cb7cb6fad33f7dc..384e6c8cb97022264c5327da5ca5861057608fbe 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -44,8 +44,12 @@ class NodeNameMapping { public: NodeNameMapping() = default; - // Normalize the input/output name and make it unique. - string GetIOName(const string& name); + // Normalize the input name and make it unique. This is the same as the + // function for output, expect that it adds a name mapping for the name. + string GetInputName(const string& name); + + // Normalize the output name and make it unique. + string GetOutputName(const string& name); // Make the node name unique. string Uniquify(const string& name); @@ -107,7 +111,13 @@ string NodeNameMapping::UniquifyHelper(const string& name) const { } } -string NodeNameMapping::GetIOName(const string& name) { +string NodeNameMapping::GetInputName(const string& name) { + const string& input_name = GetOutputName(name); + name_mapping_[name] = input_name; + return input_name; +} + +string NodeNameMapping::GetOutputName(const string& name) { const string& input_name = UniquifyHelper(Normalize(name)); // Record that we used this name, but don't add it to name_mapping_ // since this name is not for a node. @@ -214,10 +224,11 @@ Status FillFunctionBody( // Add control inputs. for (const Edge* edge : control_edges) { - // Add this control input only if the src node is in the body. + // Add this control input only if the src node is in the body or a part of + // the inputs. const string normalized = node_names.Lookup(edge->src()->name()); // If we did not find a name for the source of control edge, this - // source must be outside of the body. Raise an error. + // source must be outside of the body, and not an input. Raise an error. if (normalized.empty()) { return InvalidArgument( "The source of control edge ", edge->DebugString(), @@ -279,7 +290,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i])); argdef->set_name(output_names[i]); } else { - argdef->set_name(node_names.GetIOName(node->name())); + argdef->set_name(node_names.GetOutputName(node->name())); } } @@ -289,7 +300,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, int idx = inputs[i].index; OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg(); argdef->set_type(node->output_type(idx)); - const string& input_name = node_names.GetIOName(node->name()); + const string& input_name = node_names.GetInputName(node->name()); argdef->set_name(input_name); tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name; } @@ -467,7 +478,7 @@ Status ComputeBodyNodes( return Status::OK(); } -} // anonymous namespace +} // namespace } // namespace tensorflow using tensorflow::Node; diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index dbce66d2317a8e89288fab932cf69055f8b5a7f0..7ca50119eafe299b307f06c555aec1388e7e82e2 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -331,6 +331,11 @@ class CApiFunctionTest : public ::testing::Test { << "Failed to find expected edge " << e.ToString() << " in fdef: " << fdef.DebugString(); } + for (const EdgeSpec& e : c_edges) { + ASSERT_TRUE(a_edges.find(e) != a_edges.end()) + << "Failed to find expected control edge " << e.ToString() + << " in fdef: " << fdef.DebugString(); + } // If caller specified all edges, check that we have seen all if (is_exact_edges) { @@ -980,7 +985,7 @@ TEST_F(CApiFunctionTest, ControlDependency) { VerifyFDef( {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, - {{"scalar", "add_0"}}); + {{"^scalar", "add_0:2"}}); } TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) { @@ -1023,12 +1028,17 @@ TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) { TF_Operation* add = AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - Define(-1, {}, {feed1, feed2}, {add}, {}, true); - EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); - EXPECT_EQ(string("The source of control edge [id=3 feed1:-1 -> add:-1] " - "is not in the body. Encountered while creating " - "function 'MyFunc'"), - string(TF_Message(s_))); + Define(-1, {}, {feed1, feed2}, {add}, {}); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), + {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, + {{"^feed1", "add_0:2"}}); } TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 01954eb235f1a93d943c2ec7ea4c5ca44785d402..028f146be31790b211e546978302e81afe26b231 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -57,6 +57,52 @@ static void ExpectHasSubstr(StringPiece s, StringPiece expected) { << "'" << s << "' does not contain '" << expected << "'"; } +// Returns the GPU device name if there is one (with arbitrary tie breaking if +// there are more than one), or "" otherwise. +string GPUDeviceName(TF_Session* session) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Status* s = status.get(); + std::unique_ptr list( + TF_SessionListDevices(session, s), TF_DeleteDeviceList); + TF_DeviceList* device_list = list.get(); + + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + const int num_devices = TF_DeviceListCount(device_list); + LOG(INFO) << "There are " << num_devices << " devices."; + for (int i = 0; i < num_devices; ++i) { + const char* device_name = TF_DeviceListName(device_list, i, s); + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + const char* device_type = TF_DeviceListType(device_list, i, s); + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + LOG(INFO) << "Device " << i << " has name " << device_name << ", type " + << device_type; + if (string(device_type) == DEVICE_GPU) { + return device_name; + } + } + // No GPU device found. + return ""; +} + +string GPUDeviceName() { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Status* s = status.get(); + std::unique_ptr graph(TF_NewGraph(), + TF_DeleteGraph); + + TF_SessionOptions* opts = TF_NewSessionOptions(); + TF_Session* sess = TF_NewSession(graph.get(), opts, s); + TF_DeleteSessionOptions(opts); + + const string gpu_device_name = GPUDeviceName(sess); + TF_DeleteSession(sess, s); + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return gpu_device_name; +} + TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); } TEST(CAPI, Status) { @@ -94,6 +140,17 @@ TEST(CAPI, Tensor) { EXPECT_TRUE(deallocator_called); } +void NoOpDeallocator(void* data, size_t, void*) {} + +TEST(CAPI, MalformedTensor) { + // See https://github.com/tensorflow/tensorflow/issues/7394 + // num_dims = 0 implies a scalar, so should be backed by at least 4 bytes of + // data. + TF_Tensor* t = + TF_NewTensor(TF_FLOAT, nullptr, 0, nullptr, 0, &NoOpDeallocator, nullptr); + ASSERT_TRUE(t == nullptr); +} + TEST(CAPI, AllocateTensor) { const int num_bytes = 6 * sizeof(float); int64_t dims[] = {2, 3}; @@ -123,6 +180,10 @@ TEST(CAPI, MaybeMove) { } TEST(CAPI, LibraryLoadFunctions) { + // TODO(b/73318067): Fix linking for the GPU test generated by the + // tf_cuda_cc_test() bazel rule and remove the next line. + if (!GPUDeviceName().empty()) return; + // Load the library. TF_Status* status = TF_NewStatus(); TF_Library* lib = @@ -912,6 +973,70 @@ TEST(CAPI, Session) { TF_DeleteStatus(s); } +// If `device` is non-empty, run Min op on that device. +// Otherwise run it on the default device (CPU). +void RunMinTest(const string& device, bool use_XLA) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Make a placeholder operation. + TF_Operation* feed = Placeholder(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Make a constant operation with the scalar "0", for axis. + TF_Operation* one = ScalarConst(0, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Create a session for this graph. + CSession csession(graph, s, use_XLA); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + if (!device.empty()) { + LOG(INFO) << "Setting op Min on device " << device; + } + TF_Operation* min = MinWithDevice(feed, one, graph, device, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + csession.SetInputs({{feed, Int32Tensor({3, 2, 5})}}); + csession.SetOutputs({min}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* output_contents = static_cast(TF_TensorData(out)); + EXPECT_EQ(2, *output_contents); + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST(CAPI, Session_Min_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/false); } + +TEST(CAPI, Session_Min_XLA_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/true); } + +TEST(CAPI, Session_Min_GPU) { + const string gpu_device = GPUDeviceName(); + // Skip this test if no GPU is available. + if (gpu_device.empty()) return; + + RunMinTest(gpu_device, /*use_XLA=*/false); +} + +TEST(CAPI, Session_Min_XLA_GPU) { + const string gpu_device = GPUDeviceName(); + // Skip this test if no GPU is available. + if (gpu_device.empty()) return; + + RunMinTest(gpu_device, /*use_XLA=*/true); +} + TEST(CAPI, SessionPRun) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -1956,7 +2081,7 @@ TEST_F(CApiAttributesTest, Tensor) { } TEST_F(CApiAttributesTest, StringTensor) { - // Create the string-Tensor "atttribute" value. + // Create the string-Tensor "attribute" value. char encoded[] = { 0, 0, 0, 0, 0, 0, 0, 0, // array[uint64] offsets 1, // varint encoded string length @@ -2054,6 +2179,10 @@ TEST_F(CApiAttributesTest, Errors) { } TEST(TestApiDef, TestCreateApiDef) { + // TODO(b/73318067): Fix linking for the GPU test generated by the + // tf_cuda_cc_test() bazel rule and remove the next line. + if (!GPUDeviceName().empty()) return; + TF_Status* status = TF_NewStatus(); TF_Library* lib = TF_LoadLibrary("tensorflow/c/test_op.so", status); @@ -2084,6 +2213,10 @@ TEST(TestApiDef, TestCreateApiDef) { } TEST(TestApiDef, TestCreateApiDefWithOverwrites) { + // TODO(b/73318067): Fix linking for the GPU test generated by the + // tf_cuda_cc_test() bazel rule and remove the next line. + if (!GPUDeviceName().empty()) return; + TF_Status* status = TF_NewStatus(); TF_Library* lib = TF_LoadLibrary("tensorflow/c/test_op.so", status); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 37439ff0beac5a5220460465e954b6c093ee1ba9..3db2852ce6560ba493d60ef54a110161c112d110 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -15,11 +15,13 @@ limitations under the License. #include "tensorflow/c/c_test_util.h" +#include "tensorflow/c/c_api_experimental.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session_options.h" using tensorflow::GraphDef; using tensorflow::NodeDef; @@ -124,8 +126,9 @@ TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, return Const(tensor.get(), graph, s, name); } -void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, - const char* name, TF_Operation** op, bool check) { +void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name, TF_Operation** op, + bool check) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; TF_AddInputList(desc, add_inputs, 2); @@ -139,14 +142,14 @@ void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name) { TF_Operation* op; - AddHelper(l, r, graph, s, name, &op, true); + AddOpHelper(l, r, graph, s, name, &op, true); return op; } TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name) { TF_Operation* op; - AddHelper(l, r, graph, s, name, &op, false); + AddOpHelper(l, r, graph, s, name, &op, false); return op; } @@ -160,6 +163,36 @@ TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, return TF_FinishOperation(desc, s); } +// If `op_device` is non-empty, set the created op on that device. +void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r, + TF_Graph* graph, TF_Status* s, const char* name, + TF_Operation** op, const string& op_device, bool check) { + TF_OperationDescription* desc = TF_NewOperation(graph, op_name, name); + if (!op_device.empty()) { + TF_SetDevice(desc, op_device.c_str()); + } + TF_AddInput(desc, {l, 0}); + TF_AddInput(desc, {r, 0}); + *op = TF_FinishOperation(desc, s); + if (check) { + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); + } +} + +TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + const string& op_device, TF_Status* s, + const char* name) { + TF_Operation* op; + BinaryOpHelper("Min", l, r, graph, s, name, &op, op_device, true); + return op; +} + +TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + return MinWithDevice(l, r, graph, /*op_device=*/"", s, name); +} + TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); @@ -369,8 +402,9 @@ std::vector GetFuncNames(const tensorflow::GraphDef& graph_def) { return names; } -CSession::CSession(TF_Graph* graph, TF_Status* s) { +CSession::CSession(TF_Graph* graph, TF_Status* s, bool use_XLA) { TF_SessionOptions* opts = TF_NewSessionOptions(); + TF_EnableXLACompilation(opts, use_XLA); session_ = TF_NewSession(graph, opts, s); TF_DeleteSessionOptions(opts); } diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 6acc2fec0063a8592e8e22a00b530df05a08cdb8..2a70177c724c569844a5d8ad42b99bed20209946 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -69,6 +69,14 @@ TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name = "add"); +TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "min"); + +// If `op_device` is non-empty, set the created op on that device. +TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + const string& op_device, TF_Status* s, + const char* name = "min"); + TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, const char* name = "neg"); @@ -108,7 +116,7 @@ std::vector GetFuncNames(const tensorflow::GraphDef& graph_def); class CSession { public: - CSession(TF_Graph* graph, TF_Status* s); + CSession(TF_Graph* graph, TF_Status* s, bool use_XLA = false); explicit CSession(TF_Session* session); ~CSession(); @@ -124,6 +132,8 @@ class CSession { TF_Tensor* output_tensor(int i) { return output_values_[i]; } + TF_Session* mutable_session() { return session_; } + private: void DeleteInputValues(); void ResetOutputValues(); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 74190cb135ac6c17bfcc9d8bd2f7c75ac5e8c076..e55cb672e97e1403a3dd864c91c176426eb3f067 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -6,6 +6,7 @@ load( "tf_cuda_cc_test", "tf_cc_test", "tf_copts", + "tfe_xla_copts", "tf_cuda_library", ) @@ -16,7 +17,7 @@ tf_cuda_library( "c_api_internal.h", ], hdrs = ["c_api.h"], - copts = tf_copts(), + copts = tf_copts() + tfe_xla_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ @@ -33,7 +34,15 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", ], - }) + ["//tensorflow/core:gpu_runtime"], + }) + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + ], + "//conditions:default": [], + }) + [ + "//tensorflow/core:gpu_runtime", + ], ) tf_cuda_library( @@ -46,6 +55,7 @@ tf_cuda_library( "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:framework_lite", "//tensorflow/core:lib_internal", @@ -55,12 +65,14 @@ tf_cuda_library( tf_cuda_cc_test( name = "c_api_test", srcs = ["c_api_test.cc"], + extra_copts = tfe_xla_copts(), tags = [ "guitar", "multi_gpu", ], deps = [ ":c_api", + "//tensorflow/c:c_test_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index a76c8f5ec05fc3199addc67857d7bb2ea0e263c2..98ef6f0d0ab094eae3e2e21624c3a4ba30d1c3d3 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -25,6 +25,9 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" +#ifdef TENSORFLOW_EAGER_USE_XLA +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#endif // TENSORFLOW_EAGER_USE_XLA #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -34,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" @@ -44,13 +48,23 @@ using tensorflow::int64; using tensorflow::string; namespace { -bool IsCPU(tensorflow::Device* d) { +bool IsCPU(const tensorflow::Device* d) { return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; } -string DeviceName(tensorflow::Device* d) { +bool IsXLA(const tensorflow::Device* d) { + if (d == nullptr) return false; + const auto& device_type = d->attributes().device_type(); + return device_type.find("XLA") != std::string::npos; +} + +string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } + +#ifdef TENSORFLOW_EAGER_USE_XLA +std::atomic_int_fast64_t func_id_generator(0); +#endif // TENSORFLOW_EAGER_USE_XLA } // namespace extern "C" { @@ -85,15 +99,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { return nullptr; } - TFE_Context* ret = new TFE_Context(session); - ret->policy = opts->policy; - ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime( - ret->session->device_mgr, opts->session_options.options.env, - TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {})); - ret->rendezvous = - new tensorflow::IntraProcessRendezvous(ret->session->device_mgr); - - return ret; + return new TFE_Context(*opts, session); } void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { @@ -155,10 +161,11 @@ int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) { } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) { - // This might be a bit confusing as a tensor on CPU can sometimes return - // "CPU:0" and sometimes "/job:localhost/replica:0/task:0/cpu:0". - // TODO(ashankar): Figure out which one would be nicer. - return (h->d == nullptr) ? "CPU:0" : h->d->name().c_str(); + // TODO(apassos) this will be potentially incorrect in the distributed case as + // our local device will have a name which depends on the ClusterSpec and + // hence will require the context to resolve. + return (h->d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" + : h->d->name().c_str(); } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { @@ -191,7 +198,10 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); const bool dst_cpu = IsCPU(dstd); const bool src_cpu = IsCPU(srcd); - if (is_same_device) { + // both_on_cpu can be true and yet is_same_device is false, if one of src/dst + // has device type XLA_CPU, and the other CPU. + const bool both_on_cpu = src_cpu && dst_cpu; + if (is_same_device || both_on_cpu) { return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd); } tensorflow::Tensor* src = &(h->t); @@ -261,15 +271,6 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, void TFE_DeleteOp(TFE_Op* op) { delete op; } -static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device, - TF_Status* status) { - // Questionable heuristic: Place the op on the same device as the first input - // placed outside of host memory? - if (IsCPU(op->device) && !IsCPU(device)) { - op->device = device; - } -} - void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { tensorflow::Device* d = nullptr; if (device_name != nullptr && strlen(device_name) > 0) { @@ -277,11 +278,32 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { op->ctx->session->device_mgr->LookupDevice(device_name, &d); if (!status->status.ok()) return; } - TFE_OpSetDeviceHelper(op, d, status); + op->device = d; +} + +const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { + tensorflow::Device* device = + (op->device == nullptr) ? op->ctx->devices()[0] : op->device; + return device->name().c_str(); +} + +void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { + op->use_xla = enable; +#ifndef TENSORFLOW_EAGER_USE_XLA + LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " + "built with XLA support."; +#endif // TENSORFLOW_EAGER_USE_XLA } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - TFE_OpSetDeviceHelper(op, h->d, status); + // Questionable heuristic ... + // + // Motivation: After an 'op' is placed on GPU because some of its earlier + // inputs are on GPU, we want to keep the 'op' there, even if some later + // inputs of it are not on GPU. + if (IsCPU(op->device) && !IsCPU(h->d)) { + op->device = h->d; + } if (!status->status.ok()) return; op->inputs.push_back(h->t); op->input_devices.push_back(h->d); @@ -298,7 +320,7 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, return TF_ATTR_INT; // The compiler requires that we return something. } status->status = - tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list); + tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list); return ret; } @@ -434,6 +456,19 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, proto.get(), num_values)); } +void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, + const TFE_Op** value, int num_values) { + std::unique_ptr funcs( + new tensorflow::NameAttrList[num_values]); + for (int i = 0; i < num_values; i++) { + funcs[i].set_name(value[i]->name); + value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr()); + } + op->attrs.Set(attr_name, + tensorflow::gtl::ArraySlice( + funcs.get(), num_values)); +} + namespace { tensorflow::Status ValidateInputTypeAndPlacement( @@ -515,6 +550,228 @@ tensorflow::Status ValidateInputTypeAndPlacement( } return tensorflow::Status::OK(); } + +#ifdef TENSORFLOW_EAGER_USE_XLA +// Synthesizes and returns a wrapper function over `op`, which must be a +// primitive op (e.g. matmul). +// +// The wrapper function conforms to the function signature expected by +// _XlaLaunchOp, with input params ordered by . For example, if the op has input params , they will be reordered to as the input params to the synthesized function. +// +// It populates `const_input_types`, `arg_input_types` and +// `op_input_to_func_input` based on the reordering results, that the caller can +// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets +// `status` accordingly. +const tensorflow::FunctionDef* OpToFunction( + TFE_Op* op, std::vector* const_input_types, + std::vector* arg_input_types, + tensorflow::gtl::FlatMap* op_input_to_func_input, + TF_Status* status) { + DCHECK(!op->is_function()); + + tensorflow::FunctionDef fdef; + + // Get the OpDef of the op we are trying to encapsulate. + TFE_Context* ctx = op->ctx; + const tensorflow::OpRegistrationData* op_data; + { + tensorflow::tf_shared_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.LookUp(op->name, &op_data); + if (!status->status.ok()) { + return nullptr; + } + } + const tensorflow::OpDef& op_def = op_data->op_def; + + tensorflow::OpDef* signature = fdef.mutable_signature(); + + // Handle constant inputs. + const std::unordered_set const_inputs( + *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name)); + + // First add place holders for the input args, so that we can refer to them by + // position in the next loop. Also tally up the resource inputs. + int num_resource_inputs = 0; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) { + ++num_resource_inputs; + } + signature->add_input_arg(); + } + + // Now we map the input params from `op_def` to `signature`, where the param + // ordering for `signature` is: . + int const_index = 0; + int arg_index = const_inputs.size(); + int resource_index = op_def.input_arg_size() - num_resource_inputs; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i); + tensorflow::OpDef::ArgDef* func_input_arg = nullptr; + if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) { + VLOG(1) << "For const input, mapping op input " << i << " to func input " + << const_index; + (*op_input_to_func_input)[i] = const_index; + func_input_arg = signature->mutable_input_arg(const_index++); + const_input_types->push_back( + static_cast(op->inputs[i].dtype())); + } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) { + VLOG(1) << "For resource input, mapping op input " << i + << " to func input " << resource_index; + (*op_input_to_func_input)[i] = resource_index; + func_input_arg = signature->mutable_input_arg(resource_index++); + } else { + VLOG(1) << "For arg input, mapping op input " << i << " to func input " + << arg_index; + (*op_input_to_func_input)[i] = arg_index; + func_input_arg = signature->mutable_input_arg(arg_index++); + arg_input_types->push_back( + static_cast(op->inputs[i].dtype())); + } + + func_input_arg->set_name(op_input_arg.name()); + func_input_arg->set_type(op->inputs[i].dtype()); + } + VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString(); + + // Resources args are at the end of the function input params, and we should + // have iterated over all of them. + DCHECK_EQ(signature->input_arg_size(), resource_index); + + // Make the synthesized function's name unique. + signature->set_name(tensorflow::strings::StrCat( + op_def.name(), func_id_generator.fetch_add(1))); + + // Add the node def and set its input names to match op_def's names. + const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); + DCHECK_EQ(signature->input_arg_size(), ndef.input_size()); + *fdef.add_node_def() = ndef; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name()); + } + VLOG(1) << "Added NodeDef: " << fdef.DebugString(); + + // Fix the output names and set output types. + for (int i = 0; i < op_def.output_arg_size(); ++i) { + tensorflow::OpDef::ArgDef* arg = signature->add_output_arg(); + const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i); + const string& out_tensor_name = tensorflow::strings::StrCat( + ndef.name(), ":", op_def_arg.name(), ":", 0); + arg->set_name(op_def_arg.name()); + (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name; + const string& type_attr = op_def_arg.type_attr(); + if (!type_attr.empty()) { + auto i = ndef.attr().find(type_attr); + if (i == ndef.attr().end()) { + status->status = tensorflow::errors::InvalidArgument( + tensorflow::strings::StrCat("Could not find attr ", type_attr, + " in NodeDef ", ndef.DebugString())); + return nullptr; + } + arg->set_type(i->second.type()); + } + } + VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); + + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.AddFunctionDef(fdef); + if (!status->status.ok()) return nullptr; + const auto ret = ctx->func_lib_def.Find(signature->name()); + DCHECK(ret != nullptr); + return ret; +} + +// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed +// via XLA. +std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { + VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name; + auto launch_op = + std::unique_ptr(TFE_NewOp(op->ctx, "_XlaLaunch", status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + if (op->device) { + TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + + const tensorflow::FunctionDef* fdef; + { + tensorflow::tf_shared_lock l(op->ctx->functions_mu); + fdef = op->ctx->func_lib_def.Find(op->name); + } + std::vector const_input_types; + std::vector arg_input_types; + tensorflow::gtl::FlatMap op_input_to_func_input; + if (fdef == nullptr) { + // See if this is a primitive op, and if so create a function for it, so + // that _XlaLaunchOp can access it. + fdef = OpToFunction(op, &const_input_types, &arg_input_types, + &op_input_to_func_input, status); + if (!status->status.ok()) return nullptr; + } else { + // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for + // functions, so we need to find another way to handle constant inputs. + for (int i = const_input_types.size(); + i < fdef->signature().input_arg_size(); ++i) { + VLOG(1) << "Adding Targs from input arg " << i; + const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i); + arg_input_types.push_back(static_cast(arg.type())); + } + } + DCHECK(fdef != nullptr); + + // Copy inputs and their devices. + // Since input param reordering may have occurred between `op` and `launch_op` + // via `op_input_to_func_input`, adjust the actual inputs accordingly. + launch_op->inputs = op->inputs; + launch_op->input_devices = op->input_devices; + if (!op_input_to_func_input.empty()) { + DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size()); + if (!op->input_devices.empty()) { + DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size()); + } + for (int i = 0; i < op_input_to_func_input.size(); ++i) { + VLOG(1) << "mapping op input " << i << " to func input " + << op_input_to_func_input[i]; + + launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i]; + if (!op->input_devices.empty()) { + launch_op->input_devices[op_input_to_func_input[i]] = + op->input_devices[i]; + } + } + } + launch_op->attrs.NumInputs(op->inputs.size()); + + TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(), + const_input_types.size()); + + // Set Targs and Nresources attrs. + TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(), + arg_input_types.size()); + const int num_resource_inputs = fdef->signature().input_arg_size() - + const_input_types.size() - + arg_input_types.size(); + TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs); + + // Set Tresults attr. + std::vector tresults; + for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) { + tresults.push_back(static_cast(arg.type())); + } + TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(), + tresults.size()); + + // Set function attr. + tensorflow::AttrValue attr_value; + tensorflow::NameAttrList* func = attr_value.mutable_func(); + func->set_name(fdef->signature().name()); + launch_op->attrs.Set("function", attr_value); + + return launch_op; +} +#endif // TENSORFLOW_EAGER_USE_XLA } // namespace void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, @@ -523,6 +780,18 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU tensorflow::Device* device = (op->device == nullptr) ? ctx->devices()[0] : op->device; + +#ifdef TENSORFLOW_EAGER_USE_XLA + std::unique_ptr xla_launch_op; + if (op->use_xla && op->name != "_XlaLaunch") { + xla_launch_op = BuildXlaLaunch(op, status); + if (!status->status.ok()) { + return; + } + op = xla_launch_op.get(); + } +#endif // TENSORFLOW_EAGER_USE_XLA + std::vector outputs(1); const tensorflow::MemoryTypeVector* output_memory_types = nullptr; tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name()); diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 387de078948e5076d0b069d6380dfc04ea6254df..7a321b54da343fd2b8912187bc620c1e7456db0c 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -87,7 +87,8 @@ typedef struct TFE_Context TFE_Context; TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( const TFE_ContextOptions* opts, TF_Status* status); -TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, + TF_Status* status); TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); @@ -119,8 +120,10 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); -TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index); -TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h); +TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, + int dim_index); +TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( + TFE_TensorHandle* h); TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status); @@ -130,10 +133,9 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, // that shares the underlying buffer. Otherwise, it currently requires at least // one of the source or destination devices to be CPU (i.e., for the source or // destination tensor to be placed in host memory). -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, - TFE_Context* ctx, - const char* device_name, - TF_Status* status); +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( + TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, + TF_Status* status); // Description of the TensorFlow op to execute. // @@ -148,17 +150,31 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorH // the additional sanity checks there seem unnecessary; typedef struct TFE_Op TFE_Op; -TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, +TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, + const char* op_or_function_name, TF_Status* status); TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status); +// The returned string remains valid throughout the lifetime of 'op'. +TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op, + TF_Status* status); + +// When 'enable' is set to 1, and if TensorFlow library is built with XLA +// support, a subsequent TFE_Execute() call on `op` will run the op via XLA. +// +// If the library is not built with XLA support, this call would be a no-op. +TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op, + unsigned char enable); -TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, + TF_Status* status); -TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, - unsigned char* is_list, TF_Status* status); +TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, + const char* attr_name, + unsigned char* is_list, + TF_Status* status); // Get an attribute type given an op name; a fusion of TFE_NewOp and // TFE_OpGetAttrType for use from Python without the overhead of the individual // calls and memory management of TFE_Op. @@ -166,10 +182,13 @@ TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( TFE_Context* ctx, const char* op_or_function_name, const char* attr_name, unsigned char* is_list, TF_Status* status); -TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, +TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, + const char* attr_name, const char* value); -TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value); -TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, + int64_t value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, + float value); TF_CAPI_EXPORT extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value); TF_CAPI_EXPORT extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, @@ -178,7 +197,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, // -1 and `dims` can be null. If a dimension is unknown, the // corresponding entry in the `dims` array must be -1. TF_CAPI_EXPORT extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, - const int64_t* dims, const int num_dims, + const int64_t* dims, + const int num_dims, TF_Status* out_status); // Sets the attribute attr_name to be a function specified by 'function'. @@ -189,19 +209,33 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value); -TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, - const char** value, int num_values); -TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, - const int64_t* values, int num_values); -TF_CAPI_EXPORT extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, - const float* values, int num_values); -TF_CAPI_EXPORT extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, - const unsigned char* values, int num_values); -TF_CAPI_EXPORT extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, - const TF_DataType* values, int num_values); -TF_CAPI_EXPORT extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, - const int64_t** dims, const int* num_dims, - int num_values, TF_Status* out_status); +TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, + const char* attr_name, + const char** value, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, + const char* attr_name, + const int64_t* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFloatList(TFE_Op* op, + const char* attr_name, + const float* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrBoolList(TFE_Op* op, + const char* attr_name, + const unsigned char* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrTypeList(TFE_Op* op, + const char* attr_name, + const TF_DataType* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrShapeList( + TFE_Op* op, const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values, TF_Status* out_status); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op, + const char* attr_name, + const TFE_Op** value, + int num_values); // Execute the operation defined by 'op' and return handles to computed // tensors in 'retvals'. @@ -216,9 +250,9 @@ TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, // Add a function (serialized FunctionDef protocol buffer) to ctx so // that it can be invoked using TFE_Execute. -TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx, - const char* serialized_function_def, - size_t size, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef( + TFE_Context* ctx, const char* serialized_function_def, size_t size, + TF_Status* status); // Adds a function (created from TF_GraphToFunction or // TF_FunctionImportFunctionDef) to the context, allowing it to be executed with diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index a6f76c732f2a4c2402a27cd69c101d028dbb8fcc..7b9f1db02ed9c53a280c7bd1284165cac4fb6353 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/version.h" struct TFE_ContextOptions { TF_SessionOptions session_options; @@ -43,9 +44,15 @@ struct TFE_ContextOptions { }; struct TFE_Context { - explicit TFE_Context(TF_Session* s) : session(s) {} + explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s) + : policy(opts.policy), + session(s), + rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)), + pflr(new tensorflow::ProcessFunctionLibraryRuntime( + session->device_mgr, opts.session_options.options.env, + TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {} - TFE_ContextDevicePlacementPolicy policy; + const TFE_ContextDevicePlacementPolicy policy; // Note: we cannot use C++11 thread_local here as there is no concept of a // thread-local-object-local variable in C++11. @@ -54,8 +61,8 @@ struct TFE_Context { thread_local_policies GUARDED_BY(policy_map_mu); // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. - TF_Session* session; - tensorflow::Rendezvous* rendezvous; + TF_Session* const session; + tensorflow::Rendezvous* const rendezvous; tensorflow::mutex functions_mu; tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ @@ -64,14 +71,14 @@ struct TFE_Context { // One FunctionLibraryRuntime per device. // func_libs[i] is the FunctionLibraryRuntime corresponding to // session->devices[i]. - std::unique_ptr pflr; + const std::unique_ptr pflr; tensorflow::mutex cache_mu; std::unordered_map kernel_cache GUARDED_BY(cache_mu); - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { + tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const { return pflr->GetFLR(d->name()); } @@ -100,6 +107,8 @@ struct TFE_TensorHandle { }; struct TFE_Op { + // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a + // primitive operation. TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} @@ -112,6 +121,7 @@ struct TFE_Op { std::vector inputs; std::vector input_devices; tensorflow::Device* device; + bool use_xla = false; }; #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 18e7a64435e6c7e51998a744abd615edc7ad4318..4a3ecbc0abb16296a84c0d2184dc3fc9f7f3ebb4 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -60,6 +60,63 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { return op; } +TFE_TensorHandle* TestAxisTensorHandle() { + int64_t dims[] = {1}; + int data[] = {1}; + TF_Tensor* t = TF_AllocateTensor( + TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Min", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, axis, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrBool(op, "keep_dims", 1); + TFE_OpSetAttrType(op, "Tidx", TF_INT32); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); + + return op; +} + +// If there is a GPU device, returns true and sets 'gpu_device_name' +// accordingly. +bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + const int num_devices = TF_DeviceListCount(devices); + for (int i = 0; i < num_devices; ++i) { + const string device_type(TF_DeviceListType(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + const string device_name(TF_DeviceListName(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + if (device_type == "GPU") { + *gpu_device_name = device_name; + LOG(INFO) << "Found GPU device " << device_name; + TF_DeleteDeviceList(devices); + return true; + } + } + TF_DeleteDeviceList(devices); + return false; +} + void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); @@ -288,22 +345,15 @@ TEST(CAPI, TensorHandleSilentCopy) { TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - const int num_devices = TF_DeviceListCount(devices); - // Disable the test if no GPU is present. - if (num_devices > 1) { - const int device_to_use = 1; - const string name(TF_DeviceListName(devices, device_to_use, status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - TFE_TensorHandle* hgpu = - TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + string gpu_device_name; + if (GetGPUDeviceName(ctx, &gpu_device_name)) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); - TFE_OpSetDevice(matmul, name.c_str(), status.get()); + TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_TensorHandle* retvals[1]; int num_retvals = 1; @@ -314,7 +364,6 @@ TEST(CAPI, TensorHandleSilentCopy) { TFE_DeleteTensorHandle(hgpu); } - TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); TFE_DeleteContext(ctx, status.get()); @@ -337,22 +386,15 @@ TEST(CAPI, TensorHandleSilentCopyLocal) { TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - const int num_devices = TF_DeviceListCount(devices); - // Disable the test if no GPU is present. - if (num_devices > 1) { - const int device_to_use = 1; - const string name(TF_DeviceListName(devices, device_to_use, status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - TFE_TensorHandle* hgpu = - TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + string gpu_device_name; + if (GetGPUDeviceName(ctx, &gpu_device_name)) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); - TFE_OpSetDevice(matmul, name.c_str(), status.get()); + TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_TensorHandle* retvals[1]; int num_retvals = 1; @@ -363,14 +405,44 @@ TEST(CAPI, TensorHandleSilentCopyLocal) { TFE_DeleteTensorHandle(hgpu); } - TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, Execute) { +TEST(CAPI, SetAndGetOpDevices) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + + // Disable the test if no GPU is present. + string gpu_device_name; + if (GetGPUDeviceName(ctx, &gpu_device_name)) { + TFE_OpSetDevice(matmul, "GPU:0", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + const char* device_name = TFE_OpGetDevice(matmul, status); + ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr); + + TFE_OpSetDevice(matmul, "CPU:0", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + device_name = TFE_OpGetDevice(matmul, status); + ASSERT_TRUE(strstr(device_name, "CPU:0") != nullptr); + } + + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} + +TEST(CAPI, Execute_MatMul_CPU) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -403,6 +475,117 @@ TEST(CAPI, Execute) { TF_DeleteStatus(status); } +TEST(CAPI, Execute_Min_CPU) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* minOp = MinOp(ctx, input, axis); + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(minOp, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(minOp); + TFE_DeleteTensorHandle(input); + TFE_DeleteTensorHandle(axis); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float output[2] = {0}; + EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); + memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(1, output[0]); + EXPECT_EQ(3, output[1]); + TF_DeleteStatus(status); +} + +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST(CAPI, Execute_MatMul_XLA_CPU) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + + TFE_OpSetXLACompilation(matmul, true); + + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + // Running a primitive TF operator via XLA is not yet supported. + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + EXPECT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TF_DeleteStatus(status); +} + +TEST(CAPI, Execute_Min_XLA_CPU) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* minOp = MinOp(ctx, input, axis); + + TFE_OpSetXLACompilation(minOp, true); + + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(minOp, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(minOp); + TFE_DeleteTensorHandle(input); + TFE_DeleteTensorHandle(axis); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float output[2] = {0}; + EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); + memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(1, output[0]); + EXPECT_EQ(3, output[1]); + TF_DeleteStatus(status); +} +#endif // TENSORFLOW_EAGER_USE_XLA + TEST(CAPI, ExecuteWithTracing) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -444,7 +627,69 @@ TEST(CAPI, ExecuteWithTracing) { TF_DeleteStatus(status); } -TEST(CAPI, Function) { +TEST(CAPI, Function_ident_CPU) { + // First create a simple identity function. + TF_Graph* function_graph = TF_NewGraph(); + TF_OperationDescription* arg_descr = + TF_NewOperation(function_graph, "Placeholder", "arg"); + TF_SetAttrType(arg_descr, "dtype", TF_INT32); + TF_Status* status = TF_NewStatus(); + TF_Operation* arg = TF_FinishOperation(arg_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_OperationDescription* id_descr = + TF_NewOperation(function_graph, "Identity", "id"); + TF_SetAttrType(id_descr, "T", TF_INT32); + TF_AddInput(id_descr, {arg, 0}); + TF_Operation* id = TF_FinishOperation(id_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_Output input{arg, 0}; + TF_Output output{id, 0}; + TF_Function* fn = + TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, + &output, nullptr, nullptr, "test", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteGraph(function_graph); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_ContextAddFunction(ctx, fn, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteFunction(fn); + + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); + + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); + + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + TFE_DeleteContext(ctx, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteStatus(status); +} + +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST(CAPI, Function_ident_XLA_CPU) { // First create a simple identity function. TF_Graph* function_graph = TF_NewGraph(); TF_OperationDescription* arg_descr = @@ -486,6 +731,9 @@ TEST(CAPI, Function) { TFE_OpAddInput(op, h, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + // Now run it via XLA. + TFE_OpSetXLACompilation(op, true); + std::vector result; result.push_back(nullptr); int num_retvals = 1; @@ -504,6 +752,7 @@ TEST(CAPI, Function) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); } +#endif // TENSORFLOW_EAGER_USE_XLA string MatMulFunction() { tensorflow::FunctionDef def; diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index 3a9951e14de3a70e0b9e47fa62e6342e063c4bed..f77a937f1ffc2d146224cb3191a5ca127daefc22 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -86,10 +86,9 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { return Status::OK(); } -Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name, +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list) { - CHECK(m); - auto* t = gtl::FindOrNull(*m, attr_name); + auto* t = gtl::FindOrNull(m, attr_name); if (t == nullptr) { return errors::InvalidArgument("Attribute '", attr_name, "' does not exist for this operation"); @@ -173,14 +172,14 @@ void CombineUnordered(const tensorflow::Fprint128& a, b->high64 += a.high64; } -inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, +inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, const tensorflow::Fprint128& b) { // TODO(agarwal): avoid ToString(). tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString()); return FingerprintCat128(a, b); } -inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) { +inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) { return CacheKeyHelper(s, {b, b}); } @@ -316,7 +315,7 @@ Status KernelAndDevice::Run(std::vector* input_tensors, allocator_pair.second->GetRecordsAndUnRef(); } auto* ms = stats->mutable_memory_stats(); - ms->set_temp_memory_size(context.temp_memory_size()); + ms->set_temp_memory_size(context.temp_memory_allocated()); for (const auto& alloc_id : context.persistent_alloc_ids()) { ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); } diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index e28a416e67f8382dbd490648106a7eb6e5fcfd13..4d20b5244a46fcde2eed0a429dced2a77b86aedd 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -43,7 +43,7 @@ typedef std::unordered_map AttrTypeMap; Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. -Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name, +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list); // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 2ccca66f672b96b3c782ddbfc828eeda270cebee..643153058ce3d6f0c88dd23a0dec4c6eff060319 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -63,17 +63,17 @@ TEST(AttrTypeMap, Lookup) { TF_AttrType t; unsigned char is_list = 1; - s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list); + s = AttrTypeByName(*m, "ThisAttribyteCannotPossiblyExist", &t, &is_list); EXPECT_FALSE(s.ok()); EXPECT_NE(is_list, 0); - s = AttrTypeByName(m, "transpose_a", &t, &is_list); + s = AttrTypeByName(*m, "transpose_a", &t, &is_list); ASSERT_TRUE(s.ok()) << s; EXPECT_EQ(TF_ATTR_BOOL, t); EXPECT_EQ(is_list, 0); s = AttrTypeMapForOp("Squeeze", &m); ASSERT_TRUE(s.ok()) << s; - s = AttrTypeByName(m, "squeeze_dims", &t, &is_list); + s = AttrTypeByName(*m, "squeeze_dims", &t, &is_list); ASSERT_TRUE(s.ok()) << s; EXPECT_EQ(TF_ATTR_INT, t); EXPECT_NE(is_list, 0); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 2b65e38f54090af6731685f78d5f7f914a875e3c..bdb0815d6b68444ec1c89b835d563db20ce4d8a1 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -18,12 +18,12 @@ limitations under the License. // Language-agnostic gradient tape. Does not perform backpropagation, just // maintains the data structures required to do so. -#include -#include #include #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -54,11 +54,11 @@ struct OpTapeEntry { // Map from tensor_id to internally-defined operation-id of the operation which // produced this tensor. A value of -1 means that the tensor was directly // watched and not the result of any operation in the tape. -using TensorTape = std::unordered_map; +using TensorTape = gtl::FlatMap; // Map from operation-id to tape entry. template -using OpTape = std::unordered_map>; +using OpTape = gtl::FlatMap>; // Operations the tape needs to perform on tensors to do backpropagation. Named // "vspace" because a subset of these are related to a vector space, such as @@ -159,7 +159,7 @@ class GradientTape { // Map from tensor id to number of remaining usages (i.e. how many entries in // the tape refer to it); to aid in tape garbage collection. - std::unordered_map tensor_usage_; + gtl::FlatMap tensor_usage_; // If false, all activations are deleted in the first call to ComputeGradient. // Else, only when this is destructed. @@ -286,11 +286,11 @@ struct BackpropInitialState { // Map from tensor ID to how many references still exist for this tensor in // the tape. - std::unordered_map tensor_usage_counts; + gtl::FlatMap tensor_usage_counts; // Maps from op ID to how many output tensors of this op still need to have // their gradients computed. - std::unordered_map op_missing_tensor; + gtl::FlatMap op_missing_tensor; }; // If `persistent_tape` is true, op_tape is not changed and none of the @@ -301,8 +301,8 @@ struct BackpropInitialState { template BackpropInitialState PrepareBackprop( gtl::ArraySlice target, const TensorTape& tensor_tape, - OpTape* op_tape, - const std::unordered_set& sources_set, bool persistent_tape) { + OpTape* op_tape, const gtl::FlatSet& sources_set, + bool persistent_tape) { std::vector tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { @@ -362,7 +362,7 @@ BackpropInitialState PrepareBackprop( template std::vector InitialStack( const OpTape& op_tape, - const std::unordered_map& op_missing_tensor) { + const gtl::FlatMap& op_missing_tensor) { std::vector result; for (auto& op_entry : op_tape) { if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) { @@ -373,13 +373,13 @@ std::vector InitialStack( } template -Status InitialGradients( - const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, - const OpTape& op_tape, - const std::unordered_map& tensor_usage_counts, - std::unordered_map>* result) { +Status InitialGradients(const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice output_gradients, + const TensorTape& tensor_tape, + const OpTape& op_tape, + const gtl::FlatMap& tensor_usage_counts, + gtl::FlatMap>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { @@ -441,13 +441,13 @@ Status GradientTape::ComputeGradient( gtl::ArraySlice source_tensor_ids, gtl::ArraySlice output_gradients, std::vector* result) { - std::unordered_set sources_set(source_tensor_ids.begin(), - source_tensor_ids.end()); + gtl::FlatSet sources_set(source_tensor_ids.begin(), + source_tensor_ids.end()); BackpropInitialState state = PrepareBackprop( target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); - std::unordered_map> gradients; + gtl::FlatMap> gradients; Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, tensor_tape_, state.op_tape, state.tensor_usage_counts, &gradients); @@ -463,7 +463,7 @@ Status GradientTape::ComputeGradient( cleanup(); return s; } - std::unordered_map gradients_size; + gtl::FlatMap gradients_size; // TODO(apassos) multiple threads could be dequeuing from op_stack at the same // time, for better CPU backprop performance. VLOG(1) << "Initial stack:"; @@ -472,11 +472,10 @@ Status GradientTape::ComputeGradient( VLOG(1) << " " << t; } } - std::unordered_map> - functions_accept_none_for_indices({ - {"SoftmaxCrossEntropyWithLogits", {1}}, - {"FusedBatchNorm", {1, 2, 3, 4}}, - }); + gtl::FlatMap> functions_accept_none_for_indices({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); while (!op_stack.empty()) { const int64 op = op_stack.back(); VLOG(1) << "Popped " << op; diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index c9ade5fb83ff5b80a62bc960d1af1dc55f458c4e..9060c19e9d2cf965c2b9be07be07c42017da45a8 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -433,6 +433,7 @@ tf_gen_op_wrappers_cc( "linalg_ops", "logging_ops", "lookup_ops", + "manip_ops", "math_ops", "nn_ops", "no_op", diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index acef098c7d07f45d171679bff7c41e13ef0424f1..faa1e378d07ea94ad08ee084d18bf6a113f054af 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -96,7 +96,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr* session) { - session->reset(NewSession(session_options)); + Session* session_p = nullptr; + TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); + session->reset(session_p); return (*session)->Create(meta_graph_def.graph_def()); } diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 0ad6b33bba5fcceaca68e2f179cef2232c689a80..4c64d2cfe3c10e6c7ed82a2d72460a0b34283bb2 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -155,6 +155,24 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { << st.error_message(); } +TEST_F(LoaderTest, SessionCreationFailure) { + SavedModelBundle bundle; + // Use invalid SessionOptions to cause session creation to fail. Default + // options work, so provide an invalid value for the target field. + SessionOptions session_options; + constexpr char kInvalidTarget[] = "invalid target"; + session_options.target = kInvalidTarget; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget)) + << st.error_message(); +} + TEST_F(LoaderTest, PbtxtFormat) { SavedModelBundle bundle; SessionOptions session_options; diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index 0a7c37383f96ca65bf5ae05cf0827c01dc4d799b..97f66e79b8ad9f383b22f56e9385fc6d2080e1f8 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -23,7 +23,6 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", ], ) diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 0540260efd83e18258ec6e93c514d14e328791b1..0900e87ebabd378e6237b77ca0ef01677c07c244 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -132,7 +132,9 @@ tf_library( config = "test_graph_tfadd.config.pbtxt", cpp_class = "AddComp", graph = "test_graph_tfadd.pbtxt", - tags = ["manual"], + tags = [ + "manual", + ], ) # A test of tf_library that includes a graph with an unknown op, but where @@ -143,7 +145,9 @@ tf_library( config = "test_graph_tfunknownop.config.pbtxt", cpp_class = "UnknownOpAddComp", graph = "test_graph_tfunknownop.pbtxt", - tags = ["manual"], + tags = [ + "manual", + ], ) # A test of tf_library that includes a graph with an unknown op, but where @@ -155,7 +159,9 @@ tf_library( config = "test_graph_tfunknownop2.config.pbtxt", cpp_class = "UnknownOpAddComp", graph = "test_graph_tfunknownop.pbtxt", - tags = ["manual"], + tags = [ + "manual", + ], ) # A test of tf_library that includes a graph with an unknown op, but where @@ -166,7 +172,9 @@ tf_library( config = "test_graph_tfunknownop3.config.pbtxt", cpp_class = "UnknownOpAddComp", graph = "test_graph_tfunknownop.pbtxt", - tags = ["manual"], + tags = [ + "manual", + ], ) # Utility library for benchmark binaries, used by the *_benchmark rules that are @@ -189,7 +197,6 @@ cc_library( name = "benchmark_extra_android", tags = [ "manual", - "notap", ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 7dfd49cc3b92f83fd64ca62bd2230938ce2d0a65..28aab6eb614ca7123d9e00f7f5cc3661b62e23f7 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -74,7 +74,9 @@ tf_library( # compile but the others in this directory succeed, you may need to # expand the "required by all tf_library targets" list in tfcompile.bzl. include_standard_runtime_deps = False, - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -84,7 +86,9 @@ tf_library( cpp_class = "AddWithCkptComp", freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt", graph = "test_graph_tfadd_with_ckpt.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -95,7 +99,9 @@ tf_library( freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt", freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver", graph = "test_graph_tfadd_with_ckpt_saver.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -104,7 +110,9 @@ tf_library( config = "test_graph_tffunction.config.pbtxt", cpp_class = "FunctionComp", graph = "test_graph_tffunction.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -113,7 +121,9 @@ tf_library( config = "test_graph_tfgather.config.pbtxt", cpp_class = "GatherComp", graph = "test_graph_tfgather.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -122,7 +132,9 @@ tf_library( config = "test_graph_tfmatmul.config.pbtxt", cpp_class = "foo::bar::MatMulComp", graph = "test_graph_tfmatmul.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -131,7 +143,9 @@ tf_library( config = "test_graph_tfmatmulandadd.config.pbtxt", cpp_class = "MatMulAndAddComp", graph = "test_graph_tfmatmulandadd.pb", - tags = ["manual"], + tags = [ + "manual", + ], tfcompile_flags = "--gen_name_to_index --gen_program_shape", ) @@ -141,13 +155,17 @@ tf_library( config = "test_graph_tfsplits.config.pbtxt", cpp_class = "SplitsComp", graph = "test_graph_tfsplits.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], - tags = ["manual"], + tags = [ + "manual", + ], deps = [ ":test_graph_tfadd", ":test_graph_tfadd_with_ckpt", diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 2b9c83ba149adf9e089786b91039e256216579c8..9dff1be09fede6f65f82c2f36d94be07e781949f 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -4,7 +4,7 @@ To use from your BUILD file, add the following line to load the macro: -load("@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") Then call the macro like this: @@ -16,14 +16,15 @@ tf_library( ) """ -load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_android", "tf_copts") +load("//tensorflow:tensorflow.bzl", + "if_android", "tf_cc_test", "tf_copts") def tf_library(name, graph, config, freeze_checkpoint=None, freeze_saver=None, cpp_class=None, gen_test=True, gen_benchmark=True, visibility=None, testonly=None, tfcompile_flags=None, - tfcompile_tool="@org_tensorflow//tensorflow/compiler/aot:tfcompile", + tfcompile_tool="//tensorflow/compiler/aot:tfcompile", include_standard_runtime_deps=True, deps=None, tags=None): """Runs tfcompile to compile a TensorFlow graph into executable code. @@ -102,6 +103,7 @@ def tf_library(name, graph, config, # Now run freeze_graph to convert variables into constants. freeze_args = (" --input_graph=$(location " + graph + ")" + + " --checkpoint_version=1" + " --input_binary=" + str(not graph.endswith(".pbtxt")) + " --input_checkpoint=$(location " + freeze_checkpoint + ")" + " --output_graph=$(location " + freeze_file + ")" + @@ -119,9 +121,9 @@ def tf_library(name, graph, config, out_nodes_file, ] + freeze_saver_srcs, outs=[freeze_file], - cmd=("$(location @org_tensorflow//tensorflow/python/tools:freeze_graph)" + + cmd=("$(location //tensorflow/python/tools:freeze_graph)" + freeze_args), - tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"], + tools=["//tensorflow/python/tools:freeze_graph"], tags=tags, ) tfcompile_graph = freeze_file @@ -213,22 +215,19 @@ def tf_library(name, graph, config, # These deps are required by all tf_library targets even if # include_standard_runtime_deps is False. Without them, the # generated code will fail to compile. - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", - "@org_tensorflow//tensorflow/core:framework_lite", + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "//tensorflow/core:framework_lite", ] + (need_xla_data_proto and [ # If we're generating the program shape, we must depend on the proto. - "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_data_proto", ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually needed. - "@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", - "@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_conv2d", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_matmul", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", + "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_matmul", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//third_party/eigen3", ] or []) + (deps or []), tags=tags, @@ -254,28 +253,32 @@ def tf_library(name, graph, config, name=("gen_" + test_name), testonly=1, srcs=[ - "@org_tensorflow//tensorflow/compiler/aot:test.cc", + "//tensorflow/compiler/aot:test.cc", header_file, ], outs=[test_file], cmd=("sed " + sed_replace + - " $(location @org_tensorflow//tensorflow/compiler/aot:test.cc) " + + " $(location //tensorflow/compiler/aot:test.cc) " + "> $(OUTS)"), tags=tags, ) - # The cc_test rule for the generated code. - native.cc_test( + # The cc_test rule for the generated code. To ensure that this works + # reliably across build configurations, we must use tf_cc_test instead of + # native.cc_test. This is related to how we build + # //tensorflow/core:lib -- see the note in tensorflow/core/BUILD + # for more details. + tf_cc_test( name=test_name, srcs=[test_file], deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/aot:runtime", - "@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main", - "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/aot:tf_library_test_main", + "//tensorflow/compiler/xla:executable_run_options", "//third_party/eigen3", - "@org_tensorflow//tensorflow/core:lib", - "@org_tensorflow//tensorflow/core:test", + "//tensorflow/core:lib", + "//tensorflow/core:test", ], tags=tags, ) @@ -283,7 +286,7 @@ def tf_library(name, graph, config, if gen_benchmark: benchmark_name = name + "_benchmark" benchmark_file = benchmark_name + ".cc" - benchmark_main = ("@org_tensorflow//tensorflow/compiler/aot:" + + benchmark_main = ("//tensorflow/compiler/aot:" + "benchmark_main.template") # Rule to rewrite benchmark.cc to produce the benchmark_file. @@ -301,7 +304,9 @@ def tf_library(name, graph, config, tags=tags, ) - # The cc_benchmark rule for the generated code. + # The cc_benchmark rule for the generated code. This does not need the + # tf_cc_binary since we (by deliberate design) do not depend on + # //tensorflow/core:lib. # # Note: to get smaller size on android for comparison, compile with: # --copt=-fvisibility=hidden @@ -315,12 +320,12 @@ def tf_library(name, graph, config, linkopts = if_android(["-pie", "-s"]), deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/aot:benchmark", - "@org_tensorflow//tensorflow/compiler/aot:runtime", - "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/aot:benchmark", + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/xla:executable_run_options", "//third_party/eigen3", ] + if_android([ - "@org_tensorflow//tensorflow/compiler/aot:benchmark_extra_android", + "//tensorflow/compiler/aot:benchmark_extra_android", ]), tags=tags, ) @@ -330,11 +335,11 @@ def target_llvm_triple(): # TODO(toddw): Add target_triple for other targets. For details see: # http://llvm.org/docs/doxygen/html/Triple_8h_source.html return select({ - "@org_tensorflow//tensorflow:android_armeabi": "armv5-none-android", - "@org_tensorflow//tensorflow:android_arm": "armv7-none-android", - "@org_tensorflow//tensorflow:android_arm64": "aarch64-none-android", - "@org_tensorflow//tensorflow:android_x86": "i686-none-android", - "@org_tensorflow//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", - "@org_tensorflow//tensorflow:darwin": "x86_64-none-darwin", + "//tensorflow:android_armeabi": "armv5-none-android", + "//tensorflow:android_arm": "armv7-none-android", + "//tensorflow:android_arm64": "aarch64-none-android", + "//tensorflow:android_x86": "i686-none-android", + "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", + "//tensorflow:darwin": "x86_64-none-darwin", "//conditions:default": "x86_64-pc-linux", }) diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 0de163d3a8f082eab4d8d802485da1bbc56e8180..9c372a012789fc25ca0a711349c09ca62edc6754 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -30,12 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -141,8 +143,7 @@ struct NodeSlot { // everything to use it. static const char* const kArgOp = "_Arg"; static const char* const kRetValOp = "_Retval"; -static const char* const kSendToHostOp = "_XlaSendToHost"; -static const char* const kRecvFromHostOp = "_XlaRecvFromHost"; +static const char* const kHostComputeOp = "_XlaHostCompute"; static const char* const kSendFromHostOp = "_XlaSendFromHost"; static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; @@ -171,7 +172,8 @@ class Encapsulator { // Write a copy of the input graph to 'graph_out', where the subgraphs are // replaced with calls to the new functions. - Status BuildOutputGraph(bool parallel_checking, Graph* graph_out); + Status BuildOutputGraph(bool parallel_checking, Graph* graph_out, + FunctionLibraryDefinition* library); private: // A subgraph of the input, all marked with a common 'group_attribute' @@ -201,21 +203,29 @@ class Encapsulator { // .. . // RAH --> C --> SFH // - // The compiled cluster is as follows. STH is a SendToHost node which is the - // source of a channel to the RAH node above. RFH is a RecvFromHost node which - // is the destination of a channel from the SFH node above. There is a control - // edge that ensures RFH follows STH, which is used in shape inference to - // ensure that the shapes on the STH host channel are known before the RFH - // channel is compiled. + // The compiled cluster is as follows. HC is a HostCompute node which is the + // source of a channel to the RAH node above and the destination of a channel + // from the SFH node above. // - // Arg --> B --> STH ..> RFH --> D --> Retval + // Arg --> B --> HC --> D --> Retval // - // The channels STH/RAH and SFH/RFH each transmit a tuple, so there is at most - // one RAH and SFH in each compiled cluster. This design is preferred over - // adding separate Arg/Retval nodes for each transmitted value because it - // simplifies the host code that would like to limit communication between - // host and device and, e.g., raise only one interrupt per channel rather than - // one per transmitted value. + // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is + // at most one RAH and SFH in each outside_compilation cluster. This design is + // preferred over adding separate Arg/Retval nodes for each transmitted value + // because it allows optimizations to the host code that would like to limit + // communication between host and device and, e.g., raise only one interrupt + // per channel rather than one per transmitted value. + // + // The shapes of the outputs from the HC node in general cannot be determined + // until the shapes of its inputs are known at compile time, since e.g., + // above, the shape of C's outputs aren't known until the shape of its inputs + // are known. If the shapes of the HC's outputs can be determined during the + // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal + // graph is stored in the shape_inference_graph attr. This graph can be used + // when compiling the HC Op to determined the shape of the SFH inputs given + // the shapes of any ancestor RAH outputs. If it can be determined that the + // shape of the SFH inputs will not be inferrable even once the shapes of the + // RAH outputs are known, an error is returned by the rewriter. class Subgraph { public: // Creates a graph to build the subgraph in, if it doesn't already exist, @@ -246,6 +256,10 @@ class Encapsulator { const std::unordered_map& node_images, Graph* graph_out); + // Returns the names of all the outside_compilation subgraphs in this + // Subgraph. + void GetOutsideCompilationSubgraphNames(std::vector* names) const; + // Returns the Node that inputs to the function should be wired up to. Node* GetCallNodeForInputs() const; @@ -305,15 +319,9 @@ class Encapsulator { void RecordOutsideCompilationOutputOrControl( const string& outside_compilation_id, const Edge* edge); - // Adds the SendToHost nodes for each outside_compilation subgraph once the - // edges have all been recorded via RecordOutsideCompilationInputOrControl. - Status AddSendsToOutsideCompilation( - const std::unordered_map& node_images); - - // Adds the RecvFromHost nodes for each outside_compilation subgraph once - // the edges have all been recorded via - // RecordOutsideCompilationOutputOrControl. - Status AddRecvsFromOutsideCompilation( + // Adds the HostCompute nodes for each outside_compilation subgraph. + Status AddHostComputes( + const string& subgraph_name, const std::unordered_map& node_images); // Creates the sequencer node if it doesn't exist, adding it to graph_out. @@ -323,10 +331,16 @@ class Encapsulator { // all the downstream nodes of call_node_outputs. void ConnectSequencerToOutputs(Graph* graph_out); + Status AddShapeInferenceInfo( + const string& outside_compilation_subgraph_name, + const std::vector& shapes, GraphDef* inference_graph); + + Status ReplaceFunctionDef(FunctionLibraryDefinition* library); + private: struct OutsideCompilationSubgraph { // Map from source (producer node/slot) tensors in the original graph to - // input index (slot number in the SendToHost/RecvAtHost nodes that will + // input index (slot number in the HostCompute/RecvAtHost nodes that will // be created) for the outside_compilation subgraph. std::unordered_map inputs; @@ -335,14 +349,14 @@ class Encapsulator { // outside_compilation subgraph. These are recorded by // RecordOutsideCompilationInputOrControl while walking all the subgraph // edges, and lifted control edges within the subgraph are added by - // AddSendsToOutsideCompilation once the _SendToHost node has been + // AddSendsToOutsideCompilation once the _HostCompute node has been // created. The matching control edge from _RecvAtHost to the // destination is added by CopyEdgeToOutputGraph. std::unordered_set control_inputs; // Maps from source (producer node/slot) and destination (consumer // node/slot) tensors in the original graph to output index (slot number - // in the SendFromHost/RecvFromHost nodes that will be created) for the + // in the SendFromHost/HostCompute nodes that will be created) for the // outside_compilation subgraph. std::unordered_map outputs_by_src; std::unordered_map outputs_by_dst; @@ -352,13 +366,13 @@ class Encapsulator { // containing compiled subgraph. These are recorded by // RecordOutsideCompilationOutputOrControl while walking all the subgraph // edges, and lifted control edges within the subgraph are added by - // AddRecvsFromToOutsideCompilation once the _RecvFromHost node has been + // AddRecvsFromToOutsideCompilation once the _HostCompute node has been // created. The matching control edge from the source to _SendFromHost to // the destination is added by CopyEdgeToOutputGraph. std::unordered_set control_outputs; - // _SendToHost node in the subgraph. Not owned. - Node* send_to_host = nullptr; + // Name of the _HostCompute node in the subgraph. + string host_compute_name; // _RecvAtHost node in the output graph. Not owned. Node* recv_at_host = nullptr; @@ -516,6 +530,59 @@ class Encapsulator { const std::unordered_map& node_images, bool parallel_checking, Graph* graph_out); + // Constructs a minimal shape inference graph that can be used to determine + // the shape of send_node at the time that the subgraph is compiled. + // recv_at_host_nodes contains the names of all the recv_at_host nodes that + // send_node might depend on. These recv_at_host nodes have shapes that are + // not known during the rewrite pass, but will be known at compile time. + // + // If the shapes of all the inputs to send_node can be determined during the + // rewrite pass, on exit graphdef_out is empty and the shapes are returned in + // static_shape_out. Otherwise graphdef_out contains a graph that can be used + // for shape inference at compile time, where all the source nodes of the + // graph are either constants with known shapes, or nodes named in + // recv_at_host_nodes. + // + // A non-OK status is returned if neither of the above conditions can be + // satisfied, e.g., because send_node depends on a node that doesn't have a + // registered shape inference function. + Status DoStaticShapeInferenceForOutsideCompilationSend( + const Graph& graph_in, const ShapeRefiner& shape_refiner, + const std::unordered_set& recv_at_host_nodes, Node* send_node, + FunctionLibraryDefinition* library, + std::vector* static_shape_out, + std::unique_ptr* graphdef_out); + + // Makes a copy of graph containing only nodes that are ancestors of at least + // one node in send_from_host_nodes and store it in pruned_graph. On exit + // nodes_images contains a mapping from nodes in graph to nodes in + // pruned_graph. All functions in the copied graph are inlined. + Status MakePrunedGraphCopyAndInline( + const Graph& graph, const std::vector& sink_nodes, + std::unique_ptr* pruned_graph, + std::unordered_map* node_images, + FunctionLibraryDefinition* library); + + // Makes a copy of graph containing only nodes that are ancestors of a + // send_from_host node in an outside_compilation subgraph, and store it in + // pruned_graph. Also perform shape inference on the pruned graph, using + // shape_refiner. On exit node_images contains a mapping from nodes in graph + // to nodes in pruned_graph. + Status MakeGraphForOutsideCompilationSends( + const Graph& graph, std::unique_ptr* pruned_graph, + ShapeRefiner* shape_refiner, + std::unordered_map* node_images, + FunctionLibraryDefinition* library); + + // Performs static shape inference, as far as possible, for the send_from_host + // nodes in each outside_compilation subgraph. Where it is not possible to + // determine the shape statically, stores a serialized GraphDef in the + // HostCompute 'shape_inference_graph' attr, to be used at compile time for + // final inference. If the shapes are known statically they are stored in the + // HostCompute 'shapes' attr. + Status GetShapeInfoForOutsideCompilationSends( + Graph* graph_out, FunctionLibraryDefinition* library); + const string group_attribute_; const string outside_compilation_attribute_; const Graph* graph_in_; @@ -682,16 +749,20 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( } } -Status Encapsulator::Subgraph::AddSendsToOutsideCompilation( +Status Encapsulator::Subgraph::AddHostComputes( + const string& subgraph_name, const std::unordered_map& node_images) { for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { const string& oc_subgraph_name = oc_subgraph_iter.first; OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; - if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { - // Build a _SendToHost node sending all the args of the appropriate - // types. - std::vector dtypes(oc_subgraph.inputs.size(), DT_INVALID); + if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() || + !oc_subgraph.outputs_by_src.empty() || + !oc_subgraph.control_outputs.empty()) { + // Build a _HostCompute node. std::vector inputs(oc_subgraph.inputs.size()); + std::vector input_dtypes(oc_subgraph.inputs.size(), DT_INVALID); + std::vector output_dtypes(oc_subgraph.outputs_by_src.size(), + DT_INVALID); for (const auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; @@ -700,94 +771,64 @@ Status Encapsulator::Subgraph::AddSendsToOutsideCompilation( int input_index = input_src.second; DataType dtype = src_node->output_type(src_slot); - dtypes[input_index] = dtype; inputs[input_index].Reset(src_image->name(), src_slot, dtype); + input_dtypes[input_index] = dtype; } - NodeDef send_def; - NodeDefBuilder builder( - strings::StrCat("outside_compilation_", oc_subgraph_name, "_send"), - kSendToHostOp); - builder.Attr("dtypes", dtypes); + for (const auto& output : oc_subgraph.outputs_by_src) { + DataType dtype = output.first.dtype; + int output_index = output.second; + output_dtypes[output_index] = dtype; + } + + NodeDef host_compute_def; + NodeDefBuilder builder(strings::StrCat("outside_compilation_", + oc_subgraph_name, "_host_compute"), + kHostComputeOp); builder.Input(inputs); - Status s = builder.Finalize(&send_def); + builder.Attr("Tinputs", input_dtypes); + builder.Attr("Toutputs", output_dtypes); + builder.Attr("key", + strings::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); + Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; - oc_subgraph.send_to_host = graph_->AddNode(send_def, &s); + Node* host_compute = graph_->AddNode(host_compute_def, &s); if (!s.ok()) return s; + oc_subgraph.host_compute_name = host_compute->name(); - // Connect the _SendToHost node to its producers in the subgraph. + // Connect the _HostCompute node to its producers in the subgraph. for (auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); int src_slot = input_src.first.slot; int input_index = input_src.second; - graph_->AddEdge(src_image, src_slot, oc_subgraph.send_to_host, - input_index); + graph_->AddEdge(src_image, src_slot, host_compute, input_index); } - // Connect the _SendToHost node to its control edge producers in the + // Connect the _HostCompute node to its control edge producers in the // subgraph. for (const auto& src_node : oc_subgraph.control_inputs) { Node* src_image = node_images.at(src_node); - graph_->AddControlEdge(src_image, oc_subgraph.send_to_host); - } - } - } - - return Status::OK(); -} - -Status Encapsulator::Subgraph::AddRecvsFromOutsideCompilation( - const std::unordered_map& node_images) { - for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { - const string& oc_subgraph_name = oc_subgraph_iter.first; - OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; - if (!oc_subgraph.outputs_by_src.empty() || - !oc_subgraph.control_outputs.empty()) { - // Build a _RecvFromHost node producing all the outputs of the appropriate - // types. - std::vector dtypes(oc_subgraph.outputs_by_src.size(), - DT_INVALID); - - for (const auto& output : oc_subgraph.outputs_by_src) { - DataType dtype = output.first.dtype; - int output_index = output.second; - dtypes[output_index] = dtype; + graph_->AddControlEdge(src_image, host_compute); } - NodeDef recv_def; - NodeDefBuilder builder( - strings::StrCat("outside_compilation_", oc_subgraph_name, "_recv"), - kRecvFromHostOp); - builder.Attr("dtypes", dtypes); - Status s = builder.Finalize(&recv_def); - if (!s.ok()) return s; - - Node* recv = graph_->AddNode(recv_def, &s); - if (!s.ok()) return s; - - // Connect the consumers in the subgraph to the _RecvFromHost node. + // Connect the consumers in the subgraph to the _HostCompute node. for (const auto& output : oc_subgraph.outputs_by_dst) { const Node* dst_node = output.first.node; Node* dst_image = node_images.at(dst_node); int dst_slot = output.first.slot; int output_index = output.second; - graph_->AddEdge(recv, output_index, dst_image, dst_slot); + graph_->AddEdge(host_compute, output_index, dst_image, dst_slot); } - // Connect the control edge consumers in the subgraph to the _RecvFromHost + // Connect the control edge consumers in the subgraph to the _HostCompute // node. for (const auto& dst_node : oc_subgraph.control_outputs) { Node* dst_image = node_images.at(dst_node); - graph_->AddControlEdge(recv, dst_image); - } - - // Add a control edge in the subgraph so that the _SendToHost node, if - // any, is compiled before the _RecvFromHost node. - if (oc_subgraph.send_to_host != nullptr) { - graph_->AddControlEdge(oc_subgraph.send_to_host, recv); + graph_->AddControlEdge(host_compute, dst_image); } } } @@ -882,6 +923,63 @@ Status Encapsulator::Subgraph::BuildFunctionDef( return Status::OK(); } +Status Encapsulator::Subgraph::AddShapeInferenceInfo( + const string& outside_compilation_subgraph_name, + const std::vector& shapes, GraphDef* inference_graph) { + OutsideCompilationSubgraph& oc_subgraph = + outside_compilation_subgraphs_.at(outside_compilation_subgraph_name); + + Node* host_compute = nullptr; + for (Node* n : graph_->nodes()) { + if (n->name() == oc_subgraph.host_compute_name) { + host_compute = n; + break; + } + } + if (host_compute == nullptr) { + return errors::InvalidArgument( + "After rewriting subgraph ", outside_compilation_subgraph_name, + " there is no HostCompute Op for outside compilation subgraph ", + oc_subgraph.host_compute_name); + } + + if (inference_graph == nullptr) { + host_compute->AddAttr("shape_inference_graph", ""); + host_compute->AddAttr("shapes", shapes); + } else { + string serialized_graph; + if (!inference_graph->SerializeToString(&serialized_graph)) { + return errors::Internal( + "Failed to serialize graph for outside compilation subgraph ", + oc_subgraph.host_compute_name); + } + host_compute->AddAttr("shape_inference_graph", serialized_graph); + host_compute->AddAttr("shapes", std::vector()); + } + return Status::OK(); +} + +Status Encapsulator::Subgraph::ReplaceFunctionDef( + FunctionLibraryDefinition* library) { + const string& name = call_node_def_.name(); + + FunctionDef fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); + + if (VLOG_IS_ON(1)) { + VLOG(2) << "Replace function def " << name; + dump_graph::DumpGraphToFile( + strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, + library); + dump_graph::DumpFunctionDefToFile( + strings::StrCat("replace_encapsulate_fdef_", name), fdef); + } + + TF_RETURN_IF_ERROR(library->RemoveFunction(name)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + return Status::OK(); +} + Status Encapsulator::Subgraph::BuildParallelCheckOp( const std::unordered_map& node_images, Graph* graph_out) { @@ -980,7 +1078,9 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); - builder.Attr("dtypes", dtypes); + builder.Attr("Toutputs", dtypes); + builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); Status s = builder.Finalize(&recv_def); if (!s.ok()) return s; @@ -1020,7 +1120,9 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); - builder.Attr("dtypes", dtypes); + builder.Attr("Tinputs", dtypes); + builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); builder.Input(inputs); Status s = builder.Finalize(&send_def); if (!s.ok()) return s; @@ -1062,6 +1164,13 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( return Status::OK(); } +void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( + std::vector* names) const { + for (auto& entry : outside_compilation_subgraphs_) { + names->push_back(entry.first); + } +} + Status Encapsulator::GetFunctionNameAttr( Node const* node, string* attr, string* outside_compilation_attr) const { Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); @@ -1220,8 +1329,7 @@ Status Encapsulator::SplitIntoSubgraphs() { // single input and output node for it. for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; - TF_RETURN_IF_ERROR(subgraph.AddSendsToOutsideCompilation(node_images)); - TF_RETURN_IF_ERROR(subgraph.AddRecvsFromOutsideCompilation(node_images)); + TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images)); } MarkGuaranteedConstants(*graph_in_, src_arg_pairs); @@ -1509,8 +1617,346 @@ Status Encapsulator::AddEdgesToOutputGraph( return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, - Graph* graph_out) { +namespace { + +// Adds a dummy Const node to graph_out. The "constant" has the type of +// data_type and the shape indicated in 'shape'. The dummy node is not a valid +// Const node because it does not have any value defined, but this doesn't +// matter because it will only be used subsequently for shape inference. (It +// would be possible to add a switch statement over data_type to create a value +// for the constant, but that would entail maintaining the logic as new types +// are added, and is not necessary.) +Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, + Graph* graph_out) { + TensorProto dummy_proto; + dummy_proto.set_dtype(data_type); + *dummy_proto.mutable_tensor_shape() = shape; + // Don't set any value field in the proto, since it is only going to be used + // for shape inference. + + GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); + NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const", + options.op_registry()); + node_builder.Attr("dtype", data_type).Attr("value", dummy_proto); + return options.FinalizeBuilder(&node_builder); +} + +// Adds a copy of node_in to graph_out and adds the mapping to +// copied_node_images. +Status CopyShapeInferenceNodeToGraph( + Node* node_in, const Node* send_node, + const std::unordered_map& dummy_node_images, + FunctionLibraryDefinition* library, + std::unordered_map* copied_node_images, Graph* graph_out) { + // Once all the ancestor nodes have been added to graph_out, add this node + // and connect it to its ancestors. + Node* node_out = graph_out->CopyNode(node_in); + (*copied_node_images)[node_in] = node_out; + // Don't bother to build the shape inference graph if there's a node with no + // shape inference function, since it would just result in an error later at + // compile time. + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data)); + if (op_reg_data->shape_inference_fn == nullptr) { + return errors::InvalidArgument( + "Shape inference is not possible for outside_compilation " + "SendFromHost node ", + send_node->name(), " because it depends on node ", node_in->name(), + " which does not have a shape inference function registered."); + } + // Add all the edges to the newly copied node. + for (const Edge* in_edge : node_in->in_edges()) { + if (!in_edge->IsControlEdge()) { + Node* src = in_edge->src(); + const auto iter = dummy_node_images.find(src); + if (iter == dummy_node_images.end()) { + // The src is a copied node so use the original output port. + graph_out->AddEdge((*copied_node_images)[in_edge->src()], + in_edge->src_output(), node_out, + in_edge->dst_input()); + } else { + // The src is a dummy node so use output port 0. + graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input()); + } + } + } + return Status::OK(); +} + +} // namespace + +Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( + const Graph& graph_in, const ShapeRefiner& shape_refiner, + const std::unordered_set& recv_at_host_nodes, Node* send_node, + FunctionLibraryDefinition* library, + std::vector* static_shape_out, + std::unique_ptr* graphdef_out) { + // Maps from nodes in graph_in to nodes in graph_out. + // + // When an edge has fully defined shape the source node in graph_in is + // replaced in graph_out by a dummy constant node. The mapping from nodes + // in graph_in to dummy nodes is stored in dummy_node_images. + // + // When a node in graph_in has at least one ancestor that doesn't have fully + // defined shape, it is copied into graph_out. The mapping from nodes in + // graph_in to copied nodes is stored in copied_node_images. + // + // The two types of node are treated differently because, when adding edges to + // graph_out, an output from a dummy node always uses port 0, whereas an + // output from a copied node uses the same port that was used in graph_in. + std::unordered_map dummy_node_images; + std::unordered_map copied_node_images; + + std::unique_ptr graph_out(new Graph(graph_in.op_registry())); + graph_out->set_versions(graph_in.versions()); + static_shape_out->resize(send_node->num_inputs()); + + // We don't use the standard ReverseDFS because we want to cut off traversal + // whenever we find an output with fully defined shape. + // TODO(misard) make this work properly in the presence of control flow. + struct Work { + Node* node; + bool leave; // Are we entering or leaving node? + }; + std::vector stack({{send_node, false}}); + std::vector visited(graph_in.num_node_ids(), false); + while (!stack.empty()) { + Work w = stack.back(); + stack.pop_back(); + Node* n = w.node; + + if (w.leave) { + TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph( + n, send_node, dummy_node_images, library, &copied_node_images, + graph_out.get())); + } else { + if (visited[n->id()]) continue; + visited[n->id()] = true; + + // Arrange to revisit when all done with all inputs. + stack.push_back(Work{n, true}); + + bool has_parent_with_unknown_shape = false; + for (const Edge* in_edge : n->in_edges()) { + if (!in_edge->IsControlEdge()) { + Node* src_node = in_edge->src(); + int src_port = in_edge->src_output(); + shape_inference::InferenceContext* context = + shape_refiner.GetContext(src_node); + shape_inference::ShapeHandle shape = context->output(src_port); + if (context->FullyDefined(shape)) { + // This ancestor has known shape, so instead of adding it to the + // stack, add a dummy node with that shape to graph_out and + // continue. + TensorShapeProto proto; + context->ShapeHandleToProto(shape, &proto); + dummy_node_images[src_node] = AddDummyShapedNode( + src_node->output_type(src_port), proto, graph_out.get()); + if (n == send_node) { + (*static_shape_out)[in_edge->dst_input()] = proto; + } + } else { + if (!visited[src_node->id()]) { + has_parent_with_unknown_shape = true; + stack.push_back({src_node, false}); + } + } + } + } + if (!has_parent_with_unknown_shape) { + if (n == send_node) { + // The shapes of all the inputs to send_node are statically known. We + // won't have to do any inference at compile time so return now: the + // shapes were stored in static_shape_out above. + graphdef_out->reset(); + return Status::OK(); + } else { + // Any shape that is being processed is either the original send node + // or has at least one output with statically-unknown shape. If the + // latter and it doesn't have any inputs with statically-unknown + // shape, then check that it is of the recv nodes that we can fill in + // the shape of at run-time later. If it isn't one of those, then we + // won't have any additional knowledge at compile time, so we already + // know we won't be able to do shape inference and we can return an + // error now. + if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) { + return errors::InvalidArgument( + "Shape inference is not possible for outside_compilation " + "SendFromHost node ", + send_node->name(), " because shape of node ", n->name(), + " will not be known at compilation time."); + } + } + } + } + } + + graphdef_out->reset(new GraphDef()); + graph_out->ToGraphDef(graphdef_out->get()); + + return Status::OK(); +} + +Status Encapsulator::MakePrunedGraphCopyAndInline( + const Graph& graph, const std::vector& sink_nodes, + std::unique_ptr* pruned_graph, + std::unordered_map* node_images, + FunctionLibraryDefinition* library) { + // First copy all ancestor nodes of sink_nodes into a new graph. + pruned_graph->reset(new Graph(library)); + (*pruned_graph)->set_versions(graph.versions()); + ReverseDFSFrom(graph, sink_nodes, + /*enter=*/nullptr, + /*leave=*/[&](Node* n) { + if (!n->IsSource()) { + Node* copied = (*pruned_graph)->CopyNode(n); + node_images->emplace(n, copied); + } + }); + + // Add all the edges between copied nodes. + for (auto entry : *node_images) { + const Node* orig = entry.first; + Node* image = entry.second; + for (const Edge* out_edge : orig->out_edges()) { + auto iter = node_images->find(out_edge->dst()); + if (iter != node_images->end()) { + // The source and destination are both in the copied graph. + (*pruned_graph) + ->AddEdge(image, out_edge->src_output(), iter->second, + out_edge->dst_input()); + } + } + } + + // Find all the function call nodes, and inline them. + std::vector function_nodes; + for (auto node : (*pruned_graph)->nodes()) { + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data)); + if (op_reg_data->is_function_op) { + function_nodes.push_back(node); + } + } + for (auto node : function_nodes) { + VLOG(2) << "Inlining function " << node->name(); + const FunctionDef* fdef = library->Find(node->type_string()); + if (fdef == nullptr) { + return errors::Internal("Failed to find function ", node->type_string(), + " in function library."); + } + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR( + FunctionDefToBodyHelper(*fdef, node->attrs(), library, + [library](const string& op, const OpDef** sig) { + return library->LookUpOpDef(op, sig); + }, + &fbody)); + InlineFunctionBody(*library, pruned_graph->get(), node, fbody); + delete fbody; + } + + return Status::OK(); +} + +Status Encapsulator::MakeGraphForOutsideCompilationSends( + const Graph& graph, std::unique_ptr* pruned_graph, + ShapeRefiner* shape_refiner, + std::unordered_map* node_images, + FunctionLibraryDefinition* library) { + // Find all the send_from_host nodes in all subgraphs, to use as roots for the + // pruning. + std::vector send_from_host_nodes; + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + std::vector outside_compilation_names; + subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); + for (const auto& name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(name); + if (send_node != nullptr) { + send_from_host_nodes.push_back(send_node); + } + } + } + + // Make a copy of all the graph nodes needed to evaluate the send_from_host + // nodes, inlining any functions as needed. + TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline( + graph, send_from_host_nodes, pruned_graph, node_images, library)); + + // Perform shape inference on the pruned graph. + shape_refiner->set_require_shape_inference_fns(false); + FixupSourceAndSinkEdges(pruned_graph->get()); + std::vector post_order; + GetReversePostOrder(*(*pruned_graph), &post_order); + for (auto node : post_order) { + // Ignore the status returned by the shape_refiner. At this point we want + // the best effort shapes, even if no shape function is registered for a + // node. + Status status = shape_refiner->AddNode(node); + if (!status.ok()) { + VLOG(1) << "Shape inference failed for node: " << status; + } + } + + return Status::OK(); +} + +Status Encapsulator::GetShapeInfoForOutsideCompilationSends( + Graph* graph_out, FunctionLibraryDefinition* library) { + std::unique_ptr pruned_graph; + ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry()); + std::unordered_map node_images; + TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( + *graph_out, &pruned_graph, &shape_refiner, &node_images, library)); + + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + // Find all the recv_at_host nodes in this subgraph. + std::vector outside_compilation_names; + subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); + std::unordered_set recv_at_host_names; + for (const auto& name : outside_compilation_names) { + Node* recv_node = subgraph.GetRecvAtHostNode(name); + if (recv_node != nullptr) { + recv_at_host_names.insert(recv_node->name()); + } + } + // For each send_from_host node, do as much shape inference as possible + // without knowing the shape of the recv_at_host nodes, and store the + // result, along with enough information to complete the job at compile time + // once the recv_at_host shapes are known. + for (const auto& name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(name); + std::vector static_shape; + std::unique_ptr graphdef; + if (send_node != nullptr) { + TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( + *pruned_graph, shape_refiner, recv_at_host_names, + node_images[send_node], library, &static_shape, &graphdef)); + if (graphdef == nullptr) { + VLOG(2) << "Send node " << send_node->name() << " shapes"; + for (int i = 0; i < static_shape.size(); ++i) { + VLOG(2) << static_shape[i].DebugString(); + } + } else { + VLOG(2) << "Send node " << send_node->name() << " graph\n" + << graphdef->DebugString(); + } + } + TF_RETURN_IF_ERROR( + subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get())); + } + if (!outside_compilation_names.empty()) { + TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library)); + } + } + + return Status::OK(); +} + +Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, + FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. std::unordered_map node_images; @@ -1522,6 +1968,9 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, TF_RETURN_IF_ERROR( AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR( + GetShapeInfoForOutsideCompilationSends(graph_out, library)); + return Status::OK(); } @@ -1545,7 +1994,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); TF_RETURN_IF_ERROR( - encapsulator.BuildOutputGraph(parallel_checking, out.get())); + encapsulator.BuildOutputGraph(parallel_checking, out.get(), library)); *graph_out = std::move(out); return Status::OK(); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index b100861d5e9c04a8f9d32d486e0ee7252b79c62b..aed9cae0f1799c4524da8ee309344849798755d5 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -29,17 +29,181 @@ limitations under the License. namespace tensorflow { namespace { +template +bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, + const ::tensorflow::protobuf::Map& b, + const std::function& key_to_string, + const std::function& value_to_string, + const std::function& compare, + const string& map_name, string* diff) { + for (const auto& elt_a : a) { + const auto iter = b.find(elt_a.first); + if (iter == b.end()) { + if (diff) { + *diff = strings::StrCat( + map_name, " expected: contains element with key '", + key_to_string(elt_a.first), "' got: map has no such element"); + } + return false; + } + if (!compare(elt_a.first, elt_a.second, iter->second)) { + if (diff) { + *diff = strings::StrCat(map_name, " expected: element with key '", + key_to_string(elt_a.first), " has value '", + value_to_string(elt_a.second), "' got: '", + value_to_string(iter->second), "'"); + } + return false; + } + } + for (const auto& elt_b : b) { + const auto iter = a.find(elt_b.first); + if (iter == a.end()) { + if (diff) { + *diff = strings::StrCat(map_name, " got: contains element with key '", + key_to_string(elt_b.first), + "' expected: map has no such element"); + } + return false; + } + } + return true; +} + +bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, + const string& diff_preamble, string* diff) { + if (a.op() != b.op()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected op '", a.op(), "' got '", b.op()); + } + return false; + } + if (a.device() != b.device()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected device '", a.device(), "' got '", + b.device()); + } + return false; + } + if (a.input_size() != b.input_size()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected ", a.input_size(), " inputs got ", + b.input_size(), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); + } + return false; + } + for (int i = 0; i < a.input_size(); ++i) { + if (a.input(i) != b.input(i)) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected ", a.input(i), + " got ", b.input(i), " expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); + } + return false; + } + } + return EqualProtoMap( + a.attr(), b.attr(), [](const string& s) { return s; }, + [](const AttrValue& v) { return v.DebugString(); }, + [](const string& key, const AttrValue& av, const AttrValue& bv) { + if (key == "shape_inference_graph") { + // Default serialization of GraphDef is unstable because maps don't + // serialize deterministically. Rather than go through the hoops to + // turn on deterministic serialization of this attr just for this + // test, add logic here to compare determinstically. + GraphDef ga; + if (!ga.ParseFromString(av.s())) { + return false; + } + GraphDef gb; + if (!gb.ParseFromString(bv.s())) { + return false; + } + return EqualGraphDef(ga, gb, nullptr); + } else { + return av.DebugString() == bv.DebugString(); + } + }, + strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), + diff); +} + bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, string* diff) { - // TODO(phawkins) use a more sophisticated equality test. - if (a.DebugString() != b.DebugString()) { + if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { - *diff = strings::StrCat("Definition mismatch for function ", + *diff = strings::StrCat("Signature mismatch for function ", a.signature().name(), ", expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + a.signature().DebugString(), "\ngot:\n", + b.signature().DebugString()); } return false; } + if (!EqualProtoMap( + a.attr(), b.attr(), [](const string& s) { return s; }, + [](const AttrValue& v) { return v.DebugString(); }, + [](const string& key, const AttrValue& av, const AttrValue& bv) { + return av.DebugString() == bv.DebugString(); + }, + strings::StrCat("attr mismatch for function ", a.signature().name()), + diff)) { + return false; + } + if (!EqualProtoMap( + a.ret(), b.ret(), [](const string& s) { return s; }, + [](const string& s) { return s; }, + [](const string& key, const string& av, const string& bv) { + return av == bv; + }, + strings::StrCat("ret mismatch for function ", a.signature().name()), + diff)) { + return false; + } + for (int i = 0; i < a.node_def_size(); ++i) { + bool found = false; + for (int j = 0; j < b.node_def_size(); ++j) { + if (a.node_def(i).name() == b.node_def(j).name()) { + if (!EqualFunctionNodeDef( + a.node_def(i), b.node_def(j), + strings::StrCat("Function ", a.signature().name()), diff)) { + return false; + } + found = true; + break; + } + } + if (!found) { + if (diff) { + *diff = strings::StrCat("Function ", a.signature().name(), + ", expected: has node '", a.node_def(i).name(), + "' got: no node of that name"); + } + return false; + } + } + for (int i = 0; i < b.node_def_size(); ++i) { + bool found = false; + for (int j = 0; j < a.node_def_size(); ++j) { + if (b.node_def(i).name() == a.node_def(j).name()) { + found = true; + break; + } + } + if (!found) { + if (diff) { + *diff = strings::StrCat("Function ", a.signature().name(), + ", got: has node '", b.node_def(i).name(), + "' expected: no node of that name"); + } + return false; + } + } return true; } @@ -84,29 +248,64 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, // TODO(misard): remove these fake registrations once there are real Ops to be // compiled. -REGISTER_OP("_XlaSendToHost") - .Input("input: dtypes") - .Attr("dtypes: list(type) >= 0"); - -REGISTER_OP("_XlaRecvFromHost") - .Output("output: dtypes") - .Attr("dtypes: list(type) >= 0"); +REGISTER_OP("_XlaHostCompute") + .Input("inputs: Tinputs") + .Output("outputs: Toutputs") + .Attr("Tinputs: list(type) >= 0") + .Attr("Toutputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaSendFromHost") - .Input("input: dtypes") - .Attr("dtypes: list(type) >= 0"); + .Input("input: Tinputs") + .Attr("Tinputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaRecvAtHost") - .Output("output: dtypes") - .Attr("dtypes: list(type) >= 0"); - -REGISTER_OP("InputTest").Output("o: float"); - -REGISTER_OP("UnaryTest").Input("a: float").Output("o: float"); + .Output("output: Toutputs") + .Attr("Toutputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); + +REGISTER_OP("InputTest") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }); + +REGISTER_OP("InputTestShaped") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); + }); + +REGISTER_OP("UnaryTest") + .Input("a: float") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ::tensorflow::shape_inference::ShapeHandle o; + TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); + c->set_output(0, o); + return Status::OK(); + }); REGISTER_OP("BinaryTest") .Input("a: float") .Input("b: float") - .Output("o: float"); + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ::tensorflow::shape_inference::ShapeHandle o; + TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); + c->set_output(0, o); + return Status::OK(); + }); +REGISTER_OP("BinaryTest2") + .Input("a: float") + .Input("b: float") + .Output("o: float") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("AddNLikeTest") .Input("inputs: N * T") @@ -124,22 +323,48 @@ Node* Input(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTest", opts); } -Node* RecvAtHost(const gtl::ArraySlice& dtypes, +Node* InputShaped(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("InputTestShaped", opts); +} + +Node* KnownShape(const gtl::ArraySlice& shape, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", + opts.op_registry()); + TensorProto value; + value.set_dtype(DT_FLOAT); + for (int dim : shape) { + value.mutable_tensor_shape()->add_dim()->set_size(dim); + } + return opts.WithAttr("value", value) + .WithAttr("dtype", DT_FLOAT) + .FinalizeBuilder(&node_builder); +} + +Node* RecvAtHost(const string& key, const gtl::ArraySlice& dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); - return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder); + return opts.WithAttr("Toutputs", dtypes) + .WithAttr("key", key) + .FinalizeBuilder(&node_builder); } -Node* SendFromHost(const std::vector& inputs, - const gtl::ArraySlice& dtypes, +Node* SendFromHost(const string& key, const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); - return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder); + std::vector dtypes; + for (const auto& node : inputs) { + dtypes.push_back(node.dt); + } + return opts.WithAttr("key", key) + .WithAttr("Tinputs", dtypes) + .FinalizeBuilder(&node_builder); } Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { @@ -151,6 +376,11 @@ Node* Binary(ops::NodeOut a, ops::NodeOut b, return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts); } +Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b, + const GraphDefBuilder::Options& opts) { + return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts); +} + Node* AddNLike(const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; @@ -576,6 +806,21 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, @@ -584,19 +829,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, {{"F"}, "BinaryTest", - {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, {}, - {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"C:o:0", "c:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, {"c"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -612,11 +856,11 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* call = b2.opts().FinalizeBuilder(&node_builder); Node* recv = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost({e}, {DT_FLOAT}, + Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); @@ -674,37 +918,71 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + string shape_string_expected_1; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape1.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape1.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape1.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape1_graph; + TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph)); + EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1)); + } + + string shape_string_expected_2; + { + GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + shape2.opts().WithName("E")); + Node* recv2 = + RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithName("outside_compilation_F1_O2_recv")); + Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H")); + SendFromHost("host_compute_channel_F1_O2", {h}, + shape2.opts().WithName("outside_compilation_F1_O2_send")); + GraphDef shape2_graph; + TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph)); + EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2)); + } + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, - {{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}}, + {{"I"}, + "UnaryTest", + {"outside_compilation_O2_host_compute:outputs:0"}}, {{"F"}, "BinaryTest", - {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, {}, - {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O2_send"}, - "_XlaSendToHost", + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O2_host_compute"}, + "_XlaHostCompute", {"D:o:0", "F:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", shape_string_expected_2}, + {"shapes", gtl::ArraySlice({})}}, {"F"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected_1}, + {"shapes", gtl::ArraySlice({})}}, {"D"}}, - {{"outside_compilation_O2_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O2_send"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"i_0_retval", "I:o:0"}}); @@ -720,23 +998,24 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* call = b2.opts().FinalizeBuilder(&node_builder); Node* recv1 = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost({e}, {DT_FLOAT}, + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); Node* recv2 = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_recv")); Node* g = Binary(e, ops::NodeOut(recv2, 1), b2.opts().WithName("G").WithControlInputs({recv2, e})); Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); - Node* send2 = SendFromHost( - {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send")); + Node* send2 = + SendFromHost("host_compute_channel_F1_O2", {h}, + b2.opts().WithName("outside_compilation_F1_O2_send")); Node* s = NoOp(b2.opts() .WithName("F1_sequencer") @@ -758,8 +1037,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { { GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); - Node* a = Input(b1.opts().WithName("A")); - Node* b = Input(b1.opts().WithName("B")); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = InputShaped(b1.opts().WithName("B")); Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); Node* d = Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); @@ -791,6 +1070,24 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float", "d_0_retval:float"}, {}, @@ -799,19 +1096,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "BinaryTest", - {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, {}, - {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, {"D"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); @@ -822,16 +1118,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, {{"I"}, "BinaryTest", - {"f_0_arg", "outside_compilation_O1_recv:output:0"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"G:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F2_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -839,15 +1135,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { std::unique_ptr lib_def( new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); - Node* a = Input(b2.opts().WithName("A")); - Node* b = Input(b2.opts().WithName("B")); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = InputShaped(b2.opts().WithName("B")); Node* recv1 = - RecvAtHost({DT_FLOAT, DT_FLOAT}, + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost({e}, {DT_FLOAT}, + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); @@ -857,12 +1153,14 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* s1 = NoOp( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); - Node* recv2 = RecvAtHost( - {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv")); + Node* recv2 = + RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F2_O1_recv")); Node* h = Binary(ops::NodeOut(call1, 1), recv2, b2.opts().WithName("H").WithControlInput(s1)); - Node* send2 = SendFromHost( - {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send")); + Node* send2 = + SendFromHost("host_compute_channel_F2_O1", {h}, + b2.opts().WithName("outside_compilation_F2_O1_send")); NodeBuilder node_builder2("F2", "F2", lib_def.get()); node_builder2.Input(e).Input(call1); @@ -888,7 +1186,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { { GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); - Node* a = Input(b1.opts().WithName("A")); + Node* a = InputShaped(b1.opts().WithName("A")); Node* b = Input(b1.opts().WithName("B")); Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); Node* d = @@ -908,6 +1206,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, { @@ -915,11 +1216,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "BinaryTest", - {"D:o:0", "outside_compilation_O1_recv:output:0"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", + {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + {{"Tinputs", gtl::ArraySlice({})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -927,12 +1233,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { std::unique_ptr lib_def( new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); - Node* a = Input(b2.opts().WithName("A")); + Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); Node* e = Unary(a, b2.opts().WithName("E")); - Node* send1 = SendFromHost( - {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* send1 = + SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts().WithName("outside_compilation_F1_O1_send")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); @@ -954,7 +1261,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { { GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); - Node* a = Input(b1.opts().WithName("A")); + Node* a = InputShaped(b1.opts().WithName("A")); Node* b = Input(b1.opts().WithName("B")); Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); Node* d = @@ -975,6 +1282,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, { @@ -982,17 +1292,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "BinaryTest", - {"D:o:0", "outside_compilation_O1_recv:output:0"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {}, - {{"dtypes", gtl::ArraySlice({})}}, + {{"Tinputs", gtl::ArraySlice({})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}, {"D"}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", - {}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, - {"outside_compilation_O1_send"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1000,14 +1310,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { std::unique_ptr lib_def( new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); - Node* a = Input(b2.opts().WithName("A")); + Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); Node* recv1 = - RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + RecvAtHost("host_compute_channel_F1_O1", {}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); - Node* send1 = SendFromHost( - {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* send1 = + SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts().WithName("outside_compilation_F1_O1_send")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); @@ -1055,10 +1367,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", {"D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1069,8 +1385,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = RecvAtHost( - {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(recv1, b2.opts().WithName("E")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); @@ -1118,16 +1435,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, - {{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}}, - {{"outside_compilation_O1_send"}, - "_XlaSendToHost", + {{"F"}, + "UnaryTest", {"D:o:0"}, - {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, - {{"outside_compilation_O1_recv"}, - "_XlaRecvFromHost", {}, - {{"dtypes", gtl::ArraySlice({})}}, - {"outside_compilation_O1_send"}}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1138,10 +1458,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = RecvAtHost( - {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); Node* e = Unary(recv1, b2.opts().WithName("E")); - Node* send1 = SendFromHost({}, {}, + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {}, b2.opts() .WithName("outside_compilation_F1_O1_send") .WithControlInput(e)); @@ -1215,5 +1536,110 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } +// Test for shape inference of outside compilation. +TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + *library.add_function() = test::function::XTimesTwo(); + + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + // Give nodes 'c' and 'd' names that collide after lowercasing. + Node* c = Unary(a, b1.opts().WithName("C")); + Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); + Node* e = BinaryUnknownShape(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Binary(a, f, b1.opts().WithName("G").WithControlInput(e)); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0")); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + + *library_expected.add_function() = test::function::XTimesTwo(); + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}}, + {{"F"}, + "BinaryTest", + {"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"}, + {}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"c:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, + {"c"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + Node* c = Unary(a, b2.opts().WithName("C")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(b).Input(c); + Node* call = + b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder); + + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = BinaryUnknownShape( + c, ops::NodeOut(recv, 0), + b2.opts().WithName("E").WithControlInputs({recv, b})); + Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + + Node* s = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + + Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 4842877d9af332bdaa4a142867dde89ba66bd9a2..6353149e4afdf739fe44dd5c76502ef5d98b8477 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -45,7 +45,7 @@ namespace tensorflow { // see comment on `AllowsAsynchronousDeallocation()`. class XlaAllocator : public xla::DeviceMemoryAllocator { public: - XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context); + XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context); ~XlaAllocator() override; xla::StatusOr Allocate(int device_ordinal, uint64 size, bool retry_on_failure) override; @@ -79,7 +79,8 @@ class XlaAllocator : public xla::DeviceMemoryAllocator { std::unordered_map tensors_; }; -XlaAllocator::XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context) +XlaAllocator::XlaAllocator(const gpu::Platform* platform, + OpKernelContext* op_context) : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} XlaAllocator::~XlaAllocator() = default; @@ -248,12 +249,16 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { xla::LocalClient* client = static_cast(cache->client()); + // Builds an XLA allocator for the device. + XlaAllocator xla_allocator(client->platform(), ctx); + XlaCompiler::Options options; options.client = client; options.device_type = &cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); + options.device_allocator = &xla_allocator; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; @@ -264,9 +269,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; - // Builds an XLA allocator for the device. - XlaAllocator xla_allocator(client->platform(), ctx); - std::unique_ptr output; // Build xla::ShapedBuffers that point directly to the Tensor buffers. std::vector> arg_buffers; @@ -374,8 +376,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { OP_REQUIRES(ctx, write.input_index >= 0 && write.input_index < ctx->num_inputs(), errors::Internal("Invalid input index for variable write.")); - TensorShape write_shape; - OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape)); gpu::DeviceMemoryBase buffer = output->buffer({output_num}); @@ -397,7 +397,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // Looks up the owning Tensor by buffer address. OP_REQUIRES_OK( - ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write_shape, + ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape, variable->tensor())); ++output_num; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 79b02baba83cb47f4f2f16544ad711a4b6937d90..a0211acbbe9eec77d30c7d14293650de8826f41c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -190,6 +190,9 @@ Status FindCompilationCandidates( pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); for (Node* node : graph.op_nodes()) { + VLOG(2) << "FindCompilationCandidates(): Processing " + << node->DebugString(); + DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); @@ -216,6 +219,13 @@ Status FindCompilationCandidates( !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { continue; } + // _Arg nodes in a top-level function represent feeds. + // Do not compile them. + if (node->type_string() == "_Arg") { + VLOG(2) << "Skipping jit compilation for '_Arg'-typed node " + << node->DebugString(); + continue; + } // _Retval nodes in a top-level function represent fetches. // Do not compile them. if (node->type_string() == "_Retval") { @@ -304,6 +314,7 @@ Status MarkForCompilationPass::Run( static_cast(flags->tf_xla_auto_jit); } bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; const FunctionLibraryDefinition* fld = options.flib_def; auto is_compilable = [global_jit_level, cpu_global_jit, fld]( diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 454f0aeae98d7afd51f12b2cfb1810de275a57f7..1a8858cccef623185709ab5dc2187a313dd130f7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -80,7 +81,7 @@ TEST(XlaCompilationTest, Chains) { ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); ops::UnaryOp("Relu", e, builder.opts().WithName("F")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -105,7 +106,7 @@ TEST(XlaCompilationTest, UncompilableCycles) { Node* b = ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -125,7 +126,7 @@ TEST(XlaCompilationTest, CompilableCycles) { .WithAttr("value", Tensor())); Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -148,7 +149,7 @@ TEST(XlaCompilationTest, UnsupportedTypes) { .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -177,7 +178,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { concat_builder.Input(dim).Input({a, a}).Attr("N", 2); builder.opts().FinalizeBuilder(&concat_builder); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -212,7 +213,7 @@ TEST(XlaCompilationTest, FunctionCalls) { Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D")); ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def)); @@ -244,7 +245,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C")); Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D")); ops::UnaryOp("Shape", d, builder.opts().WithName("E")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); @@ -330,7 +331,7 @@ TEST(XlaCompilationTest, SymbolicGradients) { d_builder.Input({c, c}); builder.opts().FinalizeBuilder(&d_builder); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -382,7 +383,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { ops::BinaryOp( "MatMul", a, b, builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); - TF_CHECK_OK(builder.ToGraph(graph.get())); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -413,7 +414,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) { ops::BinaryOp( "Add", b, c, builder.opts().WithName("D").WithAttr(kXlaScopeAttr, "Scope2")); - TF_CHECK_OK(builder.ToGraph(graph.get())); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -443,7 +444,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { "Relu", a, builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_CHECK_OK(builder.ToGraph(graph.get())); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -484,7 +485,7 @@ TEST(XlaCompilationTest, Resources) { Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C")); Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D")); ops::UnaryOp("Relu", d, builder.opts().WithName("E")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); @@ -541,7 +542,7 @@ TEST(XlaCompilationTest, Retval) { .WithAttr("T", DT_FLOAT) .WithAttr("index", 0)); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index bfff52c55a7d5a4490224347019db9b3333f7e2e..6d854a920eb0b4c01b09024ceaef5035e847d392 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -148,8 +148,7 @@ Status BuildArguments(int num_constant_args, XlaCompiler::Argument& arg = (*args)[input_num]; arg.kind = XlaCompiler::Argument::kConstant; arg.type = input.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape)); + arg.shape = input.shape(); arg.constant_value = input; ++input_num; } @@ -170,8 +169,7 @@ Status BuildArguments(int num_constant_args, arg.constant_value = input; } arg.type = input.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape)); + arg.shape = input.shape(); ++input_num; } @@ -189,8 +187,7 @@ Status BuildArguments(int num_constant_args, if (variable_args[variable_id].present) { const Tensor& value = variable_args[variable_id].value; arg.type = value.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape)); + arg.shape = value.shape(); arg.initialized = true; } else { // The values of uninitialized variables are not passed as inputs, since @@ -199,7 +196,7 @@ Status BuildArguments(int num_constant_args, // uninitialized variables. arg.initialized = false; arg.type = DT_INVALID; - arg.shape = xla::Shape(); + arg.shape = TensorShape(); } ++input_num; } @@ -223,6 +220,7 @@ Status XlaCompilationCache::BuildExecutable( xla::ExecutableBuildOptions build_options; build_options.set_device_ordinal(client_->default_device_ordinal()); build_options.set_result_layout(result.xla_output_shape); + build_options.set_device_allocator(options.device_allocator); auto compile_result = client_->Compile(*result.computation, argument_layouts, build_options); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 314f5506b16e2c28736d9d39aa6c856d50885108..782bf82d4149968d5e5fbfb93bbd4ff1dcd75494 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -144,6 +144,21 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "matrix_triangular_solve_op_test", + size = "small", + srcs = ["matrix_triangular_solve_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "clustering_test", size = "small", @@ -240,6 +255,18 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "extract_image_patches_op_test", + size = "small", + srcs = ["extract_image_patches_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fft_test", size = "medium", @@ -326,6 +353,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "matrix_band_part_test", + size = "medium", + srcs = ["matrix_band_part_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "momentum_test", size = "small", @@ -437,6 +477,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "reverse_sequence_op_test", + size = "medium", + srcs = ["reverse_sequence_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "rmsprop_test", size = "small", @@ -587,6 +640,7 @@ tf_xla_py_test( name = "variable_ops_test", size = "small", srcs = ["variable_ops_test.py"], + tags = ["optonly"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -613,6 +667,31 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "gather_nd_op_test", + size = "medium", + srcs = ["gather_nd_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "scatter_nd_op_test", + size = "medium", + srcs = ["scatter_nd_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "xla_device_test", size = "small", @@ -737,6 +816,17 @@ tf_library( tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"], ) +tf_xla_py_test( + name = "fake_quant_ops_test", + size = "medium", + srcs = ["fake_quant_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 16856bd736ed408da29c3199c4593eb578775128..30a6d3a74d64f90ad33062df6d1e16e3a575bd63 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -774,15 +774,15 @@ class BinaryOpsTest(XLATestCase): def DISABLED_testSparseMatMul(self): # Binary wrappers for sparse_matmul with different hints def SparseMatmulWrapperTF(a, b): - return tf.sparse_matmul(a, b, a_is_sparse=True) + return math_ops.sparse_matmul(a, b, a_is_sparse=True) def SparseMatmulWrapperFT(a, b): - return tf.sparse_matmul(a, b, b_is_sparse=True) + return math_ops.sparse_matmul(a, b, b_is_sparse=True) def SparseMatmulWrapperTT(a, b): - return tf.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True) + return math_ops.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True) - self._testMatMul(tf.sparse_matmul) + self._testMatMul(math_ops.sparse_matmul) self._testMatMul(SparseMatmulWrapperTF) self._testMatMul(SparseMatmulWrapperFT) self._testMatMul(SparseMatmulWrapperTT) @@ -1181,6 +1181,50 @@ class BinaryOpsTest(XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) + def testMatrixSetDiag(self): + for dtype in self.numeric_types: + # Square + self._testBinary( + array_ops.matrix_set_diag, + np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]], + dtype=dtype), + np.array([1.0, 2.0, 3.0], dtype=dtype), + expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]], + dtype=dtype)) + + self._testBinary( + array_ops.matrix_set_diag, + np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]], + [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]], + dtype=dtype), + np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype), + expected=np.array( + [[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]], + [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]], + dtype=dtype)) + + # Rectangular + self._testBinary( + array_ops.matrix_set_diag, + np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype), + np.array([3.0, 4.0], dtype=dtype), + expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype)) + + self._testBinary( + array_ops.matrix_set_diag, + np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype), + np.array([3.0, 4.0], dtype=dtype), + expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype)) + + self._testBinary( + array_ops.matrix_set_diag, + np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]], + [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype), + np.array([[-1.0, -2.0], [-4.0, -5.0]], + dtype=dtype), + expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]], + [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]], + dtype=dtype)) if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0361702e7af778176daed941d64e61198090daf2 --- /dev/null +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -0,0 +1,134 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for ExtractImagePatches op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ExtractImagePatches(XLATestCase): + """Functional tests for ExtractImagePatches op.""" + + def _VerifyValues(self, image, ksizes, strides, rates, padding, patches): + """Tests input-output pairs for the ExtractImagePatches op. + + Args: + image: Input tensor with shape: [batch, in_rows, in_cols, depth]. + ksizes: Patch size specified as: [ksize_rows, ksize_cols]. + strides: Output strides, specified as [stride_rows, stride_cols]. + rates: Atrous rates, specified as [rate_rows, rate_cols]. + padding: Padding type. + patches: Expected output. + """ + ksizes = [1] + ksizes + [1] + strides = [1] + strides + [1] + rates = [1] + rates + [1] + + with self.test_session(): + image_placeholder = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + out_tensor = array_ops.extract_image_patches( + image_placeholder, + ksizes=ksizes, + strides=strides, + rates=rates, + padding=padding, + name="im2col") + feed_dict = {image_placeholder: image} + self.assertAllClose(patches, out_tensor.eval(feed_dict=feed_dict)) + + def testKsize1x1Stride1x1Rate1x1(self): + """Verifies that for 1x1 kernel the output equals the input.""" + # [2, 3, 4, 5] + image = np.reshape(range(120), [2, 3, 4, 5]) + # [2, 3, 4, 5] + patches = np.reshape(range(120), [2, 3, 4, 5]) + for padding in ["VALID", "SAME"]: + self._VerifyValues( + image, + ksizes=[1, 1], + strides=[1, 1], + rates=[1, 1], + padding=padding, + patches=patches) + + def testKsize1x1Stride2x3Rate1x1(self): + """Test for 1x1 kernel and strides.""" + # [2, 4, 5, 3] + image = np.reshape(range(120), [2, 4, 5, 3]) + # [2, 2, 2, 3] + patches = image[:, ::2, ::3, :] + for padding in ["VALID", "SAME"]: + self._VerifyValues( + image, + ksizes=[1, 1], + strides=[2, 3], + rates=[1, 1], + padding=padding, + patches=patches) + + def testKsize2x2Stride1x1Rate1x1Valid(self): + """Test for 2x2 kernel with VALID padding.""" + # [1, 2, 2, 1] + image = [[[[1], [2]], [[3], [4]]]] + # [1, 1, 1, 4] + patches = [[[[1, 2, 3, 4]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[1, 1], + padding="VALID", + patches=patches) + + def testKsize2x2Stride1x1Rate1x1Same(self): + """Test for 2x2 kernel with SAME padding.""" + # [1, 2, 2, 1] + image = [[[[1], [2]], [[3], [4]]]] + # [1, 2, 2, 4] + patches = [[[[1, 2, 3, 4], [2, 0, 4, 0]], [[3, 4, 0, 0], [4, 0, 0, 0]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + patches=patches) + + def testKsize2x2Stride1x1Rate2x2Valid(self): + """Test for 2x2 kernel with 2x2 dilation.""" + # [1, 2, 2, 1] + image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32) + # [1, 2, 2, 4] + patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]], + [[4, 6, 12, 14], [5, 7, 13, 15]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[2, 2], + padding="VALID", + patches=patches) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe9400ef0f55ca011d4e23ba5d735899ca2e054 --- /dev/null +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -0,0 +1,452 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.platform import googletest + + +class FakeQuantWithMinMaxArgsTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxArgs operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + expected = np.array( + [ + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_max, expected_nudged_input_max, + expected_nudged_input_max, expected_nudged_input_max + ], + dtype=np.float32) + + with self.test_session() as session: + with self.test_scope(): + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + outputs = array_ops.fake_quant_with_min_max_args( + input_placeholder, + min=input_min, + max=input_max, + num_bits=num_bits, + narrow_range=narrow_range) + result = session.run(outputs, {input_placeholder: inputs}) + self.assertAllCloseAccordingToType( + result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) + + +class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxArgsGradient operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + gradients = np.arange(1, len(inputs) + 1, dtype=np.float32) + expected_backprops = np.array( + [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], + dtype=np.float32) + + with self.test_session() as session: + with self.test_scope(): + gradient_placeholder = array_ops.placeholder( + dtypes.float32, gradients.shape, name="gradients") + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + outputs = gen_array_ops.fake_quant_with_min_max_args_gradient( + gradient_placeholder, + input_placeholder, + min=input_min, + max=input_max, + num_bits=num_bits, + narrow_range=narrow_range) + backprops = session.run(outputs, { + gradient_placeholder: gradients, + input_placeholder: inputs + }) + self.assertAllCloseAccordingToType( + backprops, + expected_backprops, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + + +class FakeQuantWithMinMaxVarsTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxVars operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + expected = np.array( + [ + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_max, expected_nudged_input_max, + expected_nudged_input_max, expected_nudged_input_max + ], + dtype=np.float32) + + with self.test_session() as session: + with self.test_scope(): + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min") + max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max") + outputs = array_ops.fake_quant_with_min_max_vars( + input_placeholder, + min_placeholder, + max_placeholder, + num_bits=num_bits, + narrow_range=narrow_range) + result = session.run( + outputs, { + input_placeholder: inputs, + min_placeholder: input_min, + max_placeholder: input_max + }) + self.assertAllCloseAccordingToType( + result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) + + +class FakeQuantWithMinMaxVarsGradientTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxVarsGradient operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + gradients = np.arange(1, len(inputs) + 1, dtype=np.float32) + expected_backprops_wrt_input = np.array( + [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], + dtype=np.float32) + expected_backprops_wrt_min = 1.0 + 2.0 + expected_backprops_wrt_max = 10.0 + 11.0 + + with self.test_session() as session: + with self.test_scope(): + gradient_placeholder = array_ops.placeholder( + dtypes.float32, gradients.shape, name="gradients") + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min") + max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max") + outputs = array_ops.fake_quant_with_min_max_vars_gradient( + gradient_placeholder, + input_placeholder, + min_placeholder, + max_placeholder, + num_bits=num_bits, + narrow_range=narrow_range) + backprops_wrt_input, backprops_wrt_min, backprops_wrt_max = session.run( + outputs, { + gradient_placeholder: gradients, + input_placeholder: inputs, + min_placeholder: input_min, + max_placeholder: input_max + }) + self.assertAllCloseAccordingToType( + backprops_wrt_input, + expected_backprops_wrt_input, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + self.assertAllCloseAccordingToType( + backprops_wrt_min, + expected_backprops_wrt_min, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + self.assertAllCloseAccordingToType( + backprops_wrt_max, + expected_backprops_wrt_max, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9378b1db7245c0da3e8298e7dcd972491616b0cd --- /dev/null +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -0,0 +1,147 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.gather_nd.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class GatherNdTest(XLATestCase): + + def _runGather(self, params, indices): + with self.test_session(): + paramsp = array_ops.placeholder(params.dtype) + indicesp = array_ops.placeholder(indices.dtype) + with self.test_scope(): + gather_nd_t = array_ops.gather_nd(paramsp, indicesp) + feed_dict = {paramsp: params, indicesp: indices} + return gather_nd_t.eval(feed_dict=feed_dict) + + def testSimpleDtype(self): + for dtype in self.numeric_types: + self.assertAllEqual( + np.array([7, 7, 8], dtype=dtype), + self._runGather( + np.array([8, 1, 2, 3, 7, 5], dtype=dtype), + np.array([[4], [4], [0]], np.int32))) + + def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): + with self.test_session(): + params = np.ones((3, 3), dtype=np.float32) + + indices_empty = np.empty((0, 2), dtype=np.int32) + gather_nd_ok_val = self._runGather(params, indices_empty) + self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) + + indices_empty = np.empty((0, 1), dtype=np.int32) + gather_nd_ok_val = self._runGather(params, indices_empty) + self.assertAllClose(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val) + + params_empty = np.empty((0, 3), dtype=np.float32) + indices_empty = np.empty((0, 2), dtype=np.int32) + gather_nd_ok_val = self._runGather(params_empty, indices_empty) + self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) + + params_empty = np.empty((0, 3), dtype=np.float32) + indices_nonempty = np.zeros((1, 2), dtype=np.int32) + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, r"Gather dimension 0 is of size zero"): + self._runGather(params_empty, indices_nonempty) + + def testIndexScalar(self): + params = np.array( + [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T + indices = np.array([4, 1], dtype=np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(np.array(7), gather_nd_val) + + def testParamsRankLargerThanIndexIndexScalarSlices(self): + params = np.array( + [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T + indices = np.array( + [ + 4, + ], dtype=np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(np.array([-7, 7]), gather_nd_val) + + def testParamsRankLargerThanIndexSlices(self): + params = np.array( + [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T + indices = np.array([[4], [4], [0]], np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(np.array([[-7, 7], [-7, 7], [-8, 8]]), gather_nd_val) + + def testHigherRankParamsLargerThanIndexSlices(self): + params = np.array( + [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], + [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], + dtype=np.float32).T + indices = np.array([[4], [4], [0]], np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(params[[4, 4, 0]], gather_nd_val) + + def testEmptyIndicesLastRankMeansCopyEntireTensor(self): + params = np.array( + [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], + [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], + dtype=np.float32).T + indices = np.array([[], []], dtype=np.int32) # Size (2, 0) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual( + np.vstack((params[np.newaxis, :], params[np.newaxis, :])), + gather_nd_val) + + def testHigherRankParamsAndIndicesLargerThanIndexSlices(self): + params = np.array( + [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], + [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], + dtype=np.float32).T + indices = np.array([[[3], [2], [1]], [[4], [4], [0]]], np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(params[[3, 2, 1, 4, 4, 0]].reshape(2, 3, 2, 2), + gather_nd_val) + + def testHigherRankParams(self): + shape = (10, 20, 5, 1, 17) + params = np.random.rand(*shape).astype(np.float32) + indices = np.vstack( + [np.random.randint(0, s, size=2000, dtype=np.int32) for s in shape]).T + gather_nd_val = self._runGather(params, indices) + + expected = params[tuple(indices.T)] + self.assertAllEqual(expected, gather_nd_val) + + def testHigherRankParamsAndIndices(self): + shape = (10, 20, 5, 1, 17) + params = np.random.rand(*shape).astype(np.float32) + indices = np.vstack( + [np.random.randint(0, s, size=2000, dtype=np.int32) for s in shape]).T + indices_reshaped = indices.reshape([10, 10, 20, 5]) + gather_nd_val = self._runGather(params, indices_reshaped) + expected = params[tuple(indices.T)] + self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 13cbe6f312f5175edaec28fa7a8f28064194b0e9..1a8c4519118f69ce51ca9a5eb95a9d706c7766cc 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -122,6 +122,20 @@ class GatherTest(xla_test.XLATestCase): gather_np = np.take(params, indices, axis=axis) self.assertAllEqual(gather_np, gather_value) + def testIndicesWithDifferentDimensions(self): + with self.test_session(): + for dtype in self.numeric_tf_types: + params = array_ops.placeholder(dtype=dtype) + indices = array_ops.placeholder(dtype=np.int32) + with self.test_scope(): + gather = array_ops.gather(params, indices) + self.assertAllEqual( + 7, gather.eval(feed_dict={params: [4, 7, 2], indices: 1})) + self.assertAllEqual( + [7], gather.eval(feed_dict={params: [4, 7, 2], indices: [1]})) + self.assertAllEqual( + [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) + class GatherBenchmark(test.Benchmark): """Microbenchmarks for the gather op.""" diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py new file mode 100644 index 0000000000000000000000000000000000000000..29394f9ea5139b30f88f53de0469b27e37d79195 --- /dev/null +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class MatrixBandPartTest(XLATestCase): + + def _testMatrixBandPart(self, dtype, shape): + with self.test_session(): + batch_shape = shape[:-2] + mat = np.ones(shape).astype(dtype) + batch_mat = np.tile(mat, batch_shape + [1, 1]) + for lower in -1, 0, 1, shape[-2] - 1: + for upper in -1, 0, 1, shape[-1] - 1: + band_np = mat + if lower >= 0: + band_np = np.triu(band_np, -lower) + if upper >= 0: + band_np = np.tril(band_np, upper) + if batch_shape: + band_np = np.tile(band_np, batch_shape + [1, 1]) + + placeholder = array_ops.placeholder(dtype) + with self.test_scope(): + band = array_ops.matrix_band_part( + placeholder, + constant_op.constant(lower, dtype=dtypes.int32), + constant_op.constant(upper, dtype=dtypes.int32)) + feed_dict = {placeholder: batch_mat} + self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) + + def testMatrixBandPart(self): + for dtype in self.float_types: + for batch_shape in [[], [2,], [1, 3, 2]]: + for rows in 1, 2, 7: + for cols in 1, 2, 7: + self._testMatrixBandPart(dtype, batch_shape + [rows, cols]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cccb7f5789dce39ef8c3d4b3a7573aaa983b3fbd --- /dev/null +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -0,0 +1,130 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.MatrixTriangularSolve.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def MakePlaceholder(x): + return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape) + + +class MatrixTriangularSolveOpTest(XLATestCase): + + def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca, + placeholder_b, a, clean_a, b, verification, + atol): + feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b} + verification_np = sess.run(verification, feed_dict) + self.assertAllClose(b, verification_np, atol=atol) + + def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): + clean_a = np.tril(a) if lower else np.triu(a) + with self.test_session() as sess: + placeholder_a = MakePlaceholder(a) + placeholder_ca = MakePlaceholder(clean_a) + placeholder_b = MakePlaceholder(b) + with self.test_scope(): + x = linalg_ops.matrix_triangular_solve( + placeholder_a, placeholder_b, lower=lower, adjoint=adjoint) + verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint) + self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca, + placeholder_b, a, clean_a, b, + verification, atol) + + def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4): + transp = lambda x: np.swapaxes(x, -1, -2) + for lower, adjoint in itertools.product([True, False], repeat=2): + self._VerifyTriangularSolve( + a if lower else transp(a), b, lower, adjoint, atol) + + def testBasic(self): + rng = np.random.RandomState(0) + a = np.tril(rng.randn(5, 5)) + b = rng.randn(5, 7) + for dtype in self.float_types: + self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) + + def testBasicNotActuallyTriangular(self): + rng = np.random.RandomState(0) + a = rng.randn(5, 5) # the `a` matrix is not lower-triangular + b = rng.randn(5, 7) + for dtype in self.float_types: + self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) + + def testBasicComplexDtypes(self): + rng = np.random.RandomState(0) + a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j) + b = rng.randn(5, 7) + rng.randn(5, 7) * 1j + for dtype in self.complex_types: + self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) + + def testBatch(self): + rng = np.random.RandomState(0) + shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)), + ((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))] + tuples = itertools.product(self.float_types, shapes) + for dtype, (a_shape, b_shape) in tuples: + n = a_shape[-1] + a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n) + b = rng.randn(*b_shape) + self._VerifyTriangularSolveCombo( + a.astype(dtype), b.astype(dtype), atol=1e-3) + + def testLarge(self): + n = 1024 + rng = np.random.RandomState(0) + a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n) + b = rng.randn(n, n) + self._VerifyTriangularSolve( + a.astype(np.float32), b.astype(np.float32), True, False, 1e-4) + + def testNonSquareCoefficientMatrix(self): + rng = np.random.RandomState(0) + for dtype in self.float_types: + a = rng.randn(3, 4).astype(dtype) + b = rng.randn(4, 4).astype(dtype) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(a, b) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(a, b) + + def testWrongDimensions(self): + randn = np.random.RandomState(0).randn + for dtype in self.float_types: + lhs = constant_op.constant(randn(3, 3), dtype=dtype) + rhs = constant_op.constant(randn(4, 3), dtype=dtype) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(lhs, rhs) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(lhs, rhs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5d05094e53cfecd9476d7d87f023e8a02d7458 --- /dev/null +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -0,0 +1,93 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.reverse_sequence_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ReverseSequenceTest(XLATestCase): + + def _testReverseSequence(self, + x, + batch_axis, + seq_axis, + seq_lengths, + truth, + expected_err_re=None): + with self.test_session(): + p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) + lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) + with self.test_scope(): + ans = array_ops.reverse_sequence( + p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths) + if expected_err_re is None: + tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths}) + self.assertAllClose(tf_ans, truth, atol=1e-10) + else: + with self.assertRaisesOpError(expected_err_re): + ans.eval(feed_dict={p: x, lengths: seq_lengths}) + + def testSimple(self): + x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32) + expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32) + self._testReverseSequence( + x, + batch_axis=0, + seq_axis=1, + seq_lengths=np.array([1, 3, 2], np.int32), + truth=expected) + + def _testBasic(self, dtype, len_dtype): + x = np.asarray( + [[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]], + [[17, 18, 19, 20], [21, 22, 23, 24]]], + dtype=dtype) + x = x.reshape(3, 2, 4, 1, 1) + x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 + + # reverse dim 2 up to (0:3, none, 0:4) along dim=0 + seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype) + + truth_orig = np.asarray( + [ + [[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3 + [[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none + [[20, 19, 18, 17], [24, 23, 22, 21]] + ], # reverse 0:4 (all) + dtype=dtype) + truth_orig = truth_orig.reshape(3, 2, 4, 1, 1) + truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 + + seq_axis = 0 # permute seq_axis and batch_axis (originally 2 and 0, resp.) + batch_axis = 2 + self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth) + + def testSeqLength(self): + for dtype in self.all_types: + for seq_dtype in self.int_types: + self._testBasic(dtype, seq_dtype) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..638946e234daf28dc4a34e6c33fc0f78b8e8699b --- /dev/null +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -0,0 +1,188 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.scatter_nd.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +def _AsType(v, vtype): + return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v) + + +def _FlatInnerDims(tensor, ndims=2): + shape = list(tensor.shape) + return tensor.reshape( + [functools.reduce(lambda x, y: x * y, shape[:-ndims + 1], 1)] + + shape[-ndims + 1:]) + + +def _FlatOuterDims(tensor, ndims=2): + shape = list(tensor.shape) + return tensor.reshape( + shape[:ndims - 1] + + [functools.reduce(lambda x, y: x * y, shape[ndims - 1:], 1)]) + + +def _NumpyScatterNd(ref, indices, updates, op): + ixdim = indices.shape[-1] + num_updates = indices.size // ixdim + total_nd = len(ref.shape) + slice_size = 1 + for i in range(ixdim, total_nd): + slice_size *= ref.shape[i] + flat_indices = _FlatInnerDims(indices) + flat_updates = updates.reshape((num_updates, slice_size)) + output_flat = _FlatOuterDims(ref, ixdim + 1) + for ix_updates, ix_output in enumerate(flat_indices): + ix_output = tuple(ix_output) + output_flat[ix_output] = op(output_flat[ix_output], + flat_updates[ix_updates]) + return output_flat.reshape(ref.shape) + + +def _NumpyUpdate(indices, updates, shape): + ref = np.zeros(shape, dtype=updates.dtype) + return _NumpyScatterNd(ref, indices, updates, lambda p, u: u) + + +class ScatterNdTest(XLATestCase): + + def _VariableRankTest(self, + np_scatter, + tf_scatter, + vtype, + itype, + repeat_indices=False): + np.random.seed(8) + ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)] + indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)] + for ref_shape, indices_shape in zip(ref_shapes, indices_shapes): + num_updates = indices_shape[0] + ixdim = indices_shape[-1] + + indexable_area_shape = () + for i in range(ixdim): + indexable_area_shape += (ref_shape[i],) + all_indices = [ + list(coord) + for coord, _ in np.ndenumerate(np.empty(indexable_area_shape, vtype)) + ] + np.random.shuffle(all_indices) + indices = np.array(all_indices[:num_updates]) + + if num_updates > 1 and repeat_indices: + indices = indices[:num_updates // 2] + for _ in range(num_updates - num_updates // 2): + indices = np.append( + indices, [indices[np.random.randint(num_updates // 2)]], axis=0) + np.random.shuffle(indices) + indices = _AsType(indices[:num_updates], itype) + + updates_shape = (num_updates,) + for i in range(ixdim, len(ref_shape)): + updates_shape += (ref_shape[i],) + updates = _AsType(np.random.randn(*(updates_shape)), vtype) + + # Scatter via numpy + np_out = np_scatter(indices, updates, ref_shape) + # Scatter via tensorflow + tf_out = tf_scatter(indices, updates, ref_shape) + + self.assertAllClose(np_out, tf_out) + + def _VariableRankTests(self, np_scatter, tf_scatter): + for vtype in self.numeric_types: + for itype in set([np.int32, np.int64]).intersection(set(self.int_types)): + self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) + + def _runScatterNd(self, indices, updates, shape): + with self.test_session(): + updates_placeholder = array_ops.placeholder(updates.dtype) + indices_placeholder = array_ops.placeholder(indices.dtype) + with self.test_scope(): + output = array_ops.scatter_nd(indices_placeholder, updates_placeholder, + shape) + feed_dict = {updates_placeholder: updates, indices_placeholder: indices} + return output.eval(feed_dict=feed_dict) + + def testSimple(self): + indices = np.array([[4], [3], [1], [7]], dtype=np.int32) + updates = np.array([9, 10, 11, 12], dtype=np.float32) + expected = np.array([0, 11, 0, 10, 9, 0, 0, 12], dtype=np.int32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [8])) + + def testSimple2(self): + indices = np.array([[1, 0], [1, 1]], dtype=np.int32) + updates = np.array([11., 12.], dtype=np.float32) + expected = np.array([[0., 0.], [11., 12.], [0., 0.]], dtype=np.float32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2])) + + def testSimple3(self): + indices = np.array([[1]], dtype=np.int32) + updates = np.array([[11., 12.]], dtype=np.float32) + expected = np.array([[0., 0.], [11., 12.], [0., 0.]]) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2])) + + def testVariableRankUpdate(self): + self._VariableRankTests(_NumpyUpdate, self._runScatterNd) + + def testExtraIndicesDimensions(self): + indices = np.zeros([1, 1, 2], np.int32) + updates = np.zeros([1, 1], np.int32) + expected = np.zeros([2, 2], dtype=np.int32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2])) + + def testRank3InvalidShape1(self): + indices = np.zeros([3, 2, 2], np.int32) + updates = np.zeros([2, 2, 2], np.int32) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "Must have updates.shape"): + self._runScatterNd(indices, updates, [2, 2, 2]) + + def testRank3InvalidShape2(self): + indices = np.zeros([2, 2, 1], np.int32) + updates = np.zeros([2, 2], np.int32) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "Must have updates.shape"): + self._runScatterNd(indices, updates, [2, 2, 2]) + + def testScatterOutOfRange(self): + updates = np.array([-3, -4, -5]).astype(np.float32) + + # Indices all in range, no problem. + indices = np.array([[2], [0], [5]], dtype=np.int32) + self._runScatterNd(indices, updates, [6]) + + # Indices out of range should not fail. It produces implementation-defined + # output. + indices = np.array([[-1], [0], [5]], dtype=np.int32) + self._runScatterNd(indices, updates, [6]) + indices = np.array([[2], [0], [6]], dtype=np.int32) + self._runScatterNd(indices, updates, [6]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 260a04421b62310c109d8f0ea72875a50c234bb0..4a9c0e7471f9cdb2a47b54705495d2dda9748890 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -60,6 +60,14 @@ class SegmentReductionOpsTest(XLATestCase): np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) + def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array([6, 3, 0, 6], dtype=dtype), + self.UnsortedSegmentSum( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) + def testUnsortedSegmentSum1DIndices2DDataDisjoint(self): for dtype in self.numeric_types: data = np.array( diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 8e4b8a38336c5e8b2e10edc4c81502eeebb628d2..3d3e112f4821ea8e57ea9589a5b4433647ad294b 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -154,6 +154,21 @@ class UnaryOpsTest(XLATestCase): def testFloatOps(self): for dtype in self.float_types: + x = np.arange(-0.90, 0.90, 0.25) + self._assertOpOutputMatchesExpected( + math_ops.acos, + x.astype(dtype), + expected=np.arccos(x).astype(dtype)) + self._assertOpOutputMatchesExpected( + math_ops.asin, + x.astype(dtype), + expected=np.arcsin(x).astype(dtype)) + x = np.arange(-3, 3).reshape(1, 3, 2) + self._assertOpOutputMatchesExpected( + math_ops.atan, + x.astype(dtype), + expected=np.arctan(x).astype(dtype)) + self._assertOpOutputMatchesExpected( math_ops.acosh, np.array([1, 2, 3, 4], dtype=dtype), diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3c7dfef03dfb5d86dd63fd4aa84ad56081833035..fb82c2601c432cee425a46a3b6dc2c55febeda87 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -312,6 +312,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 0249500910c6ae441f038fe9ad6178794f1997ac..82923722c54d235716b9138d95a75a441df924ca 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -64,7 +64,7 @@ Status BackwardsConstAnalysis(const Graph& g, // Mark any compile-time constant operator arguments as const. const std::unordered_set* const_inputs = XlaOpRegistry::CompileTimeConstantInputs(node->type_string()); - if (!const_inputs) return; + if (!const_inputs || const_inputs->empty()) return; NameRangeMap input_name_ranges; status = diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 1d9e0fb33ee4a4229c78d116831e95391a5ac3f8..f8169795ddfb7fd4e93d3f136c51623385868951 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -285,7 +285,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, Status FunctionalizeLoop(Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " - << dump_graph::DumpGraphToFile("functionalize_before", *graph); + << dump_graph::DumpGraphToFile("functionalize_before", *graph, + library); // Split loop-varying Enter nodes with multiple successors. If the same // Tensor is fed as input to multiple loop arguments, we may end up with a @@ -427,16 +428,36 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, // identity nodes are values used by the loop body or condition. // The Identity node may have the wrong device so copy the device from // one of its outputs instead. + std::deque possible_exit; for (const Edge* edge : arg.switch_node->out_edges()) { - if (edge->src_output() == 0 && IsExit(edge->dst())) { + if (edge->src_output() == 0) { + possible_exit.push_back(edge); + } + if (IsIdentity(edge->dst())) { + TF_RETURN_IF_ERROR( + SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); + } + } + // TODO(b/67425339): Allow general graph between switch and exit. + while (!possible_exit.empty()) { + const Edge* edge = possible_exit.front(); + possible_exit.pop_front(); + if (IsExit(edge->dst())) { if (arg.exit != nullptr) { return errors::InvalidArgument("Duplicate Exit successors to ", arg.switch_node->name()); } arg.exit = edge->dst(); - } else if (StringPiece(edge->dst()->type_string()) == "Identity") { - TF_RETURN_IF_ERROR( - SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); + } else { + if (!IsIdentity(edge->dst())) { + return errors::Unimplemented("General graph between switch (", + arg.switch_node->name(), + ") and exit node of frame ", + frame->name, " not supported yet."); + } + for (const Edge* out : edge->dst()->out_edges()) { + possible_exit.push_back(out); + } } } } @@ -450,7 +471,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); VLOG(2) << "Frame " << frame->name << " condition: " - << dump_graph::DumpGraphToFile("loop_condition", *cond_graph) + << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); static std::atomic sequence_num(0LL); @@ -531,7 +552,8 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, frame->parent->nodes.insert(while_node); VLOG(2) << "Frame " << frame->name << " after: " - << dump_graph::DumpGraphToFile("functionalize_after", *graph); + << dump_graph::DumpGraphToFile("functionalize_after", *graph, + library); return Status::OK(); } @@ -564,11 +586,11 @@ class FunctionalizeCond { explicit CondArgNode(Node* input) : input(input) {} string ToString() const { return strings::StrCat("input=", input->name(), - " switches=", NodesToString(switch_nodes)); + " switches=", NodesToString(switches)); } Node* input; - std::vector switch_nodes; + std::vector switches; }; using CondArgNodes = std::vector; @@ -582,15 +604,22 @@ class FunctionalizeCond { int count; }; - struct PredicateSwitches { - explicit PredicateSwitches(Node* predicate) : predicate(predicate) {} + // Group of switch nodes that will be part of the same XlaIf. + struct SwitchCluster { + explicit SwitchCluster(Node* predicate) : predicate(predicate) {} + string ToString() const { + return strings::StrCat(name, " predicate=", predicate->name(), + " switches=", NodesToString(switches)); + } + string name; Node* predicate; std::vector switches; }; - FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : library_(library), graph_(graph) {} + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + bool dump_graphs) + : library_(library), graph_(graph), dump_graphs_(dump_graphs) {} // Perform the actual cond functionalization. Iterate over groups of switch // nodes (linked by common predicate), from innermost to outermost, and @@ -601,27 +630,25 @@ class FunctionalizeCond { // frontier (the nodes where the cond ends). StatusOr, std::unordered_set>> - DetermineBranchMapAndFrontier(const std::vector& switches); + DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster); // Returns XlaIf node created from subgraph of merge and switch nodes. This // encapsulates the process of extracting the bodies needed for the then and // else branch, creates a XlaIf node, removing the nodes of the branches from // the graph and replacing the merge node with a XlaIf. StatusOr ConvertToXlaIf(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, - const std::vector& merge_nodes, - Node* predicate); + const SwitchCluster& switch_cluster, + const std::vector& switches); // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. StatusOr BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, - const std::vector& merge_nodes, - Node* predicate); + const SwitchCluster& switch_cluster, + const std::vector& merge_nodes); // Extracts a function body corresponding to the given input edge of the merge // node. Status ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, + const std::vector& switches, const std::vector& merge_nodes, int input_edge, Graph* body); @@ -632,9 +659,9 @@ class FunctionalizeCond { // Adds all output edges from the `if_node`. Status AddOutputEdges(const std::vector& outputs, Node* if_node); - // Returns the switches of graph_ (along with grouping predicates) in - // postorder. Dead switch nodes are skipped and removed from the graph. - std::vector DeterminePredicateSwitchOrder(); + // Returns the switch clusters of graph_ in postorder. Dead switch nodes are + // skipped and removed from the graph. + StatusOr> DeterminePredicateSwitchOrder(); // Update the state for destination based on the state of source and the node // being updated. @@ -657,6 +684,7 @@ class FunctionalizeCond { FunctionLibraryDefinition* library_; Graph* graph_; + bool dump_graphs_; }; bool IsDeadSwitch(const Node* node) { @@ -704,10 +732,13 @@ Status FunctionalizeCond::ValidateFrontier( ") in both Else and Then branch should be in Both."); } } - if (pending[kBoth].empty() && pending[kThenBranch].empty() && - pending[kElseBranch].empty()) { - return errors::Internal("Unexpected empty frontier for switch nodes"); - } + // An empty frontier indicates a dead switch. Above we attempt to remove dead + // switch nodes, but not all are removed so don't treat it as an error yet. + // TODO(jpienaar): Find out why dead switch nodes remain. + // if (pending[kBoth].empty() && pending[kThenBranch].empty() && + // pending[kElseBranch].empty()) { + // return errors::Internal("Unexpected empty frontier for switch nodes"); + // } return Status::OK(); } @@ -734,33 +765,191 @@ Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, return Status::OK(); } -std::vector +StatusOr> FunctionalizeCond::DeterminePredicateSwitchOrder() { + struct Cluster { + bool operator==(const Cluster& other) const { + return representative == other.representative; + } + int representative = -1; + }; + + // Perform a DFS over the graph and + // * Determine the reverse topological order of the nodes (there should be no + // cycles at this point so the post-order numbering corresponds to the + // reverse topological sorting); + // * Identify dead switches; + // * Initialize the cluster's representative; + std::vector> clusters(graph_->num_node_ids()); std::vector dead_switches; std::vector switch_order; - DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) { + std::vector rev_topo_sorted_nodes; + DFS(*graph_, nullptr, [&](Node* n) { + clusters[n->id()].Get().representative = n->id(); if (IsSwitch(n)) { if (IsDeadSwitch(n)) { dead_switches.push_back(n); } else { + rev_topo_sorted_nodes.push_back(n); switch_order.push_back(n); } + } else if (n->IsOp()) { + // Exclude src and sink nodes from further consideration. + rev_topo_sorted_nodes.push_back(n); } }); + std::vector switch_clusters; + // Return early if there are no switches in the graph. + if (switch_order.empty()) { + return switch_clusters; + } + // Remove all dead switch nodes. for (Node* n : dead_switches) { VLOG(2) << "Removing dead switch: " << n->DebugString(); graph_->RemoveNode(n); } - std::vector predicate_switch_order; - if (switch_order.empty()) { - return predicate_switch_order; + // Identify switch nodes that are part of the same control flow context by + // considering the operands of operations: an operation is part of the same + // control context as its operands unless the operation is a switch. Control + // dependencies are considered part of the same control flow context if the + // switch depth is the same (see comment below). + + // entry_cluster records the input cluster to a switch node. This is used when + // merging with a merge node where the dst's cluster is merged with the entry + // cluster of the merge node's cluster (which corresponds to a switch cluster + // and so has an entry cluster). + std::unordered_map*> entry_cluster; + + // Returns the output cluster of a node. Where the output cluster is cluster + // where the output of the node is used. For non-merge nodes this is simply + // the cluster they are part of, while for merge nodes it is the entry cluster + // of the cluster they are part of (this will correspond to the entry node of + // a switch node that dominates the merge). + auto find_output_cluster = [&](Node* n) { + UnionFind* cluster = &clusters[n->id()]; + if (!IsMerge(n)) return cluster; + auto it = entry_cluster.find(clusters[n->id()].Get().representative); + // If the cluster is not found in the entry_cluster map then an + // instruction not dominated by a switch node has been merged into the + // cluster of the merge. This indicates a failure of the clustering. + CHECK(it != entry_cluster.end()) + << "Unable to find entry for n=" << n->id() << " (" + << cluster->Get().representative << ")"; + return it->second; + }; + + // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier. + std::vector switch_depth(graph_->num_node_ids()); + for (auto it = rev_topo_sorted_nodes.rbegin(); + it != rev_topo_sorted_nodes.rend(); ++it) { + Node* n = *it; + + // Compute switch depth. + int new_switch_depth = 0; + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + new_switch_depth = std::max( + new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0)); + } + switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0); + + // Only merge the input operands of a switch. The switch's clustering itself + // is determined by the interaction of the switch's outputs. + if (IsSwitch(n)) { + Node* input; + TF_CHECK_OK(n->input_node(0, &input)); + entry_cluster[n->id()] = &clusters[input->id()]; + UnionFind* cluster = find_output_cluster(input); + int cluster_depth = switch_depth[cluster->Get().representative]; + // Merge the inputs of the switch node with one another. This results in + // predicates and control input residing in the same cluster. + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + UnionFind* src_cluster = find_output_cluster(src); + int src_cluster_depth = switch_depth[src_cluster->Get().representative]; + if (cluster_depth != src_cluster_depth) { + return errors::InvalidArgument( + "Unable to functionalize control flow in graph: Switch ('", + n->name(), "') has operands ('", input->name(), "' and '", + src->name(), "') that have different switch depths (", + cluster_depth, " != ", src_cluster_depth, ")"); + } + cluster->Merge(src_cluster); + } + continue; + } + + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + if (!src->IsOp()) continue; + UnionFind* cluster = find_output_cluster(src); + // Merge a node with its data operands and with its control operands if + // the src and dst are in the same ControlContext. The ControlContext is + // not explicitly available here, and instead the switch depth is used as + // a proxy here. Due to the invariant that control edges can only be from + // a containing scope to an inner scope or from the inner scope to its + // containing scope (for exit nodes), the switch depth will only match if + // the src and dst are in the same ControlContext. Control edges between + // ControlContexts are handled during the extraction. + int src_id = cluster->Get().representative; + int src_depth = switch_depth[src_id]; + if (!e->IsControlEdge() || new_switch_depth == src_depth) { + if (src_depth != new_switch_depth) { + return errors::InvalidArgument( + "Unable to functionalize control flow in graph: Operand ('", + src->name(), "') and operator ('", n->name(), + "') have different switch depths (", src_depth, + " != ", new_switch_depth, ")"); + } + cluster->Merge(&clusters[n->id()]); + } + } + } + + if (dump_graphs_) { + // Mark the switch cluster each node is part of. + for (Node* n : graph_->nodes()) { + n->ClearAttr("_XlaFunctionalizeSwitchGroup"); + n->AddAttr("_XlaFunctionalizeSwitchGroup", + clusters[n->id()].Get().representative); + } + LOG(INFO) << "FunctionalizeControlFlow (with_clusters): " + << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_, + library_); + } + + // Verify all the nodes of a cluster are at the same depth. + std::unordered_map> cluster_to_depth_node; + for (Node* n : graph_->nodes()) { + int depth = switch_depth[n->id()]; + int cluster_rep = clusters[n->id()].Get().representative; + auto it = cluster_to_depth_node.find(cluster_rep); + if (it == cluster_to_depth_node.end()) { + cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n); + } else { + if (it->second.first != depth) { + return errors::Internal( + "Illegal clustering created, mismatch in depths:", "\n\t", + n->DebugString(), "(", clusters[n->id()].Get().representative, + ") at depth=", depth, " vs\n\t", it->second.second->DebugString(), + "(", clusters[n->id()].Get().representative, ") at depth ", + it->second.first); + } + } } + struct Hash { + size_t operator()(const std::pair& item) const { + return Hash64Combine(hash()(item.first), + std::hash()(item.second.representative)); + } + }; + // Merge Switch nodes with common predicate. - std::unordered_map predicate_index; + std::unordered_map, int, Hash> predicate_index; // The nodes in switch_order are in reverse topological order, but the // clustered switches need not be (i.e., when considered as a cluster one // element of a cluster may be later in the topological order than another @@ -769,13 +958,19 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() { for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { Node* pred; TF_CHECK_OK((*it)->input_node(1, &pred)); - if (predicate_index.find(pred) == predicate_index.end()) { - predicate_index[pred] = predicate_switch_order.size(); - predicate_switch_order.emplace_back(pred); + auto repr = std::make_pair(pred, clusters[(*it)->id()].Get()); + if (predicate_index.find(repr) == predicate_index.end()) { + predicate_index[repr] = switch_clusters.size(); + switch_clusters.emplace_back(pred); + // Generate a name by concatenating with the cluster representative as + // there could be multiple switch clusters with the same predicate. + switch_clusters[predicate_index[repr]].name = + strings::StrCat(pred->name(), "_", repr.second.representative, "_If"); } - predicate_switch_order[predicate_index[pred]].switches.push_back(*it); + switch_clusters[predicate_index[repr]].switches.push_back(*it); } - return predicate_switch_order; + + return switch_clusters; } StatusOr> @@ -823,10 +1018,10 @@ StatusOr< std::pair, std::unordered_set>> FunctionalizeCond::DetermineBranchMapAndFrontier( - const std::vector& switches) { + const SwitchCluster& switch_cluster) { std::unordered_map branch_map; std::unordered_set frontier; - std::vector stack = switches; + std::vector stack = switch_cluster.switches; std::vector visited(graph_->num_node_ids(), false); while (!stack.empty()) { Node* n = stack.back(); @@ -868,7 +1063,7 @@ FunctionalizeCond::DetermineBranchMapAndFrontier( } } - if (VLOG_IS_ON(2)) { + if (dump_graphs_) { for (const auto& kv : branch_map) { // Append attribute to the graph if running with logging to make the // changes clearer in the visualization. @@ -880,8 +1075,8 @@ FunctionalizeCond::DetermineBranchMapAndFrontier( } Status FunctionalizeCond::FunctionalizeInternal() { - std::vector predicate_switch_order = - DeterminePredicateSwitchOrder(); + TF_ASSIGN_OR_RETURN(std::vector predicate_switch_order, + DeterminePredicateSwitchOrder()); // Iterate from innermost set of clustered switches to outermost, replacing // matching switch->merge subgraphs with single XlaIf nodes. @@ -894,10 +1089,12 @@ Status FunctionalizeCond::FunctionalizeInternal() { std::unordered_map branch_map; std::unordered_set frontier; TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier), - DetermineBranchMapAndFrontier(ps.switches)); + DetermineBranchMapAndFrontier(ps)); - VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_bc", *graph_); + if (dump_graphs_) + LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_bc", *graph_, + library_); TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); // Sort the merge and switch nodes using NodeCmp. The switch-nodes are @@ -914,7 +1111,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { input_index[in] = cond_arg_nodes.size(); cond_arg_nodes.emplace_back(in); } - cond_arg_nodes.at(input_index.at(in)).switch_nodes.push_back(switch_node); + cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node); } std::vector merge_nodes(frontier.begin(), frontier.end()); std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); @@ -923,9 +1120,8 @@ Status FunctionalizeCond::FunctionalizeInternal() { EnsureDominanceAndReturnNonDominatedControlNodes( branch_map, ps.switches)); - TF_ASSIGN_OR_RETURN( - Node * if_node, - ConvertToXlaIf(cond_arg_nodes, ps.switches, merge_nodes, ps.predicate)); + TF_ASSIGN_OR_RETURN(Node * if_node, + ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes)); for (Node* old : old_control_nodes) { graph_->AddControlEdge(old, if_node); } @@ -934,25 +1130,26 @@ Status FunctionalizeCond::FunctionalizeInternal() { graph_->RemoveNode(del_kv.first); } for (auto& kv : cond_arg_nodes) { - for (Node* node : kv.switch_nodes) { + for (Node* node : kv.switches) { graph_->RemoveNode(node); } } - VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_ac", *graph_); + if (dump_graphs_) + LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_ac", *graph_, + library_); } return Status::OK(); } StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( - const CondArgNodes& cond_arg_nodes, const std::vector& switch_nodes, - const std::vector& merge_nodes, Node* predicate) { - VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input " - << NodesToString(switch_nodes); + const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, + const std::vector& merge_nodes) { + VLOG(2) << "Build if op for " << switch_cluster.name; NodeDef if_def; // Create a new If node using the name of the merge node. - NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf"); + NodeDefBuilder builder(switch_cluster.name, "XlaIf"); string branch[] = {"else_branch", "then_branch"}; for (int i = 0; i < 2; ++i) { static std::atomic sequence_num(0LL); @@ -962,12 +1159,9 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( body_name.set_name( strings::StrCat("_functionalize_if_", branch[i], "_", id)); auto body = xla::MakeUnique(graph_->op_registry()); - TF_RETURN_IF_ERROR( - ExtractBody(cond_arg_nodes, switch_nodes, merge_nodes, i, body.get())); + TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches, + merge_nodes, i, body.get())); VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); - VLOG(4) << "FunctionalizeControlFlow (" << branch[i] << "): " - << dump_graph::DumpGraphToFile( - strings::StrCat("functionalize_", branch[i]), *body); FunctionDef body_fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); @@ -979,7 +1173,7 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( DataTypeVector in_arg_types; for (auto& kv : cond_arg_nodes) { bool inserted = false; - for (const Node* arg : kv.switch_nodes) { + for (const Node* arg : kv.switches) { const Edge* in_edge; TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -1006,10 +1200,11 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( builder.Attr("Tout", out_type); builder.Attr("Tcond", DT_BOOL); - builder.Device(predicate->assigned_device_name()); + builder.Device(switch_cluster.predicate->assigned_device_name()); // Conditional should be the first input ... builder.Input( - NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0))); + NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0, + switch_cluster.predicate->output_type(0))); // ... followed by the other inputs. builder.Input(inputs); @@ -1019,7 +1214,7 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( } Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, + const std::vector& switches, const std::vector& merge_nodes, int input_edge, Graph* body) { VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " @@ -1029,7 +1224,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, int arg_count = 0; for (auto& kv : cond_arg_nodes) { Node* arg_node = nullptr; - for (const auto* arg : kv.switch_nodes) { + for (const auto* arg : kv.switches) { DataType dtype = arg->input_type(0); if (arg_node == nullptr) { TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++)); @@ -1053,8 +1248,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, node_map.at(in->id()) = body->CopyNode(in); } - if (std::find(switch_nodes.begin(), switch_nodes.end(), in) == - switch_nodes.end()) { + if (std::find(switches.begin(), switches.end(), in) == switches.end()) { body->AddEdge(node_map.at(in->id()), in_edge->src_output(), node_map.at(node->id()), 0); } else { @@ -1076,7 +1270,7 @@ Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, graph_->AddEdge(predicate, 0, if_node, index++); for (auto& kv : cond_arg_nodes) { bool inserted = false; - for (const Node* arg : kv.switch_nodes) { + for (const Node* arg : kv.switches) { const Edge* in_edge; TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -1119,16 +1313,17 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, } StatusOr FunctionalizeCond::ConvertToXlaIf( - const CondArgNodes& cond_arg_nodes, const std::vector& switch_nodes, - const std::vector& merge_nodes, Node* predicate) { - VLOG(1) << "ConvertToXlaIf for " << NodesToString(switch_nodes) << " -> " + const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, + const std::vector& merge_nodes) { + VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> " << NodesToString(merge_nodes); // Extract bodies and builds a If operator. TF_ASSIGN_OR_RETURN( Node * if_node, - BuildAndAddXlaIfOp(cond_arg_nodes, switch_nodes, merge_nodes, predicate)); - TF_RETURN_IF_ERROR(AddInputEdges(cond_arg_nodes, predicate, if_node)); + BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes)); + TF_RETURN_IF_ERROR( + AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node)); TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); return if_node; @@ -1137,18 +1332,19 @@ StatusOr FunctionalizeCond::ConvertToXlaIf( Status FunctionalizeCond::Functionalize(Graph* graph, FunctionLibraryDefinition* library) { VLOG(1) << "FunctionalizeCond::Functionalize"; - FunctionalizeCond fc(graph, library); + FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2)); return fc.FunctionalizeInternal(); } } // namespace -// Transformation that converts Tensorflow's graph control flow constructs into +// Transformation that converts TensorFlow's graph control flow constructs into // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " - << dump_graph::DumpGraphToFile("functionalize_initial", *graph); + << dump_graph::DumpGraphToFile("functionalize_initial", *graph, + library); // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this // invariant. @@ -1160,7 +1356,8 @@ Status FunctionalizeControlFlow(Graph* graph, for (Node* node : graph->op_nodes()) { const ControlFlowInfo& cf = cf_info[node->id()]; - VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name + VLOG(2) << "node: " << node->name() << " (" << node->id() + << ") frame_name: " << cf.frame_name << " frame: " << (cf.frame ? cf.frame->name() : "---") << " parent_frame: " << (cf.parent_frame ? cf.parent_frame->name() : "---"); @@ -1228,7 +1425,8 @@ Status FunctionalizeControlFlow(Graph* graph, TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); VLOG(2) << "FunctionalizeControlFlow (final): " - << dump_graph::DumpGraphToFile("functionalize_final", *graph); + << dump_graph::DumpGraphToFile("functionalize_final", *graph, + library); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 71f12a13339b9b5495631b8f9350579f6a0785a3..bc7276c3afd5060d6faeceb4d479416299ecc5da 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -38,10 +38,11 @@ namespace { // Returns the names of the "then" and "else" functions for the XlaIf node in a // graph. -Status FindIfThenAndElse(const GraphDef& graph, NameAttrList* then_fn, - NameAttrList* else_fn) { +Status FindIfThenAndElse(const GraphDef& graph, string* op_name, + NameAttrList* then_fn, NameAttrList* else_fn) { for (const NodeDef& node : graph.node()) { if (node.op() == "XlaIf") { + *op_name = node.name(); const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); *then_fn = *result; @@ -96,9 +97,10 @@ TEST(FunctionalizeControlFlow, Conditional) { GraphDef graph_def; graph.ToGraphDef(&graph_def); + string op_name; NameAttrList then_fn; NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &then_fn, &else_fn)); + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); InstantiationResultForTest else_result; TF_EXPECT_OK( InstantiateFunctionForTest(else_fn.name(), library, &else_result)); @@ -109,7 +111,7 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName("cond/Less_If"), less, + auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, std::initializer_list{less, y, x}, then_fn, else_fn, {DT_INT32}); GraphDef expected; diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md index 82b3b46a2f1e97001d1e0c6b993ec243170bc7d8..91351421bcacd26c41b5c9f98ea833730e4aef30 100644 --- a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -6,6 +6,9 @@ Operator | Type Constraint `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`AdjustContrastv2` | +`AdjustHue` | +`AdjustSaturation` | `All` | `Tidx={int32,int64}` `Angle` | `Tout={double,float}`
`T={complex64}` `Any` | `Tidx={int32,int64}` @@ -34,7 +37,7 @@ Operator | Type Constraint `BroadcastGradientArgs` | `T={int32,int64}` `Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` `Ceil` | `T={double,float}` -`Cholesky` | `T={complex64,double,float}` +`Cholesky` | `T={double,float}` `Complex` | `Tout={complex64}`
`T={double,float}` `ComplexAbs` | `Tout={double,float}`
`T={complex64}` `Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` @@ -68,7 +71,11 @@ Operator | Type Constraint `Exp` | `T={complex64,double,float}` `ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Expm1` | `T={complex64,double,float}` -`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}` +`FFT` | +`FFT2D` | +`FFT3D` | +`Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` `FloorMod` | `T={double,float,int32,int64}` @@ -80,6 +87,13 @@ Operator | Type Constraint `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`HSVToRGB` | `T={double,float}` +`IFFT` | +`IFFT2D` | +`IFFT3D` | +`IRFFT` | +`IRFFT2D` | +`IRFFT3D` | `Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Imag` | `Tout={double,float}`
`T={complex64}` @@ -105,11 +119,14 @@ Operator | Type Constraint `MatMul` | `T={complex64,double,float}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` `Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` @@ -131,6 +148,10 @@ Operator | Type Constraint `PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `QuantizeAndDequantizeV2` | `T={double,float}` +`RFFT` | +`RFFT2D` | +`RFFT3D` | +`RGBToHSV` | `T={double,float}` `RandomStandardNormal` | `dtype={float}` `RandomUniform` | `T={int32,int64}`
`dtype={double,float}` `RandomUniformInt` | `T={int32,int64}`
`Tout={int32,int64}` @@ -146,6 +167,8 @@ Operator | Type Constraint `Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` `ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` `Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResizeBilinear` | `T={double,float,int32,int64}` +`ResizeBilinearGrad` | `T={double,float}` `ResourceApplyAdagrad` | `T={double,float}` `ResourceApplyAdam` | `T={double,float}` `ResourceApplyFtrl` | `T={double,float}` @@ -156,6 +179,7 @@ Operator | Type Constraint `ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` `ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseSequence` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` `RightShift` | `T={int32,int64,uint32,uint64}` `Rint` | `T={double,float}` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md index d4b7621ad2858fe17e93d292dd807e4f7c1c336b..b9bdb829d773825005a8921f48d28b6892d8f0cd 100644 --- a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -6,6 +6,9 @@ Operator | Type Constraint `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`AdjustContrastv2` | +`AdjustHue` | +`AdjustSaturation` | `All` | `Tidx={int32,int64}` `Angle` | `Tout={double,float}`
`T={complex64}` `Any` | `Tidx={int32,int64}` @@ -34,7 +37,7 @@ Operator | Type Constraint `BroadcastGradientArgs` | `T={int32,int64}` `Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` `Ceil` | `T={double,float}` -`Cholesky` | `T={complex64,double,float}` +`Cholesky` | `T={double,float}` `Complex` | `Tout={complex64}`
`T={double,float}` `ComplexAbs` | `Tout={double,float}`
`T={complex64}` `Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` @@ -68,7 +71,11 @@ Operator | Type Constraint `Exp` | `T={complex64,double,float}` `ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Expm1` | `T={complex64,double,float}` -`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}` +`FFT` | +`FFT2D` | +`FFT3D` | +`Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` `FloorMod` | `T={double,float,int32,int64}` @@ -80,6 +87,13 @@ Operator | Type Constraint `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`HSVToRGB` | `T={double,float}` +`IFFT` | +`IFFT2D` | +`IFFT3D` | +`IRFFT` | +`IRFFT2D` | +`IRFFT3D` | `Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Imag` | `Tout={double,float}`
`T={complex64}` @@ -105,11 +119,14 @@ Operator | Type Constraint `MatMul` | `T={complex64,double,float}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` `Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` @@ -131,6 +148,10 @@ Operator | Type Constraint `PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `QuantizeAndDequantizeV2` | `T={double,float}` +`RFFT` | +`RFFT2D` | +`RFFT3D` | +`RGBToHSV` | `T={double,float}` `Range` | `Tidx={double,float,int32,int64}` `Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` @@ -143,6 +164,8 @@ Operator | Type Constraint `Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` `ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` `Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResizeBilinear` | `T={double,float,int32,int64}` +`ResizeBilinearGrad` | `T={double,float}` `ResourceApplyAdagrad` | `T={double,float}` `ResourceApplyAdam` | `T={double,float}` `ResourceApplyFtrl` | `T={double,float}` @@ -153,6 +176,7 @@ Operator | Type Constraint `ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` `ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseSequence` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` `RightShift` | `T={int32,int64,uint32,uint64}` `Rint` | `T={double,float}` diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 02215b5112d37f726604da2c2caa4f804388d6e5..058a1f2621c64a735bd9d9c9d0ae007f93aa4dea 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -60,9 +60,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, for (int i = 0; i < args->size(); ++i) { XlaCompiler::Argument& arg = (*args)[i]; arg.type = ctx->input_type(i); - - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + arg.shape = ctx->InputShape(i); if (arg.type == DT_RESOURCE) { return errors::InvalidArgument( @@ -136,7 +134,7 @@ Status GraphCompiler::Compile() { TF_RET_CHECK(src->id() < output_registry.size()); const NodeOutputs& src_outputs = output_registry[src->id()]; - tensor_inputs_[e->dst_input()] = src_outputs[e->src_output()]; + tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output()); } OpKernelContext op_context(¶ms, n->num_outputs()); diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index ba00160b6d78c1e55cc2e053cd5285344e0179fb..127562eb23d775f17179cc9ee968ec2255cf3a14 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -70,7 +70,7 @@ class GraphCompiler { private: // Partially sets params. This partially set params can be reused - // across multple nodes visit. + // across multiple nodes visit. void PartiallySetupParams(OpKernelContext::Params* params); // Tests if a node is a functional node. A functional node represents a diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 5e1b01878b74f2fbc2e84f8c2db1fa37c2c1eb0e..d2fa933cf9c085f92b2f442827a94d72938e4bb2 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -31,6 +31,8 @@ tf_kernel_library( "diag_op.cc", "dynamic_stitch_op.cc", "elu_op.cc", + "extract_image_patches_op.cc", + "fake_quantize_ops.cc", "fft_ops.cc", "fill_op.cc", "function_ops.cc", @@ -43,6 +45,9 @@ tf_kernel_library( "l2loss_op.cc", "lrn_ops.cc", "matmul_op.cc", + "matrix_band_part_op.cc", + "matrix_set_diag_op.cc", + "matrix_triangular_solve_op.cc", "mirror_pad_op.cc", "no_op.cc", "one_hot_op.cc", @@ -58,7 +63,9 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "reverse_sequence_op.cc", "scan_ops.cc", + "scatter_nd_op.cc", "segment_reduction_ops.cc", "select_op.cc", "sendrecv_ops.cc", @@ -82,7 +89,6 @@ tf_kernel_library( "variable_ops.cc", ], hdrs = [ - "gather_op.h", "index_ops.h", "shape_util.h", ], @@ -92,11 +98,15 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:scatter", + "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", + "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index a015b8e0e8949f8aaa03a78b0f88b7ea8d6aaa1c..b0ba25b9983c3a9af26728ce4b1c263c844327db 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -28,8 +28,9 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = - BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), adj_x_, adj_y_); + auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), + /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, + /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); OP_REQUIRES_OK(ctx, result.status()); ctx->SetOutput(0, result.ValueOrDie()); } diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 87d858f763560be454c162e0cf40307c68217663..fe6651793dc763d13f4a4b0ac294ec3ecf64af8f 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -33,7 +33,7 @@ class CholeskyOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Cholesky"), CholeskyOp); +REGISTER_XLA_OP(Name("Cholesky").TypeConstraint("T", kFloatTypes), CholeskyOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2970eae20a3fb71f06619f476a49d41b22bca56 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -0,0 +1,169 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace { + +class ExtractImagePatchesOp : public XlaOpKernel { + public: + explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorFormat data_format = FORMAT_NHWC; + const int num_dims = ksizes_.size(); + + OP_REQUIRES( + ctx, num_dims >= 3, + errors::InvalidArgument("Kernel size must have at least 3 dimensions")); + const int num_spatial_dims = num_dims - 2; + + OP_REQUIRES(ctx, strides_.size() == num_dims, + errors::InvalidArgument("Sliding window strides field must " + "specify ", + num_dims, " dimensions")); + OP_REQUIRES(ctx, dilations_.size() == num_dims, + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims, " dimensions")); + + int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format); + OP_REQUIRES( + ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "kernel sizes > 1 in the batch and depth " + "dimensions.")); + OP_REQUIRES( + ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not support " + "dilations in the batch and depth dimensions.")); + + for (int i = 0; i < num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + OP_REQUIRES( + ctx, ksizes_[input_dim] >= 0, + errors::Unimplemented("Kernel size values must be non-negative; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + OP_REQUIRES(ctx, strides_[input_dim] >= 1, + errors::Unimplemented("Stride values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + OP_REQUIRES(ctx, dilations_[input_dim] >= 1, + errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + } + + xla::PrimitiveType type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type)); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES( + ctx, input_shape.dims() == num_dims, + errors::InvalidArgument("input must be ", num_dims, "-dimensional", + input_shape.DebugString())); + const int64 depth = input_shape.dim_size(feature_dim); + + xla::ComputationBuilder* builder = ctx->builder(); + + // The following code is equivalent to: + // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD]) + int64 kernel_size = 1; + std::vector lhs_shape(num_dims, 1); + for (int i = 0; i < num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + lhs_shape[i] = ksizes_[input_dim]; + kernel_size *= ksizes_[input_dim]; + } + lhs_shape[num_spatial_dims] = depth; + lhs_shape[num_spatial_dims + 1] = 1; + + // Builds an identity matrix as a broadcast equality of iotas. + // iota = np.arange(np.prod(ksize), depth) + // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) + xla::ComputationDataHandle iota; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, + kernel_size * depth, &iota)); + + auto lhs = builder->Reshape(iota, lhs_shape); + auto filter = builder->ConvertElementType( + builder->Eq(lhs, iota, {num_spatial_dims + 1}), type); + + xla::ConvolutionDimensionNumbers dims; + std::vector window_strides(num_spatial_dims); + std::vector lhs_dilation(num_spatial_dims, 1); + std::vector rhs_dilation(num_spatial_dims); + std::vector> padding(num_spatial_dims); + + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); + dims.set_kernel_input_feature_dimension(num_spatial_dims); + dims.set_kernel_output_feature_dimension(num_spatial_dims + 1); + + for (int i = 0; i < num_spatial_dims; ++i) { + const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + dims.add_input_spatial_dimensions(dim); + dims.add_kernel_spatial_dimensions(i); + dims.add_output_spatial_dimensions(dim); + window_strides[i] = strides_.at(dim); + rhs_dilation[i] = dilations_.at(dim); + + int64 unused_output_size; + OP_REQUIRES_OK( + ctx, GetWindowedOutputSizeVerboseV2( + input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i], + window_strides[i], padding_, &unused_output_size, + &padding[i].first, &padding[i].second)); + } + + xla::ComputationDataHandle conv = + builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, + padding, lhs_dilation, rhs_dilation, dims); + ctx->SetOutput(0, conv); + } + + protected: + std::vector ksizes_; + std::vector dilations_; + std::vector strides_; + Padding padding_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp); +}; + +REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..453a32c494b42e9922bc35fc526f3306530054fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -0,0 +1,289 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace { + +// Gymnastics with nudged zero point is to ensure that the real zero maps to +// an integer, which is required for e.g. zero-padding in convolutional layers. +void CpuNudge(const float min, const float max, const float quant_min, + const float quant_max, float* nudged_min, float* nudged_max, + float* scale) { + *scale = (max - min) / (quant_max - quant_min); + + const float zero_point_from_min = quant_min - min / *scale; + float nudged_zero_point; + if (zero_point_from_min <= quant_min) { + nudged_zero_point = quant_min; + } else if (zero_point_from_min >= quant_max) { + nudged_zero_point = quant_max; + } else { + nudged_zero_point = std::round(zero_point_from_min); + } + + *nudged_min = (quant_min - nudged_zero_point) * (*scale); + *nudged_max = (quant_max - nudged_zero_point) * (*scale); +} + +// An XLA version of CpuNudge(). +void XlaNudge(xla::ComputationBuilder* b, const DataType data_type, + const xla::ComputationDataHandle& min, + const xla::ComputationDataHandle& max, + const float quant_min_value, const float quant_max_value, + xla::ComputationDataHandle* nudged_min, + xla::ComputationDataHandle* nudged_max, + xla::ComputationDataHandle* scale) { + *scale = b->Div(b->Sub(max, min), + XlaHelpers::FloatLiteral(b, data_type, + quant_max_value - quant_min_value)); + xla::ComputationDataHandle quant_min = + XlaHelpers::FloatLiteral(b, data_type, quant_min_value); + xla::ComputationDataHandle zero_point_from_min = + b->Sub(quant_min, b->Div(min, *scale)); + xla::ComputationDataHandle quant_max = + XlaHelpers::FloatLiteral(b, data_type, quant_max_value); + xla::ComputationDataHandle nudged_zero_point = + b->Select(b->Le(zero_point_from_min, quant_min), quant_min, + b->Select(b->Ge(zero_point_from_min, quant_max), quant_max, + b->Round(zero_point_from_min))); + *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale); + *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale); +} + +xla::ComputationDataHandle Quantize( + xla::ComputationBuilder* b, const xla::ComputationDataHandle& input, + const DataType data_type, + const xla::ComputationDataHandle& nudged_input_min, + const xla::ComputationDataHandle& nudged_input_max, + const xla::ComputationDataHandle& input_scale) { + xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); + xla::ComputationDataHandle inv_scale = b->Div(one, input_scale); + xla::ComputationDataHandle half = + XlaHelpers::FloatLiteral(b, data_type, 0.5f); + + xla::ComputationDataHandle clamped = + b->Clamp(nudged_input_min, input, nudged_input_max); + xla::ComputationDataHandle clamped_shifted = + b->Sub(clamped, nudged_input_min); + xla::ComputationDataHandle rounded = + b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half)); + return b->Add(b->Mul(rounded, input_scale), nudged_input_min); +} + +class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + quant_min_ = narrow_range ? 1 : 0; + quant_max_ = (1 << num_bits) - 1; + + float input_min, input_max; + OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max)); + CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_, + &nudged_input_max_, &input_scale_); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle input = ctx->Input(0); + const DataType data_type = ctx->input_type(0); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); + xla::ComputationDataHandle nudged_input_max = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); + xla::ComputationDataHandle input_scale = + XlaHelpers::FloatLiteral(b, data_type, input_scale_); + xla::ComputationDataHandle output = Quantize( + b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + ctx->SetOutput(0, output); + } + + private: + float quant_min_; + float quant_max_; + float nudged_input_min_; + float nudged_input_max_; + float input_scale_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp); + +class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + const float quant_min = narrow_range ? 1 : 0; + const float quant_max = (1 << num_bits) - 1; + + float input_min, input_max, scale; + OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max)); + CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_, + &nudged_input_max_, &scale); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle gradient = ctx->Input(0); + const TensorShape gradient_shape = ctx->InputShape(0); + xla::ComputationDataHandle input = ctx->Input(1); + const DataType data_type = ctx->input_type(1); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); + xla::ComputationDataHandle nudged_input_max = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); + + xla::ComputationDataHandle between_nudged_min_max = + b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); + xla::ComputationDataHandle zeroes = b->Broadcast( + XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes()); + xla::ComputationDataHandle output = + b->Select(between_nudged_min_max, gradient, zeroes); + ctx->SetOutput(0, output); + } + + private: + float nudged_input_min_; + float nudged_input_max_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"), + FakeQuantWithMinMaxArgsGradOp); + +class FakeQuantWithMinMaxVarsOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + quant_min_ = narrow_range ? 1 : 0; + quant_max_ = (1 << num_bits) - 1; + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle input = ctx->Input(0); + const DataType data_type = ctx->input_type(0); + xla::ComputationDataHandle input_min = ctx->Input(1); + xla::ComputationDataHandle input_max = ctx->Input(2); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, + &nudged_input_min, &nudged_input_max, &input_scale); + + xla::ComputationDataHandle output = Quantize( + b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + ctx->SetOutput(0, output); + } + + private: + float quant_min_; + float quant_max_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp); + +class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + quant_min_ = narrow_range ? 1 : 0; + quant_max_ = (1 << num_bits) - 1; + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle gradient = ctx->Input(0); + const TensorShape gradient_shape = ctx->InputShape(0); + xla::ComputationDataHandle input = ctx->Input(1); + const DataType data_type = ctx->input_type(1); + xla::ComputationDataHandle input_min = ctx->Input(2); + xla::ComputationDataHandle input_max = ctx->Input(3); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, + &nudged_input_min, &nudged_input_max, &input_scale); + + xla::ComputationDataHandle between_nudged_min_max = + b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type); + xla::ComputationDataHandle zeroes = + b->Broadcast(zero, gradient_shape.dim_sizes()); + xla::ComputationDataHandle output0 = + b->Select(between_nudged_min_max, gradient, zeroes); + ctx->SetOutput(0, output0); + + xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min); + xla::ComputationDataHandle output1 = + b->ReduceAll(b->Select(below_min, gradient, zeroes), zero, + *ctx->GetOrCreateAdd(data_type)); + ctx->SetOutput(1, output1); + + xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max); + xla::ComputationDataHandle output2 = + b->ReduceAll(b->Select(above_max, gradient, zeroes), zero, + *ctx->GetOrCreateAdd(data_type)); + ctx->SetOutput(2, output2); + } + + private: + float quant_min_; + float quant_max_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"), + FakeQuantWithMinMaxVarsGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index ffed38249416766850ba10f1069e706570b995fe..7945c05af40df21a798a2cff51fe7f8e935793f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/kernels/gather_op.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -26,25 +26,38 @@ limitations under the License. namespace tensorflow { -xla::ComputationDataHandle XlaComputeGatherDynamicSlice( - XlaOpKernelContext* context, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 axis, DataType dtype, - DataType index_type, xla::ComputationBuilder* builder) { +Status XlaGather(const xla::ComputationDataHandle& input, + const TensorShape& input_shape, + const xla::ComputationDataHandle& indices, + TensorShape indices_shape, int64 axis, bool indices_are_nd, + DataType dtype, DataType index_type, + xla::ComputationBuilder* builder, + xla::ComputationDataHandle* gather_output) { + // If the indices are N-dimensional, then the minor dimension of indices + // should be of size N and correspond to the N indices. + int64 num_index_dims = 1; + if (indices_are_nd) { + CHECK_GE(indices_shape.dims(), 1); + num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); + indices_shape.RemoveLastDims(1); + } + // Although the indices Tensor is flattened into rank 1 during the lookup, // and each scalar entry is used as an index into the first dimension of the // input, the output is returned with shape: // input.shape[:axis] + indices.shape + input.shape[axis+1:] - const int num_indices = indices_shape.num_elements(); + + const int64 num_indices = indices_shape.num_elements(); TensorShape input_shape_pre_axis(input_shape); input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims()); TensorShape input_shape_post_axis(input_shape); - input_shape_post_axis.RemoveDimRange(0, axis + 1); - + input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); // Each slice of the input tensor has shape: - // [, 1, ] + // [, 1, ..., 1, ] TensorShape slice_shape(input_shape); - slice_shape.set_dim(axis, 1); + for (int64 i = 0; i < num_index_dims; ++i) { + slice_shape.set_dim(axis + i, 1); + } TensorShape loop_out_shape; loop_out_shape.AppendShape(input_shape_pre_axis); @@ -62,131 +75,176 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Degenerate case: empty indices. if (num_indices == 0) { - return builder->Broadcast(XlaHelpers::Zero(builder, dtype), - out_shape.dim_sizes()); + *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype), + out_shape.dim_sizes()); + return Status::OK(); + } + + for (int64 i = 0; i < num_index_dims; ++i) { + if (input_shape.dim_size(axis + i) == 0) { + return errors::InvalidArgument("Gather dimension ", axis + i, + " is of size zero in tensor with shape ", + input_shape.DebugString()); + } + } + + // Flatten the major dimensions of indices into a single dimension for ease of + // iteration. If there is an axis dimension, we must leave it alone. + std::vector flat_indices_shape = {num_indices}; + if (indices_are_nd) { + flat_indices_shape.push_back(num_index_dims); } // Specify the shape of the loop-carried Tensor tuple. - xla::PrimitiveType ptype; - TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); - xla::PrimitiveType idxtype; - TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &idxtype)); - std::vector tuple_shapes( - {// The iteration counter i is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(idxtype, {}), - // The input array has shape input_shape. Loop invariant. - xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()), - // The gather indices are reshaped to rank 1. Loop invariant. - xla::ShapeUtil::MakeShape(idxtype, {num_indices}), - // The output array, which is updated on each loop iteration. - xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())}); - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); // Construct the initial values of the loop-carried Tensors. - auto init_i = XlaHelpers::Zero(builder, index_type); + auto flat_indices = builder->Reshape(indices, flat_indices_shape); auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), loop_out_shape.dim_sizes()); - // Flatten the indices into 1-D for ease of iteration. - auto indices_1d = builder->Reshape(indices, {num_indices}); - auto init = builder->Tuple({init_i, input, indices_1d, init_out}); - - // Construct the while loop condition (i < num_indices) - xla::ComputationBuilder condb(context->builder()->client(), - "GatherWhileCond"); - condb.Lt(condb.GetTupleElement( - condb.Parameter(0, tuple_shape, "GatherWhileTuple"), 0), - XlaHelpers::IntegerLiteral(&condb, index_type, num_indices)); - auto cond_status = condb.Build(); - auto cond = cond_status.ConsumeValueOrDie(); + auto init = {input, flat_indices, init_out}; // Construct the while loop body's function. The implementation of gather is: // for i in range(num_indices): // index = dynamic-slice(indices, i) // xi = dynamic-slice(input, index) // output = dynamic-update-slice(output, xi, i) - xla::ComputationBuilder bodyb(context->builder()->client(), - "GatherWhileBody"); - { - // The four loop carried values. - auto loop_tuple = bodyb.Parameter(0, tuple_shape, "GatherWhileTuple"); - auto i = bodyb.GetTupleElement(loop_tuple, 0); - auto input = bodyb.GetTupleElement(loop_tuple, 1); - auto indices = bodyb.GetTupleElement(loop_tuple, 2); - auto output = bodyb.GetTupleElement(loop_tuple, 3); - - // Slice from the input array. - auto index = bodyb.DynamicSlice(indices, bodyb.Reshape(i, {1}), {1}); - auto start_indices = bodyb.Pad( - bodyb.Reshape(index, {1}), XlaHelpers::Zero(&bodyb, index_type), + auto body_fn = [&](xla::ComputationDataHandle i, + gtl::ArraySlice loop_vars, + xla::ComputationBuilder* bodyb) { + auto input = loop_vars[0]; + auto indices = loop_vars[1]; + auto output = loop_vars[2]; + + auto zero_index = XlaHelpers::Zero(bodyb, index_type); + + // Slice the i-th index from the indices array. + xla::ComputationDataHandle index; + auto indices_offset = bodyb->Reshape(i, {1}); + if (indices_are_nd) { + // Slice out the entire nd index, if applicable. + indices_offset = bodyb->Pad(indices_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, 1}})); + index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims}); + index = bodyb->Collapse(index, {0, 1}); + } else { + index = bodyb->DynamicSlice(indices, indices_offset, {1}); + } + + // Slice the corresponding data from the input array. + auto start_indices = bodyb->Pad( + index, zero_index, xla::MakeEdgePaddingConfig( {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); - auto slice_i = bodyb.Reshape( - bodyb.DynamicSlice(input, start_indices, slice_shape.dim_sizes()), + auto slice_i = bodyb->Reshape( + bodyb->DynamicSlice(input, start_indices, slice_shape.dim_sizes()), loop_out_slice_shape.dim_sizes()); // Construct the index into the output Tensor 0, ..., , 0, ... std::vector out_index_vals( - loop_out_shape.dims(), - bodyb.Reshape(XlaHelpers::Zero(&bodyb, index_type), {1})); - out_index_vals[input_shape_pre_axis.dims()] = bodyb.Reshape(i, {1}); - auto out_index = bodyb.ConcatInDim(out_index_vals, 0); + loop_out_shape.dims(), bodyb->Reshape(zero_index, {1})); + out_index_vals[input_shape_pre_axis.dims()] = bodyb->Reshape(i, {1}); + auto out_index = bodyb->ConcatInDim(out_index_vals, 0); // Update the output Tensor - auto updated_output = bodyb.DynamicUpdateSlice(output, slice_i, out_index); + auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index); - bodyb.Tuple({bodyb.Add(i, XlaHelpers::One(&bodyb, index_type)), input, - indices, updated_output}); - } - auto body_status = bodyb.Build(); - auto body = body_status.ConsumeValueOrDie(); + return std::vector{input, indices, + updated_output}; + }; // Construct the While loop, extract and reshape the output. - auto gather_while = builder->While(cond, body, init); - auto gather_output = builder->GetTupleElement(gather_while, 3); - return builder->Reshape(gather_output, out_shape.dim_sizes()); + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype)); + TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn, + init, "gather", builder)); + *gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes()); + return Status::OK(); } -GatherOpDynamicSlice::GatherOpDynamicSlice(OpKernelConstruction* context) - : XlaOpKernel(context) {} - -void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) { - xla::ComputationBuilder* builder = context->builder(); - auto input = context->Input(0); - auto input_shape = context->InputShape(0); - auto indices = context->Input(1); - auto indices_shape = context->InputShape(1); - int64 axis = 0; - if (context->num_inputs() == 3) { - const TensorShape axis_shape = context->InputShape(2); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), - errors::InvalidArgument("axis must be scalar")); - DataType axis_type = input_type(2); - OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, - errors::InvalidArgument("axis must be int32 or int64")); - - OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); - const auto params_dims = input_shape.dims(); - if (axis < 0) { - axis += params_dims; +class GatherOp : public XlaOpKernel { + public: + explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::ComputationBuilder* builder = context->builder(); + auto input = context->Input(0); + auto input_shape = context->InputShape(0); + auto indices = context->Input(1); + auto indices_shape = context->InputShape(1); + int64 axis = 0; + if (context->num_inputs() == 3) { + const TensorShape axis_shape = context->InputShape(2); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), + errors::InvalidArgument("axis must be scalar")); + DataType axis_type = input_type(2); + OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, + errors::InvalidArgument("axis must be int32 or int64")); + + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); + const auto params_dims = input_shape.dims(); + if (axis < 0) { + axis += params_dims; + } + OP_REQUIRES( + context, 0 <= axis && axis < params_dims, + errors::InvalidArgument("Expected axis in the range [", -params_dims, + ", ", params_dims, "), but got ", axis)); } - OP_REQUIRES( - context, 0 <= axis && axis < params_dims, - errors::InvalidArgument("Expected axis in the range [", -params_dims, - ", ", params_dims, "), but got ", axis)); - } - DataType index_type = input_type(1); - OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("indices must be int32 or int64")); + DataType index_type = input_type(1); + OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, + errors::InvalidArgument("indices must be int32 or int64")); - xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - context, input, input_shape, indices, indices_shape, axis, input_type(0), - index_type, builder); - context->SetOutput(0, gather); -} + xla::ComputationDataHandle gather; + OP_REQUIRES_OK( + context, XlaGather(input, input_shape, indices, indices_shape, axis, + /*indices_are_nd=*/false, input_type(0), index_type, + builder, &gather)); + context->SetOutput(0, gather); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(GatherOp); +}; + +REGISTER_XLA_OP(Name("Gather"), GatherOp); +REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), GatherOp); + +class GatherNdOp : public XlaOpKernel { + public: + explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + DataType params_type = context->input_type(0); + DataType indices_type = context->input_type(1); + + TensorShape params_shape = context->InputShape(0); + TensorShape indices_shape = context->InputShape(1); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape), + errors::InvalidArgument("params must be at least a vector")); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape), + errors::InvalidArgument("indices must be at least a vector")); + const int64 num_index_dims = + indices_shape.dim_size(indices_shape.dims() - 1); + OP_REQUIRES( + context, num_index_dims <= params_shape.dims(), + errors::InvalidArgument( + "index innermost dimension length must be <= params rank; saw: ", + indices_shape.dim_size(indices_shape.dims() - 1), " vs. ", + params_shape.dims())); + + xla::ComputationBuilder* builder = context->builder(); + auto params = context->Input(0); + auto indices = context->Input(1); + xla::ComputationDataHandle gather; + OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices, + indices_shape, /*axis=*/0, + /*indices_are_nd=*/true, params_type, + indices_type, builder, &gather)); + context->SetOutput(0, gather); + } +}; -REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice); -REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), - GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.h b/tensorflow/compiler/tf2xla/kernels/gather_op.h deleted file mode 100644 index df86e1fcdd1a4860ed7ee0c5017d25ccf9d227ea..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Declaration of the Gather Op using the XLA dynamic slice implementation. - -#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_H_ -#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_H_ - -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/util/bcast.h" - -namespace tensorflow { - -class GatherOpDynamicSlice : public XlaOpKernel { - public: - explicit GatherOpDynamicSlice(OpKernelConstruction* context); - - void Compile(XlaOpKernelContext* context) override; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(GatherOpDynamicSlice); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 2c80395c56d73adad7dc1679ba6423fbe103605a..bd8b92c22d71fe89ab8951ec79f411feef6505e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -30,11 +30,16 @@ namespace tensorflow { // shape input_shape) keyed on indices (of shape indices_shape). // // index_type must be must be DT_INT32 or DT_INT64. -xla::ComputationDataHandle XlaComputeGatherDynamicSlice( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 axis, DataType dtype, - DataType index_type, xla::ComputationBuilder* builder); +// If `indices_are_nd` is true, the last dimension of `indices` are treated as +// a multidimensional index values. Otherwise, `indices` is treated as a tensor +// of scalar indices. +Status XlaGather(const xla::ComputationDataHandle& input, + const TensorShape& input_shape, + const xla::ComputationDataHandle& indices, + TensorShape indices_shape, int64 axis, bool indices_are_nd, + DataType dtype, DataType index_type, + xla::ComputationBuilder* builder, + xla::ComputationDataHandle* gather_output); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index d2b1f7913ecc9113284827b53de8fb0e5b711322..39af662b638cb9d723118e58fcfc983633fed497 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -40,6 +40,7 @@ REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); +REGISTER_XLA_OP(Name("Snapshot"), IdentityOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..faa415a97b053b4b11d015fefcd430210b98118a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class MatrixBandPartOp : public XlaOpKernel { + public: + explicit MatrixBandPartOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input_shape.DebugString())); + + const TensorShape num_lower_in_shape = context->InputShape(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in_shape), + errors::InvalidArgument("num_lower must be scalar, got shape ", + num_lower_in_shape.DebugString())); + + const TensorShape num_upper_in_shape = context->InputShape(2); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in_shape), + errors::InvalidArgument("num_upper must be scalar, got shape ", + num_upper_in_shape.DebugString())); + + xla::ComputationBuilder* builder = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle num_lower = context->Input(1); + xla::ComputationDataHandle num_upper = context->Input(2); + DataType input_type = context->input_type(0); + DataType index_type = context->input_type(1); + + TensorShape batch_shape = input_shape; + batch_shape.RemoveLastDims(2); + const int64 m = input_shape.dim_size(input_shape.dims() - 2); + const int64 n = input_shape.dim_size(input_shape.dims() - 1); + + // Compute 'offset', which is how many diagonals we are above/below the + // diagonal. + xla::ComputationDataHandle iota_m; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); + + xla::ComputationDataHandle iota_n; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); + + auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m, + /*broadcast_dimensions=*/{0}); + + // If num_lower or num_upper are negative, include all lower/upper + // diagonals. + auto zero_index = XlaHelpers::Zero(builder, index_type); + num_lower = builder->Select( + builder->Lt(num_lower, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower); + num_upper = builder->Select( + builder->Lt(num_upper, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper); + + auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset), + builder->Le(offset, num_upper)); + indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + + auto zero_input = XlaHelpers::Zero(builder, input_type); + auto output = builder->Select( + indicator, input, + builder->Broadcast(zero_input, input_shape.dim_sizes())); + + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp); +}; +REGISTER_XLA_OP(Name("MatrixBandPart"), MatrixBandPartOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2940bdcff75a087c914fdad0cb2426276e41aff --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +class MatrixSetDiagOp : public XlaOpKernel { + public: + explicit MatrixSetDiagOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const TensorShape diag_shape = context->InputShape(1); + + const int rank = input_shape.dims(); + + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input_shape.DebugString())); + + // Check to make sure the last dimension of diag is equal to the smaller of + // the last two dimensions of input. + const int64 m = input_shape.dim_size(rank - 2); + const int64 n = input_shape.dim_size(rank - 1); + const int64 min_dim = std::min(m, n); + + TensorShape batch_shape = input_shape; + batch_shape.RemoveLastDims(2); + + TensorShape expected_diag_shape = batch_shape; + expected_diag_shape.AddDim(min_dim); + OP_REQUIRES(context, expected_diag_shape == diag_shape, + errors::InvalidArgument( + "must have diagonal.shape == input.shape[:-2] + " + "min(input.shape[-2:]), but received input shape: ", + input_shape.DebugString(), + " and diagonal shape: ", diag_shape.DebugString())); + + xla::ComputationBuilder* builder = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle diag = context->Input(1); + + auto zero = XlaHelpers::Zero(builder, context->input_type(0)); + + // Create an indicator tensor that is true only on the diagonal. + xla::ComputationDataHandle iota_m; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); + xla::ComputationDataHandle iota_n; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); + auto indicator = builder->Eq(iota_m, + builder->Broadcast(iota_n, {m}), + /*broadcast_dimensions=*/{0}); + indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + + // Broadcast diag up to the input shape. Use an implicit broadcast (Add) + // because we need to broadcast on the right. + std::vector diag_broadcast_dims(rank - 1); + std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0); + if (min_dim != m) { + diag_broadcast_dims.back() = rank - 1; + } + diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()), + /*broadcast_dimensions=*/diag_broadcast_dims); + + auto output = builder->Select(indicator, diag, input); + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp); +}; + +REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eaed93146460de5a6e8328432302cc75bf36a534 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class MatrixTriangularSolveOp : public XlaOpKernel { + public: + explicit MatrixTriangularSolveOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint", &adjoint_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + auto result = TriangularSolve( + ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true, + /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); + if (!result.ok()) { + ctx->SetStatus(result.status()); + return; + } + ctx->SetOutput(0, result.ValueOrDie()); + } + + private: + bool lower_; + bool adjoint_; +}; + +REGISTER_XLA_OP(Name("MatrixTriangularSolve"), MatrixTriangularSolveOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 0b5a38967aeb5b4cd66de5220e2c764371440c2d..d4fb5dd4e06c7c70591262c0d63a91c383a2a6e0 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -37,21 +37,23 @@ class PoolingOp : public XlaOpKernel { public: PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims) : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { - std::vector ksize_int; - std::vector stride_int; - OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); - OP_REQUIRES(ctx, ksize_int.size() == num_dims(), - errors::InvalidArgument("Sliding window ksize field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int)); - OP_REQUIRES(ctx, stride_int.size() == num_dims(), - errors::InvalidArgument("Sliding window stride field must " - "specify ", - num_dims(), " dimensions")); - for (int i = 0; i < num_dims(); ++i) { - ksize_.push_back(ksize_int[i]); - stride_.push_back(stride_int[i]); + if (ctx->num_inputs() == 1) { + std::vector ksize_int; + std::vector stride_int; + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); + OP_REQUIRES(ctx, ksize_int.size() == num_dims(), + errors::InvalidArgument("Sliding window ksize field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int)); + OP_REQUIRES(ctx, stride_int.size() == num_dims(), + errors::InvalidArgument("Sliding window stride field must " + "specify ", + num_dims(), " dimensions")); + for (int i = 0; i < num_dims(); ++i) { + ksize_.push_back(ksize_int[i]); + stride_.push_back(stride_int[i]); + } } Padding padding; OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); @@ -77,6 +79,33 @@ class PoolingOp : public XlaOpKernel { xla::ComputationDataHandle input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); + std::vector ksize = ksize_; + std::vector stride = stride_; + if (ctx->num_inputs() != 1) { + const TensorShape ksize_shape = ctx->InputShape(1); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), + errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString())); + OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), + errors::InvalidArgument("Sliding window ksize field must " + "specify ", + num_dims(), " dimensions")); + ksize.clear(); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); + + const TensorShape stride_shape = ctx->InputShape(2); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), + errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString())); + OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(), + errors::InvalidArgument("Sliding window stride field must " + "specify ", + num_dims(), " dimensions")); + stride.clear(); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); + } OP_REQUIRES(ctx, input_shape.dims() == num_dims(), errors::InvalidArgument("Input to ", type_string(), " operator must have ", num_dims(), @@ -84,8 +113,8 @@ class PoolingOp : public XlaOpKernel { const DataType type = input_type(0); xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow( - input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize_, - stride_, padding_); + input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize, + stride, padding_); ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape)); } @@ -130,6 +159,10 @@ class MaxPool2DOp : public MaxPoolOp { } }; REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); +REGISTER_XLA_OP(Name("MaxPoolV2") + .CompileTimeConstInput("ksize") + .CompileTimeConstInput("strides"), + MaxPool2DOp); class MaxPool3DOp : public MaxPoolOp { public: @@ -243,22 +276,44 @@ class MaxPoolGradOp : public XlaOpKernel { public: MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + if (ctx->num_inputs() == 3) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + int num_dims() const { return num_spatial_dims_ + 2; } + + void Compile(XlaOpKernelContext* ctx) override { + if (ctx->num_inputs() != 3) { + OP_REQUIRES( + ctx, ctx->num_inputs() == 5, + errors::InvalidArgument("Must supply ksize and stride arguments.")); + const TensorShape ksize_shape = ctx->InputShape(3); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), + errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); + + const TensorShape stride_shape = ctx->InputShape(4); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), + errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); + } + OP_REQUIRES(ctx, ksize_.size() == num_dims(), errors::InvalidArgument("Sliding window ksize field must " "specify ", num_dims(), " dimensions")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); OP_REQUIRES(ctx, stride_.size() == num_dims(), errors::InvalidArgument("Sliding window strides field must " "specify ", num_dims(), " dimensions")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - } - int num_dims() const { return num_spatial_dims_ + 2; } - - void Compile(XlaOpKernelContext* ctx) override { const TensorShape tensor_in_shape = ctx->InputShape(0); const TensorShape tensor_out_shape = ctx->InputShape(1); const TensorShape out_backprop_shape = ctx->InputShape(2); @@ -315,6 +370,10 @@ class MaxPool2DGradOp : public MaxPoolGradOp { } }; REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp); +REGISTER_XLA_OP(Name("MaxPoolGradV2") + .CompileTimeConstInput("ksize") + .CompileTimeConstInput("strides"), + MaxPool2DGradOp); class MaxPool3DGradOp : public MaxPoolGradOp { public: diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bc5d3adb091cd238974c5b69b7a2f8fe639cc68 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class ReverseSequenceOp : public XlaOpKernel { + public: + explicit ReverseSequenceOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_)); + OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const TensorShape seq_lens_shape = context->InputShape(1); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape), + errors::InvalidArgument("seq_lens input must be 1-dim, not ", + seq_lens_shape.dims())); + OP_REQUIRES(context, batch_dim_ != seq_dim_, + errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_)); + OP_REQUIRES( + context, seq_dim_ < input_shape.dims(), + errors::InvalidArgument("seq_dim must be < input.dims()", "( ", + seq_dim_, " vs. ", input_shape.dims(), ")")); + OP_REQUIRES( + context, batch_dim_ < input_shape.dims(), + errors::InvalidArgument("batch_dim must be < input.dims()", "( ", + batch_dim_, " vs. ", input_shape.dims(), ")")); + OP_REQUIRES( + context, + seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_), + errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_, + "), ", "(", seq_lens_shape.num_elements(), + " vs. ", input_shape.dim_size(batch_dim_))); + + xla::ComputationBuilder* builder = context->builder(); + const auto input = context->Input(0); + const auto seq_lens = context->Input(1); + + const int64 batch_size = input_shape.dim_size(batch_dim_); + + const DataType input_type = context->input_type(0); + const DataType seq_lens_type = context->input_type(1); + const int64 max_seq_len = input_shape.dim_size(seq_dim_); + + xla::Shape input_xla_shape; + OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape, + &input_xla_shape)); + xla::Shape seq_lens_xla_shape; + OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape, + &seq_lens_xla_shape)); + + const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({ + xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}), + seq_lens_xla_shape, + input_xla_shape, + }); + + // For each entry in the batch, reverse the sequence. + // TODO(b/65689298): generalize the Map() operator to non-scalar cases and + // use it here, instead of a While loop. + + // Condition: lambda (i, _, _): i < batch_size + auto condition_builder = + builder->CreateSubBuilder("reverse_sequence_condition"); + { + auto param = condition_builder->Parameter(0, tuple_shape, "param"); + auto i = condition_builder->GetTupleElement(param, 0); + condition_builder->Lt( + i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type, + batch_size)); + } + auto condition = condition_builder->Build(); + OP_REQUIRES_OK(context, condition.status()); + + auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); + { + auto param = body_builder->Parameter(0, tuple_shape, "param"); + auto i = body_builder->GetTupleElement(param, 0); + auto seq_lens = body_builder->GetTupleElement(param, 1); + auto output = body_builder->GetTupleElement(param, 2); + + // seq_len is the sequence length of the current batch element (rank 1) + auto seq_len = body_builder->DynamicSlice( + seq_lens, body_builder->Reshape(i, {1}), {1}); + + // Indices is the offset of the batch element in the input. + auto indices = body_builder->Broadcast( + XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {input_shape.dims()}); + indices = body_builder->DynamicUpdateSlice( + indices, body_builder->Reshape(i, {1}), + body_builder->Reshape( + XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + batch_dim_), + {1})); + + // slice_indices is the offset of the start of the reversed sequence in + // the input. + auto slice_indices = body_builder->DynamicUpdateSlice( + indices, + body_builder->Sub(XlaHelpers::IntegerLiteral( + body_builder.get(), seq_lens_type, max_seq_len), + seq_len), + body_builder->Reshape( + XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + seq_dim_), + {1})); + + // Slice out the reversed sequence. The slice will overflow the end of the + // sequence, and the contents of the overflow are implementation-defined. + // However, we will mask off these elements and replace them with elements + // from the original input so their values do not matter. + TensorShape slice_shape = input_shape; + slice_shape.set_dim(batch_dim_, 1); + auto slice = body_builder->DynamicSlice(output, slice_indices, + slice_shape.dim_sizes()); + + // Shift the reversed sequence to the left. + output = body_builder->DynamicUpdateSlice(output, slice, indices); + + body_builder->Tuple( + {body_builder->Add( + i, XlaHelpers::One(body_builder.get(), seq_lens_type)), + seq_lens, output}); + } + auto body = body_builder->Build(); + OP_REQUIRES_OK(context, body.status()); + + auto loop_output = builder->While( + condition.ValueOrDie(), body.ValueOrDie(), + builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens, + builder->Rev(input, {seq_dim_})})); + auto output = builder->GetTupleElement(loop_output, 2); + + // Mask out elements after the sequence length. + xla::ComputationDataHandle iota; + OP_REQUIRES_OK( + context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); + std::vector dims(input_shape.dims(), 1); + dims[batch_dim_] = batch_size; + auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_}); + + // Broadcast the mask up to the input shape. + mask = + builder->Or(mask, builder->Broadcast(builder->ConstantR0(false), + input_shape.dim_sizes())); + + output = builder->Select(mask, output, input); + context->SetOutput(0, output); + } + + private: + int32 batch_dim_; + int32 seq_dim_; +}; + +REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8433a29c4e203cac726ee6bf7f67a863447326ed --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +// Check whether updates.shape = indices.shape[:batch_dim] + +// buffer_shape[num_index_dims:] +Status ValidateUpdateShape(const TensorShape& buffer_shape, + const TensorShape& indices_shape, + const TensorShape& updates_shape) { + if (indices_shape.dims() < 1) { + return errors::InvalidArgument( + "indices shape must have >= 1 dimension; got ", + indices_shape.DebugString()); + } + + const int64 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); + const int64 batch_dim = indices_shape.dims() - 1; + + auto shape_err = [&]() { + return errors::InvalidArgument( + "Must have updates.shape = indices.shape[:batch_dim] + ", + "buffer_shape[num_index_dims:], got updates.shape: ", + updates_shape.DebugString(), + ", indices.shape: ", indices_shape.DebugString(), + ", buffer_shape: ", buffer_shape.DebugString(), + ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim); + }; + + if (updates_shape.dims() < batch_dim) return shape_err(); + if (buffer_shape.dims() < + num_index_dims + (updates_shape.dims() - batch_dim)) { + return shape_err(); + } + if (updates_shape.dims() != + batch_dim + buffer_shape.dims() - num_index_dims) { + return shape_err(); + } + for (int d = 0; d < batch_dim; ++d) { + if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) { + return shape_err(); + } + } + for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) { + if (updates_shape.dim_size(d + batch_dim) != + buffer_shape.dim_size(d + num_index_dims)) { + return shape_err(); + } + } + return Status::OK(); +} + +class ScatterNdOp : public XlaOpKernel { + public: + explicit ScatterNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + DataType dtype = context->input_type(1); + + TensorShape indices_shape = context->InputShape(0); + TensorShape updates_shape = context->InputShape(1); + + TensorShape buffer_shape; + OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape)); + + OP_REQUIRES( + context, TensorShapeUtils::IsVectorOrHigher(buffer_shape), + errors::InvalidArgument("Output must be at least 1-D, ", + "got shape: ", buffer_shape.DebugString())); + + OP_REQUIRES( + context, + buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 && + updates_shape.num_elements() == 0), + errors::InvalidArgument( + "Indices and updates specified for empty output. indices shape: ", + indices_shape.DebugString())); + + OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape, + updates_shape)); + + xla::ComputationBuilder* builder = context->builder(); + auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype), + buffer_shape.dim_sizes()); + auto indices = context->Input(0); + auto updates = context->Input(1); + auto result = + XlaScatter(buffer, updates, indices, + /*indices_are_vectors=*/true, /*combiner=*/{}, builder); + OP_REQUIRES_OK(context, result.status()); + context->SetOutput(0, result.ValueOrDie()); + } +}; + +REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstInput("shape"), ScatterNdOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/scatter_op_helpers.h deleted file mode 100644 index a5ab7de17adb734014fe2dcbd60ae5c219c8e486..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/scatter_op_helpers.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Helper methods for XLA Scatter Ops. -#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_ -#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_ - -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/util/bcast.h" - -namespace tensorflow { - -// Adds to builder an XLA computation that performs a scatter-add of input (of -// shape input_shape) keyed on indices (of shape indices_shape). The shape -// of the Tensor returned by this is num_segments input_shape[indices.dims():] -// -static xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 num_segments, DataType dtype, - xla::ComputationBuilder* builder); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index c220edd588071ef262621784015d34cd475b2918..80d6df6c48b0141734dcee1c2a3c413926931feb 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,125 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/types.h" namespace tensorflow { - -xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 num_segments, DataType dtype, - xla::ComputationBuilder* builder) { - // Flatten data for dynamic indexing via indices_1d. - TensorShape input_shape_i(input_shape); - for (int64 d = 0; d < indices_shape.dims(); ++d) { - input_shape_i.RemoveDim(0); - } - TensorShape flat_shape({indices_shape.num_elements()}); - flat_shape.AppendShape(input_shape_i); - - // output is same as flattened input shape with dim_size(0) = num_segments. - TensorShape out_shape(flat_shape); - out_shape.set_dim(0, num_segments); - - // Slices from the input data are same shape as the input data, except dim 0. - TensorShape slice_shape(flat_shape); - slice_shape.set_dim(0, 1); - TensorShape loop_out_slice_shape(out_shape); - loop_out_slice_shape.set_dim(0, 1); - - // Construct the initial values of the loop-carried variables - // Flatten the indices into 1-D for ease of iteration. - auto indices_1d = builder->Reshape(indices, {indices_shape.num_elements()}); - // Flatten the data for ease of indexing via values in indices_1d. - auto data_flat = builder->Reshape(input, flat_shape.dim_sizes()); - - auto init_i = builder->ConstantR0(0); - auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - out_shape.dim_sizes()); - - xla::PrimitiveType ptype; - TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); - - std::vector tuple_shapes( - {// The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The flattened input data is loop invariant. - xla::ShapeUtil::MakeShape(ptype, flat_shape.dim_sizes()), - // The scatter indices tensor is loop invariant. - xla::ShapeUtil::MakeShape(xla::S32, {indices_shape.num_elements()}), - // The output data array is updated each loop iteration. - xla::ShapeUtil::MakeShape(ptype, out_shape.dim_sizes())}); - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - - auto init = builder->Tuple({init_i, data_flat, indices_1d, init_out}); - - // Construct the while loop condition (i < num_indices) - xla::ComputationBuilder condb(ctx->builder()->client(), - "ScatterAddWhileCond"); - condb.Lt(condb.GetTupleElement( - condb.Parameter(0, tuple_shape, "ScatterAddWhileTuple"), 0), - condb.ConstantR0(indices_shape.num_elements())); - auto cond_status = condb.Build(); - auto cond = cond_status.ConsumeValueOrDie(); - - // Construct the while loop body's function. The implementation of scatter is: - // for i in range(num_indices): - // index = dynamic-slice(indices, i) - // xi = dynamic-slice(input, i) - // output = dynamic-update-slice(output, xi, index) - xla::ComputationBuilder bodyb(ctx->builder()->client(), - "ScatterAddWhileBody"); - { - auto input_tuple = bodyb.Parameter(0, tuple_shape, "ScatterAddWhileTuple"); - auto i = bodyb.GetTupleElement(input_tuple, 0); - auto data = bodyb.GetTupleElement(input_tuple, 1); - auto idcs = bodyb.GetTupleElement(input_tuple, 2); - auto output = bodyb.GetTupleElement(input_tuple, 3); - - // Index into the data array at i. - auto zero = bodyb.ConstantR1({0}); - std::vector index_vals(flat_shape.dims(), zero); - index_vals[0] = bodyb.Reshape(i, {1}); - auto index = bodyb.ConcatInDim(index_vals, 0); - - auto data_slice = - bodyb.Reshape(bodyb.DynamicSlice(data, index, slice_shape.dim_sizes()), - loop_out_slice_shape.dim_sizes()); - - // Index into the output array. - std::vector out_index_vals(out_shape.dims(), - zero); - out_index_vals[0] = bodyb.DynamicSlice(idcs, bodyb.Reshape(i, {1}), {1}); - auto out_index = bodyb.ConcatInDim(out_index_vals, 0); - - // Slice the output array, update value, and update the output slice. - auto updated_output = bodyb.DynamicUpdateSlice( - output, - bodyb.Add(data_slice, - bodyb.DynamicSlice(output, out_index, - loop_out_slice_shape.dim_sizes())), - out_index); - - auto ip1 = bodyb.Add(i, bodyb.ConstantR0(1)); - bodyb.Tuple({ip1, data, idcs, updated_output}); - } - auto body_status = bodyb.Build(); - auto body = body_status.ConsumeValueOrDie(); - - auto gather_while = builder->While(cond, body, init); - return builder->GetTupleElement(gather_while, 3); -} - namespace { class UnsortedSegmentSum : public XlaOpKernel { @@ -153,10 +41,10 @@ class UnsortedSegmentSum : public XlaOpKernel { // as data with the first indices.rank dimensions are replaced // by a single dimension with size num_segments. auto data = ctx->Input(0); - auto data_shape = ctx->InputShape(0); + TensorShape data_shape = ctx->InputShape(0); auto indices = ctx->Input(1); - auto indices_shape = ctx->InputShape(1); + TensorShape indices_shape = ctx->InputShape(1); int64 num_segments; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments)); @@ -174,10 +62,21 @@ class UnsortedSegmentSum : public XlaOpKernel { d, " differs ", data_shape.dim_size(d), " vs. ", indices_shape.dim_size(d))); } - auto result = XlaComputeScatterAddDynamicSlice( - ctx, data, data_shape, indices, indices_shape, num_segments, dtype_, - ctx->builder()); - ctx->SetOutput(0, result); + xla::ComputationBuilder* builder = ctx->builder(); + TensorShape buffer_shape = data_shape; + buffer_shape.RemoveDimRange(0, indices_shape.dims()); + buffer_shape.InsertDim(0, num_segments); + auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), + buffer_shape.dim_sizes()); + + auto combiner = + [](xla::ComputationDataHandle a, xla::ComputationDataHandle b, + xla::ComputationBuilder* builder) { return builder->Add(a, b); }; + + auto result = XlaScatter(buffer, /*updates=*/data, indices, + /*indices_are_vectors=*/false, combiner, builder); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index d77fb768ef4d124c403a1dc9b321c4f29571d806..1a78c7ab9be701d3d02285ed21604f0f856b3f1f 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -77,10 +77,8 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder, // Stack has not been initialized. xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type()); - TF_RETURN_IF_ERROR(resource->SetValue( - dtype, - builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()), - builder->ConstantR0(0)}))); + TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); + TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { // Checks the expected shape matches the actual shape. TensorShape actual_shape; @@ -119,8 +117,8 @@ class StackOp : public XlaOpKernel { string name = strings::StrCat("Stack: ", stack_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, - value, &resource)); - resource->set_tensor_array_size(size); + TensorShape(), value, /*tensor_array_size=*/size, + /*tensor_array_gradients=*/{}, &resource)); ctx->SetResourceOutput(0, resource); } @@ -164,11 +162,9 @@ class StackPushOp : public XlaOpKernel { // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - OP_REQUIRES_OK( - ctx, - resource->SetValue( - dtype_, b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices), - b->Add(index, b->ConstantR0(1))}))); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple( + {b->DynamicUpdateSlice(ta, update, start_indices), + b->Add(index, b->ConstantR0(1))}))); ctx->SetOutput(0, value); } @@ -208,7 +204,7 @@ class StackPopOp : public XlaOpKernel { xla::ComputationDataHandle index = b->GetTupleElement(state, 1); index = b->Sub(index, b->ConstantR0(1)); - OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, b->Tuple({ta, index}))); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index f0525a5fb86d6d6f0aae954a916186cffc7f3a9f..91c169428c7a88a8d107a97445aeea999946e3e9 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -231,6 +231,7 @@ class StridedSliceAssignOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); } void Compile(XlaOpKernelContext* ctx) override { @@ -252,9 +253,9 @@ class StridedSliceAssignOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, &strides_tensor)); - DataType lhs_type; TensorShape lhs_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape)); + xla::ComputationDataHandle lhs; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); const TensorShape rhs_shape = ctx->InputShape(4); @@ -282,9 +283,6 @@ class StridedSliceAssignOp : public XlaOpKernel { " does not match r-value shape ", rhs_shape.DebugString(), ". Automatic broadcasting not yet implemented.")); - xla::ComputationDataHandle lhs; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs)); - xla::ComputationDataHandle rhs = ctx->Input(4); gtl::InlinedVector dimensions_to_reverse; @@ -320,13 +318,14 @@ class StridedSliceAssignOp : public XlaOpKernel { lhs, rhs, ctx->builder()->ConstantR1(slice_begin)); } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); } private: int32 begin_mask_, end_mask_; int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; DataType index_type_; + DataType dtype_; }; REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 9224072a3cb92b8ff0e99c79e568ca1a76966ed6..000b50af6bd86b7268c016865fb0856c16053ece 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -62,15 +62,13 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, TF_RET_CHECK(resource->tensor_array_size() >= 0) << resource->name() << " size " << resource->tensor_array_size(); - TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size()); - ta_shape.AppendShape(elem_shape); if (!resource->initialized()) { xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type()); - TF_RETURN_IF_ERROR(resource->SetValue( - dtype, builder->Broadcast(zero, ta_shape.dim_sizes()))); + + TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); + TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { // Checks the elem_shape matches the TensorArray shape. auto shape_or_status = builder->GetShape(resource->value()); @@ -80,6 +78,10 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, TensorShape shape; TF_RETURN_IF_ERROR( XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + + TensorShape ta_shape; + ta_shape.AddDim(resource->tensor_array_size()); + ta_shape.AppendShape(elem_shape); if (ta_shape != shape) { return errors::InvalidArgument( "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ", @@ -114,10 +116,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name, Status GetTensorArrayShape(const XlaResource* resource, xla::ComputationBuilder* builder, TensorShape* shape) { - TF_RETURN_IF_ERROR(resource->GetShape(builder, shape)); - if (shape->dims() < 1) { - return errors::InvalidArgument("TensorArray rank must be >= 1"); - } + *shape = resource->shape(); + shape->InsertDim(0, resource->tensor_array_size()); return Status::OK(); } @@ -160,8 +160,8 @@ class TensorArrayOp : public XlaOpKernel { // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. xla::ComputationDataHandle value; + TensorShape shape; if (element_shape_.IsFullyDefined()) { - TensorShape shape; CHECK(element_shape_.AsTensorShape(&shape)); TensorShape ta_shape; ta_shape.AddDim(size); @@ -175,8 +175,8 @@ class TensorArrayOp : public XlaOpKernel { string name = strings::StrCat("TensorArray: ", tensor_array_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), - dtype_, value, &var)); - var->set_tensor_array_size(size); + dtype_, shape, value, /*tensor_array_size=*/size, + /*tensor_array_gradients=*/{}, &var)); ctx->SetResourceOutput(0, var); Tensor flow(DT_FLOAT, TensorShape({})); @@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written)); + OP_REQUIRES_OK(ctx, resource->SetValue(written)); ctx->SetOutput(0, flow); } @@ -337,8 +337,11 @@ class TensorArrayGatherOp : public XlaOpKernel { } } - xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); + xla::ComputationDataHandle gather; + OP_REQUIRES_OK( + ctx, + XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0, + /*indices_are_nd=*/false, dtype_, index_type, b, &gather)); ctx->SetOutput(0, gather); } @@ -421,7 +424,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } } - OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta)); + OP_REQUIRES_OK(ctx, resource->SetValue(ta)); ctx->SetOutput(0, flow); } @@ -525,9 +528,8 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK( - ctx, resource->SetValue( - dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())))); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Add( + ta, b->Reshape(value, ta_shape.dim_sizes())))); ctx->SetOutput(0, flow); } diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 5534d1bfa1338c7fe3647cd6aa281c4907dfdf8c..f750f7003be288461f5f10455e58932d1b4e4524 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -32,9 +32,24 @@ class ResourceApplyGradientDescent : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationDataHandle handle; xla::ComputationBuilder* b = ctx->builder(); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + DataType type = ctx->input_type(1); + TensorShape var_shape; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); + + TensorShape alpha_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + + TensorShape delta_shape = ctx->InputShape(2); + OP_REQUIRES( + ctx, var_shape.IsSameSize(delta_shape), + errors::InvalidArgument("var and delta do not have the same shape: ", + var_shape.DebugString(), " vs ", + delta_shape.DebugString())); + handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2))); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -52,18 +67,10 @@ class ResourceApplyMomentum : public XlaOpKernel { DataType type = ctx->input_type(2); - DataType var_type, accum_type; TensorShape var_shape, accum_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == accum_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyMomentum must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type))); + xla::ComputationDataHandle var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -86,10 +93,6 @@ class ResourceApplyMomentum : public XlaOpKernel { errors::InvalidArgument("momentum is not a scalar: ", momentum_shape.DebugString())); - xla::ComputationDataHandle var, accum; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); - xla::ComputationDataHandle lr = ctx->Input(2); xla::ComputationDataHandle grad = ctx->Input(3); xla::ComputationDataHandle momentum = ctx->Input(4); @@ -122,18 +125,10 @@ class ResourceApplyAdagrad : public XlaOpKernel { DataType type = ctx->input_type(2); - DataType var_type, accum_type; TensorShape var_shape, accum_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == accum_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyAdagrad must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type))); + xla::ComputationDataHandle var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -151,9 +146,6 @@ class ResourceApplyAdagrad : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, accum; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); xla::ComputationDataHandle lr = ctx->Input(2); xla::ComputationDataHandle grad = ctx->Input(3); @@ -175,18 +167,11 @@ class ResourceApplyAdam : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType var_type, m_type, v_type; TensorShape var_shape, m_shape, v_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape)); - - OP_REQUIRES( - ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyRMSProp must match: ", - DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ", - DataTypeString(m_type), " vs. ", DataTypeString(v_type))); + xla::ComputationDataHandle var, m, v; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); TensorShape beta1_power_shape = ctx->InputShape(3); TensorShape beta2_power_shape = ctx->InputShape(4); @@ -228,10 +213,6 @@ class ResourceApplyAdam : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, m, v; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v)); xla::ComputationDataHandle beta1_power = ctx->Input(3); xla::ComputationDataHandle beta2_power = ctx->Input(4); xla::ComputationDataHandle lr = ctx->Input(5); @@ -278,18 +259,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { DataType type = ctx->input_type(3); - DataType var_type, ms_type, mom_type; TensorShape var_shape, ms_shape, mom_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == ms_type && type == mom_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyRMSProp must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ", - DataTypeString(ms_type), " vs. ", DataTypeString(mom_type))); + xla::ComputationDataHandle var, ms, mom; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom)); TensorShape lr_shape = ctx->InputShape(3); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), @@ -323,10 +297,6 @@ class ResourceApplyRMSProp : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, ms, mom; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom)); xla::ComputationDataHandle lr = ctx->Input(3); xla::ComputationDataHandle rho = ctx->Input(4); xla::ComputationDataHandle momentum = ctx->Input(5); @@ -373,20 +343,11 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { xla::ComputationBuilder* b = ctx->builder(); - DataType var_type, accum_type, linear_type; TensorShape var_shape, accum_shape, linear_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape)); - - OP_REQUIRES( - ctx, dtype == var_type && dtype == accum_type && dtype == linear_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyFtrlV2 must match: ", - DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type), " and ", DataTypeString(linear_type))); + xla::ComputationDataHandle var, accum, linear; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -438,10 +399,6 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, errors::InvalidArgument("lr_power is not a scalar: ", lr_power_shape.DebugString())); - xla::ComputationDataHandle var, accum, linear; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear)); xla::ComputationDataHandle grad = ctx->Input(3); xla::ComputationDataHandle lr = ctx->Input(4); xla::ComputationDataHandle l1 = ctx->Input(5); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index a266e9013c41b88788dbc99849f01c09f3d61348..0c5ad9e5255ffc3dfcfb83335060ae833937b3ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -50,18 +50,41 @@ XLAJIT_MAKE_UNARY(Conj, b->Conj(x)); // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) +XLAJIT_MAKE_UNARY( + Acos, + b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), + b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), + b->Mul(x, x)), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), + b->Add(XlaHelpers::One(b, input_type(0)), x)))); + // acosh(x) = log(x + sqrt(x^2 - 1)) XLAJIT_MAKE_UNARY( Acosh, b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x), XlaHelpers::One(b, input_type(0))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + +// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) +XLAJIT_MAKE_UNARY( + Asin, + b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), + b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)), + b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), + b->Mul(x, x)), + XlaHelpers::FloatLiteral(b, input_type(0), + 0.5)))))); + // asinh(x) = log(x + sqrt(x^2 + 1)) XLAJIT_MAKE_UNARY( Asinh, b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x), XlaHelpers::One(b, input_type(0))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + +XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0)))); + // atanh(x) = 0.5 * log((1 + x) / (1 - x)) XLAJIT_MAKE_UNARY( Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x), diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 68847ae7a2cb926edd9d29007e24b0db7fb5a75f..71173f5aead47702f0ed9e95b827a6fefd9b7efd 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -33,21 +33,29 @@ class VarIsInitializedOp : public XlaOpKernel { public: explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle handle; - bool initialized = ctx->ReadVariableInput(0, &handle).ok(); - ctx->SetOutput(0, ctx->builder()->ConstantR0(initialized)); + XlaResource* variable; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); + ctx->SetOutput(0, + ctx->builder()->ConstantR0(variable->initialized())); } }; REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); class ReadVariableOp : public XlaOpKernel { public: - explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + void Compile(XlaOpKernelContext* ctx) override { xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK( + ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle)); ctx->SetOutput(0, handle); } + + private: + DataType dtype_; }; REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); @@ -65,10 +73,12 @@ class AssignAddVariableOp : public XlaOpKernel { public: explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(1); xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Add(handle, ctx->Input(1)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -79,10 +89,12 @@ class AssignSubVariableOp : public XlaOpKernel { public: explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(1); xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Sub(handle, ctx->Input(1)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -95,28 +107,21 @@ class ResourceGatherOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); - // Get the shape of the resource tensor. - TensorShape resource_shape; - DataType resource_dtype; - OP_REQUIRES_OK( - ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape)); - - DataType expected_output_dtype = ctx->expected_output_dtype(0); - OP_REQUIRES(ctx, resource_dtype == expected_output_dtype, - errors::InvalidArgument( - "Variable dtype is ", DataTypeString(resource_dtype), - " but expected output dtype is ", - DataTypeString(expected_output_dtype), ".")); + DataType type = ctx->expected_output_dtype(0); + TensorShape resource_shape; xla::ComputationDataHandle resource_handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, + &resource_handle)); auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, resource_handle, resource_shape, indices, indices_shape, 0, - resource_dtype, index_type, builder); + xla::ComputationDataHandle gather; + OP_REQUIRES_OK( + ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, index_type, + builder, &gather)); ctx->SetOutput(0, gather); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 4a711e4d9b7aedb166a8a0ec9fe9ec2390f01b17..0ff1b65ae9179d506e453f98097cd88083eb2be7 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -58,9 +58,8 @@ Status MakeXlaCompilerArgumentsFromInputs( } arg.type = resource->type(); - if (arg.initialized) { - TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape)); - } else { + arg.shape = resource->shape(); + if (!arg.initialized) { *has_uninitialized_vars = true; } arg.tensor_array_size = resource->tensor_array_size(); @@ -70,14 +69,13 @@ Status MakeXlaCompilerArgumentsFromInputs( arg.name = resource->name(); VLOG(2) << " resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << xla::ShapeUtil::HumanString(arg.shape) + << " shape: " << arg.shape.DebugString() << " initialized: " << arg.initialized; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = ctx->input_type(i); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + arg.shape = ctx->InputShape(i); } } return Status::OK(); @@ -154,17 +152,14 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler::Argument& arg = arguments[update.input_index]; if (!arg.initialized) { VLOG(2) << "Update shape for argument " << update.input_index << " " - << xla::ShapeUtil::HumanString(update.shape); + << update.shape.DebugString(); arg.initialized = true; - xla::Shape shape = update.shape; - if (!update.tensor_array_gradients_accessed.empty()) { - shape = xla::ShapeUtil::GetTupleElementShape(shape, 0); - } - std::unique_ptr zero = - xla::Literal::CreateFromShape(shape); - OP_REQUIRES_OK(ctx, resource->SetValue( - update.type, builder->ConstantLiteral(*zero))); + arg.shape = update.shape; + OP_REQUIRES_OK(ctx, + resource->SetTypeAndShape(update.type, update.shape)); + + OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder)); } // Add any TensorArray gradients touched by the body to the enclosing @@ -182,9 +177,6 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } - - // Recompute the argument shape. - OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape)); } // Recompile the body with the "correct" resource shapes. VLOG(1) << "Recompiling body with corrected resource shapes"; @@ -292,13 +284,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - builder->GetTupleElement(while_result, pos), - /*reset_initial_values=*/false, builder)); + builder->GetTupleElement(while_result, pos), builder)); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index << " name: " << resource->name() << " modified: " << update.modified << " type: " << DataTypeString(update.type) - << " shape: " << xla::ShapeUtil::HumanString(update.shape); + << " shape: " << update.shape.DebugString(); // Copies the identity of the resource variable from input to output // unchanged, even if the variable was not modified. ctx->op_kernel_context()->set_output( diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 21ad21f73737a289390ed1ea767db1078d05b466..488fda74bf7b5c1d66f8d706a1be3cc1fc29a492 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -49,6 +49,25 @@ cc_library( ], ) +cc_library( + name = "scatter", + srcs = ["scatter.cc"], + hdrs = ["scatter.h"], + deps = [ + ":util", + ":while_loop", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/core:lib", + ], +) + cc_library( name = "triangular_solve", srcs = ["triangular_solve.cc"], @@ -60,6 +79,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/core:lib", @@ -105,6 +126,21 @@ cc_library( ], ) +cc_library( + name = "while_loop", + srcs = ["while_loop.cc"], + hdrs = ["while_loop.h"], + deps = [ + ":util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 9b0e6174475c22e325c090bec5f1d56822e106bc..798f0fa78055e800038e8bf41b4f410b670be7dd 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -25,11 +25,10 @@ limitations under the License. namespace tensorflow { -// The current implementation simply unrolls the computation along the batch -// dimension. xla::StatusOr BatchDot( xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) { + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, + bool conjugate_x, bool conjugate_y) { TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape, builder->GetShape(x)); TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape, @@ -89,10 +88,10 @@ xla::StatusOr BatchDot( dimensions); } - if (x_shape->element_type() == xla::C64 && transpose_x) { + if (x_shape->element_type() == xla::C64 && conjugate_x) { x = builder->Conj(x); } - if (y_shape->element_type() == xla::C64 && transpose_y) { + if (y_shape->element_type() == xla::C64 && conjugate_y) { y = builder->Conj(y); } diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index b46bc7417d29dc5b7e9649ac28cc78b57d4b619c..b230e885f10f45a78cdd6e455da3ba55ce589b96 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -27,7 +27,10 @@ namespace tensorflow { // viewed as an element of a batch), and arranges the individual results // in a single output tensor of the same batch size. Each of the // individual slices can optionally be transposed before multiplication by -// setting the `transpose_x` or `transpose_y` flag to `true`. +// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each +// can be elementwise-complex-conjugated by setting the `conjugate_x` or +// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both +// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`. // // The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` // and `[..., r_y, c_y]`. @@ -40,11 +43,10 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -// TODO(phawkins): add an option to take the complex conjugate of the LHS or -// RHS. xla::StatusOr BatchDot( xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y); + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, + bool conjugate_x = false, bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index b3cc489adf6042acb3f56b3a0a6c8fbe43bde629..e795701181dd80a2ff544743d513bffd52fd2399 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -71,11 +71,14 @@ xla::StatusOr CholeskyUnblocked( SliceInMinorDims(builder, l, {j + 1, 0}, {n, j})); TF_ASSIGN_OR_RETURN(auto r_squared, BatchDot(builder, r, r, /*transpose_x=*/false, - /*transpose_y=*/true)); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false)); new_d_squared = builder->Sub(new_d_squared, r_squared); TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false, - /*transpose_y=*/true)); + /*transpose_y=*/true, + /*conjugate_x=*/false, + /*conjugate_y=*/false)); } auto new_d_inv = builder->Pow( new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5)); @@ -134,7 +137,8 @@ xla::StatusOr Cholesky( SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); TF_ASSIGN_OR_RETURN(auto delta, BatchDot(builder, lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true)); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false)); TF_ASSIGN_OR_RETURN(auto before, SliceInMinorDims(builder, a, {i, i}, {n, i + k})); TF_ASSIGN_OR_RETURN( @@ -155,6 +159,10 @@ xla::StatusOr Cholesky( SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); TF_ASSIGN_OR_RETURN(auto update, TriangularSolve(builder, factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*transpose_a=*/true, + /*conjugate_a=*/false, /*block_size=*/8)); TF_ASSIGN_OR_RETURN( l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 2bead7359baaf3582c1230adf0cd4a90046859d2..e083a383be4be0d1b556b63214fe5f70323b4149 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -29,6 +29,7 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. +// TODO(mattjj): handle the complex Hermitian case xla::StatusOr Cholesky( xla::ComputationBuilder* builder, xla::ComputationDataHandle a, int64 block_size = 256); diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc new file mode 100644 index 0000000000000000000000000000000000000000..45699233ea8b2a75e3850098250307b95546cc28 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -0,0 +1,192 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/scatter.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +xla::StatusOr XlaScatter( + const xla::ComputationDataHandle& buffer, + const xla::ComputationDataHandle& updates, + const xla::ComputationDataHandle& indices, bool indices_are_vectors, + const std::function& combiner, + xla::ComputationBuilder* builder) { + TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_shape, + builder->GetShape(buffer)); + TF_ASSIGN_OR_RETURN(std::unique_ptr updates_shape, + builder->GetShape(updates)); + TF_ASSIGN_OR_RETURN(std::unique_ptr indices_shape, + builder->GetShape(indices)); + gtl::ArraySlice indices_dims = + xla::AsInt64Slice(indices_shape->dimensions()); + gtl::ArraySlice buffer_dims = + xla::AsInt64Slice(buffer_shape->dimensions()); + + // If the indices are N-dimensional, the minor dimension of indices contains + // the indices to update. Otherwise the indices are all scalars. + int64 num_index_dims = 1; + if (indices_are_vectors) { + TF_RET_CHECK(!indices_dims.empty()); + num_index_dims = indices_dims.back(); + if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) { + return errors::InvalidArgument( + "The size of the minor dimension of the indices (shape: ", + xla::ShapeUtil::HumanString(*indices_shape), + ") must be <= the rank of the buffer (shape: ", + xla::ShapeUtil::HumanString(*buffer_shape), ")"); + } + indices_dims.pop_back(); + } + + int64 num_indices = 1; + for (int64 dim : indices_dims) { + num_indices *= dim; + } + + // Degenerate case: nothing to update. Return the buffer unchanged. + if (num_indices == 0) { + return buffer; + } + + // If any of the indexed dimensions are zero in the buffer, the update cannot + // succeed since it updates a slice of size 1. + for (int64 i = 0; i < num_index_dims; ++i) { + if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) { + return errors::InvalidArgument( + "Scatter dimension ", i, " is of size zero in tensor with shape ", + xla::ShapeUtil::HumanString(*buffer_shape)); + } + } + + // Shape of the non-indexed dimensions of the buffer. + std::vector buffer_shape_post_axes( + buffer_dims.begin() + num_index_dims, buffer_dims.end()); + + // Flatten the major dimensions of indices and updates into a single dimension + // for ease of iteration. + std::vector flat_indices_shape({num_indices}); + if (indices_are_vectors) { + flat_indices_shape.push_back(num_index_dims); + } + + std::vector flat_updates_shape({num_indices}); + flat_updates_shape.insert(flat_updates_shape.end(), + buffer_shape_post_axes.begin(), + buffer_shape_post_axes.end()); + + // Construct the initial values of the loop-carried Tensors. + auto flat_indices = builder->Reshape(indices, flat_indices_shape); + auto flat_updates = builder->Reshape(updates, flat_updates_shape); + auto init = {flat_indices, flat_updates, buffer}; + + // Constructs the loop body. The implementation of scatter is essentially: + // for i in range(num_indices): + // index = dynamic-slice(indices, i) + // update = dynamic-slice(updates, i) + // buffer = dynamic-update-slice(buffer, update, index) + auto body_fn = [&](xla::ComputationDataHandle i, + gtl::ArraySlice loop_vars, + xla::ComputationBuilder* body_builder) { + auto indices = loop_vars[0]; + auto updates = loop_vars[1]; + auto buffer = loop_vars[2]; + + auto zero_index = body_builder->ConstantLiteral( + xla::Literal::Zero(indices_shape->element_type())); + + // Slice the i-th index from the indices array. + xla::ComputationDataHandle index; + auto indices_offset = body_builder->Reshape(i, {1}); + if (indices_are_vectors) { + indices_offset = body_builder->Pad(indices_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, 1}})); + + index = body_builder->DynamicSlice(indices, indices_offset, + {1, num_index_dims}); + index = body_builder->Collapse(index, {0, 1}); + } else { + index = body_builder->DynamicSlice(indices, indices_offset, {1}); + } + + // Discard updates with negative indices, since some users expect this. + auto index_in_range = + body_builder->ReduceAll(body_builder->Le(zero_index, index), + body_builder->ConstantR0(true), + xla::CreateScalarAndComputation(body_builder)); + + // Make the index in bounds to prevent implementation defined behavior. + index = body_builder->Max(index, zero_index); + index = body_builder->Pad( + index, zero_index, + xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); + + // Slice the i-th index from the updates array. + auto updates_offset = body_builder->Reshape(i, {1}); + updates_offset = body_builder->Pad( + updates_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); + std::vector flat_updates_slice_shape({1}); + flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), + buffer_shape_post_axes.begin(), + buffer_shape_post_axes.end()); + auto update = body_builder->DynamicSlice(updates, updates_offset, + flat_updates_slice_shape); + + // Unflatten the major (iteration) dimensions of the slice to their + // original shape. + std::vector updates_slice_shape(num_index_dims, 1); + updates_slice_shape.insert(updates_slice_shape.end(), + buffer_shape_post_axes.begin(), + buffer_shape_post_axes.end()); + update = body_builder->Reshape(update, updates_slice_shape); + + // Apply the update to the buffer. If there is a combiner, use it to merge + // the current values with the update. + auto current_value = + body_builder->DynamicSlice(buffer, index, updates_slice_shape); + if (combiner) { + update = combiner(current_value, update, body_builder); + } + // Use the current value instead of the update if the index is out of + // bounds. + update = body_builder->Select(index_in_range, update, current_value); + // Apply the update. + buffer = body_builder->DynamicUpdateSlice(buffer, update, index); + + return std::vector{indices, updates, buffer}; + }; + + TF_ASSIGN_OR_RETURN( + auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(), + body_fn, init, "scatter", builder)); + return outputs[2]; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..41e6d3b195ebf90662c7b9b42c53fcb0133ab29e --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ + +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { + +// Builds an XLA computation that performs a scatter operation on `buffer`, +// returning an updated buffer. +// For each i0, i1, ..., sets +// buffer[indices[i0, i1, ...], ...] := updates[i0, i1, ...] +// +// If `indices_are_vectors` is false, then each index in indices is a scalar, +// and the shape of `indices` must be a prefix of the shape of updates. +// Otherwise, `indices_are_vectors`, then indices are multidimensional and the +// minor dimension of `indices` represents a vector of indices. +// +// If any indices are negative, the corresponding update is discarded. +// +// If a `combiner` is provided, updates are combined with the existing values in +// the buffer using the combiner function. Otherwise, the updates replace the +// existing values. The order of updates is implementation-defined. +xla::StatusOr XlaScatter( + const xla::ComputationDataHandle& buffer, + const xla::ComputationDataHandle& updates, + const xla::ComputationDataHandle& indices, bool indices_are_vectors, + const std::function& combiner, + xla::ComputationBuilder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 579944c3a381e7018b7fee5013d0509158ce21cc..7f72a6073df218b9e2bd4cc0c0b5bb10b5cd4b84 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -24,13 +24,15 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { xla::StatusOr TriangularSolve( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, int64 block_size) { + xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size) { TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, @@ -60,14 +62,15 @@ xla::StatusOr TriangularSolve( batch_dimensions.push_back(a_size); } - const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + if (xla::ShapeUtil::GetDimension(*a_shape, -1) != + xla::ShapeUtil::GetDimension(*a_shape, -2)) { return errors::InvalidArgument( "The 'a' arguments to TriangularSolve must be square matrices: ", xla::ShapeUtil::HumanString(*a_shape)); } - if (n != xla::ShapeUtil::GetDimension(*b_shape, -1)) { + const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) { return errors::InvalidArgument( "Arguments to TriangularSolve have incompatible matrix shapes: ", xla::ShapeUtil::HumanString(*a_shape), " vs ", @@ -89,6 +92,14 @@ xla::StatusOr TriangularSolve( return output; }; + // Applies a complex conjugation operation if `a` is complex and `conjugate_a` + // is true, otherwise returns its argument. + auto maybe_conj = [&](xla::ComputationBuilder* builder, + xla::ComputationDataHandle x) { + auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; + return perform_conj ? builder->Conj(x) : x; + }; + std::map base_computations; auto get_base_triangular_solve = [&](int k) -> xla::StatusOr { @@ -103,19 +114,35 @@ xla::StatusOr TriangularSolve( prepend_batch_dims({k, k})), "a"); + std::array b_lastd; + if (left_side) { + b_lastd = {k, n}; + } else { + b_lastd = {m, k}; + } auto b_param = sub->Parameter(1, xla::ShapeUtil::MakeShape(b_shape->element_type(), - prepend_batch_dims({m, k})), + prepend_batch_dims(b_lastd)), "b"); - // TODO(phawkins): it might make sense to use a while loop here, rather - // than unrolling. - // TODO(phawkins): the left-looking variant of the algorithm might be more - // efficient at block size 1. - TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, - /*block_size=*/1) - .status()); + // We use a left-looking subroutine on the block diagonal in some common + // cases, while falling back to a recursive call in unsupported cases. The + // left-looking subroutine is written with a While loop and so yields much + // faster compile times. Moreover, the left-looking variant can give + // higher performance on smaller (sub)problems. + if (left_side && lower) { + TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, + b_param, transpose_a, + conjugate_a) + .status()); + } else { + TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, + left_side, lower, transpose_a, + conjugate_a, + /*block_size=*/1) + .status()); + } TF_ASSIGN_OR_RETURN(computation, sub->Build()); } @@ -129,47 +156,396 @@ xla::StatusOr TriangularSolve( // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 // (2008): 4. - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - // if k > 1: - // output[..., :, i:i+k] = triangular_solve( - // a[..., i:i+k, ..., i:i+k], b[..., :, i:i+k], side='Right', - // kind='Lower', transpose=True, block_size=1) - // else: - // output[..., :, i] = b[..., :, i] / a[..., i, i] + // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if + // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if + // conjugate_a is True. + + if (!left_side && lower == transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {0, i}, {m, i + k})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {0, i})); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) + if (i + k < n) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); + } else { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, + BatchDot(builder, update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, i + k}, {m, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + } + } + + } else if (left_side && lower != transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < m; i += block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) + if (i + k < m) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); + } else { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); + } + } + } else if (!left_side && lower != transpose_a) { + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {0, i}, {m, i + k})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {0, i})); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) + if (i - k >= 0) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + } else { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, + BatchDot(builder, update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, 0}, {m, i})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + } + } + } else { // left_side && lower == transpose_a + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) + if (i - k >= 0) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + } else { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, 0}, {i, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + } + } + } + + return output; +} + +xla::StatusOr TriangularSolveLeftLooking( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) { + TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, + builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, + builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(*a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape->dimensions(i); + batch_dimensions.push_back(a_size); + } + + auto prepend_batch_dims = [&](std::array indices) { + std::vector output(ndims); + std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); + std::copy(indices.begin(), indices.end(), + output.begin() + batch_dimensions.size()); + return output; + }; + + auto maybe_conj = [&](xla::ComputationBuilder* builder, + xla::ComputationDataHandle x) { + auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; + return perform_conj ? builder->Conj(x) : x; + }; + + // The main computation is performed in a While loop. + + // Allocate the output and set its first or last row, + // output = np.zeros_like(b) + // if transpose_a: + // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] + // else: + // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] + xla::ComputationDataHandle output = Zeros(builder, *b_shape); + { + auto i = transpose_a ? m - 1 : 0; TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); + SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); + auto update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); + } + + // Construct the initial loop carry tuple, + // if transpose_a: + // init = (m-2, output, a, b) + // else: + // init = (1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + *b_shape, + // The coefficient matrix a is a loop invariant. + *a_shape, + // The right-hand-side matrix b is a loop invariant. + *b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = builder->ConstantR0(transpose_a ? m - 2 : 1); + auto init = builder->Tuple({init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i >= 0 if transpose_a else i < m + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); + { + auto i = condb->GetTupleElement( + condb->Parameter(0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"), + 0); + if (transpose_a) { + condb->Ge(i, condb->ConstantR0(0)); } else { - update = builder->Div(b_slice, a_slice); + condb->Lt(i, condb->ConstantR0(m)); } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - // b[..., :, i+k:] -= np.dot(output[..., :, i:i+k], - // np.transpose(..., a[i+k:, i:i+k])) - if (i + k < n) { - TF_ASSIGN_OR_RETURN(auto a_slice_2, - SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/true)); - - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, i + k}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) + // else: + // a_row = a[..., i:i+1, :i] + // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) + // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); + { + auto input_tuple = bodyb->Parameter(0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = bodyb->GetTupleElement(input_tuple, 0); + auto body_out = bodyb->GetTupleElement(input_tuple, 1); + auto body_a = bodyb->GetTupleElement(input_tuple, 2); + auto body_b = bodyb->GetTupleElement(input_tuple, 3); + auto zero = bodyb->ConstantR0(0); + + // Set up some helper functions. + auto prepend_zeros = [&](std::array starts) { + auto zero = bodyb->Reshape(bodyb->ConstantR0(0), {1}); + std::vector padded_starts(ndims, zero); + padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1}); + padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1}); + return bodyb->ConcatInDim(padded_starts, 0); + }; + + auto dynamic_slice = [&](xla::ComputationDataHandle x, + std::array starts, + std::array sizes) { + auto padded_starts = prepend_zeros(starts); + auto padded_sizes = prepend_batch_dims(sizes); + return bodyb->DynamicSlice(x, padded_starts, padded_sizes); + }; + + auto update = [&](xla::ComputationDataHandle x, + xla::ComputationDataHandle update, + std::array starts) { + auto padded_starts = prepend_zeros(starts); + return bodyb->DynamicUpdateSlice(x, update, padded_starts); + }; + + // We'd like to implement this: + // if transpose_a: + // a_row = T(a[..., i+1:, i:i+1]) + // result_row = (b[..., i:i+1, :] + // - np.matmul(a_row, body_out[..., i+1:, :])) + // else: + // result_row = (b[..., i:i+1, :] + // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) + // But since we can't have intermediate array sizes depend on the loop + // counter, we instead exploit the fact that we initialized the output to + // all zeros and use that as zero-padding (doing unnecessary FLOPs). + xla::ComputationDataHandle a_row; + if (transpose_a) { + a_row = dynamic_slice(body_a, {zero, i}, {m, 1}); + } else { + a_row = dynamic_slice(body_a, {i, zero}, {1, m}); } + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false)); + auto result_row = + bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update); + + // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1}); + auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt)); + body_out = update(body_out, div_result, {i, zero}); + + // if transpose_a: + // return (i - 1, body_out, a, b) + // else: + // return (i + 1, body_out, a, b) + auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? -1 : 1)); + bodyb->Tuple({next_i, body_out, body_a, body_b}); } - return output; + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = builder->While(cond, body, init); + return builder->GetTupleElement(triangular_solve_left_looking_while, 1); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 501d026411c80359c7efa406ece5929a2e46ac1f..e32223bfdddda800b1fd4de3e4f0c8061e0f81d8 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -21,25 +21,50 @@ limitations under the License. namespace tensorflow { -// Solves systems of linear equations with upper or lower triangular matrices by -// backsubstitution. +// Solves systems of linear equations with lower or upper triangular coefficient +// matrices by forward- or back-substitution. Broadcasting along leading +// dimensions, this routine solves one of the matrix systems +// `op(a) * x = b`, or `x * op(a) = b`, +// for the variable `x` given `a` and `b`, where `op(a)` is either +// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. +// That is, the innermost matrices in the output satisfy a scalar system +// depending on the value of the value of (left_side, transpose_a, conjugate_a) +// according to: +// (F, F, F) => `output[..., i, k] a[..., k, j] = b[..., i, j]`, +// (F, F, T) => `output[..., i, k] a*[..., k, j] = b[..., i, j]`, +// (F, T, F) => `output[..., i, k] a[..., j, k] = b[..., i, j]`, +// (F, T, T) => `output[..., i, k] a*[..., j, k] = b[..., i, j]`, +// (T, F, F) => ` a[..., i, k] output[..., k, j] = b[..., i, j]`, +// (T, F, T) => `a*[..., i, k] output[..., k, j] = b[..., i, j]`, +// (T, T, F) => ` a[..., i, k] output[..., j, k] = b[..., i, j]`, +// (T, T, T) => `a*[..., i, k] output[..., j, k] = b[..., i, j]`, +// where * denotes complex conjugation and where the index `k` is summed over. // -// `a` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form -// square matrices. The strictly upper triangular part of each inner-most matrix -// is assumed to be zero and not accessed. -// `b` is a tensor of shape `[..., M, K]`. -// -// The innermost matrices in the output satisfy matrix equations -// `output[..., i, j] * adjoint(a[..., k, j]) = b[..., i, k]`. +// `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form +// square matrices. If lower is true (false), then the strictly upper (lower) +// triangular part of each innermost matrix in `a` is assumed to be zero and is +// not accessed. +// `b` is a tensor of shape `[..., M, K]` if left_side is true, otherwise a +// tensor of shape `[..., K, M]`. +// `left_side` is a boolean, indicating whether to solve a system of the form +// op(a) * x = b (true) or x * op(a) = b (false). +// `lower` is a boolean, indicating whether the argument `a` is lower-triangular +// (true) or upper-triangular (false). +// `transpose_a` is a boolean indicating whether the matrix `a` is transposed. +// `conjugate_a` is a boolean indicating whether the entries of `a` are complex +// conjugated (independently of whether they are transposed), so that when both +// transpose_a and conjugate_a are true the effect is a Hermitian adjoint. // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -// TODO(phawkins): equivalent to the BLAS TRSM routine with side=right, -// kind=lower, and transposed_a=true. Implement the other possible combinations -// of side, kind and transposed_a. xla::StatusOr TriangularSolve( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, int64 block_size = 256); + xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size = 256); + +xla::StatusOr TriangularSolveLeftLooking( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 671d9aa4fe0c042a3cc44468074653d51c2be75d..661707062916263fd0d5d935ce41698a7655df02 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -27,32 +27,134 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { namespace { using TriangularSolveTest = xla::ClientLibraryTestBase; +using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase; +using complex64 = xla::complex64; -XLA_TEST_F(TriangularSolveTest, Simple) { +xla::Array2D AValsLower() { + return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +xla::Array2D AValsUpper() { + return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}}; +} + +xla::Array2D BValsRight() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeft() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsLowerComplex() { + return {{2, 0, 0, 0}, + {complex64(3, 1), 6, 0, 0}, + {4, complex64(7, 2), 9, 0}, + {5, 8, complex64(10, 3), 11}}; +} + +xla::Array2D AValsUpperComplex() { + return {{2, 3, complex64(4, 3), 5}, + {0, 6, complex64(7, 2), 8}, + {0, 0, complex64(9, 1), 10}, + {0, 0, 0, 11}}; +} + +xla::Array2D BValsRightComplex() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeftComplex() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsFull() { + return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { xla::ComputationBuilder builder(client_, TestName()); - xla::Array2D a_vals({ - {2, 0, 0, 0}, - {3, 6, 0, 0}, - {4, 7, 9, 0}, - {5, 8, 10, 11}, + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 0.08333334, 0.04629629, 0.03367003}, + {2.5, -0.25, -0.1388889, -0.1010101}, + {4.5, -0.58333331, -0.32407406, -0.23569024}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, + {0.64393939, 0.06565657, -0.03030303, 0.72727273}, + {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); - xla::Array2D b_vals({ - {1, 2, 3, 4}, - {5, 6, 7, 8}, - {9, 10, 11, 12}, + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, + {0.64393939, 0.06565657, -0.03030303, 0.72727273}, + {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + xla::ComputationDataHandle a, b; - auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(b_vals, 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, /*block_size=*/2); + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); TF_ASSERT_OK(result.status()); xla::Array2D expected({ @@ -62,7 +164,201 @@ XLA_TEST_F(TriangularSolveTest, Simple) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(2e-3, 2e-3)); + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.89646465, -0.69444444, -0.49242424}, + {-0.27441077, -0.24074074, -0.20707071}, + {-0.23232323, -0.22222222, -0.21212121}, + {0.90909091, 1., 1.09090909}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.89646465, -0.69444444, -0.49242424}, + {-0.27441077, -0.24074074, -0.20707071}, + {-0.23232323, -0.22222222, -0.21212121}, + {0.90909091, 1., 1.09090909}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = + CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/true, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, complex64(0.08333333, 0.08333333), + complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, + {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), + complex64(0.08670034, -0.02104377)}, + {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296), + complex64(0.11026936, -0.03114478)}, + }); + + ComputeAndCompareR2(&builder, expected, + {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = + CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1., 1.5}, + {0.41666667, 0.33333333, 0.25}, + {complex64(0.20020325, -2.81504065e-01), + complex64(0.13821138, -4.22764228e-01), + complex64(0.07621951, -5.64024390e-01)}, + {complex64(0.19678492, 2.55912786e-01), + complex64(0.17738359, 3.84331116e-01), + complex64(0.15798226, 5.12749446e-01)}, + }); + + ComputeAndCompareR2(&builder, expected, + {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolveLeftLooking(&builder, a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolveLeftLooking(&builder, a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); } } // namespace diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index ce24b61b5dc7176f3caa05e3eb9257399fef7926..f579669bbd852b514e021ce71d635f8ce5e4fe4d 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -57,6 +57,61 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, } } +xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, + int64 value) { + xla::Literal literal; + switch (type) { + case xla::U8: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::U32: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::U64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::S8: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::S32: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::S64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::F32: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::F64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::C64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::PRED: + LOG(FATAL) << "pred element type is not integral"; + case xla::S16: + case xla::U16: + LOG(FATAL) << "u16/s16 literals not yet implemented"; + case xla::BF16: + literal = std::move( + *xla::Literal::CreateR0(static_cast(value))); + break; + case xla::F16: + literal = std::move( + *xla::Literal::CreateR0(static_cast(value))); + break; + case xla::TUPLE: + LOG(FATAL) << "tuple element type is not integral"; + case xla::OPAQUE: + LOG(FATAL) << "opaque element type is not integral"; + default: + LOG(FATAL) << "unhandled element type " << type; + } + return builder->ConstantLiteral(literal); +} + xla::StatusOr SliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, gtl::ArraySlice start, gtl::ArraySlice end) { @@ -107,4 +162,15 @@ xla::StatusOr UpdateSliceInMinorDims( return UpdateSlice(builder, x, update, padded_start); } +xla::StatusOr TransposeInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return builder->Transpose(x, permutation); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index fb138b4f736500aac8184770d97fbf930ced69ea..51f8baaf00bd8fd25baa1a87be8cb0089dfb22b5 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -32,6 +32,11 @@ xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, xla::PrimitiveType type, double value); +// Returns a integer scalar constant of 'type' with 'value'. +// If 'type' is complex, returns a real value with zero imaginary component. +xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, int64 value); + // Performs a slice in the minor dimensions of a Tensor. xla::StatusOr SliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, @@ -49,6 +54,10 @@ xla::StatusOr UpdateSliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, const xla::ComputationDataHandle& update, gtl::ArraySlice start); +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::StatusOr TransposeInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc new file mode 100644 index 0000000000000000000000000000000000000000..86c02ac2e65c12d3527c4022df0cc603e522ef7a --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { + +xla::StatusOr> XlaWhileLoop( + const LoopConditionFunction& condition_function, + const LoopBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder) { + int arity = initial_values.size(); + std::vector var_shapes; + var_shapes.reserve(arity); + for (const xla::ComputationDataHandle& input : initial_values) { + TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); + var_shapes.push_back(std::move(*shape)); + } + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); + + // Unpacks a tuple into its component parts. + auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity, + xla::ComputationBuilder* builder) { + std::vector elements(arity); + for (int i = 0; i < arity; ++i) { + elements[i] = builder->GetTupleElement(tuple, i); + } + return elements; + }; + + // Build the condition. + std::unique_ptr cond_builder = + builder->CreateSubBuilder(strings::StrCat(name, "_condition")); + { + auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); + + TF_ASSIGN_OR_RETURN( + auto result, + condition_function(unpack_tuple(parameter, arity, cond_builder.get()), + cond_builder.get())); + TF_RETURN_IF_ERROR(cond_builder->SetReturnValue(result)); + } + TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); + + // Build the body. + std::unique_ptr body_builder = + builder->CreateSubBuilder(strings::StrCat(name, "_body")); + { + auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); + + TF_ASSIGN_OR_RETURN( + auto result, + body_function(unpack_tuple(parameter, arity, body_builder.get()), + body_builder.get())); + + TF_RET_CHECK(result.size() == initial_values.size()); + body_builder->Tuple(result); + } + TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); + + auto outputs = builder->While(cond, body, builder->Tuple(initial_values)); + + return unpack_tuple(outputs, arity, builder); +} + +xla::StatusOr> XlaForEachIndex( + int64 num_iterations, xla::PrimitiveType num_iterations_type, + const ForEachIndexBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder) { + auto while_cond_fn = [&](gtl::ArraySlice values, + xla::ComputationBuilder* cond_builder) + -> xla::StatusOr { + return cond_builder->Lt( + values[0], + IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); + }; + auto while_body_fn = [&](gtl::ArraySlice values, + xla::ComputationBuilder* body_builder) + -> xla::StatusOr> { + xla::ComputationDataHandle iteration = values[0]; + + std::vector updated_values; + updated_values.reserve(values.size()); + updated_values.push_back(body_builder->Add( + iteration, + body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); + + values.remove_prefix(1); + TF_ASSIGN_OR_RETURN(std::vector body_outputs, + body_function(iteration, values, body_builder)); + updated_values.insert(updated_values.end(), body_outputs.begin(), + body_outputs.end()); + return updated_values; + }; + + std::vector values; + values.reserve(initial_values.size() + 1); + values.push_back( + builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); + values.insert(values.end(), initial_values.begin(), initial_values.end()); + + TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, + name, builder)); + values.erase(values.begin(), values.begin() + 1); + return values; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h new file mode 100644 index 0000000000000000000000000000000000000000..2e67a0c99b6deb65fa16ab2dec1727f5cb5fcb92 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Function that builds a loop condition. Takes as input a sequence of input +// values, and returns a boolean value representing if the condition succeeds. +typedef std::function( + gtl::ArraySlice, xla::ComputationBuilder*)> + LoopConditionFunction; + +// Function that builds a loop body. Takes as input a sequence of input values +// and returns a sequence of output values. +typedef std::function>( + gtl::ArraySlice, xla::ComputationBuilder*)> + LoopBodyFunction; + +// Helper function for building an XLA while loop, where the values carried by +// the loop are a tuple of values, e.g., (a, b, c): +// while( +// condition: (a, b, c) -> bool, +// body: (a, b, c) -> (a, b, c) +// init: (a, b, c) +// ) +// 'name' is a descriptive name for the loop. +xla::StatusOr> XlaWhileLoop( + const LoopConditionFunction& condition_function, + const LoopBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder); + +// Builds an XLA loop that repeats a computation `num_iterations` times. +// +// The body function (ForEachIndexBodyFunction) takes as input a pair of +// (current iteration number, loop-carried values), and returns an updated +// vector of the loop-carried values. +typedef std::function>( + xla::ComputationDataHandle, gtl::ArraySlice, + xla::ComputationBuilder*)> + ForEachIndexBodyFunction; + +xla::StatusOr> XlaForEachIndex( + int64 num_iterations, xla::PrimitiveType num_iterations_type, + const ForEachIndexBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index fcbd157c6191655865d5e250fdf71338780bc2a6..2c3cd658e0462368ac0b51938979b7a6815a7574 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,20 +40,20 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor) { +Status CopyLiteralToHostTensor(const xla::Literal& literal, + Tensor* host_tensor) { + TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && + xla::ShapeUtil::ElementsIn(literal.shape()) == + host_tensor->NumElements()); xla::PrimitiveType primitive_type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(target_type, &primitive_type)); + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(host_tensor->dtype(), &primitive_type)); if (literal.shape().element_type() != primitive_type) { return errors::InvalidArgument( "Cannot convert literal of type ", xla::PrimitiveType_Name(literal.shape().element_type()), - " to tensor of type ", DataTypeString(target_type)); + " to tensor of type ", DataTypeString(host_tensor->dtype())); } - - TensorShape shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); - *host_tensor = Tensor(target_type, shape); size_t total_bytes = host_tensor->TotalBytes(); if (total_bytes > 0) { const void* src_ptr = literal.untyped_data(); @@ -63,4 +63,12 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, return Status::OK(); } +Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, + Tensor* host_tensor) { + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); + *host_tensor = Tensor(target_type, shape); + return CopyLiteralToHostTensor(literal, host_tensor); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index fe08e83c2391a8b24696961cacfd909d46e49e7d..f283b0236811f8d52e8fe2982a74c11c92cd20d8 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -29,7 +29,8 @@ namespace tensorflow { // unsupported type. Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); -// Copies 'literal' to 'host_tensor', which is allocated of type . +// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of +// type . // Fails if the literal's primitive type != // DataTypeToPrimitiveType(target_type). Note that is not // derivable from the type of , because multiple tensorflow types map @@ -38,6 +39,12 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, Tensor* host_tensor); +// Copies the contents of 'literal' to a previously allocated tensor +// 'host_tensor'. The tensor and the literal must have the same number of +// elements and the same type. +Status CopyLiteralToHostTensor(const xla::Literal& literal, + Tensor* host_tensor); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f0a2ef0651ff6115bd201a3b1c34b3c061a22a3d --- /dev/null +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -0,0 +1,24 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//learning/tfx:__subpackages__", + "//tensorflow:internal", + ], +) + +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_py_clif_cc", +) + +tf_py_clif_cc( + name = "xla_op_registry", + srcs = ["xla_op_registry.clif"], + pyclif_deps = [ + "//tensorflow/core:framework/kernel_def_pyclif", + ], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + ], +) diff --git a/tensorflow/compiler/tf2xla/python/xla_op_registry.clif b/tensorflow/compiler/tf2xla/python/xla_op_registry.clif new file mode 100644 index 0000000000000000000000000000000000000000..e1ee6cc656a314876fc1fabbebe1bee39a6b2831 --- /dev/null +++ b/tensorflow/compiler/tf2xla/python/xla_op_registry.clif @@ -0,0 +1,7 @@ +from "third_party/tensorflow/core/framework/kernel_def_pyclif.h" import * # KernelDef + +from "third_party/tensorflow/compiler/tf2xla/xla_op_registry.h": + namespace `tensorflow`: + def `XlaOpRegistry::DeviceKernels` as + device_kernels(device: str, include_compilation_only_kernels: bool) -> + list diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 906f2290433face4cce3296b2f815d50d8c496ce..6051d7dffd7493d8cffb07c1b5d10500e7e75522 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -241,9 +241,7 @@ Status CreateXlaArgs(const Graph& graph, XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); - TensorShape shape; - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 69b265436bb19bbbdd9deb872f4097d4bac7ea52..15bba46ac62a97592656942afc767a303c9b97f3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -66,13 +66,14 @@ Status CheckSignature(const DataTypeVector& types, bool XlaCompiler::Argument::operator==( const XlaCompiler::Argument& other) const { - if (std::tie(kind, resource_kind, type, name, tensor_array_size, + if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size, tensor_array_gradients) != std::tie(other.kind, other.resource_kind, other.type, other.name, - other.tensor_array_size, other.tensor_array_gradients)) { + other.initialized, other.tensor_array_size, + other.tensor_array_gradients)) { return false; } - if (!xla::ShapeUtil::Equal(shape, other.shape)) { + if (shape != other.shape) { return false; } if (constant_value.shape() != other.constant_value.shape()) { @@ -108,6 +109,12 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); + + // The default variable representation shape is the identity function. + if (!options_.variable_representation_shape_fn) { + options_.variable_representation_shape_fn = + [](const TensorShape& shape, DataType type) { return shape; }; + } } XlaCompiler::~XlaCompiler() = default; @@ -152,7 +159,8 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); OptimizerOptions opts; - opts.set_do_common_subexpression_elimination(true); + opts.set_opt_level(OptimizerOptions::L0); + opts.set_do_common_subexpression_elimination(false); opts.set_do_function_inlining(true); opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); @@ -183,8 +191,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, CheckSignature(fbody->arg_types, args), "Signature check failure while compiling: ", function.name()); - std::unique_ptr graph(new Graph(options_.flib_def)); - CopyGraph(*fbody->graph, graph.get()); + std::unique_ptr graph = GetGraph(fbody); // _Arg and _Retval nodes don't exist in the stored subgraph for the function; // they are added by the function body looked up. Therefore, they don't have @@ -212,15 +219,6 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, *graph); } - // Optimize the graph before running the compiler. - OptimizerOptions opts; - opts.set_do_common_subexpression_elimination(true); - opts.set_do_function_inlining(true); - opts.set_do_constant_folding(true); - GraphOptimizer optimizer(opts); - optimizer.Optimize(flib_runtime_, flib_runtime_->env(), - /*device=*/nullptr, &graph, /*shape_map=*/nullptr); - VLOG(1) << "===================================================="; TF_RETURN_IF_ERROR( CompileGraph(options, function_id, std::move(graph), args, result)); @@ -230,6 +228,68 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, return Status::OK(); } +// Computes the XLA shape for argument 'arg'. +Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, + xla::Shape* xla_shape) { + switch (arg.kind) { + case XlaCompiler::Argument::kConstant: + return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), + xla_shape); + case XlaCompiler::Argument::kParameter: + return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + case XlaCompiler::Argument::kResource: { + TF_RET_CHECK(arg.initialized); + + switch (arg.resource_kind) { + case XlaResource::kVariable: { + TensorShape representation_shape = + options_.variable_representation_shape_fn(arg.shape, arg.type); + return TensorShapeToXLAShape(arg.type, representation_shape, + xla_shape); + } + case XlaResource::kTensorArray: { + if (arg.tensor_array_size < 0) { + return errors::InvalidArgument( + "Negative tensor_array_size in XLAShapeForArgument"); + } + TensorShape shape; + shape.AddDim(arg.tensor_array_size); + shape.AppendShape(arg.shape); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); + + if (!arg.tensor_array_gradients.empty()) { + std::vector tuple_shape( + arg.tensor_array_gradients.size() + 1, *xla_shape); + *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape); + } + return Status::OK(); + } + case XlaResource::kStack: { + if (arg.tensor_array_size < 0) { + return errors::InvalidArgument( + "Negative tensor_array_size in XLAShapeForArgument"); + } + TensorShape shape; + shape.AddDim(arg.tensor_array_size); + shape.AppendShape(arg.shape); + xla::Shape buffer_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); + *xla_shape = xla::ShapeUtil::MakeTupleShape( + {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})}); + return Status::OK(); + } + + case XlaResource::kInvalid: + return errors::Internal( + "Invalid resource type in XLAShapeForArgument()"); + } + } + case XlaCompiler::Argument::kInvalid: + return errors::Internal("Invalid argument type in XLAShapeForArgument()"); + } +} + namespace { Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, @@ -260,23 +320,124 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, return Status::OK(); } +// Builds the XLA computation. +// +// `retvals` is the list of retvals produced by _Retval operators, in index +// order. `variable_map` is a map from variable ID numbers to XlaOpContext +// variable states, generated by the symbolic evaluation. +// If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. +// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. +// Sets `*resource_updates` to a description of resources whose values are +// written by the computation; the variable writes are the last +// `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a (input_index, type) pair, where `input_index` is the +// index of a resource variable argument to the computation, and `type` is the +// type of the final output. +Status BuildComputation( + const std::vector& args, + const std::vector& arg_cores, + const std::vector& retvals, + const std::vector>& resources, + bool return_updated_values_for_all_resources, + xla::ComputationBuilder* builder, xla::Computation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* resource_updates) { + std::vector elems; + elems.reserve(retvals.size()); + for (const XlaExpression& retval : retvals) { + if (!retval.has_constant_value()) { + elems.push_back(retval.handle()); + } + } + *num_nonconst_outputs = elems.size(); + + // Add return values for resources whose values have changed. + std::vector arg_resources; + arg_resources.reserve(resources.size()); + for (const auto& resource : resources) { + if (resource->arg_num() >= 0) { + arg_resources.push_back(resource.get()); + } + } + std::sort(arg_resources.begin(), arg_resources.end(), + [](const XlaResource* a, const XlaResource* b) { + return a->arg_num() < b->arg_num(); + }); + + for (const XlaResource* resource : arg_resources) { + const XlaCompiler::Argument& arg = args[resource->arg_num()]; + const int core = arg_cores[resource->arg_num()]; + DCHECK_LT(resource->arg_num(), arg_cores.size()); + bool modified = + resource->value().handle() != resource->initial_value().handle(); + // TensorArray gradients were modified if their values changed or there are + // any newly created gradients. + for (const auto& grad : resource->tensor_array_gradients()) { + modified = modified || + grad.second->value().handle() != + grad.second->initial_value().handle() || + arg.tensor_array_gradients.count(grad.first) == 0; + } + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); + update.input_index = resource->arg_num(); + update.type = resource->type(); + update.shape = resource->shape(); + update.modified = modified; + for (const auto& grad : resource->tensor_array_gradients()) { + update.tensor_array_gradients_accessed.insert(grad.first); + } + + // Request that the value be returned on a specific core. + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::sharding_builder::AssignDevice(core)); + + xla::ComputationDataHandle handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Since we can't change the sharding metadata of as this point, + // create a tuple/get-tuple-element combination so that sharding + // assignment will be placed on this value, which will cause the resource + // update to be returned from the same device that provided the resource. + handle = builder->GetTupleElement(builder->Tuple({handle}), 0); + + elems.push_back(handle); + } + } + + *num_computation_outputs = elems.size(); + + // Builds the XLA computation. + builder->Tuple(elems); + xla::StatusOr computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status(); + } + *computation = computation_status.ConsumeValueOrDie(); + return Status::OK(); +} + +} // namespace + // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. -Status BuildArguments(const Graph& graph, - const std::vector& args, - bool use_tuple_arg, xla::ComputationBuilder* builder, - XlaContext* context, std::vector* arg_cores, - std::vector* arg_expressions, - std::vector* input_mapping, - std::vector* input_shapes, - bool is_entry_computation) { +Status XlaCompiler::BuildArguments( + const Graph& graph, const std::vector& args, + bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context, + std::vector* arg_cores, std::vector* arg_expressions, + std::vector* input_mapping, std::vector* input_shapes, + bool is_entry_computation) { arg_expressions->resize(args.size()); *arg_cores = std::vector(args.size(), -1); // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. - std::vector parameters, resources; - parameters.reserve(args.size()); + input_mapping->clear(); + input_mapping->reserve(args.size()); + std::vector resources; resources.reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. @@ -290,18 +451,20 @@ Status BuildArguments(const Graph& graph, // TODO(phawkins): this code assumes that resource arguments do not // alias. XlaResource* resource; - TF_RETURN_IF_ERROR( - context->CreateResource(arg.resource_kind, i, arg.name, arg.type, - xla::ComputationDataHandle(), &resource)); - resource->set_tensor_array_size(arg.tensor_array_size); + TF_RETURN_IF_ERROR(context->CreateResource( + arg.resource_kind, i, arg.name, arg.type, arg.shape, + xla::ComputationDataHandle(), + /*tensor_array_size=*/arg.tensor_array_size, + /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); arg_expression.set_resource(resource); if (arg.initialized) { resources.push_back(i); } break; - case XlaCompiler::Argument::kParameter: - parameters.push_back(i); + case XlaCompiler::Argument::kParameter: { + input_mapping->push_back(i); break; + } case XlaCompiler::Argument::kConstant: arg_expression.set_constant_value(arg.constant_value); break; @@ -312,19 +475,17 @@ Status BuildArguments(const Graph& graph, // Append parameters containing variable values after the other runtime // parameters. - parameters.insert(parameters.end(), resources.begin(), resources.end()); - if (parameters.empty()) { + input_mapping->insert(input_mapping->end(), resources.begin(), + resources.end()); + if (input_mapping->empty()) { return Status::OK(); } - std::vector arg_shapes; - arg_shapes.reserve(parameters.size()); - input_mapping->resize(parameters.size()); - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const XlaCompiler::Argument& arg = args[parameters[i]]; + std::vector arg_shapes(input_mapping->size()); + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - arg_shapes.push_back(arg.shape); - (*input_mapping)[i] = parameters[i]; + TF_RETURN_IF_ERROR( + XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); } if (use_tuple_arg) { @@ -354,13 +515,13 @@ Status BuildArguments(const Graph& graph, } // Build parameter handles for non-constant arguments. - std::vector arg_handles(parameters.size()); + std::vector arg_handles(input_mapping->size()); if (use_tuple_arg) { xla::ComputationDataHandle tuple; if (is_entry_computation) { xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); - for (int64 parameter : parameters) { + for (int64 parameter : *input_mapping) { const int core = (*arg_cores)[parameter]; const int root_device = 0; *tuple_sharding.add_tuple_shardings() = @@ -373,16 +534,16 @@ Status BuildArguments(const Graph& graph, } else { tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); } - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const int core = (*arg_cores)[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const int core = (*arg_cores)[input_mapping->at(i)]; xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); } } else { - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const int core = (*arg_cores)[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const int core = (*arg_cores)[input_mapping->at(i)]; xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -393,19 +554,18 @@ Status BuildArguments(const Graph& graph, // Fill in the handles in non-constant arguments. VLOG(2) << "XLA computation inputs:"; - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const XlaCompiler::Argument& arg = args[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; VLOG(2) << " XLA arg " << i << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) - << " name: " << arg.name << " TF arg " << parameters[i]; - XlaExpression& arg_expression = (*arg_expressions)[parameters[i]]; + << " name: " << arg.name << " TF arg " << input_mapping->at(i); + XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)]; switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); XlaResource* resource = arg_expression.resource(); - TF_RETURN_IF_ERROR( - resource->SetFromPack(arg.tensor_array_gradients, arg_handles[i], - /*reset_initial_values=*/true, builder)); + TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients, + arg_handles[i], builder)); VLOG(2) << " resource: num_gradients: " << arg.tensor_array_gradients.size(); break; @@ -422,107 +582,6 @@ Status BuildArguments(const Graph& graph, return Status::OK(); } -// Builds the XLA computation. -// -// `retvals` is the list of retvals produced by _Retval operators, in index -// order. `variable_map` is a map from variable ID numbers to XlaOpContext -// variable states, generated by the symbolic evaluation. -// If `return_updated_values_for_all_resources` is true, all resources will be -// included in `resource_updates`, regardless of whether their value changed. -// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*resource_updates` to a description of resources whose values are -// written by the computation; the variable writes are the last -// `resource_updates.size()` return values from the computation. Each entry in -// `resource_updates` is a (input_index, type) pair, where `input_index` is the -// index of a resource variable argument to the computation, and `type` is the -// type of the final output. -Status BuildComputation( - const std::vector& args, - const std::vector& arg_cores, - const std::vector& retvals, - const std::vector>& resources, - bool return_updated_values_for_all_resources, - xla::ComputationBuilder* builder, xla::Computation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, - std::vector* resource_updates) { - std::vector elems; - elems.reserve(retvals.size()); - for (const XlaExpression& retval : retvals) { - if (!retval.has_constant_value()) { - elems.push_back(retval.handle()); - } - } - *num_nonconst_outputs = elems.size(); - - // Add return values for resources whose values have changed. - std::vector arg_resources; - arg_resources.reserve(resources.size()); - for (const auto& resource : resources) { - if (resource->arg_num() >= 0) { - arg_resources.push_back(resource.get()); - } - } - std::sort(arg_resources.begin(), arg_resources.end(), - [](const XlaResource* a, const XlaResource* b) { - return a->arg_num() < b->arg_num(); - }); - - for (const XlaResource* resource : arg_resources) { - const XlaCompiler::Argument& arg = args[resource->arg_num()]; - const int core = arg_cores[resource->arg_num()]; - DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = - resource->value().handle() != resource->initial_value().handle(); - // TensorArray gradients were modified if their values changed or there are - // any newly created gradients. - for (const auto& grad : resource->tensor_array_gradients()) { - modified = modified || - grad.second->value().handle() != - grad.second->initial_value().handle() || - arg.tensor_array_gradients.count(grad.first) == 0; - } - if (return_updated_values_for_all_resources || modified) { - resource_updates->emplace_back(); - XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = resource->arg_num(); - update.type = resource->type(); - update.modified = modified; - for (const auto& grad : resource->tensor_array_gradients()) { - update.tensor_array_gradients_accessed.insert(grad.first); - } - - // Request that the value be returned on a specific core. - xla::ScopedShardingAssignment assign_sharding( - builder, core == -1 ? tensorflow::gtl::optional() - : xla::sharding_builder::AssignDevice(core)); - - xla::ComputationDataHandle handle; - TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); - - // Since we can't change the sharding metadata of as this point, - // create a tuple/get-tuple-element combination so that sharding - // assignment will be placed on this value, which will cause the resource - // update to be returned from the same device that provided the resource. - handle = builder->GetTupleElement(builder->Tuple({handle}), 0); - - elems.push_back(handle); - } - } - - *num_computation_outputs = elems.size(); - - // Builds the XLA computation. - builder->Tuple(elems); - xla::StatusOr computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status(); - } - *computation = computation_status.ConsumeValueOrDie(); - return Status::OK(); -} - -} // namespace - Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -547,7 +606,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, xla::ComputationBuilder builder(client(), name); XlaContext* context = new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants); + options.resolve_compile_time_constants, + &options_.variable_representation_shape_fn); core::ScopedUnref context_unref(context); std::vector arg_expressions; @@ -616,13 +676,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, ++computation_output; } } - - for (std::vector::size_type i = 0; - i < result->resource_updates.size(); ++i) { - result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape( - result->xla_output_shape, computation_output); - ++computation_output; - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 6a46e54f61cb4dbb2a2c1916696655a4e3d85fff..c4449bc4be06daff856eff70c6d89be6ddbcf0ee 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -29,6 +29,9 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { + +class XlaContext; + // The XlaCompiler class is responsible for compilation of a self-contained // subgraph of a TensorFlow computation using the XLA linear algebra runtime. // It does a symbolic execution of the graph starting from specific input @@ -104,9 +107,17 @@ class XlaCompiler { // is the type of the variable's value, not DT_RESOURCE. DataType type; - // The shape of the argument. If the argument is a resource, this is the - // shape of the resource's value. - xla::Shape shape; + // The shape of the argument. For: + // * a parameter: the shape of the parameter. + // * a constant: ignored; the shape given by constant_value is used + // instead. + // * an uninitialized resource: ignored. We don't yet know the shape of an + // uninitialized resource (otherwise we would have initialized it!) + // * an initialized variable: the shape of the variable's value. + // * an initialized TensorArray or Stack resource: the shape of an entry in + // the TensorArray/Stack. Note this is the size of a single entry, not the + // XLA data structure that represents the complete stack/array. + TensorShape shape; // The value of the argument, if it is a compile-time constant. Must be a // host-memory tensor. @@ -175,8 +186,9 @@ class XlaCompiler { int input_index; // Type and shape of the tensor to be written back. + // The `shape` field has the same meaning as the Argument::shape field. DataType type; - xla::Shape shape; + TensorShape shape; // Was the value of the variable modified by the computation? // (Always true, unless `return_updated_values_for_all_resources` is true.) @@ -230,11 +242,30 @@ class XlaCompiler { // for CPU. bool allow_cpu_custom_calls = false; + // If set, the XLA representation of variables represented to XLA as the + // shape given by this shape function. Variables are reshaped to this shape + // on write, and reshaped to their original shape on read. + std::function + variable_representation_shape_fn; + // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation // device is created, and can be used to create metadata objects // that can be accessed by XLA op kernels. std::function* populate_resource_manager = nullptr; + + // If not nullptr, this memory allocator can be used by the compiler for + // temporary allocations it might want to make during compilation. + // + // For example, the compiler may want to try out different algorithms and + // choose the fastest one, and it might run those algorithms over buffers + // created using this allocator. + // + // The compiler can function correctly without an explicit allocator given + // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly + // allocate most or all available memory on the device, leaving none for the + // compiler to access, unless it can use TensorFlow's allocator. + xla::DeviceMemoryAllocator* device_allocator = nullptr; }; explicit XlaCompiler(Options options); @@ -253,11 +284,10 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); - Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func, - const std::vector& types, - const std::vector& shapes, - const std::vector& expressions, - std::vector* args); + // Returns the shape of the XLA parameter for an argument 'arg'. + // See the class comment for more details about the argument passing + // convention. + Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. @@ -278,6 +308,17 @@ class XlaCompiler { // Returns the optimized graph object in this function body. std::unique_ptr GetGraph(const FunctionBody* fbody); + // Builds XLA computations for each of the arguments to the computation. + // `args` are the arguments to the computation. + Status BuildArguments(const Graph& graph, + const std::vector& args, + bool use_tuple_arg, xla::ComputationBuilder* builder, + XlaContext* context, std::vector* arg_cores, + std::vector* arg_expressions, + std::vector* input_mapping, + std::vector* input_shapes, + bool is_entry_computation); + // Graph compiler needs to know how to get an optimized graph from a function // body. friend class GraphCompiler; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 7ebe4b75bc1e33e506624314b11163e36a2477de..a18eeacd41808884fac9ec5d617cb0d274ea27d8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -191,10 +192,10 @@ TEST_F(XlaCompilerTest, Simple) { std::vector args(2); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); args[1].kind = XlaCompiler::Argument::kParameter; args[1].type = DT_INT32; - args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[1].shape = TensorShape({2}); // Compiles the graph. XlaCompiler compiler(DefaultOptions()); @@ -242,10 +243,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::vector args(2); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); args[1].kind = XlaCompiler::Argument::kParameter; args[1].type = DT_INT32; - args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[1].shape = TensorShape({2}); // Compiles the graph. XlaCompiler compiler(DefaultOptions()); @@ -281,7 +282,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); XlaCompiler::Options options = DefaultOptions(); XlaCompiler compiler(options); @@ -373,7 +374,7 @@ TEST_F(XlaCompilerTest, ResourceManager) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); DummyResourceForTest* resource = new DummyResourceForTest(); @@ -420,7 +421,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); // Compiles the graph. auto options = DefaultOptions(); @@ -472,9 +473,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad2"}; @@ -540,9 +539,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; @@ -574,9 +571,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; @@ -689,5 +684,128 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { << status.error_message(); } +// Tests a simple graph that reads and writes a variable. +TEST_F(XlaCompilerTest, Variables) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + auto write = ops::AssignAddVariableOp(scope, var, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::Literal::CreateR1({7, 42}); + std::unique_ptr param1_literal = + xla::Literal::CreateR1({-3, 101}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::Literal::CreateR1({5, 144}); + std::unique_ptr expected1 = + xla::Literal::CreateR1({4, 143}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); +} + +// Tests a simple graph that reads and writes a variable, with a +// variable_representation_shape_fn passed to the compiler that flattens all +// variable tensors to vectors. +TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + auto write = ops::AssignAddVariableOp(scope, var, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 2}); + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.variable_representation_shape_fn = [](const TensorShape& shape, + DataType type) { + return TensorShape({shape.num_elements()}); + }; + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::Literal::CreateR2({{4, 55}, {1, -3}}); + std::unique_ptr param1_literal = + xla::Literal::CreateR1({22, 11, 33, 404}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::Literal::CreateR2({{27, 67}, {35, 402}}); + std::unique_ptr expected1 = + xla::Literal::CreateR1({26, 66, 34, 401}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index e8d17e2e0a1ba01f16d4bbbd2895b112f4dd1989..8423921086fec1cf534cf613102fc3839035cb85 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -62,13 +62,16 @@ void XlaContext::set_args(std::vector args) { args_ = std::move(args); } -XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, - bool allow_cpu_custom_calls, - bool resolve_compile_time_constants) +XlaContext::XlaContext( + XlaCompiler* compiler, xla::ComputationBuilder* builder, + bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + const std::function* + variable_representation_shape_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), - resolve_compile_time_constants_(resolve_compile_time_constants) {} + resolve_compile_time_constants_(resolve_compile_time_constants), + variable_representation_shape_fn_(variable_representation_shape_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } @@ -103,16 +106,23 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, xla::ComputationBuilder* XlaContext::builder() { return builder_; } -Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num, - string name, DataType type, - const xla::ComputationDataHandle& handle, - XlaResource** resource) { +Status XlaContext::CreateResource( + XlaResource::Kind kind, int arg_num, string name, DataType type, + TensorShape shape, const xla::ComputationDataHandle& handle, + int64 tensor_array_size, const std::set& tensor_array_gradients, + XlaResource** resource) { resources_.emplace_back( - new XlaResource(kind, arg_num, std::move(name), type, handle)); + new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), + handle, tensor_array_size, tensor_array_gradients)); *resource = resources_.back().get(); return Status::OK(); } +TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, + DataType type) const { + return (*variable_representation_shape_fn_)(shape, type); +} + const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { return LookupOrCreate(type, &max_func_, [this, type] { const string type_string = DataTypeString(type); diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 1a7dafe8cdb56cc9b8fcd3ba6e262c21c2a07d90..00fbaba37c542954f690b310a184cff985a05156 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -44,7 +44,9 @@ class XlaContext : public ResourceBase { // Creates a new XlaContext. XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants); + bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + const std::function* + variable_representation_shape_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -71,17 +73,26 @@ class XlaContext : public ResourceBase { Status AddConstRetval(int retval_index, DataType dtype, const xla::Literal& literal); - // Creates a resource with resource `kind` and initial type `type` and - // value `handle`. `name` is a descriptive name for use in error messages. + // Creates a resource with resource `kind` and initial value `handle`. `name` + // is a descriptive name for use in error messages. See the `XlaResource` + // constructor for a description of the remaining arguments. // Fails if the resource already exists. Status CreateResource(XlaResource::Kind kind, int arg_num, string name, - DataType type, const xla::ComputationDataHandle& handle, + DataType type, TensorShape shape, + const xla::ComputationDataHandle& handle, + int64 tensor_array_size, + const std::set& tensor_array_gradients, XlaResource** resource); const std::vector>& resources() { return resources_; } + // Returns the XLA shape to be used to represent a variable of TF `shape` + // and `type`. + TensorShape VariableRepresentationShape(const TensorShape& shape, + DataType type) const; + // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. @@ -129,6 +140,11 @@ class XlaContext : public ResourceBase { // Holds ownership of resources. The resources are not ordered. std::vector> resources_; + // A function that describes how variable shapes should be represented + // in XLA. Variable values will be reshaped to this shape. Must be non-null. + const std::function* + variable_representation_shape_fn_; + // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 77e24162676045b88dc8b62d2c6a4ecc1e738e96..f048662953e20b2a612271e2daeef6e370c4822a 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -135,58 +135,9 @@ xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::IntegerLiteral( xla::ComputationBuilder* b, DataType data_type, int64 value) { - xla::Literal literal; xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - switch (type) { - case xla::U8: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::U32: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::U64: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::S8: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::S32: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::S64: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::F32: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::F64: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::C64: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::PRED: - LOG(FATAL) << "pred element type is not integral"; - case xla::S16: - case xla::U16: - LOG(FATAL) << "u16/s16 literals not yet implemented"; - case xla::BF16: - literal = std::move( - *xla::Literal::CreateR0(static_cast(value))); - break; - case xla::F16: - literal = std::move( - *xla::Literal::CreateR0(static_cast(value))); - break; - case xla::TUPLE: - LOG(FATAL) << "tuple element type is not integral"; - case xla::OPAQUE: - LOG(FATAL) << "opaque element type is not integral"; - default: - LOG(FATAL) << "unhandled element type " << type; - } - return b->ConstantLiteral(literal); + return ::tensorflow::IntegerLiteral(b, type, value); } xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index ee0aed672e1b264fee0a7f381c334400c55f3581..c4bb90d58755f16672ca7c6a6738065be6330485 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -286,7 +286,8 @@ Status XlaOpKernelContext::ConstantInputList( } Status XlaOpKernelContext::ReadVariableInput( - int index, xla::ComputationDataHandle* value) { + int index, DataType type, TensorShape* shape, + xla::ComputationDataHandle* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -296,7 +297,24 @@ Status XlaOpKernelContext::ReadVariableInput( return errors::InvalidArgument("Read of uninitialized variable ", variable->name()); } - *value = variable->value(); + if (variable->type() != type) { + return errors::InvalidArgument( + "Type mismatch for read of variable ", variable->name(), ". Expected ", + DataTypeString(type), "; got ", DataTypeString(variable->type())); + } + if (shape) { + *shape = variable->shape(); + } + + XlaContext& xla_context = XlaContext::Get(context_); + TensorShape representation_shape = xla_context.VariableRepresentationShape( + variable->shape(), variable->type()); + if (representation_shape == variable->shape()) { + *value = variable->value(); + } else { + *value = + builder()->Reshape(variable->value(), variable->shape().dim_sizes()); + } return Status::OK(); } @@ -312,12 +330,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, variable->name()); } *type = variable->type(); - auto shape_or_status = builder()->GetShape(variable->value()); - if (!shape_or_status.ok()) { - return shape_or_status.status(); - } - TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape)); + *shape = variable->shape(); return Status::OK(); } @@ -396,8 +409,8 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { return Status::OK(); } -Status XlaOpKernelContext::AssignVariable( - int input_index, DataType type, const xla::ComputationDataHandle& handle) { +Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, + xla::ComputationDataHandle handle) { TF_RET_CHECK(handle.handle() != 0); const XlaExpression* expression = @@ -405,7 +418,24 @@ Status XlaOpKernelContext::AssignVariable( XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); - return variable->SetValue(type, handle); + + auto shape_or_status = builder()->GetShape(handle); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + TensorShape shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + + TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); + + XlaContext& xla_context = XlaContext::Get(context_); + TensorShape representation_shape = + xla_context.VariableRepresentationShape(shape, type); + if (shape != representation_shape) { + handle = builder()->Reshape(handle, representation_shape.dim_sizes()); + } + return variable->SetValue(handle); } XlaCompiler* XlaOpKernelContext::compiler() const { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 6d3b6db2289d6c0b8f266062f9f3baca1145154a..4e4b97e0cec8d16b9b5686a779b1285906765dbd 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -164,13 +164,18 @@ class XlaOpKernelContext { TensorShape* shape) const; // Reads the current value of the resouce variable referred to by input - // 'index'. - Status ReadVariableInput(int index, xla::ComputationDataHandle* value); + // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the + // variable. Returns an error if the variable has not been initialized, or if + // its type does not match `type`. + Status ReadVariableInput(int index, DataType type, TensorShape* shape, + xla::ComputationDataHandle* value); // Assigns the value `handle` to the variable referenced by input - // `input_index`. Marks the operator as having side effects. + // `input_index`. The variable must be of `type`. Returns an error if the + // variable has been initialized with a different type or with a + // different shape. Status AssignVariable(int input_index, DataType type, - const xla::ComputationDataHandle& handle); + xla::ComputationDataHandle handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 0dde6a986c61bdd5b0b2e6d7a16b29ab95be98ab..bbe808595d958346bd55bf8419306bf3de4cd1d0 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -255,6 +255,8 @@ void XlaOpRegistry::RegisterCompilationKernels() { std::vector XlaOpRegistry::DeviceKernels( const string& compilation_device_name, bool include_compilation_only_kernels) { + // Ensure compilation kernels registered. + RegisterCompilationKernels(); std::vector kernels; XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 9abac8bdaa77c99a57b2f8ac66fe6ed06fbcd102..c2075b44b82ba279d1246ec6bfcf305d12c418a6 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -25,51 +25,99 @@ limitations under the License. namespace tensorflow { -XlaResource::XlaResource(Kind kind, int arg_num, string name, - DataType initial_type, - const xla::ComputationDataHandle& initial_value) +XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, + TensorShape shape, + const xla::ComputationDataHandle& initial_value, + int64 tensor_array_size, + const std::set& tensor_array_gradients) : kind_(kind), arg_num_(arg_num), name_(std::move(name)), - type_(initial_type), + type_(type), + shape_(std::move(shape)), value_(initial_value), - initial_value_(initial_value) { + initial_value_(initial_value), + tensor_array_size_(tensor_array_size) { CHECK(kind_ != kInvalid); + + for (const string& gradient : tensor_array_gradients) { + tensor_array_gradients_[gradient].reset( + new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, + /*name=*/strings::StrCat("TensorArrayGrad: ", name_), + type_, shape_, xla::ComputationDataHandle(), + tensor_array_size_, /*tensor_array_gradients=*/{})); + } } -Status XlaResource::SetValue(DataType type, - const xla::ComputationDataHandle& value) { - if (type_ == DT_INVALID && type == DT_INVALID) { - return errors::InvalidArgument("Attempted to initialized resource ", name_, - " to an invalid type"); +Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { + if (type == DT_INVALID) { + return errors::InvalidArgument("Attempted to set type of resource '", name_, + "'' to an invalid type"); } - if (type_ != DT_INVALID && type_ != type) { + if (initialized() && type_ != type) { return errors::InvalidArgument("Type of resource ", name_, " cannot be changed after initialization: " "old type was ", DataTypeString(type_), ", new type is ", DataTypeString(type)); } + if (initialized() && shape_ != shape) { + return errors::InvalidArgument("Shape of resource ", name_, + " cannot be changed after initialization: " + "old shape was ", + shape_.DebugString(), ", new shape is ", + shape.DebugString()); + } type_ = type; - value_ = value; + shape_ = shape; return Status::OK(); } -Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder, - xla::Shape* shape) const { - auto shape_or_status = builder->GetShape(value_); - if (!shape_or_status.ok()) { - return shape_or_status.status(); +Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { + if (type_ == DT_INVALID) { + return errors::InvalidArgument( + "Resource '", name_, + "' must be initialized with a valid type before use."); } - *shape = *shape_or_status.ValueOrDie(); + value_ = value; return Status::OK(); } -Status XlaResource::GetShape(xla::ComputationBuilder* builder, - TensorShape* shape) const { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape)); - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape)); +Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { + if (type_ == DT_INVALID) { + return errors::InvalidArgument( + "Resource '", name_, + "' must be initialized with a valid type before use."); + } + switch (kind_) { + case kVariable: { + value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), + shape_.dim_sizes()); + break; + } + case kTensorArray: { + TensorShape ta_shape; + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); + value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()); + break; + } + case kStack: { + TensorShape ta_shape; + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); + value_ = + builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()), + builder->ConstantR0(0)}); + break; + } + + case kInvalid: + default: + LOG(FATAL) << "Invalid resource type"; + } return Status::OK(); } @@ -82,36 +130,20 @@ Status XlaResource::GetOrCreateTensorArrayGradient( std::unique_ptr& gradient = tensor_array_gradients_[source]; if (!gradient) { TensorShape ta_shape; - TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape)); + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); xla::ComputationDataHandle gradient_value = builder->Broadcast( XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/strings::StrCat("TensorArrayGrad: ", name_), - type_, gradient_value)); - gradient->tensor_array_size_ = tensor_array_size_; + type_, shape_, gradient_value, tensor_array_size_, + /*tensor_array_gradients=*/{})); } *gradient_out = gradient.get(); return Status::OK(); } -Status XlaResource::PackedShape(xla::ComputationBuilder* builder, - xla::Shape* packed_shape) const { - if (tensor_array_gradients_.empty()) { - return GetXlaShape(builder, packed_shape); - } - TF_RET_CHECK(kind_ == kTensorArray); - std::vector elem_shapes(1 + tensor_array_gradients_.size()); - int pos = 0; - TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++])); - for (const auto& gradient : tensor_array_gradients_) { - TF_RETURN_IF_ERROR( - gradient.second->GetXlaShape(builder, &elem_shapes[pos++])); - } - *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes); - return Status::OK(); -} - Status XlaResource::Pack(xla::ComputationDataHandle* pack, xla::ComputationBuilder* builder) const { if (tensor_array_gradients_.empty()) { @@ -130,27 +162,32 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack, Status XlaResource::SetFromPack(const std::set& gradient_sources, const xla::ComputationDataHandle& pack, - bool reset_initial_values, xla::ComputationBuilder* builder) { if (gradient_sources.empty()) { + if (!initialized()) { + initial_value_ = pack; + } value_ = pack; } else { TF_RET_CHECK(kind_ == kTensorArray); int pos = 0; - value_ = builder->GetTupleElement(pack, pos++); + auto v = builder->GetTupleElement(pack, pos++); + if (!initialized()) { + initial_value_ = v; + } + value_ = v; + for (const auto& source : gradient_sources) { XlaResource* gradient; TF_RETURN_IF_ERROR( GetOrCreateTensorArrayGradient(source, builder, &gradient)); - gradient->value_ = builder->GetTupleElement(pack, pos++); - if (reset_initial_values) { - gradient->initial_value_ = gradient->value_; + auto v = builder->GetTupleElement(pack, pos++); + if (!gradient->initialized()) { + gradient->initial_value_ = v; } + gradient->value_ = v; } } - if (reset_initial_values) { - initial_value_ = value_; - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 6b46089e4f5e10c195bb59f78c33305c2fa3f84d..1bb2c7274ecdf0954768fd96def51194e52deee8 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -36,8 +36,11 @@ class XlaResource { kStack, }; - XlaResource(Kind kind, int arg_num, string name, DataType initial_type, - const xla::ComputationDataHandle& initial_value); + XlaResource(Kind kind, int arg_num, string name, DataType type, + TensorShape shape, + const xla::ComputationDataHandle& initial_value, + int64 tensor_array_size, + const std::set& tensor_array_gradients); XlaResource(const XlaResource&) = delete; XlaResource(XlaResource&&) = delete; @@ -60,6 +63,12 @@ class XlaResource { // a resource is first initialized we do not yet know its type, so we keep // track of its type dynamically. DataType type() const { return type_; } + + // Shape of the resource. For an uninitialized resource, this is ignored. + // For a Variable, this is the shape of the value. For a TensorArray or Stack + // this is the shape of each entry in the TensorArray/Stack. + const TensorShape& shape() const { return shape_; } + const xla::ComputationDataHandle& value() const { return value_; } // Value of the resource at computation entry. Used to detect which @@ -68,17 +77,19 @@ class XlaResource { return initial_value_; } + // A variable is initialized if it has a value. bool initialized() const { return value_.handle() > 0; } - // Sets the current type/value of the resource. - Status SetValue(DataType type, const xla::ComputationDataHandle& value); + // Sets the type and shape of the resource. The type and shape of a resource + // must not change once the variable has been initialized. + Status SetTypeAndShape(DataType type, const TensorShape& shape); - // Returns the shape of the resource as an xla::Shape. - Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const; + // Sets the current value of the resource. Returns an error if the type is not + // set to a valid value. + Status SetValue(const xla::ComputationDataHandle& value); - // Returns the shape of the resource as an TensorShape. Fails if the shape is - // not representable as a TensorShape. - Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const; + // Sets the current value of the resource to an all-zero value. + Status SetZeroValue(xla::ComputationBuilder* builder); // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A @@ -96,10 +107,6 @@ class XlaResource { Status Pack(xla::ComputationDataHandle* pack, xla::ComputationBuilder* builder) const; - // Returns the shape of the `pack` value computed by `Pack()`. - Status PackedShape(xla::ComputationBuilder* builder, - xla::Shape* packed_shape) const; - // Updates the resource with values from `pack`. If `gradient_sources` is // non-empty, treats `pack` as a tuple that represents a TensorArray and // its gradients, and unpacks and updates the gradient resources. @@ -108,14 +115,14 @@ class XlaResource { // Opposite of Pack(). Status SetFromPack(const std::set& gradient_sources, const xla::ComputationDataHandle& pack, - bool reset_initial_values, xla::ComputationBuilder* builder); - // TensorArray-specific fields + // TensorArray and Stack specific fields // 'tensor_array_size' stores the expected size of the TensorArray or Stack. // We need to store this since sometimes TensorArrays must be initialized // lazily since we do not know the element shape at construction time. + // Used by both TensorArrays and Stacks. int64 tensor_array_size() const { return tensor_array_size_; } void set_tensor_array_size(int64 size) { tensor_array_size_ = size; } @@ -136,6 +143,7 @@ class XlaResource { const string name_; DataType type_; + TensorShape shape_; xla::ComputationDataHandle value_; xla::ComputationDataHandle initial_value_; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 438f1443f17717a3806827abcb36d4ccbbbf756c..34e733bc8d80b364cec1783006eba0a5468b55ea 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -88,7 +88,6 @@ cc_library( visibility = [":friends"], deps = [ "//tensorflow/core:framework_lite", - "//tensorflow/core:lib", "//third_party/eigen3", ], ) @@ -182,6 +181,7 @@ cc_library( deps = [ ":status", ":status_macros", + ":statusor", ":types", ":xla_data_proto", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 71aa057cd3a1c273c0e851497a78f94ba37c778e..46ee4e64c9ae7ca111d9d04bedcb74ff02a42386 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -121,6 +121,23 @@ class Array { CHECK(idx == num_elements()); } + // Creates a 2D array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array(std::initializer_list> values) + : Array(ToInt64Vector({values.size(), values.begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + values_[idx] = static_cast(it2); + ++idx; + } + } + CHECK(idx == num_elements()); + } + // Creates a 3D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. Array(InitializerList3D values) @@ -138,6 +155,27 @@ class Array { CHECK(idx == num_elements()); } + // Creates a 3D array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array(std::initializer_list>> + values) + : Array(ToInt64Vector({values.size(), values.begin()->size(), + values.begin()->begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + for (const auto& it3 : it2) { + values_[idx] = static_cast(it3); + ++idx; + } + } + } + CHECK(idx == num_elements()); + } + // Creates a 4D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. Array(InitializerList4D values) @@ -158,6 +196,31 @@ class Array { CHECK(idx == num_elements()); } + // Creates a 4D array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array(std::initializer_list< + std::initializer_list>>> + values) + : Array(ToInt64Vector({values.size(), values.begin()->size(), + values.begin()->begin()->size(), + values.begin()->begin()->begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + for (const auto& it3 : it2) { + for (const auto& it4 : it3) { + values_[idx] = static_cast(it4); + ++idx; + } + } + } + } + CHECK(idx == num_elements()); + } + Array(const Array& other) : sizes_(other.sizes_), values_(new T[num_elements()]) { std::copy(&other.values_[0], &other.values_[0] + num_elements(), @@ -185,7 +248,7 @@ class Array { // Fills the array with the sequence i*multiplier for i=0,1,... void FillWithMultiples(const T& multiplier) { for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = i * multiplier; + values_[i] = static_cast(i) * multiplier; } } diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index bb85fbee9b97fd6b9b0bf7223a9b820989dcbfa7..41f563486d21e42e88dcf6c751ce4a64da5e3213 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -52,6 +52,14 @@ class Array2D : public Array { Array2D(std::initializer_list> values) : Array(values) {} + // Creates an array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array2D(std::initializer_list> values) + : Array(values) {} + Array2D(const Array2D& other) : Array(other) {} int64 n1() const { return this->dim(0); } diff --git a/tensorflow/compiler/xla/array2d_test.cc b/tensorflow/compiler/xla/array2d_test.cc index c08e42c20ee684dfad8268aa8223440fbfad8a33..93034a719bfbd6724c007059715754677f3f1e62 100644 --- a/tensorflow/compiler/xla/array2d_test.cc +++ b/tensorflow/compiler/xla/array2d_test.cc @@ -63,6 +63,20 @@ TEST(Array2dTest, InitializerListCtor) { EXPECT_EQ(arr(1, 2), 6); } +TEST(Array2dTest, InitializerListCtorHalf) { + Array2D arr = {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}; + + EXPECT_EQ(arr.n1(), 2); + EXPECT_EQ(arr.n2(), 3); + + EXPECT_EQ(arr(0, 0), static_cast(1)); + EXPECT_EQ(arr(0, 1), static_cast(2)); + EXPECT_EQ(arr(0, 2), static_cast(3)); + EXPECT_EQ(arr(1, 0), static_cast(4)); + EXPECT_EQ(arr(1, 1), static_cast(5)); + EXPECT_EQ(arr(1, 2), static_cast(6)); +} + TEST(Array2dTest, Accessors) { Array2D arr = {{1, 2, 3}, {4, 5, 6}}; diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index a1c5840a5f3874e27043c821ed4684da2fa6c542..e5eb235d45d160d486d1499db665ed14a8509043 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -57,6 +57,16 @@ class Array3D : public Array { values) : Array(values) {} + // Creates an array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array3D( + std::initializer_list>> + values) + : Array(values) {} + int64 n1() const { return this->dim(0); } int64 n2() const { return this->dim(1); } int64 n3() const { return this->dim(2); } diff --git a/tensorflow/compiler/xla/array3d_test.cc b/tensorflow/compiler/xla/array3d_test.cc index 6b5f4b343b2113652758bbd5ce0fc803239c1266..691ff6c03594a98a12e0fdd2151c4c2a2c9c128a 100644 --- a/tensorflow/compiler/xla/array3d_test.cc +++ b/tensorflow/compiler/xla/array3d_test.cc @@ -69,6 +69,29 @@ TEST(Array3dTest, InitializerListCtor) { EXPECT_EQ(arr(2, 3, 1), 24); } +TEST(Array3dTest, InitializerListCtorHalf) { + Array3D arr = { + {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}, {7.0f, 8.0f}}, + {{9.0f, 10.0f}, {11.0f, 12.0f}, {13.0f, 14.0f}, {15.0f, 16.0f}}, + {{17.0f, 18.0f}, {19.0f, 20.0f}, {21.0f, 22.0f}, {23.0f, 24.0f}}}; + + EXPECT_EQ(arr.n1(), 3); + EXPECT_EQ(arr.n2(), 4); + EXPECT_EQ(arr.n3(), 2); + EXPECT_EQ(arr.num_elements(), 24); + + EXPECT_EQ(arr(0, 0, 0), static_cast(1)); + EXPECT_EQ(arr(0, 0, 1), static_cast(2)); + EXPECT_EQ(arr(0, 1, 0), static_cast(3)); + EXPECT_EQ(arr(0, 3, 1), static_cast(8)); + EXPECT_EQ(arr(1, 0, 0), static_cast(9)); + EXPECT_EQ(arr(1, 1, 1), static_cast(12)); + EXPECT_EQ(arr(2, 0, 0), static_cast(17)); + EXPECT_EQ(arr(2, 1, 1), static_cast(20)); + EXPECT_EQ(arr(2, 2, 0), static_cast(21)); + EXPECT_EQ(arr(2, 3, 1), static_cast(24)); +} + TEST(Array3dTest, Fill) { Array3D fullof7(2, 3, 4, 7); for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) { diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index f8b2b2afe5fed9c465c2a1f39308b7f44311b16a..cff70e54bad0116bdd08674b626b3bf99dc89e1f 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -82,6 +82,16 @@ class Array4D : public Array { values) : Array(values) {} + // Creates an array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array4D(std::initializer_list>>> + values) + : Array(values) {} + // Numerically-named aliases for the various dimensions. This matches the // dimension names used in array3d. int64 n4() const { return this->dim(3); } diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc index 3bc8148c911df0aeade364e4ac2e2ee828bacb53..927733ea1eab43feff643c35535cc6d9ea59ba5a 100644 --- a/tensorflow/compiler/xla/array4d_test.cc +++ b/tensorflow/compiler/xla/array4d_test.cc @@ -97,6 +97,36 @@ TEST(Array3dTest, InitializerListCtor) { EXPECT_EQ(arr(2, 3, 1, 0), 24); } +TEST(Array3dTest, InitializerListCtorHalf) { + Array4D arr = { + {{{1.0f}, {2.0f}}, {{3.0f}, {4.0f}}, {{5.0f}, {6.0f}}, {{7.0f}, {8.0f}}}, + {{{9.0f}, {10.0f}}, + {{11.0f}, {12.0f}}, + {{13.0f}, {14.0f}}, + {{15.0f}, {16.0f}}}, + {{{17.0f}, {18.0f}}, + {{19.0f}, {20.0f}}, + {{21.0f}, {22.0f}}, + {{23.0f}, {24.0f}}}}; + + EXPECT_EQ(arr.n1(), 3); + EXPECT_EQ(arr.n2(), 4); + EXPECT_EQ(arr.n3(), 2); + EXPECT_EQ(arr.n4(), 1); + EXPECT_EQ(arr.num_elements(), 24); + + EXPECT_EQ(arr(0, 0, 0, 0), static_cast(1)); + EXPECT_EQ(arr(0, 0, 1, 0), static_cast(2)); + EXPECT_EQ(arr(0, 1, 0, 0), static_cast(3)); + EXPECT_EQ(arr(0, 3, 1, 0), static_cast(8)); + EXPECT_EQ(arr(1, 0, 0, 0), static_cast(9)); + EXPECT_EQ(arr(1, 1, 1, 0), static_cast(12)); + EXPECT_EQ(arr(2, 0, 0, 0), static_cast(17)); + EXPECT_EQ(arr(2, 1, 1, 0), static_cast(20)); + EXPECT_EQ(arr(2, 2, 0, 0), static_cast(21)); + EXPECT_EQ(arr(2, 3, 1, 0), static_cast(24)); +} + TEST(Array4dTest, Fill) { Array4D fullof7(2, 3, 4, 5, 7); fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc index 8b9419477479d952126fd831eb44899e7649ca71..e8356c9832d34135f5ffb1a5c7a9d6db6db3a051 100644 --- a/tensorflow/compiler/xla/array_test.cc +++ b/tensorflow/compiler/xla/array_test.cc @@ -60,6 +60,25 @@ TEST(ArrayTest, InitializerListCtor) { EXPECT_EQ(arr(1, 2), 6); } +TEST(ArrayTest, InitializerListCtorHalf) { + Array d2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + EXPECT_EQ(d2.dim(0), 2); + EXPECT_EQ(d2.dim(1), 3); + + Array d3({{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}}); + EXPECT_EQ(d3.dim(0), 3); + EXPECT_EQ(d3.dim(1), 2); + EXPECT_EQ(d3.dim(2), 1); + + Array d4( + {{{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}}, + {{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}}}); + EXPECT_EQ(d4.dim(0), 2); + EXPECT_EQ(d4.dim(1), 3); + EXPECT_EQ(d4.dim(2), 2); + EXPECT_EQ(d4.dim(3), 1); +} + TEST(ArrayTest, IndexingReadWrite) { Array arr({2, 3}); diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index d6b4ebfc39ae039ff27fe9fb8a3487c870832f3e..02356699a25e47be50eb15872df4c9c302fc289b 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -80,6 +80,18 @@ cc_library( ], ) +cc_library( + name = "executable_build_options", + srcs = ["executable_build_options.cc"], + hdrs = ["executable_build_options.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/core:lib", + ], +) + cc_library( name = "local_client", srcs = ["local_client.cc"], @@ -87,6 +99,7 @@ cc_library( deps = [ ":client", ":computation", + ":executable_build_options", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -98,6 +111,7 @@ cc_library( "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:source_map_util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:support", diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc index 4baea8df6e3331200ee52f500fb7b961428e56be..e6c57bda0f0c4cb969939883efebcf3a6d6be381 100644 --- a/tensorflow/compiler/xla/client/computation.cc +++ b/tensorflow/compiler/xla/client/computation.cc @@ -64,4 +64,14 @@ void Computation::ResetWithoutFreeing() { parent_ = nullptr; } +StatusOr Computation::GetProgramShape() const { + GetComputationShapeRequest request; + *request.mutable_computation() = handle_; + GetComputationShapeResponse response; + + TF_RETURN_IF_ERROR(parent_->GetComputationShape(&request, &response)); + + return std::move(*response.mutable_program_shape()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h index b595172486950bf08b057625d7b2dd97ac9b2278..a53fc9e9cf34704bd08ddb5bf062c1ec1107f5fb 100644 --- a/tensorflow/compiler/xla/client/computation.h +++ b/tensorflow/compiler/xla/client/computation.h @@ -60,6 +60,10 @@ class Computation { // Returns true if this object is a null Computation. bool IsNull() const { return parent_ == nullptr; } + // Returns the "program shape" (parameter and return shapes) for this + // computation. + StatusOr GetProgramShape() const; + private: void ResetWithoutFreeing(); diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 46f2ed4836eda6bf6d5b68f2e29ac6888cd1749b..2a6e02649d15bc9fd47a893c41f9c8a62ac076c6 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -233,6 +233,26 @@ StatusOr> ComputationBuilder::GetShape( return status_or_shape; } +StatusOr ComputationBuilder::GetProgramShape() { + TF_RETURN_IF_ERROR(first_error_); + + GetComputationShapeRequest request; + *request.mutable_computation() = computation_.handle(); + GetComputationShapeResponse response; + + VLOG(2) << "making get-program-shape-request"; + Status status = client_->stub()->GetComputationShape(&request, &response); + VLOG(2) << "done with get-program-shape-request"; + + if (!status.ok()) { + first_error_ = status; + return status; + } + + TF_RET_CHECK(response.has_program_shape()); + return std::move(*response.mutable_program_shape()); +} + ComputationDataHandle ComputationBuilder::CheckShape( const ComputationDataHandle& operand, const Shape& expected_shape) { std::unique_ptr actual_shape = GetShape(operand).ConsumeValueOrDie(); @@ -769,6 +789,20 @@ ComputationDataHandle ComputationBuilder::CustomCall( return RunOpAndParseResponse(&op_request); } +ComputationDataHandle ComputationBuilder::HostCompute( + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { + OpRequest op_request; + HostComputeRequest* request = op_request.mutable_host_compute_request(); + for (const ComputationDataHandle& operand : operands) { + *request->add_operands() = operand; + } + *request->mutable_shape() = shape; + request->set_channel_name(channel_name); + request->set_cost_estimate_ns(cost_estimate_ns); + return RunOpAndParseResponse(&op_request); +} + ComputationDataHandle ComputationBuilder::Complex( const ComputationDataHandle& real, const ComputationDataHandle& imag, tensorflow::gtl::ArraySlice broadcast_dimensions) { @@ -1200,6 +1234,22 @@ ComputationDataHandle ComputationBuilder::While( return RunOpAndParseResponse(&op_request); } +ComputationDataHandle ComputationBuilder::Gather( + const ComputationDataHandle& input, + const ComputationDataHandle& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + OpRequest op_request; + GatherRequest* gather_request = op_request.mutable_gather_request(); + *gather_request->mutable_input() = input; + *gather_request->mutable_gather_indices() = gather_indices; + *gather_request->mutable_dimension_numbers() = dimension_numbers; + for (int64 window_bound : window_bounds) { + gather_request->add_window_bounds(window_bound); + } + return RunOpAndParseResponse(&op_request); +} + ComputationDataHandle ComputationBuilder::Conditional( const ComputationDataHandle& predicate, const ComputationDataHandle& true_operand, diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index ea4cdb76673b1c99036224bcd754ce4fe1360945..e3facb3f258bb0bdf4b2dd3648e55421dfd56e79 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -101,6 +101,9 @@ class ComputationBuilder { StatusOr> GetShape( const ComputationDataHandle& operand); + // Retrieves the (inferred) result for the current computation's shape. + StatusOr GetProgramShape(); + // Checks that the operand has the given expected shape. Returns the operand // if yes, fails with a CHECK error if no. ComputationDataHandle CheckShape(const ComputationDataHandle& operand, @@ -443,6 +446,16 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice operands, const Shape& shape); + // Enqueues a pseudo-op to represent host-side computation data-dependencies. + // During code generation, host send and receive operations will be generated + // to transfer |operands| to the host and a single result of |shape| back to + // the device. Host send/recv operations are emitted using |channel_name|. + // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO + // instruction scheduling. + ComputationDataHandle HostCompute( + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, const Shape& shape); + // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one // of the operands is a scalar, or an explicit broadcast dimension is given @@ -705,6 +718,13 @@ class ComputationBuilder { const int exponent_bits, const int mantissa_bits); + // Enqueues a Gather node onto the computation. + ComputationDataHandle Gather( + const ComputationDataHandle& input, + const ComputationDataHandle& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + // Enqueues a Send node onto the computation, to send the given operand to // a Recv instruction that shares the same channel handle. void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc new file mode 100644 index 0000000000000000000000000000000000000000..804e34f5e75ce2d153ac7627b94a543fda88e810 --- /dev/null +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/executable_build_options.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { + +ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator( + DeviceMemoryAllocator* allocator) { + device_allocator_ = allocator; + return *this; +} + +DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const { + return device_allocator_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( + int device_ordinal) { + CHECK_GE(device_ordinal, 0); + device_ordinal_ = device_ordinal; + return *this; +} + +int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; } + +ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( + const Shape& shape_with_layout) { + result_layout_set_ = true; + result_layout_ = shape_with_layout; + return *this; +} + +const Shape* ExecutableBuildOptions::result_layout() const { + return result_layout_set_ ? &result_layout_ : nullptr; +} + +string ExecutableBuildOptions::ToString() const { + string result_layout = "nullopt"; + if (result_layout_set_) { + result_layout = ShapeUtil::HumanStringWithLayout(result_layout_); + } + string generate_hlo_graph = "nullopt"; + if (generate_hlo_graph_.has_value()) { + generate_hlo_graph = generate_hlo_graph_.value(); + } + return tensorflow::strings::Printf( + "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " + "generate_hlo_graph=%s}", + device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str()); +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( + string regex) { + generate_hlo_graph_ = std::move(regex); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::generate_hlo_graph() const { + return generate_hlo_graph_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h new file mode 100644 index 0000000000000000000000000000000000000000..3a52dbac9adb155ad9a7d91a8102707f70fe2fbf --- /dev/null +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// Class containing options for building an LocalExecutable with +// LocalClient::Compile. +class ExecutableBuildOptions { + public: + // If set, this is the device to build the computation for. Valid + // device_ordinal values are: 0 to # of devices - 1. These values are + // identical to the device ordinal values used by StreamExecutor. The built + // executable will be executable on any device equivalent to the specified + // device as determined by Backend::devices_equivalent(). A value of -1 + // indicates this option has not been set. + ExecutableBuildOptions& set_device_ordinal(int device_ordinal); + int device_ordinal() const; + + // If set, this specifies the layout of the result of the computation. If not + // set, the service will chose the layout of the result. A Shape is used to + // store the layout to accommodate tuple result shapes. A value of nullptr + // indicates the option has not been set. + ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); + const Shape* result_layout() const; + + // If set, this specifies an allocator that can be used to allocate temporary + // space on the device during compilation. For example, the compiler might + // want to run various algorithms on the device and pick the fastest one -- it + // might allocate buffers for use by these algorithms using this allocator. + // + // This does not need to be the same as the DeviceMemoryAllocator passed when + // running the executable. + ExecutableBuildOptions& set_device_allocator( + DeviceMemoryAllocator* allocator); + DeviceMemoryAllocator* device_allocator() const; + + // If set, specifies a regexp of HLO graphs to dump (as in DebugOptions). + ExecutableBuildOptions& set_generate_hlo_graph(string regex); + const tensorflow::gtl::optional& generate_hlo_graph() const; + + // Returns a string representation of the build options, suitable for + // debugging. + string ToString() const; + + private: + int device_ordinal_ = -1; + Shape result_layout_; + bool result_layout_set_ = false; + tensorflow::gtl::optional generate_hlo_graph_; + DeviceMemoryAllocator* device_allocator_ = nullptr; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 5f2b55713e342aa3d0251386d57cb52481fe748d..b63a1465ea755b906853860d47768ecbeaa0dcdd 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -31,14 +31,43 @@ limitations under the License. namespace xla { namespace { +// Calculates the number of bytes required to store the data within the +// specified shape. In case of a (nested) tuple shape this is the total byte +// size of all sub-shapes within the tuple. +int64 DataSizeOfShape(const Shape& shape) { + if (ShapeUtil::IsArray(shape)) { + return ShapeUtil::ByteSizeOf(shape); + } + + int64 total_size = 0; + for (const Shape& s : shape.tuple_shapes()) { + total_size += DataSizeOfShape(s); + } + return total_size; +} + +// Create a ComputationDataHandle for an op what generates fake data with the +// given shape. +ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, + ComputationBuilder* builder) { + if (ShapeUtil::IsArray(shape)) { + return builder->Broadcast( + builder->ConstantLiteral(Literal::One(shape.element_type())), + AsInt64Slice(shape.dimensions())); + } + std::vector parts; + for (const Shape& s : shape.tuple_shapes()) { + parts.push_back(BuildFakeDataOpOnDevice(s, builder)); + } + return builder->Tuple(parts); +} + std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { ComputationBuilder b( client, tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); - // TODO(b/26811613): Replace this when RNG is supported on all backends. - b.Broadcast(b.ConstantLiteral(Literal::One(shape.element_type())), - AsInt64Slice(shape.dimensions())); + BuildFakeDataOpOnDevice(shape, &b); Computation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); @@ -51,7 +80,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { - if (ShapeUtil::ByteSizeOf(shape) < (1LL << 20)) { + if (DataSizeOfShape(shape) < (1LL << 20)) { StatusOr> literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 523169fdd266d445c9d0d056ba20091f77610ad9..91396f055fe4a3ecbd436139be9470e2a35e1c63 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -21,30 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" +#include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/status_macros.h" namespace se = ::perftools::gputools; -namespace xla { - -ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( - int device_ordinal) { - device_ordinal_ = device_ordinal; - return *this; -} - -int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; } +using xla::source_map_util::InvalidParameterArgument; -ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( - const Shape& shape_with_layout) { - result_layout_set_ = true; - result_layout_ = shape_with_layout; - return *this; -} - -const Shape* ExecutableBuildOptions::result_layout() const { - return result_layout_set_ ? &result_layout_ : nullptr; -} +namespace xla { namespace { StatusOr BorrowStreamForDevice(int device_ordinal, @@ -57,16 +41,18 @@ StatusOr BorrowStreamForDevice(int device_ordinal, } // namespace LocalExecutable::LocalExecutable(std::unique_ptr executable, - Backend* backend, int device_ordinal, - const ExecutableBuildOptions& build_options) + Backend* backend, + ExecutableBuildOptions build_options) : executable_(std::move(executable)), backend_(backend), - build_device_ordinal_(device_ordinal), - build_options_(build_options) {} + build_options_(std::move(build_options)) { + CHECK_GE(build_options_.device_ordinal(), 0) + << "Must have a valid device ordinal that the executable was built for."; +} tensorflow::Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options, const Backend& backend) { + const ExecutableRunOptions& run_options, const Backend& backend) { const ComputationLayout& computation_layout = executable_->module_config().entry_computation_layout(); @@ -79,9 +65,10 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( for (int i = 0; i < arguments.size(); ++i) { if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( arguments[i]->on_host_shape())) { - return InvalidArgument( - "argument does not match shape or layout of computation parameter " - "%d: expected %s, got %s", + return InvalidParameterArgument( + executable_.get(), i, + "Argument does not match shape or layout of computation parameter " + "%d: want %s, got %s", i, ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) .c_str(), @@ -89,14 +76,14 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( } } - if (options.stream() != nullptr) { - if (!options.stream()->ok()) { + if (run_options.stream() != nullptr) { + if (!run_options.stream()->ok()) { return InvalidArgument("stream is uninitialized or in an error state"); } // Check stream matches service platform. const se::Platform* stream_platform = - options.stream()->parent()->platform(); + run_options.stream()->parent()->platform(); if (stream_platform != backend_->platform()) { return InvalidArgument( "stream is for platform %s, but service targets platform %s", @@ -106,7 +93,7 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( // Cannot specify device_ordinal with a stream. The stream determines these // values. - if (options.device_ordinal() != -1) { + if (run_options.device_ordinal() != -1) { return InvalidArgument( "cannot set both device ordinal and stream options in " "ExecutableRunOptions; the stream determines the device ordinal"); @@ -115,34 +102,34 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( // Verify that the device the executable was built for is equivalent to the // device it will run on. - int run_device_ordinal = options.device_ordinal() == -1 + int run_device_ordinal = run_options.device_ordinal() == -1 ? backend_->default_device_ordinal() - : options.device_ordinal(); - TF_ASSIGN_OR_RETURN( - bool devices_equivalent, - backend_->devices_equivalent(run_device_ordinal, build_device_ordinal_)); + : run_options.device_ordinal(); + TF_ASSIGN_OR_RETURN(bool devices_equivalent, + backend_->devices_equivalent( + run_device_ordinal, build_options_.device_ordinal())); if (!devices_equivalent) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor, backend_->stream_executor(run_device_ordinal)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor, - backend_->stream_executor(build_device_ordinal_)); + backend_->stream_executor(build_device_ordinal())); return InvalidArgument( "executable is built for device %s of type \"%s\"; cannot run it on " "device %s of type \"%s\"", - backend_->device_name(build_device_ordinal_).c_str(), + backend_->device_name(build_device_ordinal()).c_str(), build_executor->GetDeviceDescription().name().c_str(), backend_->device_name(run_device_ordinal).c_str(), run_executor->GetDeviceDescription().name().c_str()); } - if (!options.allocator()) { + if (!run_options.allocator()) { return InvalidArgument("an allocator must be provided to ExecuteLocally"); } - if (options.allocator()->platform() != backend.platform()) { + if (run_options.allocator()->platform() != backend.platform()) { return InvalidArgument( "allocator platform (%s) does not match service platform (%s)", - options.allocator()->platform()->Name().c_str(), + run_options.allocator()->platform()->Name().c_str(), backend.platform()->Name().c_str()); } @@ -151,23 +138,22 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( StatusOr> LocalExecutable::Run( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options) { - TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_)); - - ExecutableRunOptions actual_options = options; + ExecutableRunOptions run_options) { + TF_RETURN_IF_ERROR( + ValidateExecutionOptions(arguments, run_options, *backend_)); Backend::StreamPtr stream; - if (options.stream() == nullptr) { + if (run_options.stream() == nullptr) { // NB! The lifetime of `stream` needs to match the lifetime of // `actual_options` (otherwise we will end up using a returned stream in // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if" // scope. TF_ASSIGN_OR_RETURN( - stream, BorrowStreamForDevice(options.device_ordinal(), backend_)); - actual_options.set_stream(stream.get()); + stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_)); + run_options.set_stream(stream.get()); } - if (options.allocator() == nullptr) { - actual_options.set_allocator(backend_->memory_allocator()); + if (run_options.allocator() == nullptr) { + run_options.set_allocator(backend_->memory_allocator()); } // For local client execution on CPU backends: @@ -176,7 +162,7 @@ StatusOr> LocalExecutable::Run( // *) The thread pool used for XLA CPU ops is from // backend_->eigen_intra_op_thread_pool(). ServiceExecutableRunOptions service_options( - actual_options, backend_->StreamBorrower(), + run_options, backend_->StreamBorrower(), backend_->eigen_intra_op_thread_pool()); if (executable_->dumping()) { @@ -185,9 +171,10 @@ StatusOr> LocalExecutable::Run( TF_ASSIGN_OR_RETURN( std::unique_ptr result, executable_->ExecuteOnStreamWrapper( - &service_options, options.execution_profile(), arguments)); - return ScopedShapedBuffer::MakeScoped(result.get(), - actual_options.allocator()); + &service_options, run_options.execution_profile(), arguments)); + + return MakeUnique(std::move(*result), + run_options.allocator()); } StatusOr> LocalExecutable::ExecuteAndDump( @@ -263,16 +250,19 @@ StatusOr> LocalClient::Compile( const Computation& computation, const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options) { - int device_ordinal = options.device_ordinal() == -1 - ? default_device_ordinal() - : options.device_ordinal(); - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - local_service_->CompileExecutable( - computation.handle(), argument_layouts, - options.result_layout(), device_ordinal)); + ExecutableBuildOptions updated_options = options; + if (options.device_ordinal() == -1) { + updated_options.set_device_ordinal(default_device_ordinal()); + VLOG(3) << "Set device ordinal to default value of: " + << updated_options.device_ordinal(); + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + local_service_->CompileExecutable(computation.handle(), argument_layouts, + updated_options)); return WrapUnique(new LocalExecutable(std::move(executable), local_service_->mutable_backend(), - device_ordinal, options)); + updated_options)); } StatusOr> diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 19fd14f76bc69d528193f7981a51a305f03f987e..b52a30f5a0b92e0094e6b0de3241c10a5a909cad 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -33,39 +34,13 @@ limitations under the License. namespace xla { -// Class containing options for building an LocalExecutable with -// LocalClient::Compile. -class ExecutableBuildOptions { - public: - // If set, this is the device to build the computation for. Valid - // device_ordinal values are: 0 to # of devices - 1. These values are - // identical to the device ordinal values used by StreamExecutor. The built - // executable will be executable on any device equivalent to the specified - // device as determined by Backend::devices_equivalent(). A value of -1 - // indicates this option has not been set. - ExecutableBuildOptions& set_device_ordinal(int device_ordinal); - int device_ordinal() const; - - // If set, this specifies the layout of the result of the computation. If not - // set, the service will chose the layout of the result. A Shape is used to - // store the layout to accommodate tuple result shapes. A value of nullptr - // indicates the option has not been set. - ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); - const Shape* result_layout() const; - - private: - int device_ordinal_ = -1; - Shape result_layout_; - bool result_layout_set_ = false; -}; - class LocalExecutable { public: // Run the compiled computation with the given arguments and options and // return the result. StatusOr> Run( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options); + ExecutableRunOptions run_options); // Return the layout (contained in a shape) of the result produced by the // computation. @@ -88,8 +63,7 @@ class LocalExecutable { // Constructor invoked by LocalClient. LocalExecutable(std::unique_ptr executable, Backend* backend, - int device_ordinal, - const ExecutableBuildOptions& build_options); + ExecutableBuildOptions build_options); // Validates that the given arguments and options satisfy various constraints // of the computation. @@ -117,19 +91,19 @@ class LocalExecutable { StatusOr> LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer); + // The ordinal of the device which this executable was compiled for. The + // executable can run on all equivalent devices (as determined by + // Backend::devices_equivalent). + int build_device_ordinal() const { return build_options_.device_ordinal(); } + // Compiled computation. std::unique_ptr executable_; // Execution backend. - Backend* backend_; - - // The ordinal of the device which this executable was compiled for. The - // executable can run on all equivalent devices (as determined by - // Backend::devices_equivalent). - int build_device_ordinal_; + Backend* backend_ = nullptr; // Options used to build the executable. - const ExecutableBuildOptions& build_options_; + const ExecutableBuildOptions build_options_; }; // An XLA Client specialization for use when the client and service run in diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 0b9188e8524d6f1367541496dc5a86a250a0d530..142006f2626e83d3254f2de65fc28fd5d6694e53 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -37,7 +37,7 @@ class IndexUtil { static int64 MultidimensionalIndexToLinearIndex( const Shape& shape, tensorflow::gtl::ArraySlice multi_index); - // Coverts a linear index into multidimensional index (eg {x, y, z}) based on + // Converts a linear index into multidimensional index (eg {x, y, z}) based on // the shape and its layout. The first index in the returned multidimensional // index is dimension 0. static std::vector LinearIndexToMultidimensionalIndex( diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index fe3a4d2f6df47d9f156529e55198a5f339bc8e3c..c8ed3e3a2b009ddffdfb79a9a6ced8d5e736bee6 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -221,13 +221,19 @@ void AllocateFlags() { flag_values->xla_gpu_disable_multi_streaming(), "If true, multi-streaming in the GPU backend is disabled."), tensorflow::Flag( - "xla_dump_hlo_proto_to", flag_values->mutable_xla_dump_hlo_proto_to(), - "Dump compilation artifacts as proto binary into this directory."), + "xla_dump_optimized_hlo_proto_to", + flag_values->mutable_xla_dump_optimized_hlo_proto_to(), + "Dump Hlo after all hlo passes are executed as proto binary into " + "this directory."), tensorflow::Flag( - "xla_dump_prepass_hlo_proto_to", - flag_values->mutable_xla_dump_prepass_hlo_proto_to(), - "Dump compilation artifacts, before hlo passes are executed, as " - "proto binary into this directory."), + "xla_dump_unoptimized_hlo_proto_to", + flag_values->mutable_xla_dump_unoptimized_hlo_proto_to(), + "Dump HLO before any hlo passes are executed as proto binary into " + "this directory."), + tensorflow::Flag("xla_dump_per_pass_hlo_proto_to", + flag_values->mutable_xla_dump_per_pass_hlo_proto_to(), + "Dump HLO after each pass as an HloProto in binary file " + "format into this directory."), tensorflow::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 89279b659c75ce4775581dfbfa8d830f54ae6fe8..e0a9b148b443e90a0c4f3e19660b6234d49eef84 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -234,7 +234,8 @@ Status Literal::CopySliceFromInternal( int64 src_index = linear_index(src_literal.shape(), src_indexes); int64 dest_index = linear_index(shape(), dest_indexes); - StridedCopy(data(), dest_index, stride_config.dest_stride, + // `this->` is needed to workaround MSVC bug: #16882 + StridedCopy(this->data(), dest_index, stride_config.dest_stride, src_literal.data(), src_index, stride_config.source_stride, stride_config.minor_loop_size); return true; @@ -1257,11 +1258,17 @@ string Literal::ToString(bool print_layout) const { /* static */ std::unique_ptr Literal::MakeTupleOwned( std::vector> elements) { - std::vector element_ptrs; + std::vector element_shapes; + element_shapes.reserve(elements.size()); for (const auto& element : elements) { - element_ptrs.push_back(element.get()); + element_shapes.push_back(element->shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int64 i = 0; i < elements.size(); ++i) { + TF_CHECK_OK( + literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); } - return MakeTuple(element_ptrs); + return literal; } void Literal::EachCellAsString( diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e0196509a7483abac3d9c0e59a54b591a327b980..d996004888ab521790b4c5a10da2a93f0d98d12f 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -485,7 +485,29 @@ class Literal { static std::unique_ptr MakeTupleOwned( std::vector> elements); + // This overload lets you pass a braced list of unique_ptrs to + // MakeTupleOwned: + // + // Literal::MakeTupleOwned(Literal::CreateR1(...), ...). + // + // Simply relying on the MakeTupleOwned(std::vector>) + // overload doesn't work because std::initializer_list's elements are always + // const. + // + // The arguments to this function must all be unique_ptr. + template + static std::unique_ptr MakeTupleOwned( + std::unique_ptr... elements) { + std::array, sizeof...(Ts)> arr{ + std::move(elements)...}; + std::vector> v; + v.insert(v.begin(), std::make_move_iterator(arr.begin()), + std::make_move_iterator(arr.end())); + return MakeTupleOwned(std::move(v)); + } + // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. string ToString(bool print_layout = false) const; // Invokes the "per cell" callback for each element in the provided diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 50659c12405f2a29c69b03b3c7de5bd6cb6af9c2..8db8c6f3de84a6c46625eadbb6b0f83d2262e5f7 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -16,6 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -44,6 +49,41 @@ typename Collection::value_type::second_type& FindOrDie( return it->second; } +// Like FindOrDie but returns an error instead of dying if `key` is not in +// `container`. +template +StatusOr< + std::reference_wrapper> +MaybeFind(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + std::ostringstream os; + os << key; + return NotFound("key not found: %s", os.str().c_str()); + } + return {it->second}; +} + +// Returns a const reference to the value associated with the given key if it +// exists, otherwise returns a const reference to the provided default value. +// +// WARNING: If a temporary object is passed as the default "value," +// this function will return a reference to that temporary object, +// which will be destroyed at the end of the statement. A common +// example: if you have a map with string values, and you pass a char* +// as the default "value," either use the returned value immediately +// or store it in a string (not string&). +template +const typename Collection::value_type::second_type& FindOrDefault( + const Collection& collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + auto it = collection.find(key); + if (it != collection.end()) return it->second; + return value; +} + // Inserts the key-value pair into the collection. Dies if key was already // present. template diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index a8ca0e3ea0115d412e96ebacb320cc0dde061dff..e2972f06016ab3555c4fc0cc4616993fe6764b1e 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -49,6 +49,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 37f1eada2bc9f5ef72d99a835a17b4e78a354ae6..cb7bb21e092c80d7360c23f3d6b00409a75dce23 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -98,15 +98,25 @@ const std::unique_ptr& LocalShapedBuffer::shaped_buffer() return shaped_buffer_; } +static StatusOr> ToBuffer( + LocalClient* client, int device_ordinal, const Literal& arg) { + return client->LiteralToShapedBuffer(arg, device_ordinal, + client->backend().memory_allocator()); +} + /* static */ -LocalShapedBuffer* LocalShapedBuffer::FromLiteral(const Literal& argument) { +LocalShapedBuffer* LocalShapedBuffer::FromLiteral( + const Literal& argument, + const tensorflow::gtl::optional& shape_with_layout) { LocalClient* client = GetOrCreateLocalClient(); - std::unique_ptr buf = - client - ->LiteralToShapedBuffer(argument, - /*device_ordinal=*/0, - client->backend().memory_allocator()) - .ConsumeValueOrDie(); + std::unique_ptr buf; + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + buf = ToBuffer(client, /*device_ordinal=*/0, *relaid).ConsumeValueOrDie(); + } else { + buf = ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie(); + } return new LocalShapedBuffer(std::move(buf)); } @@ -120,7 +130,8 @@ CompiledLocalComputation::CompiledLocalComputation( : executable_(std::move(executable)) {} StatusOr> CompiledLocalComputation::Execute( - const std::vector& arguments) { + const std::vector& arguments, + const std::vector>& shapes_with_layout) { LocalClient* client = GetOrCreateLocalClient(); VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; @@ -133,7 +144,8 @@ StatusOr> CompiledLocalComputation::Execute( GetReplicaCount()); for (int replica = 0; replica < GetReplicaCount(); ++replica) { - pool.Schedule([this, client, replica, &arguments, &results] { + pool.Schedule([this, client, replica, &arguments, &shapes_with_layout, + &results] { StatusOr device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(replica); if (!device_ordinal_status.ok()) { @@ -144,18 +156,28 @@ StatusOr> CompiledLocalComputation::Execute( VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; + // Transfer arguments in std::vector> scoped_buffers; scoped_buffers.reserve(arguments.size()); - for (const Literal& argument : arguments) { - StatusOr> pushed = - client->LiteralToShapedBuffer( - argument, device_ordinal, - client->backend().memory_allocator()); + for (int i = 0; i < arguments.size(); ++i) { + const Literal& argument = arguments[i]; + const tensorflow::gtl::optional& shape_with_layout = + shapes_with_layout[i]; + + StatusOr> pushed; + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, *relaid); + } else { + pushed = ToBuffer(client, device_ordinal, argument); + } if (!pushed.ok()) { results[replica] = pushed.status(); return; } + scoped_buffers.push_back(std::move(pushed).ValueOrDie()); } @@ -233,7 +255,8 @@ LocalComputation::LocalComputation(Computation computation) : computation_(std::move(computation)) {} StatusOr LocalComputation::Compile( - const std::vector& argument_shapes) { + const std::vector& argument_shapes, + const ExecutableBuildOptions* build_options) { std::vector argument_shape_pointers; argument_shape_pointers.reserve(argument_shapes.size()); for (auto& argument_shape : argument_shapes) { @@ -242,6 +265,9 @@ StatusOr LocalComputation::Compile( LocalClient* client = GetOrCreateLocalClient(); ExecutableBuildOptions options; + if (build_options != nullptr) { + options = *build_options; + } TF_ASSIGN_OR_RETURN( auto local_executable, client->Compile(computation_, argument_shape_pointers, options)); @@ -252,6 +278,12 @@ const Computation& LocalComputation::computation() const { return computation_; } +StatusOr LocalComputation::GetReturnValueShape() const { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation_.GetProgramShape()); + return std::move(*program_shape.mutable_result()); +} + LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) : builder_(GetOrCreateLocalClient(), computation_name) {} @@ -277,6 +309,11 @@ std::unique_ptr LocalComputationBuilder::GetShape( return builder_.GetShape(operand).ConsumeValueOrDie(); } +StatusOr LocalComputationBuilder::GetReturnValueShape() { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); + return program_shape.result(); +} + ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } @@ -363,12 +400,6 @@ LocalComputationBuilder::SelectAndScatterWithGeneralPadding( source, init_value, scatter.computation()); } -ComputationDataHandle LocalComputationBuilder::Select( - const ComputationDataHandle& pred, const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false) { - return builder_.Select(pred, on_true, on_false); -} - ComputationDataHandle LocalComputationBuilder::Tuple( tensorflow::gtl::ArraySlice elements) { return builder_.Tuple(elements); @@ -384,6 +415,12 @@ ComputationDataHandle LocalComputationBuilder::Dot( return builder_.Dot(lhs, rhs); } +ComputationDataHandle LocalComputationBuilder::DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers) { + return builder_.DotGeneral(lhs, rhs, dimension_numbers); +} + ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice window_strides, @@ -467,6 +504,17 @@ ComputationDataHandle LocalComputationBuilder::While( return builder_.While(condition.computation(), body.computation(), init); } +ComputationDataHandle LocalComputationBuilder::Conditional( + const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const LocalComputation& true_computation, + const ComputationDataHandle& false_operand, + const LocalComputation& false_computation) { + return builder_.Conditional(predicate, true_operand, + true_computation.computation(), false_operand, + false_computation.computation()); +} + #define _FORWARD(method_name, return_sig, args_sig, args) \ return_sig LocalComputationBuilder::method_name args_sig { \ return builder_.method_name args; \ @@ -483,6 +531,15 @@ ComputationDataHandle LocalComputationBuilder::While( tensorflow::gtl::ArraySlice broadcast_dimensions), \ (lhs, rhs, broadcast_dimensions)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + const ComputationDataHandle& ehs), \ + (lhs, rhs, ehs)) + +_FORWARD_TRIOP(Select) +_FORWARD_TRIOP(Clamp) _FORWARD_BINOP(Eq) _FORWARD_BINOP(Ne) _FORWARD_BINOP(Ge) @@ -503,6 +560,7 @@ _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) +_FORWARD_UNOP(Round) _FORWARD_UNOP(Log) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) @@ -519,6 +577,7 @@ _FORWARD_UNOP(Sort) #undef _FORWARD #undef _FORWARD_UNOP #undef _FORWARD_BINOP +#undef _FORWARD_TRIOP void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { delete local_shaped_buffer; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index e5503cd52fa60eff30eea38c83aafe0f0ff1efc8..d3e9503ea10b011520ec5148a756ef4d421f244c 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -58,7 +59,9 @@ StatusOr > TransferFromOutfeedLocalReplica( // client. class LocalShapedBuffer { public: - static LocalShapedBuffer* FromLiteral(const Literal& argument); + static LocalShapedBuffer* FromLiteral( + const Literal& argument, + const tensorflow::gtl::optional& shape_with_layout); LocalShapedBuffer(std::unique_ptr shaped_buffer); const std::unique_ptr& shaped_buffer() const; std::unique_ptr ToLiteral() const; @@ -76,8 +79,15 @@ class LocalShapedBuffer { class CompiledLocalComputation { public: CompiledLocalComputation(std::unique_ptr executable); + + // Execute the computation with the given argument literals, and + // with optionally-specified argument layouts. The literals will be + // re-laid out according to the corresponding elements of + // shapes_with_layout. StatusOr > Execute( - const std::vector& arguments); + const std::vector& arguments, + const std::vector >& shapes_with_layout); + LocalShapedBuffer* ExecuteWithShapedBuffers( tensorflow::gtl::ArraySlice argument_handles); @@ -92,10 +102,16 @@ class CompiledLocalComputation { class LocalComputation { public: LocalComputation(Computation computation); + StatusOr Compile( - const std::vector& argument_shapes); + const std::vector& argument_shapes, + const ExecutableBuildOptions* build_options); + const Computation& computation() const; + // Returns the return-value shape for this computation. + StatusOr GetReturnValueShape() const; + private: Computation computation_; }; @@ -122,6 +138,9 @@ class LocalComputationBuilder { std::unique_ptr GetShape(const ComputationDataHandle& operand); + // Returns the shape of the current return value for the computation. + StatusOr GetReturnValueShape(); + ComputationDataHandle Infeed(const Shape& shape); void Outfeed(const ComputationDataHandle& operand, const Shape& shape, @@ -172,10 +191,6 @@ class LocalComputationBuilder { const ComputationDataHandle& source, const ComputationDataHandle& init_value, const LocalComputation& scatter); - ComputationDataHandle Select(const ComputationDataHandle& pred, - const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false); - ComputationDataHandle Tuple( tensorflow::gtl::ArraySlice elements); @@ -185,6 +200,10 @@ class LocalComputationBuilder { ComputationDataHandle Dot(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs); + ComputationDataHandle DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers); + ComputationDataHandle ConvGeneralDilated( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice window_strides, @@ -239,6 +258,12 @@ class LocalComputationBuilder { const LocalComputation& body, const ComputationDataHandle& init); + ComputationDataHandle Conditional(const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const LocalComputation& true_computation, + const ComputationDataHandle& false_operand, + const LocalComputation& false_computation); + #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; @@ -252,6 +277,14 @@ class LocalComputationBuilder { (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ tensorflow::gtl::ArraySlice broadcast_dimensions)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + const ComputationDataHandle& ehs)) + + _FORWARD_TRIOP(Select) + _FORWARD_TRIOP(Clamp) _FORWARD_BINOP(Eq) _FORWARD_BINOP(Ne) _FORWARD_BINOP(Ge) @@ -272,6 +305,7 @@ class LocalComputationBuilder { _FORWARD_UNOP(Exp) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) + _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) @@ -288,6 +322,7 @@ class LocalComputationBuilder { #undef _FORWARD #undef _FORWARD_UNOP #undef _FORWARD_BINOP +#undef _FORWARD_TRIOP private: ComputationBuilder builder_; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 31789259609714e7d20247eec072e05a181715e6..456e341f877e529f7fc5ebc81d85862bfa291943 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -27,12 +27,14 @@ limitations under the License. // ArraySlice <- sequence of int // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray -// Shape <-> pair holding (dtype, dimensions) -// std::vector <- sequence of shape information pairs +// Shape -> pair holding (dtype, dimensions) +// <- object duck-typed as xla_client.Shape +// std::vector <- sequence of xla_client.Shape objects // PrimitiveType <- int // ArraySlice> <- sequence of int pairs // PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto +// DotDimensionNumbers proto <- corresponding Python proto // // Arrows indicate whether a conversion only ever occurs in one // direction, or whether it is maintained bidirectionally. @@ -55,7 +57,7 @@ limitations under the License. // translates to a tuple-shaped XLA Literal, whose component subshapes // are a 2x3 F32-shaped literal followed by two tuple-shaped literals. // -// The Python objects corresponding to C++ Shapes have the type: +// Shapes output by C++ become Python objects with the type: // // T = (dtype, S) // S = DIMENSIONS | TUPLE_SHAPES @@ -176,6 +178,16 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr< std::unique_ptr > { + if ($1.ok()) { + std::unique_ptr value = $1.ConsumeValueOrDie(); + $result = numpy::PyObjectFromXlaLiteral(*value); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + %typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); @@ -189,6 +201,15 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + %typemap(out) Status { if (!$1.ok()) { PyErr_SetString( @@ -343,15 +364,31 @@ tensorflow::ImportNumpy(); // Shape %typemap(in) const Shape& (Shape temp) { - Status shape_status = numpy::CheckPyShapeInfo($input); - if (!shape_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str()); + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); return NULL; } - temp = numpy::XlaShapeFromPyShapeInfo($input); + temp = std::move(statusor).ValueOrDie(); $1 = &temp; } +%typemap(in) const tensorflow::gtl::optional& ( + tensorflow::gtl::optional temp) { + if ($input == Py_None) { + temp = tensorflow::gtl::nullopt; + $1 = &temp; + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + return NULL; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; + } +} + %typemap(out) std::unique_ptr { $result = numpy::PyShapeInfoFromXlaShape(*$1); } @@ -364,14 +401,37 @@ tensorflow::ImportNumpy(); const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - Status shape_status = numpy::CheckPyShapeInfo(o); - if (!shape_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str()); - Py_DECREF(o); + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); return NULL; } - temps.push_back(numpy::XlaShapeFromPyShapeInfo(o)); - Py_DECREF(o); + temps.push_back(statusor.ConsumeValueOrDie()); + } + $1 = &temps; +} + +%typemap(in) const std::vector >& ( + std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (o == Py_None) { + temps.push_back(tensorflow::gtl::nullopt); + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + return NULL; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } } $1 = &temps; } @@ -461,6 +521,135 @@ tensorflow::ImportNumpy(); $1 = temps; } +// DotDimensionNumbers + +%typemap(in) const DotDimensionNumbers& + (DotDimensionNumbers dimension_numbers) { + int length; + + /* lhs_contracting_dimensions */ + PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( + $input, "lhs_contracting_dimensions"); + if (!lhs_contracting_dimensions) { + return NULL; + } + + length = PySequence_Size(lhs_contracting_dimensions); + if (length == -1) { + Py_DECREF(lhs_contracting_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); + if (!item) { + Py_DECREF(lhs_contracting_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(lhs_contracting_dimensions); + return NULL; + } + dimension_numbers.add_lhs_contracting_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(lhs_contracting_dimensions); + + /* rhs_contracting_dimensions */ + PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( + $input, "rhs_contracting_dimensions"); + if (!lhs_contracting_dimensions) { + return NULL; + } + + length = PySequence_Size(rhs_contracting_dimensions); + if (length == -1) { + Py_DECREF(rhs_contracting_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); + if (!item) { + Py_DECREF(rhs_contracting_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(rhs_contracting_dimensions); + return NULL; + } + dimension_numbers.add_rhs_contracting_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(rhs_contracting_dimensions); + + /* lhs_batch_dimensions */ + PyObject* lhs_batch_dimensions = PyObject_GetAttrString( + $input, "lhs_batch_dimensions"); + if (!lhs_batch_dimensions) { + return NULL; + } + + length = PySequence_Size(lhs_batch_dimensions); + if (length == -1) { + Py_DECREF(lhs_batch_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); + if (!item) { + Py_DECREF(lhs_batch_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(lhs_batch_dimensions); + return NULL; + } + dimension_numbers.add_lhs_batch_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(lhs_batch_dimensions); + + /* rhs_batch_dimensions */ + PyObject* rhs_batch_dimensions = PyObject_GetAttrString( + $input, "rhs_batch_dimensions"); + if (!rhs_batch_dimensions) { + return NULL; + } + + length = PySequence_Size(rhs_batch_dimensions); + if (length == -1) { + Py_DECREF(rhs_batch_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); + if (!item) { + Py_DECREF(rhs_batch_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(rhs_batch_dimensions); + return NULL; + } + dimension_numbers.add_rhs_batch_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(rhs_batch_dimensions); + + $1 = &dimension_numbers; +} + // PaddingConfig %typemap(in) const PaddingConfig& @@ -623,6 +812,45 @@ tensorflow::ImportNumpy(); $1 = &dimension_numbers; } +// ExecutableBuildOptions + +%typemap(in) const ExecutableBuildOptions* + (ExecutableBuildOptions build_options) { + if ($input == Py_None) { + $1 = NULL; + } else { + PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph"); + if (!o) { + return NULL; + } + if (o != Py_None) { + if (!PyString_Check(o)) { + PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None."); + return NULL; + } + build_options.set_generate_hlo_graph(PyString_AsString(o)); + } + Py_DECREF(o); + + o = PyObject_GetAttrString($input, "result_shape"); + if (o == nullptr) { + return nullptr; + } + if (o != Py_None) { + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); + Py_DECREF(o); + return NULL; + } + build_options.set_result_layout(statusor.ValueOrDie()); + } + Py_DECREF(o); + + $1 = &build_options; + } +} + %ignoreall %unignore xla; %unignore xla::swig; @@ -639,6 +867,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; +%unignore xla::swig::LocalComputation::GetReturnValueShape; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; @@ -646,6 +875,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; %unignore xla::swig::LocalComputationBuilder::Parameter; %unignore xla::swig::LocalComputationBuilder::GetShape; +%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; %unignore xla::swig::LocalComputationBuilder::Infeed; %unignore xla::swig::LocalComputationBuilder::Outfeed; %unignore xla::swig::LocalComputationBuilder::ConstantLiteral; @@ -667,6 +897,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Call; %unignore xla::swig::LocalComputationBuilder::Transpose; %unignore xla::swig::LocalComputationBuilder::Rev; +%unignore xla::swig::LocalComputationBuilder::Clamp; %unignore xla::swig::LocalComputationBuilder::Map; %unignore xla::swig::LocalComputationBuilder::Reduce; %unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; @@ -674,6 +905,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::RngUniform; %unignore xla::swig::LocalComputationBuilder::RngBernoulli; %unignore xla::swig::LocalComputationBuilder::While; +%unignore xla::swig::LocalComputationBuilder::Conditional; %unignore xla::swig::LocalComputationBuilder::Eq; %unignore xla::swig::LocalComputationBuilder::Ne; %unignore xla::swig::LocalComputationBuilder::Ge; @@ -681,6 +913,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Lt; %unignore xla::swig::LocalComputationBuilder::Le; %unignore xla::swig::LocalComputationBuilder::Dot; +%unignore xla::swig::LocalComputationBuilder::DotGeneral; %unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; %unignore xla::swig::LocalComputationBuilder::Add; %unignore xla::swig::LocalComputationBuilder::Sub; @@ -696,6 +929,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Exp; %unignore xla::swig::LocalComputationBuilder::Floor; %unignore xla::swig::LocalComputationBuilder::Ceil; +%unignore xla::swig::LocalComputationBuilder::Round; %unignore xla::swig::LocalComputationBuilder::Log; %unignore xla::swig::LocalComputationBuilder::Sign; %unignore xla::swig::LocalComputationBuilder::Cos; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 5c722623e318ece9eca6bdc8750195ce5fd5defb..3d87480728aab1d4ebbc71c6c7504d37cae5edaf 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -176,85 +176,107 @@ static string PyObjectCppRepr(PyObject* o) { return ExtractStringAndDecref(r); } -Status CheckPyShapeInfo(PyObject* o) { +StatusOr XlaShapeFromPyShape(PyObject* o) { auto error = [o](const string& prefix) { return InvalidArgument("%s; got %s", prefix.c_str(), PyObjectCppRepr(o).c_str()); }; - // The object is a tuple (a pair) - if (!PyTuple_Check(o)) { - return error("Shape record must be a tuple"); - } - if (PyTuple_Size(o) != 2) { - return error("Shape record tuple must be of length 2"); - } - // It has a first element, which is a numpy dtype object - PyObject* first = PyTuple_GetItem(o, 0); - if (first == nullptr) { - return error("Tuple has no item 0 (shape dtype)"); - } - if (first->ob_type != &PyArrayDescr_Type) { - return error( - "Shape record does not have a numpy dtype as its first element"); - } - const int np_type = NumpyTypenum(first); - if (!NumpyTypeIsValid(np_type)) { - return error("Shape record has an invalid integer dtype"); - } + auto get_attr = [o, &error](const string& field) -> StatusOr { + PyObject* result = + PyObject_GetAttrString(o, const_cast(field.c_str())); + if (result == nullptr) { + return error(tensorflow::strings::StrCat( + "Failed to get attribute of Shape object:", field)); + } + return result; + }; - // It has a second element, which is a tuple, either of shape - // records or of Python ints - PyObject* second = PyTuple_GetItem(o, 1); - if (!second) { - return error("Tuple has no item 0 (shape dimensions)"); - } - if (!PyTuple_Check(second)) { - return error("Shape record does not have a tuple as its second element"); - } - const int length = PyTuple_Size(second); - const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type); - for (int i = 0; i < length; i++) { - PyObject* dimension = PyTuple_GetItem(second, i); - if (element_type == TUPLE) { - VLOG(3) << "element_type is tuple, checking member: " << i; - Status result = CheckPyShapeInfo(dimension); - if (!result.ok()) { - return AddStatus( - result, tensorflow::strings::StrCat("Validating tuple member ", i, - " of ", PyObjectCppRepr(o))); - } - } else if (!CheckPyIntOrLong(dimension)) { - return error("Non-tuple shape record has a non-integer dimension"); + auto call_method = [o, &error](const string& method) -> StatusOr { + PyObject* result = + PyObject_CallMethod(o, const_cast(method.c_str()), nullptr); + if (result == nullptr) { + return error(tensorflow::strings::StrCat( + "Failed to call method of shape object:", method)); } - } + return result; + }; - return Status::OK(); -} + PyObject* np_type; + TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype")); + if (np_type->ob_type != &PyArrayDescr_Type) { + return error("Shape attribute np_dtype is not an integer numpy dtype"); + } + if (!NumpyTypeIsValid(NumpyTypenum(np_type))) { + return error("Shape attribute np_dtype is not a valid integer numpy dtype"); + } + const PrimitiveType element_type = + NumpyTypeToPrimitiveType(NumpyTypenum(np_type)); + Py_DECREF(np_type); -// Precondition: CheckPyShapeInfo(o) -Shape XlaShapeFromPyShapeInfo(PyObject* o) { - const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0)); - const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type); - PyObject* py_dimensions = PyTuple_GetItem(o, 1); - const int length = PyTuple_Size(py_dimensions); if (element_type == TUPLE) { + PyObject* py_subshapes; + TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes")); + if (!PyTuple_Check(py_subshapes)) { + return error( + "Return value of Shape method tuple_shapes() is not a tuple"); + } + const int length = PyTuple_Size(py_subshapes); std::vector subshapes; subshapes.reserve(length); for (int i = 0; i < length; i++) { - subshapes.push_back( - XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i))); + TF_ASSIGN_OR_RETURN( + const Shape& subshape, + XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i))); + subshapes.push_back(subshape); } + Py_DECREF(py_subshapes); return ShapeUtil::MakeTupleShape(subshapes); } else { + PyObject* py_dimensions; + PyObject* py_minor_to_major; + TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions")); + TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major")); + if (!PyTuple_Check(py_dimensions)) { + return error("Return value of Shape method dimensions() is not a tuple"); + } + if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) { + return error( + "Return value of Shape method minor_to_major() is neither a tuple " + "nor None"); + } + const int length = PyTuple_Size(py_dimensions); + if (py_minor_to_major != Py_None && + length != PyTuple_Size(py_minor_to_major)) { + return error( + "Shape methods dimensions() and minor_to_major() return " + "different-length tuples"); + } std::vector dimensions(length); + std::vector minor_to_major(length); for (int i = 0; i < length; i++) { dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i)); - if (dimensions[i] == -1) { - CHECK(!PyErr_Occurred()); + if (dimensions[i] == -1 && PyErr_Occurred()) { + return error("Dimension is not an int"); } + + if (py_minor_to_major != Py_None) { + minor_to_major[i] = + PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i)); + if (minor_to_major[i] == -1 && PyErr_Occurred()) { + return error("Minor-to-major value is not an int"); + } + } + } + bool with_layout = py_minor_to_major != Py_None; + Py_DECREF(py_dimensions); + Py_DECREF(py_minor_to_major); + if (with_layout) { + return ShapeUtil::MakeShapeWithLayout(element_type, dimensions, + minor_to_major); + } else { + return ShapeUtil::MakeShape(element_type, dimensions); } - return ShapeUtil::MakeShape(element_type, dimensions); } } diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 6ff1c34cfc5e0323a6729bdfd5572239f4966211..adfcc3b8588dce01718bb19dea936bace483be4d 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -56,15 +56,11 @@ bool NumpyTypeIsValid(int np_type); // The return value is a new reference. PyObject* PyShapeInfoFromXlaShape(const Shape& shape); -// Returns the outcome of a best-effort check that the Python object -// is a pair of the form (numpy dtype, dimensions), as produced by -// PyShapeInfoFromXlaShape. -Status CheckPyShapeInfo(PyObject* o); - -// Performs the inverse conversion to that of PyShapeInfoFromXlaShape. +// Converts a Python object with a method interface mathing that of +// xla_client.Shape into an XLA Shape object. // // The return value is a new reference. -Shape XlaShapeFromPyShapeInfo(PyObject* o); +StatusOr XlaShapeFromPyShape(PyObject* o); // Converts a PyObject that represents operation metadata into protocol buffer // form. diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 9cfe1249f50fd3c4b09d5af0c0e17a6f40b024a2..9bda9d09294bc75acaa35d8e4a512820046e8920 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -36,15 +36,22 @@ from tensorflow.compiler.xla.python import pywrap_xla as c_api # pylint: disable=invalid-name -OpMetadata = collections.namedtuple( - 'OpMetadata', - [ - 'op_type', - 'op_name', - 'source_file', - 'source_line', - ], -) +_OP_METADATA_FIELDS = [ + 'op_type', + 'op_name', + 'source_file', + 'source_line', +] +OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) + + +def OpMetadataToProto(pyobj): + proto = xla_data_pb2.OpMetadata() + for field in _OP_METADATA_FIELDS: + attr = getattr(pyobj, field) + if attr is not None: + setattr(proto, field, attr) + return proto def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): @@ -82,6 +89,7 @@ _UNARY_OPS = [ 'Abs', 'Exp', 'Floor', + 'Round', 'Ceil', 'Log', 'Sign', @@ -148,9 +156,14 @@ class LocalBuffer(object): self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_py(npval): + def from_py(npval, layout_fn=None): npval = require_numpy_array_layout(npval) - return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval)) + if layout_fn: + shape = Shape.from_numpy(npval) + shape = shape.map_leaves(layout_fn) + else: + shape = None + return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape)) def to_py(self): return self.c_local_shaped_buffer.ToLiteral() @@ -175,13 +188,23 @@ class Shape(object): represents an XLA tuple. """ - def __init__(self, np_dtype, dimensions): + def __init__(self, np_dtype, dimensions, minor_to_major=None): + assert isinstance(dimensions, tuple) self.np_dtype = np_dtype self._dimensions = dimensions + self._minor_to_major = minor_to_major + self._check_minor_to_major() + + def __eq__(self, other): + # pylint: disable=protected-access + return (self.np_dtype == other.np_dtype and + self._dimensions == other._dimensions and + self._minor_to_major == other._minor_to_major) def __repr__(self): - return 'xla_client.Shape(np_dtype={!r}, dimensions={!r})'.format( - self.np_dtype, self._dimensions) + return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, ' + 'minor_to_major={!r})').format(self.np_dtype, self._dimensions, + self._minor_to_major) def element_type(self): return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)] @@ -194,11 +217,49 @@ class Shape(object): raise ValueError('Tuple shape has no dimensions') return self._dimensions + def minor_to_major(self): + return self._minor_to_major + def tuple_shapes(self): if not self.is_tuple(): raise ValueError('Shape is not a tuple shape') return self._dimensions + def rank(self): + return len(self.dimensions()) + + def map_leaves(self, f): + """Map f over each leaf-level array subshape. + + Args: + f: The function to apply. Whenever f returns None, the identity is + applied instead. + + Returns: + A new Shape with the mapped leaves. + """ + if self.is_tuple(): + children = tuple(child.map_leaves(f) for child in self.tuple_shapes()) + return Shape(np.dtype('O'), children) + else: + mapped = f(self) + return self if mapped is None else mapped + + def _check_minor_to_major(self): + mtm = self._minor_to_major + if self.is_tuple(): + assert mtm is None, self + if mtm is not None: + assert self.rank() == len(mtm), self + assert sorted(mtm) == range(len(mtm)), self + + def update_minor_to_major(self, minor_to_major): + if not isinstance(minor_to_major, tuple): + raise TypeError('minor_to_major must be a tuple') + updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major) + updated._check_minor_to_major() # pylint: disable=protected-access + return updated + @staticmethod def from_numpy(npval): @@ -215,23 +276,10 @@ def _wrap_shape(shape_info): dtype, dims = shape_info element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] if element_type == xla_data_pb2.TUPLE: - dims = [_wrap_shape(subshape_info) for subshape_info in dims] + dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims) return Shape(dtype, dims) -def _unwrap_shape(shape): - if shape.is_tuple(): - components = tuple( - _unwrap_shape(subshape) for subshape in shape.tuple_shapes()) - else: - components = shape.dimensions() - return (shape.np_dtype, components) - - -def _unwrap_shapes(shapes): - return [_unwrap_shape(shape) for shape in shapes] - - def _wrap_data_handle(handle): cdh = xla_data_pb2.ComputationDataHandle() cdh.handle = handle @@ -253,6 +301,17 @@ def require_numpy_array_layout(value): return np.require(value, requirements=['C', 'A']) +class CompileOptions(object): + """Python object for XLA compile options. + + These options can be passed to the 'compile' step when using a local XLA + client. + """ + + def __init__(self): + self.generate_hlo_graph = None + + def transfer_to_infeed(value, replica_number=None): """Transfers the given value into the XLA infeed queue. @@ -284,8 +343,7 @@ def transfer_from_outfeed(shape, replica_number=None): Returns: The literal value that is produced from the outfeed queue. """ - return c_api.TransferFromOutfeedLocalReplica( - _unwrap_shape(shape), replica_number or 0) + return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0) class LocalComputation(object): @@ -302,26 +360,70 @@ class LocalComputation(object): # Ensure a reference to C-based destructor for use in __del__. if is_compiled: + assert isinstance(c_local_computation, c_api.CompiledLocalComputation) self._delete = c_api.DeleteCompiledLocalComputation else: + assert isinstance(c_local_computation, c_api.LocalComputation) self._delete = c_api.DeleteLocalComputation - def Compile(self, argument_shapes=()): + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): + """Compiles an un-compiled local computation. + + Local computations are the result of a "LocalComputationBuild'ing" process + -- they start in uncompiled form, and via a call to Compile() turn into a + compiled local computation. + + Raises: + ValueError: if this is already a compiled local computation. + + Arguments: + argument_shapes: parameter shapes -- they are first laid out by layout_fn + if layout_fn is provided. Otherwise, the default layout for those shapes + will be used. + compile_options: options to use for compilation, includes an optional + laid out result shape for the computation. + layout_fn: lambda that is used to lay out the argument/result shapes. + + Returns: + A newly *compiled* local computation instance. + """ if self.is_compiled: raise ValueError('Attempt to compile a compiled local XLA computation.') + + if layout_fn: + argument_shapes = [ + shape.map_leaves(layout_fn) for shape in argument_shapes + ] + result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) + result_shape = result_shape.map_leaves(layout_fn) + compile_options = compile_options or CompileOptions() + compile_options.result_shape = result_shape return LocalComputation( - self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)), + self.c_local_computation.Compile(argument_shapes, compile_options), is_compiled=True) - def CompileWithExampleArguments(self, arguments=()): + def CompileWithExampleArguments(self, + arguments=(), + compile_options=None, + layout_fn=None): return self.Compile( - argument_shapes=[Shape.from_numpy(arg) for arg in arguments]) + argument_shapes=[Shape.from_numpy(arg) for arg in arguments], + compile_options=compile_options, + layout_fn=layout_fn) - def Execute(self, arguments=()): + def Execute(self, arguments=(), layout_fn=None): + """Execute with Python values as arguments and return value.""" if not self.is_compiled: raise ValueError('Cannot execute an uncompiled local XLA computation.') + argument_shapes = [Shape.from_numpy(arg) for arg in arguments] + if layout_fn: + argument_shapes = [ + shape.map_leaves(layout_fn) for shape in argument_shapes + ] + else: + argument_shapes = [None for shape in argument_shapes] arguments = tuple(map(require_numpy_array_layout, arguments)) - return self.c_local_computation.Execute(arguments) + return self.c_local_computation.Execute(arguments, argument_shapes) def ExecuteWithLocalBuffers(self, arguments=()): """Execute with LocalBuffer arguments and return value.""" @@ -377,7 +479,7 @@ class ComputationBuilder(object): Returns: A ComputationDataHandle message. """ - return _wrap_data_handle(self._client.Infeed(_unwrap_shape(shape))) + return _wrap_data_handle(self._client.Infeed(shape)) def Outfeed(self, operand): """Enqueues an outfeed op onto the computation. @@ -386,7 +488,7 @@ class ComputationBuilder(object): outfeed queue for subsequent dequeue via the client API. """ self._client.Outfeed( - _unwrap_data_handle(operand), _unwrap_shape(self.GetShape(operand)), + _unwrap_data_handle(operand), self.GetShape(operand), ''.encode('utf-8')) def Constant(self, value): @@ -477,8 +579,7 @@ class ComputationBuilder(object): parameter_num = next(self._parameter_numbering) return _wrap_data_handle( - self._client.Parameter( - parameter_num, _unwrap_shape(shape), name.encode('utf8'))) + self._client.Parameter(parameter_num, shape, name.encode('utf8'))) def ParameterFromNumpy(self, value, name=None, parameter_num=None): """Enqueues a Parameter op onto the computation. @@ -538,6 +639,9 @@ class ComputationBuilder(object): def GetShape(self, operand): return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + def GetReturnValueShape(self): + return _wrap_shape(self._client.GetReturnValueShape()) + def GetComputationStats(self): raise NotImplementedError() @@ -599,6 +703,13 @@ class ComputationBuilder(object): return _wrap_data_handle( self._client.Rev(_unwrap_data_handle(operand), dimensions)) + def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin + """Clamp op.""" + return _wrap_data_handle( + self._client.Clamp(_unwrap_data_handle(min), + _unwrap_data_handle(operand), + _unwrap_data_handle(max))) + def SelectAndScatter(self, operand, select, window_dimensions, window_strides, padding, source, init_value, scatter): """Select and scatter op, used by the gradient of ReduceWindow. @@ -818,8 +929,7 @@ class ComputationBuilder(object): shape = Shape(self.GetShape(mu).np_dtype, dims) return _wrap_data_handle( self._client.RngNormal( - _unwrap_data_handle(mu), _unwrap_data_handle(sigma), - _unwrap_shape(shape))) + _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) def RngUniform(self, a, b, dims): """Enqueues an RngUniform operation onto the computation. @@ -839,8 +949,7 @@ class ComputationBuilder(object): shape = Shape(self.GetShape(a).np_dtype, dims) return _wrap_data_handle( self._client.RngUniform( - _unwrap_data_handle(a), _unwrap_data_handle(b), - _unwrap_shape(shape))) + _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) def While(self, cond, body, init): """Enqueues a While operation onto the computation. @@ -848,7 +957,7 @@ class ComputationBuilder(object): Args: cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T - init: an ComputationDataHandle for the initial parameter, which has type T + init: a ComputationDataHandle for the initial parameter, which has type T Returns: a ComputationDataHandle representing the While operation. """ @@ -857,11 +966,58 @@ class ComputationBuilder(object): body.c_local_computation, _unwrap_data_handle(init))) + def Conditional(self, pred, true_operand, true_computation, false_operand, + false_computation): + """Enqueues a Conditional operation onto the computation. + + Args: + predicate: a ComputationDataHandle to test, which has scalar type PRED + true_operand: a ComputationDataHandle of type T_0 + true_computation: a Computation to apply to true_operand, type T_0 -> S + false_operand: a ComputationDatahandle of type T_1 + false_computation: a Computation to apply to false_operand, type T_1 -> S + + Returns: a ComputationDataHandle representing the Conditional operation. + """ + return _wrap_data_handle( + self._client.Conditional( + _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), + true_computation.c_local_computation, + _unwrap_data_handle(false_operand), + false_computation.c_local_computation)) + def Dot(self, lhs, rhs): - """Matrix multiplication between lhs and rhs.""" + """Enqueues a dot operation onto the computation. + + Args: + lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. + rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + + Returns: a ComputationDataHandle representing the Dot operation. + """ return _wrap_data_handle( self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + def DotGeneral(self, lhs, rhs, dimension_numbers): + """Enqueues a general dot operation onto the computation. + + Args: + lhs: ComputationDataHandle for the left-hand-side array. + rhs: ComputationDataHandle for the right-hand-side array. + dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested + tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of + integers representing the dimensions to treat as contracting dimensions + and batch dimensions on each input operand. + + Returns: a ComputationDataHandle representing the DotGeneral operation. + """ + if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): + dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) + return _wrap_data_handle( + self._client.DotGeneral( + _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), + dimension_numbers)) + def Conv(self, lhs, rhs, window_strides, padding): """Enqueues a Conv operation onto the computation. @@ -972,7 +1128,7 @@ def initialize_replica_count(replica_count): Args: replica_count: number of replicas that are desired for set up during XLA - initalization. + initialization. Raises: A runtime exception if the XLA service has already been initialized. @@ -998,3 +1154,13 @@ def GetPaddingConfigFromTriples(triples): dimension.edge_padding_high = hi dimension.interior_padding = interior return padding_config + + +def GetDotDimensionsFromLists(dimension_numbers): + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + dot_dims_proto = xla_data_pb2.DotDimensionNumbers() + dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) + dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) + dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) + dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) + return dot_dims_proto diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c0413b9bbc3b7f8b63e4cf7a8f24980322cffc47..c9d09cd5d57e001fd48d2dba9f2b0ee18374231b 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -86,7 +86,8 @@ class ComputationsWithConstantsTest(LocalComputationTest): def testConstantScalarSumF32(self): c = self._NewComputation() - c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) self._ExecuteAndCompareClose(c, expected=4.25) def testConstantScalarSumF64(self): @@ -444,6 +445,30 @@ class SingleOpTest(LocalComputationTest): c.Dot(c.Constant(lhs), c.Constant(rhs)) self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + def testDotGeneral(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = (([2], [1]), ([0], [0])) + c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) + + def testDotGeneralWithDotDimensionNumbersProto(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + + dimension_numbers = xla_client.xla_data_pb2.DotDimensionNumbers() + dimension_numbers.lhs_contracting_dimensions.append(2) + dimension_numbers.rhs_contracting_dimensions.append(1) + dimension_numbers.lhs_batch_dimensions.append(0) + dimension_numbers.rhs_batch_dimensions.append(0) + + c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) + def testConvF32Same(self): c = self._NewComputation() a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") @@ -496,6 +521,12 @@ class SingleOpTest(LocalComputationTest): c.Exp(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + def testRound(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Round(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.round(arr)) + def testLog(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -699,6 +730,23 @@ class SingleOpTest(LocalComputationTest): self._ExecuteAndCompareExact( c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]) + def testClampF32(self): + c = self._NewComputation() + c.Clamp( + c.Constant(NumpyArrayF32(-1)), + c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])), + c.Constant(NumpyArrayF32(2))) + self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) + + # TODO(b/72689392): re-enable when bug S32 resolved + def DISABLED_testClampS32(self): + c = self._NewComputation() + c.Clamp( + c.Constant(NumpyArrayS32(-1)), + c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])), + c.Constant(NumpyArrayS32(2))) + self._ExecuteAndCompareExact(c, expected=[-1, 0, 1, 2, 2]) + def testSelect(self): c = self._NewComputation() c.Select( @@ -834,6 +882,13 @@ class EmbeddedComputationsTest(LocalComputationTest): c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0)) return c.Build() + def _CreateMulF32ByParamComputation(self): + """Computation (f32) -> f32 that multiplies one parameter by the other.""" + c = self._NewComputation("mul_f32_by_param") + c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + def _CreateMulF64By2Computation(self): """Computation (f64) -> f64 that multiplies its parameter by 2.""" c = self._NewComputation("mul_f64_by2") @@ -974,6 +1029,14 @@ class EmbeddedComputationsTest(LocalComputationTest): self._CreateBinaryDivF64Computation(), [0]) self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + def DISABLED_testMapWithStaticOperands(self): + c = self._NewComputation() + factor = c.ConstantF32Scalar(3.0) + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF32ByParamComputation(), [0], + static_operands=[factor]) + self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0]) + def testSelectAndScatterF32(self): c = self._NewComputation() c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), @@ -1172,6 +1235,28 @@ class EmbeddedComputationsTest(LocalComputationTest): c.While(cond, body, init) self._ExecuteAndCompareClose(c, expected=16.) + def testConditionalTrue(self): + c = self._NewComputation() + pred = c.ConstantPredScalar(True) + true_operand = c.ConstantF32Scalar(3.) + true_computation = self._CreateMulF32By2Computation() + false_operand = c.ConstantF32Scalar(2.) + false_computation = self._CreateConstantF32Computation() + c.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=6.) + + def testConditionalFalse(self): + c = self._NewComputation() + pred = c.ConstantPredScalar(False) + true_operand = c.ConstantF32Scalar(3.) + true_computation = self._CreateMulF32By2Computation() + false_operand = c.ConstantF32Scalar(2.) + false_computation = self._CreateConstantF32Computation() + c.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=1.) + def testInfeedS32Values(self): to_infeed = NumpyArrayS32([1, 2, 3, 4]) c = self._NewComputation() diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 9a0acda94fb08ee0accfba6c5380f628c07ebaa2..4a076ac0909d24f6c7355a323d4b78151d3fe2ac 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -43,6 +43,81 @@ filegroup( ]), ) +cc_library( + name = "bfloat16_support", + srcs = ["bfloat16_support.cc"], + hdrs = ["bfloat16_support.h"], + deps = [ + ":hlo", + ], +) + +cc_library( + name = "bfloat16_conversion_folding", + srcs = ["bfloat16_conversion_folding.cc"], + hdrs = ["bfloat16_conversion_folding.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_conversion_folding_test", + srcs = ["bfloat16_conversion_folding_test.cc"], + deps = [ + ":bfloat16_conversion_folding", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "bfloat16_normalization", + srcs = ["bfloat16_normalization.cc"], + hdrs = ["bfloat16_normalization.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_normalization_test", + srcs = ["bfloat16_normalization_test.cc"], + deps = [ + ":bfloat16_normalization", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], @@ -70,7 +145,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", ], ) @@ -460,6 +536,7 @@ cc_library( ":hlo_proto_util", ":platform_util", ":session_proto", + ":source_map_util", ":transfer_manager", ":user_computation", ":versioned_computation_handle", @@ -508,6 +585,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -641,6 +719,7 @@ cc_library( hdrs = ["llvm_compiler.h"], deps = [ ":compiler", + "//tensorflow/core:lib_internal", "@llvm//:core", ], ) @@ -1109,8 +1188,6 @@ cc_library( ":hlo", ":hlo_evaluator", ":hlo_pass", - ":tuple_util", - ":while_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", ], @@ -1155,6 +1232,34 @@ tf_cc_test( ], ) +cc_library( + name = "implicit_broadcast_remover", + srcs = ["implicit_broadcast_remover.cc"], + hdrs = ["implicit_broadcast_remover.h"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "implicit_broadcast_remover_test", + srcs = ["implicit_broadcast_remover_test.cc"], + deps = [ + ":hlo_matchers", + ":implicit_broadcast_remover", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + ], +) + cc_library( name = "dot_decomposer", srcs = ["dot_decomposer.cc"], @@ -1824,7 +1929,9 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -1855,6 +1962,7 @@ cc_library( ":hlo", ":hlo_graph_dumper", ":hlo_pass", + ":hlo_proto_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -2348,6 +2456,18 @@ tf_cc_test( ], ) +cc_library( + name = "source_map_util", + srcs = ["source_map_util.cc"], + hdrs = ["source_map_util.h"], + deps = [ + ":executable", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index ba82e822b216528c28536181059bc2417048de01..fb857559f972a220a19b108baa4c441e09b90e1f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1618,9 +1618,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { reduce, HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); } + // A Transpose feeding a reduce can simply permute the reduction dimensions - // field. - if (arg->opcode() == HloOpcode::kTranspose) { + // field if the output of the reduce is a vector or scalar. Higher ranked + // result may require a transpose of the output. + if (ShapeUtil::Rank(reduce->shape()) <= 1 && + arg->opcode() == HloOpcode::kTranspose) { auto transpose_dimensions = arg->dimensions(); std::vector new_reduce_dimensions; for (auto dim : dimensions) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index e43ea50af45318adf2c95aa69b3e53a5225c5579..0f08eb3a3267c4b7b04958270a5788fc48d3fa04 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -61,13 +61,12 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -83,13 +82,12 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Constant())); } @@ -110,13 +108,12 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); } @@ -133,13 +130,12 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -156,13 +152,12 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -178,13 +173,12 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -200,13 +194,12 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kSubtract, param0, constant)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); } @@ -226,15 +219,14 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Divide(param0, param1), param2)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Multiply(param1, param2))); @@ -255,15 +247,14 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Divide(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Multiply(param0, param2), param1)); @@ -289,8 +280,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), @@ -298,7 +288,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -320,15 +310,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Exp(op::Negate(param1)))); @@ -349,15 +338,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -380,15 +368,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -411,12 +398,11 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Divide(op::Constant(), constant))); @@ -438,11 +424,10 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, inner_power, exp2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Power(base, op::Multiply(exp1, exp2))); } @@ -451,24 +436,23 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { // numbers. TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { Shape r0c64 = ShapeUtil::MakeShape(C64, {}); - Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); + Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction::CreateParameter(0, r1c64, "param0")); HloInstruction* exp1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r0c64, "param1")); HloInstruction* exp2 = builder.AddInstruction( HloInstruction::CreateParameter(2, r0c64, "param2")); HloInstruction* inner_power = builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); - builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, + HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1)); + builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, inner_power, exp2)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); } // Test that A/1 is simplified to A for a scalar. @@ -482,13 +466,12 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -504,13 +487,12 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -529,13 +511,12 @@ TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { HloInstruction* cplx = builder.AddInstruction( HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, cplx); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -554,13 +535,12 @@ TEST_F(AlgebraicSimplifierTest, RealOfComplex) { HloInstruction* real = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, real); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -579,13 +559,12 @@ TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { HloInstruction* imag = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, imag); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param1); } @@ -607,13 +586,12 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param1, param2)); } @@ -633,15 +611,14 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Exp(param0), op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Subtract(param0, param1))); @@ -662,15 +639,14 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Exp(param0), op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Add(param0, param1))); @@ -689,15 +665,14 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(op::Exp(param0), param1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Multiply(param0, param1))); @@ -716,15 +691,14 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Power(param0, param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Log(param0), param1)); @@ -741,14 +715,13 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } @@ -770,15 +743,14 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); } @@ -795,14 +767,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); @@ -820,14 +791,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast()); @@ -849,14 +819,13 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } @@ -872,14 +841,13 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); } @@ -895,14 +863,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); @@ -941,16 +908,15 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { dim->set_base_dilation(1); dim->set_window_reversal(false); // Create add computation. - std::unique_ptr module = CreateNewModule(); builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); - module->AddEntryComputation(builder.Build()); + module().AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Convolution(lhs, rhs)); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } @@ -969,7 +935,6 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { dim->set_base_dilation(1); } // Create add computation. - std::unique_ptr module = CreateNewModule(); HloComputation* add_computation = nullptr; { HloComputation::Builder builder(TestName() + ".add"); @@ -980,20 +945,20 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module->AddEmbeddedComputation(builder.Build()); + add_computation = module().AddEmbeddedComputation(builder.Build()); } builder.AddInstruction(HloInstruction::CreateReduceWindow( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), window, add_computation)); - module->AddEntryComputation(builder.Build()); + module().AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + EXPECT_THAT(module().entry_computation()->root_instruction(), op::ReduceWindow(param, op::Constant())); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } @@ -1014,14 +979,13 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), padding)); - std::unique_ptr module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + module().AddEntryComputation(builder.Build()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Pad(param, op::Constant())); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } @@ -1039,17 +1003,16 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); auto computation = builder.Build(); - auto module = CreateNewModule(); - module->AddEntryComputation(std::move(computation)); + module().AddEntryComputation(std::move(computation)); - EXPECT_THAT(module->entry_computation()->root_instruction(), + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reshape(op::Broadcast(op::Reshape(op)))); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), op); + EXPECT_THAT(module().entry_computation()->root_instruction(), op); } // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. @@ -1060,14 +1023,13 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), input); } @@ -1081,14 +1043,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } @@ -1102,14 +1063,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { builder.AddInstruction( HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } @@ -1132,8 +1092,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), @@ -1141,7 +1100,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0, param0, param1)); @@ -1163,15 +1122,14 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, empty_slice}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(empty_literal, empty_slice)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), empty_literal); } @@ -1188,14 +1146,13 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { HloInstruction* broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r1f32, param1, {})); builder.AddInstruction(HloInstruction::CreateConcatenate( - param0->shape(), {broadcast, param0}, 0)); + ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); } @@ -1209,8 +1166,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); // Set to different layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -1220,7 +1176,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); // Copy has not been removed. EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); @@ -1236,8 +1192,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); // Set to same layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -1247,7 +1202,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); // Copy has been removed. EXPECT_THAT(computation->root_instruction(), param0); @@ -1268,14 +1223,13 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { *reshape->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); // Reshape is not replaced with a bitcast. EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); @@ -1314,8 +1268,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { builder.AddInstruction(HloInstruction::CreateTuple( {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Tuple(transformable_reshape, dimensions_wrong_reshape, @@ -1323,7 +1276,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - simplifier.Run(module.get()).ValueOrDie(); + simplifier.Run(&module()).ValueOrDie(); // Verify that only the first reshape is replaced. EXPECT_THAT( @@ -1344,8 +1297,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), HloOpcode::kMaximum, movable_reshape, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Reshape(param), zero)); @@ -1353,7 +1305,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - simplifier.Run(module.get()).ValueOrDie(); + simplifier.Run(&module()).ValueOrDie(); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Maximum(param, zero))); } @@ -1371,8 +1323,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { HloInstruction::CreateConstant(Literal::CreateR1({1., 2., 3.}))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Reshape(param), zero)); @@ -1380,7 +1331,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - simplifier.Run(module.get()).ValueOrDie(); + simplifier.Run(&module()).ValueOrDie(); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Reshape(param), zero)); @@ -1405,9 +1356,8 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + module().AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); } // Regression test for a bug where if we failed to sink a reshape, we'd set the @@ -1424,14 +1374,14 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2({{0, 0}, {0, 0}}))))); - builder.AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0})); + builder.AddInstruction( + HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, + /*broadcast_dimensions=*/{0, 1})); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + module().AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { @@ -1448,14 +1398,13 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); @@ -1475,14 +1424,13 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({3, 1, 2, 0}); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); @@ -1501,15 +1449,14 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } @@ -1529,14 +1476,13 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}), HloOpcode::kCopy, copy1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); } @@ -1554,14 +1500,13 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); EXPECT_EQ(std::vector({2, 1, 0}), @@ -1576,17 +1521,16 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 5, 1}), param0)); builder.AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3})); + ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } @@ -1601,15 +1545,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } @@ -1623,15 +1566,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); @@ -1646,15 +1588,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); EXPECT_THAT(computation->root_instruction()->dimensions(), @@ -1670,15 +1611,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); const std::vector broadcast_dims = @@ -1696,15 +1636,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); @@ -2410,12 +2349,11 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); - auto module = CreateNewModule(); - module->AddEmbeddedComputation(std::move(dot_computation)); - module->AddEntryComputation(call_builder.Build()); + module().AddEmbeddedComputation(std::move(dot_computation)); + module().AddEntryComputation(call_builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); } // Test that a constant with tuple shape becomes a tuple of constants. @@ -2428,12 +2366,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { Literal::CreateR1(constant_vector).get()}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } @@ -2453,11 +2390,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))), /*slice_sizes=*/{10, 100, 1000})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Parameter()); } @@ -2487,11 +2423,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))))); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::DynamicSlice(op::Parameter(), op::Parameter())); } @@ -2554,15 +2489,16 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { PaddingConfig padding = window_util::MakeSymmetricPadding( decorate_spatials(param.symmetric_pad_spatials, 0, 0)); + TF_ASSERT_OK_AND_ASSIGN( + const Shape pad_shape, + ShapeInference::InferPadShape(input->shape(), + ShapeUtil::MakeShape(F32, {}), padding)); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( - ShapeUtil::MakeShape( - F32, decorate_spatials(param.reduce_window_spatials, 128, 2048)), - input, + pad_shape, input, builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), padding)); - std::unique_ptr module = CreateNewModule(); HloComputation* add_computation = nullptr; { HloComputation::Builder builder(TestName() + ".add"); @@ -2573,24 +2509,24 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module->AddEmbeddedComputation(builder.Build()); + add_computation = module().AddEmbeddedComputation(builder.Build()); } - TF_ASSERT_OK_AND_ASSIGN( - const Shape output_shape, - ShapeInference::InferPadShape(input_shape, ShapeUtil::MakeShape(F32, {}), - padding)); Window window = window_util::MakeWindow( decorate_spatials(param.reduce_window_spatials, 1, 1)); auto zero = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, + ShapeInference::InferReduceWindowShape( + pad->shape(), zero->shape(), window, + add_computation->ComputeProgramShape())); builder.AddInstruction(HloInstruction::CreateReduceWindow( output_shape, pad, zero, window, add_computation)); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -2667,11 +2603,10 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { dot_dnums.add_rhs_contracting_dimensions(0); builder.AddInstruction( HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module())); const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; const bool computation_should_be_modified = dot_should_be_transformed || (transpose_lhs && transpose_rhs); @@ -2699,7 +2634,7 @@ struct DotOfConcatTestSpec { }; class DotOfConcatSimplificationTest - : public HloTestBase, + : public HloVerifiedTestBase, public ::testing::WithParamInterface {}; // Test that we transform @@ -2745,11 +2680,10 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { builder.AddInstruction( HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -2790,17 +2724,17 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { HloInstruction* lhs2 = builder.AddInstruction( HloInstruction::CreateParameter(2, lhs2_shape, "lhs2")); HloInstruction* lhs3 = builder.AddInstruction( - HloInstruction::CreateParameter(3, lhs2_shape, "lhs3")); + HloInstruction::CreateParameter(3, lhs3_shape, "lhs3")); Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateConcatenate( lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1)); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.m}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); auto* rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( - /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.m))); + /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n))); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -2810,11 +2744,10 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { builder.AddInstruction( HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc new file mode 100644 index 0000000000000000000000000000000000000000..cde990e176ddb57a8e93ecc3c60260b2dbae32a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -0,0 +1,184 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { + public: + explicit BFloat16ConversionFoldingVisitor( + HloComputation* computation, const BFloat16Support* bfloat16_support) + : computation_(computation), bfloat16_support_(bfloat16_support) {} + + Status DefaultAction(HloInstruction* hlo) override; + + static bool Run(HloComputation* computation, + const BFloat16Support* bfloat16_support) { + BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; + } + + private: + // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16 + // conversion as output, and folds them to the HLO itself if feasible. + Status TryFoldBF16Conversions(HloInstruction* hlo); + + // Folds the F32 -> BF16 conversions from the HLO's output. + // + // Precondition: all of the HLO's users are F32 -> BF16 conversions. + Status FoldOutputConversions(HloInstruction* hlo); + + // Folds the BF16 -> F32 conversion operand to the HLO. + // + // Precondition: the operand is a F32 -> BF16 conversion. + Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index); + + HloComputation* computation_; + const BFloat16Support* bfloat16_support_; + bool changed_ = false; +}; + +Status BFloat16ConversionFoldingVisitor::FoldOutputConversions( + HloInstruction* hlo) { + std::vector materialized_users = hlo->users(); + hlo->mutable_shape()->set_element_type(BF16); + for (auto user : materialized_users) { + CHECK_EQ(user->opcode(), HloOpcode::kConvert); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); + changed_ = true; + } + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::FoldOperandConversion( + HloInstruction* hlo, int64 operand_index) { + // The operand is a convert from BF16 to F32. + auto operand = hlo->mutable_operand(operand_index); + CHECK_EQ(operand->opcode(), HloOpcode::kConvert); + TF_RETURN_IF_ERROR( + hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0))); + changed_ = true; + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( + HloInstruction* hlo) { + std::vector bf16_to_f32_operands; + bool has_other_f32_operands = false; + for (int64 i = 0; i < hlo->operands().size(); ++i) { + auto operand = hlo->operand(i); + if (operand->shape().element_type() == F32) { + if (operand->opcode() == HloOpcode::kConvert && + operand->operand(0)->shape().element_type() == BF16 && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + // Operand is a convert from BF16 to F32 and we support BF16 input + // directly in the current HLO at the operand index. + bf16_to_f32_operands.push_back(i); + } else { + has_other_f32_operands = true; + } + continue; + } + } + + bool fold_output_conversion = hlo->user_count() > 0 && + hlo->shape().element_type() == F32 && + bfloat16_support_->SupportsBF16Output(*hlo) && + hlo != computation_->root_instruction(); + if (fold_output_conversion) { + for (auto user : hlo->users()) { + if (user->opcode() == HloOpcode::kConvert && + user->shape().element_type() == BF16) { + continue; + } + // We should not change the output type if any user is not a conversion + // from F32 to BF16. + fold_output_conversion = false; + break; + } + } + + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) { + if (has_other_f32_operands || + (!fold_output_conversion && hlo->shape().element_type() == F32)) { + // Some of the operands/output will remain F32, but we cannot use mixed + // precisions, so we cannot do anything here. + return Status::OK(); + } + } + + if (fold_output_conversion) { + TF_RETURN_IF_ERROR(FoldOutputConversions(hlo)); + } + + for (int64 i : bf16_to_f32_operands) { + TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i)); + } + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { + // Do not fold BF16 conversions for instructions related to tuples, entry and + // exit of a computation, fusion, convert, and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional) { + return Status::OK(); + } + if (hlo == computation_->root_instruction() && + !bfloat16_support_->SupportsMixedPrecisions(*hlo)) { + // If hlo is the root instruction, we cannot change its output, so folding + // can only happen when it supports mixed precision so that we can change + // its operands. + return Status::OK(); + } + return TryFoldBF16Conversions(hlo); +} + +StatusOr BFloat16ConversionFolding::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) { + changed = true; + } + } + XLA_VLOG_LINES( + 2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h new file mode 100644 index 0000000000000000000000000000000000000000..c9398387098fad84ba28735c30e426fedd9b0cb0 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which folds F32 <-> BF16 conversions to their operands or users, when +// it is supported by the backend. +// +// This pass follows the passed-in backend-specific BF16 support rules, but can +// introduce mixed precision in individual HLOs which breaks the assumption of +// some other HLO passes. So it should be used at the end of the HLO +// optimization pipeline followed by a DCE pass. If other passes are needed +// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the +// changed made by this pass. +class BFloat16ConversionFolding : public HloPassInterface { + public: + explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + + ~BFloat16ConversionFolding() override = default; + tensorflow::StringPiece name() const override { return "bfloat16-fold"; } + + // Run BF16 conversion folding on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + const BFloat16Support* bfloat16_support_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb37759439debf41a305ec7dccaa548e1bf234cd --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -0,0 +1,209 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } +}; + +class BFloat16ConversionFoldingTest : public HloTestBase { + protected: + bool FoldConversions(HloModule* module) { + TestBFloat16Support bfloat16_support_; + BFloat16ConversionFolding fold(&bfloat16_support_); + StatusOr result = fold.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c)); + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), add1); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), BF16); + EXPECT_EQ(add1->operand(0), add0); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* mul0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_shape, HloOpcode::kMultiply, convert1, c)); + HloInstruction* convert2 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert2); + EXPECT_EQ(mul0->shape().element_type(), F32); + EXPECT_EQ(mul1->shape().element_type(), F32); + EXPECT_EQ(mul1->operand(0), convert1); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* sub0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_shape, HloOpcode::kSubtract, convert1, c)); + HloInstruction* convert2 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert2); + EXPECT_EQ(sub0->shape().element_type(), F32); + EXPECT_EQ(sub1->shape().element_type(), F32); + EXPECT_EQ(sub1->operand(0), convert1); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b)); + + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({a, convert0})); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0)); + HloInstruction* convert1 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert1); + EXPECT_EQ(gte->shape().element_type(), F32); + EXPECT_EQ(tuple->operand(1), convert0); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc new file mode 100644 index 0000000000000000000000000000000000000000..b032c040e8aff49f9e0fc1ff9a1c1e79ea4bb77f --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -0,0 +1,351 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { + public: + explicit BFloat16NormalizationVisitor(HloComputation* computation, + const BFloat16Support* bfloat16_support) + : computation_(computation), bfloat16_support_(bfloat16_support) {} + + Status DefaultAction(HloInstruction* hlo) override; + + // Special handling for cross-replica-sum which can have a tuple output. + Status HandleCrossReplicaSum(HloInstruction* crs) override; + + static bool Run(HloComputation* computation, + const BFloat16Support* bfloat16_support) { + BFloat16NormalizationVisitor visitor(computation, bfloat16_support); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; + } + + private: + // Checks if the HLO uses BF16 in an unsupported way, and if so, inserts + // conversions between F32 and BF16 to make it supported. + Status HandleInstruction(HloInstruction* hlo); + + // Inserts a conversion HLO that changes the given HLO's output type. + Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to, + HloComputation* computation); + + // Changes the output type to the specified type, then inserts a conversion + // to the original type. + Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo, + PrimitiveType to, + HloComputation* computation); + + // Inserts a conversion HLO that changes the given HLO's operand type. + Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx, + PrimitiveType to, + HloComputation* computation); + + // Inserts conversion HLOs to replace the called computations' BF16 + // operands/outputs to F32. + Status ConvertCalledComputations( + HloInstruction* hlo, + tensorflow::gtl::ArraySlice bf16_called_comps); + + HloComputation* computation_; + const BFloat16Support* bfloat16_support_; + bool changed_ = false; +}; + +Status BFloat16NormalizationVisitor::InsertConvertAfterOutput( + HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { + bool is_root = computation->root_instruction() == hlo; + std::vector materialized_users = hlo->users(); + // Use inst's shape temporarily, in order to pass checks in ReplaceUseWith. + auto convert = computation->AddInstruction( + HloInstruction::CreateConvert(hlo->shape(), hlo)); + for (auto* user : materialized_users) { + TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert)); + } + if (is_root) { + computation->set_root_instruction(convert); + } + convert->mutable_shape()->set_element_type(to); + changed_ = true; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( + HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { + auto original_type = hlo->shape().element_type(); + hlo->mutable_shape()->set_element_type(to); + return InsertConvertAfterOutput(hlo, original_type, computation); +} + +Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( + HloInstruction* hlo, int64 operand_idx, PrimitiveType to, + HloComputation* computation) { + auto operand = hlo->mutable_operand(operand_idx); + auto convert = computation->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(operand->shape(), to), operand)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert)); + changed_ = true; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::ConvertCalledComputations( + HloInstruction* hlo, + tensorflow::gtl::ArraySlice bf16_called_comps) { + std::map cloned_computations; + for (auto& comp : bf16_called_comps) { + auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone()); + cloned_computations[comp] = cloned; + changed_ = true; + } + hlo->ReplaceCalledComputations([&](HloComputation* comp) { + auto it = cloned_computations.find(comp); + if (it != cloned_computations.end()) { + return it->second; + } + return comp; + }); + for (auto& comp_pair : cloned_computations) { + auto comp = comp_pair.second; + if (comp->root_instruction()->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + InsertConvertAfterOutput(comp->root_instruction(), F32, comp)); + } + for (auto* param : comp->parameter_instructions()) { + if (param->shape().element_type() == BF16) { + // This changes the parameter to F32 then inserts a convert after it. + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(param, F32, comp)); + } + } + } + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( + HloInstruction* crs) { + if (!ShapeUtil::IsTuple(crs->shape())) { + return HandleInstruction(crs); + } + + std::vector operand_types(crs->operand_count()); + std::vector output_types(crs->operand_count()); + bool has_f32 = false; + bool has_bf16 = false; + bool has_bf16_output = false; + for (int64 i = 0; i < crs->operand_count(); ++i) { + operand_types[i] = crs->operand(i)->shape().element_type(); + output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); + if (operand_types[i] == F32 || output_types[i] == F32) { + has_f32 = true; + } else if (operand_types[i] == BF16) { + has_bf16 = true; + } + if (output_types[i] == BF16) { + has_bf16 = true; + has_bf16_output = true; + } + } + + for (int64 i = 0; i < crs->operand_count(); ++i) { + if (operand_types[i] != BF16) { + continue; + } + if (bfloat16_support_->SupportsBF16Operand(*crs, i) && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + continue; + } + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + has_f32 = true; + } + + if (!has_bf16_output) { + return Status::OK(); + } + + if (bfloat16_support_->SupportsBF16Output(*crs) && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + return Status::OK(); + } + + std::vector output_elements(crs->operand_count()); + auto original_shape = crs->shape(); + for (int64 i = 0; i < crs->operand_count(); ++i) { + auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}); + if (output_types[i] != BF16) { + output_elements[i] = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + continue; + } + subshape->set_element_type(F32); + auto gte = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + output_elements[i] = + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(*subshape, BF16), gte)); + } + auto tuple = computation_->AddInstruction( + HloInstruction::CreateTuple(output_elements)); + + std::vector materialized_users = crs->users(); + // Use the crs' shape temporarily, in order to pass checks in + // ReplaceUseWith. + *tuple->mutable_shape() = crs->shape(); + for (auto* user : materialized_users) { + TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple)); + } + *tuple->mutable_shape() = original_shape; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { + std::vector bf16_operands; + std::vector f32_operands; + bool has_f32 = false; + bool has_bf16 = false; + + for (int64 i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == F32) { + f32_operands.push_back(i); + has_f32 = true; + } else if (hlo->operand(i)->shape().element_type() == BF16) { + bf16_operands.push_back(i); + has_bf16 = true; + } + } + + if (hlo->shape().element_type() == F32) { + has_f32 = true; + } else if (hlo->shape().element_type() == BF16) { + has_bf16 = true; + } + + std::vector bf16_called_comps; + for (auto* comp : hlo->called_computations()) { + bool comp_has_bf16 = false; + if (comp->root_instruction()->shape().element_type() == F32) { + has_f32 = true; + } else if (comp->root_instruction()->shape().element_type() == BF16) { + has_bf16 = true; + comp_has_bf16 = true; + } + for (auto* param : comp->parameter_instructions()) { + if (param->shape().element_type() == F32) { + has_f32 = true; + } else if (param->shape().element_type() == BF16) { + has_bf16 = true; + comp_has_bf16 = true; + } + } + if (comp_has_bf16) { + bf16_called_comps.push_back(comp); + } + } + + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 && + has_f32) { + // Resolve unsupported mixed precision. + // + // See if we can change everything to BF16. + if (hlo->called_computations().empty() && + hlo->shape().element_type() == BF16) { + bool can_use_bf16 = true; + for (int i : f32_operands) { + if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, + i) && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + continue; + } + can_use_bf16 = false; + break; + } + if (can_use_bf16) { + for (int i : f32_operands) { + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, BF16, computation_)); + } + return Status::OK(); + } + } + if (hlo->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + for (int i : bf16_operands) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + return ConvertCalledComputations(hlo, bf16_called_comps); + } + + for (int i : bf16_operands) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + } + + if (hlo->shape().element_type() == BF16 && + !bfloat16_support_->SupportsBF16Output(*hlo)) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { + // Do not change instructions related to entry and exit of a computation, + // tuples, fusion, convert, and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional) { + return Status::OK(); + } + return HandleInstruction(hlo); +} + +StatusOr BFloat16Normalization::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "BFloat16Normalization::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeComputationPostOrder()) { + if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) { + changed = true; + } + } + XLA_VLOG_LINES(2, + "BFloat16Normalization::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h new file mode 100644 index 0000000000000000000000000000000000000000..2a60fe0af3218484acb95e6c69815d551350764c --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not +// support BF16 input/output or mixed precision, according to the passed-in +// backend-specific BF16 support rules. +class BFloat16Normalization : public HloPassInterface { + public: + explicit BFloat16Normalization(const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + + ~BFloat16Normalization() override = default; + tensorflow::StringPiece name() const override { return "bf16-normalization"; } + + // Run BF16 normalization on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + const BFloat16Support* bfloat16_support_; +}; + +// A pass that unconditionally removes the mixed F32/BF16 uses in HLO +// instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike +// BFloat16Normalization, this pass does not use a backend-specific +// BFloat16Support, and does not change HLOs that have BF16 data if they do not +// use mixed precision; it removes mixed precision even if the backend supports +// it. This pass is used to make the HLO module valid for other HLO passes which +// do not support mixed precision. +class BFloat16MixedPrecisionRemoval : public HloPassInterface { + public: + BFloat16MixedPrecisionRemoval() {} + + ~BFloat16MixedPrecisionRemoval() override = default; + + tensorflow::StringPiece name() const override { + return "bf16-mixed-precision-removal"; + } + + // Run mixed precision removal on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override { + BFloat16Normalization normalization(&no_mixed_precision_support_); + return normalization.Run(module); + } + + private: + class BFloat16SupportForMixedPrecisionRemoval : public BFloat16Support { + public: + BFloat16SupportForMixedPrecisionRemoval() {} + + ~BFloat16SupportForMixedPrecisionRemoval() override = default; + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + return true; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + return true; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + return false; + } + } no_mixed_precision_support_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..66c3085842c4afe7ffc4d5891883e4cce9389d45 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -0,0 +1,248 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kReduce || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } +}; + +class BFloat16NormalizationTest : public HloTestBase { + protected: + bool Normalize(HloModule* module) { + TestBFloat16Support bfloat16_support_; + BFloat16Normalization normalization(&bfloat16_support_); + StatusOr result = normalization.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(BFloat16NormalizationTest, NoopIfSupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b)); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), add1); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), F32); +} + +TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* mul0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b)); + + HloInstruction* mul1 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(computation->root_instruction()->operand(0), mul1); + EXPECT_EQ(mul0->shape().element_type(), F32); + EXPECT_EQ(mul1->shape().element_type(), F32); + EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert); +} + +TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* sub0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b)); + + HloInstruction* sub1 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(computation->root_instruction()->operand(0), sub1); + EXPECT_EQ(sub0->shape().element_type(), F32); + EXPECT_EQ(sub1->shape().element_type(), F32); + EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert); +} + +TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { + Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4}); + + Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + auto reduce_comp_builder = HloComputation::Builder("reduce_comp"); + auto reduce_comp_param0 = reduce_comp_builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0")); + auto reduce_comp_param1 = reduce_comp_builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1")); + reduce_comp_builder.AddInstruction( + HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd, + reduce_comp_param0, reduce_comp_param1)); + + auto module = CreateNewModule(); + auto reduce_computation = + module->AddEmbeddedComputation(reduce_comp_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_input_shape, "a")); + HloInstruction* init = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_scalar_shape, "init")); + HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce( + f32_output_shape, input, init, {0}, reduce_computation)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), reduce); + EXPECT_EQ(reduce->called_computations().size(), 1); + EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2); + EXPECT_EQ(reduce->called_computations()[0] + ->parameter_instruction(0) + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->called_computations()[0] + ->parameter_instruction(1) + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->called_computations()[0] + ->root_instruction() + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->shape().element_type(), F32); + EXPECT_EQ(reduce->operand(0), input); + EXPECT_EQ(input->shape().element_type(), F32); + EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert); + EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32); +} + +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + + HloInstruction* crs = + builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), gte); + EXPECT_EQ(gte->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(1)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc new file mode 100644 index 0000000000000000000000000000000000000000..3fd9e24601f27633c8063e4574c7c4f91f30dcff --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { + +bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + case HloOpcode::kConvert: + CHECK_EQ(operand_index, 0); + return hlo.operand(0)->shape().element_type() == BF16; + default: + break; + } + return false; +} + +bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + case HloOpcode::kConvert: + return hlo.shape().element_type() == BF16; + default: + break; + } + return false; +} + +bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + default: + break; + } + return false; +} + +/* static */ +bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( + const HloInstruction& hlo, int64 operand_index) { + switch (hlo.opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kBroadcast: + case HloOpcode::kClamp: + case HloOpcode::kConcatenate: + case HloOpcode::kCopy: + case HloOpcode::kGetTupleElement: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return true; + case HloOpcode::kDynamicSlice: + return operand_index == 0; + case HloOpcode::kDynamicUpdateSlice: + return operand_index == 0 || operand_index == 1; + case HloOpcode::kSelect: + return operand_index == 1 || operand_index == 2; + default: + break; + } + return false; +} + +bool BFloat16Support::EffectiveOperandPrecisionIsBF16( + const HloInstruction& hlo, int64 operand_index) const { + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.h b/tensorflow/compiler/xla/service/bfloat16_support.h new file mode 100644 index 0000000000000000000000000000000000000000..29f662d22b4e5486662a1387407d41e0fd2ed1b3 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_support.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { + +class BFloat16Support { + public: + BFloat16Support() {} + virtual ~BFloat16Support() {} + + // Returns whether the backend supports BF16 operand for the HLO instruction + // at the given index. + virtual bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const; + + // Returns whether the backend supports BF16 output for the HLO instruction. + virtual bool SupportsBF16Output(const HloInstruction& hlo) const; + + // Returns whether the backend support mixed precision: the operands, output, + // and parameters/output of the called computations can have different + // precisions (BF16 and F32). + virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const; + + // Returns whether the given HLO inherits its BF16 operand precision at the + // given index, so even if the output is F32, elements in the output that + // depend on the BF16 operand will still have BF16 effective precision even if + // they have F32 format. Similarly, this also means if the output is BF16 then + // increasing the operand precision from BF16 to F32 will not change the + // output. This typically includes HLOs that pass elements from the operand to + // the output without arithmetic operations. + static bool EffectiveOperandPrecisionIsOutputPrecision( + const HloInstruction& hlo, int64 operand_index); + + // Returns if the backend only uses BF16 precision for the operand at the + // specified index, even if the operand is F32. + virtual bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo, + int64 operand_index) const; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 323620c13186ed5f3c8613adb7e736f33674c270..b1e693da9d5af4babe619b8796007f2da318f6a8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -45,6 +45,8 @@ using ::tensorflow::gtl::FlatMap; using ::tensorflow::gtl::FlatSet; using ::tensorflow::strings::Appendf; using ::tensorflow::strings::HumanReadableNumBytes; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { uint64 h = std::hash()(s.index()); @@ -93,6 +95,9 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto.set_color(color_.value()); if (is_entry_computation_parameter_) { proto.set_is_entry_computation_parameter(true); + for (int64 idx : param_shape_index()) { + proto.add_parameter_shape_index(idx); + } proto.set_parameter_number(parameter_number_); } proto.set_maybe_live_out(maybe_live_out_); @@ -112,25 +117,24 @@ BufferAllocationProto BufferAllocation::ToProto() const { string BufferAllocation::ToString() const { string output; - tensorflow::strings::StrAppend( - &output, tensorflow::strings::Printf("allocation %lld: %p, size %lld", - index_, this, size())); + Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); if (color().value() != 0) { - tensorflow::strings::StrAppend(&output, ", color ", color().value()); + StrAppend(&output, ", color ", color().value()); } if (is_entry_computation_parameter()) { - tensorflow::strings::StrAppend(&output, ", parameter ", parameter_number()); + StrAppend(&output, ", parameter ", parameter_number(), " at ShapeIndex ", + param_shape_index().ToString()); } if (is_thread_local()) { - tensorflow::strings::StrAppend(&output, ", thread-local"); + StrAppend(&output, ", thread-local"); } if (maybe_live_out()) { - tensorflow::strings::StrAppend(&output, ", maybe-live-out"); + StrAppend(&output, ", maybe-live-out"); } if (IsPreallocatedTempBuffer()) { - tensorflow::strings::StrAppend(&output, ", preallocated-temp"); + StrAppend(&output, ", preallocated-temp"); } - tensorflow::strings::StrAppend(&output, ":\n"); + StrAppend(&output, ":\n"); // Dump the assigned buffers ordered by id. std::vector sorted_buffers; for (const auto& buffer_offset_size : assigned_buffers_) { @@ -142,12 +146,11 @@ string BufferAllocation::ToString() const { }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); - tensorflow::strings::StrAppend( - &output, - tensorflow::strings::Printf( - " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), - offset_size.offset, offset_size.size, - ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); + StrAppend(&output, + tensorflow::strings::Printf( + " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), + offset_size.offset, offset_size.size, + ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); } return output; } @@ -840,7 +843,7 @@ Status BufferAssigner::AssignBuffersForComputation( /*is_thread_local=*/false, /*is_reusable=*/false); allocation->set_entry_computation_parameter( - instruction->parameter_number()); + instruction->parameter_number(), buffer->index()); VLOG(3) << "New allocation #" << allocation->index() << " for entry computation parameter: " << *buffer; continue; @@ -997,14 +1000,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( auto color = single_colored_set.first; VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); + HeapSimulator::Options options; + options.buffers_to_assign = &single_colored_set.second; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( MakeUnique(alignment)), assignment->module(), module_sequence, assignment->points_to_analysis(), - assignment->buffer_size_, - &single_colored_set.second)); + assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1024,14 +1028,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( auto color = single_colored_set.first; VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); + HeapSimulator::Options options; + options.buffers_to_assign = &single_colored_set.second; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( MakeUnique(alignment)), *computation, *instruction_sequence, assignment->points_to_analysis(), - assignment->buffer_size_, - &single_colored_set.second)); + assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1120,140 +1125,6 @@ void BufferAssigner::AddSetToColocatedBufferSets( } } -// Conceptually the same as AddSetToColocatedBufferSets, but specific to the -// colocated buffers for while instructions. 'colocated_set' contains the -// buffers for a single while instruction that must be colocated. The idea here -// is to apply a memory-saving heuristic for separate while instructions whose -// buffers are disjoint in liveness, by using the colocation mechanism to force -// buffer sharing. This often reduces memory for multi-layer RNNs. -// -// TODO(b/32491382): We should be able to remove this heuristic after we -// implement module-level liveness analysis, which would let us directly detect -// buffer sharing opportunities between the while instruction buffer and the -// buffers from the predicate and body computation, as well as sharing across -// different while instructions. -void BufferAssigner::AddWhileSetToColocatedBufferSets( - const std::vector& colocated_set, - const LogicalBuffer* while_init_buffer, - const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo, - const HloComputation& computation, const BufferLiveness& buffer_liveness, - const LogicalBuffer::SizeFunction& buffer_size, - std::vector* colocated_buffer_sets) { - CHECK(!colocated_set.empty()); - const TuplePointsToAnalysis& points_to_analysis = - buffer_liveness.points_to_analysis(); - - // Parallel while loops cannot safely share colocated buffer sets. - if (buffer_liveness.hlo_ordering().SequentialOrder(computation) == nullptr) { - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); - return; - } - - // Scan 'colocated_buffer_sets' in reverse order for locality; colocated sets - // are added in postorder over computations and instructions. - const int64 init_buffer_size = buffer_size(*while_init_buffer); - const bool is_live_out = buffer_liveness.MaybeLiveOut(*while_result_buffer); - for (int i = colocated_buffer_sets->size() - 1; i >= 0; --i) { - const ColocatedBufferSet& predecessor_set = (*colocated_buffer_sets)[i]; - - // Skip predecessor sets not associated with while loops. - if (std::all_of(predecessor_set.begin(), predecessor_set.end(), - [](const LogicalBuffer* buffer) { - return buffer->instruction()->opcode() != - HloOpcode::kWhile; - })) { - continue; - } - - // Skip predecessor sets already associated with 'while_hlo'. - if (std::any_of(predecessor_set.begin(), predecessor_set.end(), - [&while_hlo](const LogicalBuffer* buffer) { - return buffer->instruction() == while_hlo; - })) { - continue; - } - - // Skip predecessor sets with entry parameter if the while result is live - // out. - if (is_live_out && - std::any_of(predecessor_set.begin(), predecessor_set.end(), - [](const LogicalBuffer* buffer) { - auto* instruction = buffer->instruction(); - auto* computation = instruction->parent(); - auto* module = computation->parent(); - return instruction->opcode() == HloOpcode::kParameter && - computation == module->entry_computation(); - })) { - continue; - } - - // Build vector of predecessor while result and init buffers, which are - // checked for liveness interference below. We must check both the result - // and init buffers because they're aliased together, but - // TuplePointsToAnalysis is unaware of this aliasing. - std::vector predecessor_while_buffers; - for (const LogicalBuffer* buffer : predecessor_set) { - const HloInstruction* instruction = buffer->instruction(); - if (instruction->opcode() == HloOpcode::kWhile && - buffer_size(*buffer) == init_buffer_size && - instruction->parent() == &computation) { - predecessor_while_buffers.push_back(buffer); - // Add the init buffer at the same index, which must also exist in the - // predecessor set, and must be unambiguous. - const PointsToSet& init_points_to = - points_to_analysis.GetPointsToSet(instruction->operand(0)); - const auto& init_buffers = init_points_to.element(buffer->index()); - CHECK_EQ(init_buffers.size(), 1); - CHECK_GT(predecessor_set.count(init_buffers[0]), 0); - predecessor_while_buffers.push_back(init_buffers[0]); - } - } - if (predecessor_while_buffers.empty()) { - continue; - } - - // Skip predecessor set if the live range of any predecessor - // buffers overlaps with 'while_init_buffer' or - // 'while_result_buffer' (we need to check both since they're - // aliased together, but the points-to analysis is unaware of this - // aliasing). Note that tuple element buffer forwarding can cause - // the same buffer to appear on both sides of the interference - // comparison below. - auto may_interfere_with_init_or_result = [&](const LogicalBuffer* buffer) { - if (while_init_buffer->id() != buffer->id() && - buffer_liveness.MayInterfere(*while_init_buffer, *buffer)) { - return true; - } - - if (while_result_buffer->id() != buffer->id() && - buffer_liveness.MayInterfere(*while_result_buffer, *buffer)) { - return true; - } - - return false; - }; - - if (std::any_of(predecessor_while_buffers.begin(), - predecessor_while_buffers.end(), - may_interfere_with_init_or_result)) { - continue; - } - - // All our checks have passed; merge 'predecessor_set' with 'colocated_set', - // and add the merged set to 'colocated_buffer_sets'. This forces the - // colocation of buffers across different while instructions. - FlatSet unique; - unique.insert(predecessor_set.begin(), predecessor_set.end()); - unique.insert(colocated_set.begin(), colocated_set.end()); - std::vector merged_set(unique.begin(), unique.end()); - AddSetToColocatedBufferSets(merged_set, colocated_buffer_sets); - return; - } - - // Failed to merge into predecessor set; add 'colocated_set' as-is. - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); -} - namespace { // Checks that points-to set of 'instruction' is unambiguous and distinct @@ -1270,8 +1141,130 @@ const LogicalBuffer* AddBufferToColocatedSet( return colocated_set->back(); } +// Given the interference map of a graph (the list of interfering node indices +// for each node), perform graph coloring such that interfering nodes are +// assigned to different colors. Returns the assigned color of the nodes, where +// the colors are represented as integer values [0, color_count). +std::vector ColorInterferenceGraph( + const std::vector>& interference_map) { + const int64 node_count = interference_map.size(); + + // Sort the nodes such that we assign nodes with more interference first. This + // relies on the common heuristic of assigning the most constrained node + // first, but it would be good to investigate other ordering heuristics too. + std::vector nodes(node_count); + std::iota(nodes.begin(), nodes.end(), 0); + std::sort(nodes.begin(), nodes.end(), + [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); + + const int64 kColorUnassigned = -1; + std::vector assigned_colors(node_count, kColorUnassigned); + for (int64 node : nodes) { + // Mark the colors that are already assigned to the neighbors. + std::vector available_colors(node_count, true); + for (int64 neighbor : interference_map[node]) { + int64 color = assigned_colors[neighbor]; + if (color != kColorUnassigned) { + available_colors[color] = false; + } + } + + // Find the color that is not yet assigned to the neighbors. + int64 color = kColorUnassigned; + for (color = 0; color < available_colors.size(); ++color) { + if (available_colors[color]) { + break; + } + } + CHECK_NE(color, kColorUnassigned); + assigned_colors[node] = color; + } + return assigned_colors; +} + } // namespace +std::vector +BufferAssigner::MergeColocatedBufferSets( + const std::vector& colocated_buffer_sets, + const BufferLiveness& buffer_liveness, + const LogicalBuffer::SizeFunction& buffer_size) { + VLOG(1) << "colocation sets count before coalescing:" + << colocated_buffer_sets.size(); + + // Returns true if the given buffer is for the entry parameter. + auto is_entry_parameter = [](const LogicalBuffer& buffer) { + auto* instruction = buffer.instruction(); + auto* computation = instruction->parent(); + auto* module = computation->parent(); + return instruction->opcode() == HloOpcode::kParameter && + computation == module->entry_computation(); + }; + + // Returns true if the two colocated buffer sets (specified by their indices + // into the colocated_buffer_sets) can be merged into a single set. + auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, + &buffer_size, + &is_entry_parameter](int64 i, int64 j) { + for (auto& buffer_a : colocated_buffer_sets[i]) { + for (auto& buffer_b : colocated_buffer_sets[j]) { + // Do not merge if the set includes live outs or entry parameters. + if ((buffer_liveness.MaybeLiveOut(*buffer_a) && + is_entry_parameter(*buffer_b)) || + (buffer_liveness.MaybeLiveOut(*buffer_b) && + is_entry_parameter(*buffer_a))) { + return true; + } + // Do not merge if the buffers interfere with each other. + if (buffer_a->id() != buffer_b->id() && + buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) { + return true; + } + // Do not merge if the buffer sizes are different. + if (buffer_size(*buffer_a) != buffer_size(*buffer_b)) { + return true; + } + } + } + return false; + }; + + // Build the interference map among the colocated buffer sets (nodes), by + // adding an edge between any two nodes that cannot be merged into a single + // colocated buffer set. + std::vector> interference_map( + colocated_buffer_sets.size()); + for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { + for (int64 j = i + 1; j < colocated_buffer_sets.size(); ++j) { + if (cannot_merge_buffer_sets(i, j)) { + interference_map[i].push_back(j); + interference_map[j].push_back(i); + } + } + } + + // Assign a color to each colocation set in colocated_buffer_sets, such that + // the sets that can be merged are assigned with the same color. + auto assigned_colors = ColorInterferenceGraph(interference_map); + + // Merge the buffer sets with the same color. + CHECK(!assigned_colors.empty()); + int64 num_sets = + *std::max_element(assigned_colors.begin(), assigned_colors.end()) + 1; + std::vector new_colocated_buffer_sets(num_sets); + for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { + const auto& buffer_set = colocated_buffer_sets[i]; + new_colocated_buffer_sets[assigned_colors[i]].insert(buffer_set.begin(), + buffer_set.end()); + } + + VLOG(1) << "colocation sets count after coalescing:" + << colocated_buffer_sets.size(); + return new_colocated_buffer_sets; +} + // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated // in the same allocation (currently just supports kWhile, kCall, and // kConditional). @@ -1297,12 +1290,11 @@ void BufferAssigner::BuildColocatedBufferSets( const Shape& /*subshape*/, const ShapeIndex& index) { std::vector colocated_set; // Add while.init. - auto* init_buffer = - AddBufferToColocatedSet(while_hlo->operand(0), index, - points_to_analysis, &colocated_set); + AddBufferToColocatedSet(while_hlo->operand(0), index, + points_to_analysis, &colocated_set); // Add while.result. - auto* result_buffer = AddBufferToColocatedSet( - while_hlo, index, points_to_analysis, &colocated_set); + AddBufferToColocatedSet(while_hlo, index, points_to_analysis, + &colocated_set); // Add while.cond.parameter. AddBufferToColocatedSet( while_hlo->while_condition()->parameter_instruction(0), index, @@ -1315,10 +1307,7 @@ void BufferAssigner::BuildColocatedBufferSets( AddBufferToColocatedSet( while_hlo->while_body()->root_instruction(), index, points_to_analysis, &colocated_set); - AddWhileSetToColocatedBufferSets( - colocated_set, init_buffer, result_buffer, while_hlo, - *computation, buffer_liveness, buffer_size, - colocated_buffer_sets); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); } else if (opcode == HloOpcode::kCall) { const HloInstruction* call_hlo = instruction; @@ -1358,9 +1347,62 @@ void BufferAssigner::BuildColocatedBufferSets( index, points_to_analysis, &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); + + // Add true_operand and conditional.true_computation.parameter(0) as a + // colocated buffer set. Note that this has to be done for each subshape + // in the true_operand of the conditional. + ShapeUtil::ForEachSubshape( + conditional_hlo->operand(1)->shape(), + [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( + const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector true_set; + // Add conditional.true_operand. + AddBufferToColocatedSet(conditional_hlo->operand(1), index, + points_to_analysis, &true_set); + // Add conditional.true_computation.parameter_instruction(0). + AddBufferToColocatedSet( + conditional_hlo->true_computation()->parameter_instruction(0), + index, points_to_analysis, &true_set); + AddSetToColocatedBufferSets(true_set, colocated_buffer_sets); + }); + + // Add false_operand and conditional.false_computation.parameter(0) as a + // colocated buffer set. Note that this has to be done for each subshape + // in the false_operand of the conditional. + ShapeUtil::ForEachSubshape( + conditional_hlo->operand(2)->shape(), + [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( + const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector false_set; + // Add conditional.false_operand. + AddBufferToColocatedSet(conditional_hlo->operand(2), index, + points_to_analysis, &false_set); + // Add conditional.false_computation.parameter_instruction(0). + AddBufferToColocatedSet( + conditional_hlo->false_computation()->parameter_instruction( + 0), + index, points_to_analysis, &false_set); + AddSetToColocatedBufferSets(false_set, colocated_buffer_sets); + }); } } } + + if (colocated_buffer_sets->empty()) { + return; + } + + // Try to find more coalescing opportunities among the colocated buffer sets. + // + // TODO(b/32491382): We should be able to remove this by using the + // module-level liveness analysis, which would let us directly detect buffer + // sharing opportunities between the while instruction buffer and the buffers + // from the predicate and body computation, as well as sharing across + // different while instructions. + std::vector new_colocated_buffer_sets = + MergeColocatedBufferSets(*colocated_buffer_sets, buffer_liveness, + buffer_size); + std::swap(*colocated_buffer_sets, new_colocated_buffer_sets); } // Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same @@ -1372,14 +1414,17 @@ void BufferAssigner::AssignColocatedBufferSets( FlatSet* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; - // Set 'entry_parameter_number' if entry param in 'colocated_buffer_set'. + // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry + // param in 'colocated_buffer_set'. int64 entry_parameter_number = -1; + const ShapeIndex* entry_parameter_shape_idx = nullptr; for (const LogicalBuffer* buffer : colocated_buffer_set) { const HloInstruction* instruction = buffer->instruction(); const HloComputation* computation = instruction->parent(); if (instruction->opcode() == HloOpcode::kParameter && computation == computation->parent()->entry_computation()) { entry_parameter_number = instruction->parameter_number(); + entry_parameter_shape_idx = &buffer->index(); break; } } @@ -1400,7 +1445,8 @@ void BufferAssigner::AssignColocatedBufferSets( // body computation (which updates in place). // Set 'entry_computation_parameter' to indicate that it contains // an entry parameter, and to prevent reuse in MaybeAssignBuffer. - allocation->set_entry_computation_parameter(entry_parameter_number); + allocation->set_entry_computation_parameter( + entry_parameter_number, *entry_parameter_shape_idx); } colocated_allocations->insert(allocation->index()); } else { diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 08a40bfeb2a2a78c25805308e73154c6cc667f21..6b7fd0014d103ef0617afcc5cb3f663554a01aa4 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -91,6 +91,13 @@ class BufferAllocation { return parameter_number_; } + // If this allocation is for a parameter of the entry computation, this + // function returns which subshape of the parameter the allocation is for. + const ShapeIndex& param_shape_index() const { + CHECK(is_entry_computation_parameter_); + return param_shape_index_; + } + // Returns whether this allocation is assigned a LogicalBuffer which may // be live out of the entry computation. bool maybe_live_out() const { return maybe_live_out_; } @@ -203,9 +210,11 @@ class BufferAllocation { // Adds a LogicalBuffer to the set assigned to this buffer. void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size); - void set_entry_computation_parameter(int64 parameter_number) { + void set_entry_computation_parameter(int64 parameter_number, + ShapeIndex param_shape_index) { is_entry_computation_parameter_ = true; parameter_number_ = parameter_number; + param_shape_index_ = std::move(param_shape_index); } void set_maybe_live_out(bool value) { maybe_live_out_ = value; } void set_index(Index index) { index_ = index; } @@ -235,6 +244,10 @@ class BufferAllocation { // indicates the index (starting from 0) of the parameter. int64 parameter_number_ = 0; + // If this buffer is for an entry computation parameter, which subshape of the + // parameter is it for? + ShapeIndex param_shape_index_; + // Whether the allocation contains a LogicalBuffer which may be live-out of // the entry computation. Note that this flag is conservatively computed by // TuplePointsToAnalysis. That is, an allocation marked `maybe_live_out_` @@ -528,15 +541,13 @@ class BufferAssigner { const std::vector& colocated_set, std::vector* colocated_buffer_sets); - // Conceptually the same as AddSetToColocatedBufferSets, but specific to the - // colocated buffers for while instructions. - void AddWhileSetToColocatedBufferSets( - const std::vector& colocated_set, - const LogicalBuffer* while_init_buffer, - const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo, - const HloComputation& computation, const BufferLiveness& buffer_liveness, - const LogicalBuffer::SizeFunction& buffer_size, - std::vector* colocated_buffer_sets); + // Given a list of colocated buffer sets (each colocated buffer set represents + // the logical buffers that would be assigned to the same physical buffer), + // try to merge the sets if the buffers can be shared. Returns the merged set. + std::vector MergeColocatedBufferSets( + const std::vector& colocated_buffer_sets, + const BufferLiveness& buffer_liveness, + const LogicalBuffer::SizeFunction& buffer_size); // Split a set of buffers into several sets, each of which contains buffers // colored with the same color. diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 6fc9d783f1b34de8c0f93c6aa342591891d08eaf..cd73654b8f666c4b96c000235cc3ad2cd0a46c17 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -614,7 +614,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map); EXPECT_NE(param0_buffer.index(), map_buffer.index()); - // The final computation node of the map is an add of an f32 parm and a + // The final computation node of the map is an add of an f32 param and a // constant. EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode()); const BufferAllocation& inner_add_buffer = @@ -1587,6 +1587,117 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); } +// Tests that the colocated buffers for while instructions are properly assigned +// during buffer assignment such that the result tuple elements are not assigned +// to the same buffer. +// +// %infeed --> %while.0 --> %while.1 --+ +// +-- %tuple +// %zero --> %add --> %while.2 --+ +// +// Execution Order: +// %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple +// +// The HLO computation used in this test requires specific ordering to expose +// the bug (b/72496031). During buffer assignment, the visitation order of +// colocated buffers is %while.2 -> while.0 -> while.1, and the buffer +// assignment was coalescing the colocated buffers for all 3 while instructions, +// therefore assigning the same buffer to the two result tuple elements. +TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + + // Builds a condition computation: x -> x < 4 + auto build_cond = [&]() { + auto builder = HloComputation::Builder("cond"); + auto const4 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(4))); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4)); + return builder.Build(); + }; + + // Builds a body computation: x -> x + 9 + auto build_body = [&]() { + auto builder = HloComputation::Builder("body"); + auto const9 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(9))); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9)); + return builder.Build(); + }; + + // Build the entry computation as described in the comment above. + auto module = xla::MakeUnique(TestName()); + auto builder = HloComputation::Builder("entry"); + + auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); + auto cond0 = module->AddEmbeddedComputation(build_cond()); + auto body0 = module->AddEmbeddedComputation(build_body()); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(r0s32, cond0, body0, infeed)); + + auto cond1 = module->AddEmbeddedComputation(build_cond()); + auto body1 = module->AddEmbeddedComputation(build_body()); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(r0s32, cond1, body1, while0)); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero)); + auto cond2 = module->AddEmbeddedComputation(build_cond()); + auto body2 = module->AddEmbeddedComputation(build_body()); + auto while2 = builder.AddInstruction( + HloInstruction::CreateWhile(r0s32, cond2, body2, add)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({while2, while1})); + module->AddEntryComputation(builder.Build()); + + // Run CopyInsertion and check if the graph constructed above doesn't need + // any copies inserted for BufferAssignment to run. + int64 instruction_count = module->instruction_count(); + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_EQ(instruction_count, module->instruction_count()); + + // Create a sequential order among all the instructions in the entry + // computation, since the issue this test stresses depends on the order the + // nodes are traversed during BufferAssignment. + SequentialHloOrdering::HloModuleSequence sequence; + sequence[module->entry_computation()] = {infeed, while0, while1, zero, + add, while2, tuple}; + TF_ASSERT_OK_AND_ASSIGN( + auto assignment, + BufferAssigner::Run( + module.get(), + xla::MakeUnique(module.get(), sequence), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; })); + + // The result tuple elements must be assigned with different buffers. + TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); + TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1})); + EXPECT_NE(slice0, slice1); + + // while0 and while1 result buffers must be equal to slice1. + TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, + assignment->GetUniqueSlice(while0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while1, + assignment->GetUniqueSlice(while1, {})); + EXPECT_EQ(slice1, slice_while0); + EXPECT_EQ(slice1, slice_while1); + + // while2 result buffer must be equal to slice0. + TF_ASSERT_OK_AND_ASSIGN(auto slice_while2, + assignment->GetUniqueSlice(while2, {})); + EXPECT_EQ(slice0, slice_while2); +} + TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index e7749252ce44f0daf7016f72d80401695eaaacb9..37982aaef9eddd64ef6b57ad5a9cf8dd6a565097 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -117,11 +117,12 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // If the root instruction aliases the buffer 'a', the live range of 'a' is // until the end of the computation and can never be strictly before another - // buffer. This is needed to prevent the root instruction's buffers from - // being reused by later instructions even when the root is not the last - // instruction in the schedule. + // buffer defined in the same computation. This is needed to prevent the + // root instruction's buffers from being reused by later instructions even + // when the root is not the last instruction in the schedule. if (alias.instruction()->parent()->root_instruction() == - alias.instruction()) { + alias.instruction() && + alias.instruction()->parent() == b.instruction()->parent()) { return false; } } diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index b9306a8bb09dc4541014716bb0c5e73e3c93ec85..dab73596e1639eed62151197048ee8d29570b20a 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -101,7 +101,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options)); + &execution_options, *user_computation)); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, computation_tracker_.BuildHloModule( diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index fc67330f5cbdbcb0d1a259d284599916a908d1fe..74fd24edf88d44b2dfdc87556b0af43987e69e08 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -72,8 +72,18 @@ class AotCompilationOptions { // Returns the ID of the platform to which these options apply. virtual perftools::gputools::Platform::Id PlatformId() const = 0; + // Optional allocator that may be used for allocating temp space on the device + // during compilation. + DeviceMemoryAllocator* device_allocator() const { return device_allocator_; } + void set_device_allocator(DeviceMemoryAllocator* device_allocator) { + device_allocator_ = device_allocator; + } + protected: AotCompilationOptions() = default; + + private: + DeviceMemoryAllocator* device_allocator_ = nullptr; }; // Abstract compiler interface that is subclassed for compilation on a @@ -99,9 +109,16 @@ class Compiler { // Runs Hlo passes to optimize the given Hlo module, returns the optimized // module. + // + // If device_allocator is not null, the compiler may use it to allocate temp + // space on the device for use during compilation. For example, the compiler + // may allocate buffers on the device and then run variants of a given + // algorithm over those buffers, to see which variant is fastest. Any space + // allocated should be deallocated before this function returns. virtual StatusOr> RunHloPasses( std::unique_ptr module, - perftools::gputools::StreamExecutor* executor) = 0; + perftools::gputools::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) = 0; // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are @@ -112,21 +129,27 @@ class Compiler { // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. // + // device_allocator is optional; see RunHloPasses. + // // Use the overload below to compile computations that run in parallel. virtual StatusOr> RunBackend( std::unique_ptr module, - perftools::gputools::StreamExecutor* executor) = 0; + perftools::gputools::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. // + // device_allocator is optional; see RunHloPasses. + // // TODO(b/68666782): Remove this method after adding support for multiple // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( std::vector> modules, std::vector> - stream_exec) = 0; + stream_exec, + DeviceMemoryAllocator* device_allocator) = 0; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index cd983bc03e993caed883916de01d75dffdbc4bab..cc195879a6bb490a9b49ad962aa9326cb51d9b0a 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -729,7 +729,8 @@ class CopyRemover { // has a different operand (the operand of the elided copy). for (const HloUse* copy_use : copy_value_node->uses) { operand_node->uses.push_back(copy_use); - if (copy_use->instruction->opcode() == HloOpcode::kCopy) { + if (copy_use->instruction->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, copy_use->instruction)) { copy_map_.at(copy_use->instruction).src = operand_node; } } @@ -1155,7 +1156,7 @@ bool IsWhileBody(const HloComputation* computation, HloModule* module) { std::unique_ptr call_graph = CallGraph::Build(module); TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(module)); + HloDataflowAnalysis::Run(*module)); bool changed = false; diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 128ee726ea6e4a8b63727fdc9762d865cee1c985..153f062d015e49db11c4c9ae0a2a61e76c020f02 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1724,8 +1724,58 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { } } +std::unique_ptr MakeBenchmarkWhileBody( + const int num_tuple_inputs) { + auto builder = HloComputation::Builder("benchmark_loop_body"); + const Shape element_shape = ShapeUtil::MakeShape(F32, {}); + std::vector input_shape(num_tuple_inputs, element_shape); + const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); + std::vector gte_nodes(num_tuple_inputs); + for (int i = 0; i < num_tuple_inputs; ++i) { + gte_nodes[i] = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, i)); + } + builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes)); + return builder.Build(); +} + +void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { + tensorflow::testing::StopTiming(); + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + CopyInsertion copy_insertion; + const Shape element_shape = ShapeUtil::MakeShape(F32, {}); + std::vector tuple_params(num_tuple_inputs); + for (int i = 0; i < num_iters; ++i) { + auto builder = HloComputation::Builder("BM_ParallelWhiles"); + HloModule module("BM_ManyElementTuple", VersionedComputationHandle(), + config); + for (int j = 0; j < num_tuple_inputs; ++j) { + tuple_params[j] = builder.AddInstruction( + HloInstruction::CreateParameter(j, element_shape, "")); + } + HloInstruction* init = + builder.AddInstruction(HloInstruction::CreateTuple(tuple_params)); + HloComputation* condition = + module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); + HloComputation* body = + module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs)); + HloInstruction* xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(init->shape(), condition, body, init)); + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(F32, {}), xla_while, 0)); + module.AddEntryComputation(builder.Build()); + tensorflow::testing::StartTiming(); + ASSERT_IS_OK(copy_insertion.Run(&module).status()); + tensorflow::testing::StopTiming(); + } +} + BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); +BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288); TEST_F(CopyInsertionTest, SimpleControlFlowTest) { const string& hlo_string = R"( diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2f0259163120dd5d62a5d1289deada8dc59c2c6c..c13a0b1cdf0b5be0b69db98b2b9587f30ca4c304 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -159,9 +159,6 @@ cc_library( deps = [ ":compiler_functor", ":cpu_runtime", - ":cpu_runtime_avx", - ":cpu_runtime_neon", - ":cpu_runtime_sse4_1", ":custom_call_target_registry", ":disassembler", ":external_constant_pool", @@ -408,9 +405,6 @@ cc_library( hdrs = ["compiler_functor.h"], deps = [ ":cpu_runtime", - ":cpu_runtime_avx", - ":cpu_runtime_neon", - ":cpu_runtime_sse4_1", ":disassembler", ":llvm_ir_runtime", "//tensorflow/compiler/xla:statusor", @@ -430,43 +424,6 @@ cc_library( ], ) -cc_library( - name = "cpu_runtime_sse4_1", - srcs = ["cpu_runtime_sse4_1.cc"], - hdrs = ["cpu_runtime_sse4_1.h"], - copts = ["-DEIGEN_AVOID_STL_ARRAY"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], -) - -cc_library( - name = "cpu_runtime_avx", - srcs = ["cpu_runtime_avx.cc"], - hdrs = ["cpu_runtime_avx.h"], - copts = ["-DEIGEN_AVOID_STL_ARRAY"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], -) - -cc_library( - name = "cpu_runtime_neon", - srcs = ["cpu_runtime_neon.cc"], - hdrs = ["cpu_runtime_neon.h"], - # runtime_copts() enables -mfpu=neon - copts = ["-DEIGEN_AVOID_STL_ARRAY"] + runtime_copts(), - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], -) - cc_library( name = "cpu_runtime", srcs = [ @@ -497,6 +454,7 @@ cc_library( "llvm_ir_runtime.h", ], deps = [ + ":vector_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@llvm//:core", @@ -852,6 +810,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "@llvm//:core", + "@llvm//:support", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 04b4a8c5c80eeefdbe10001ba5c462affbc9b21d..ed290fcdf8bb69f1bbad57fa5a0926376bc9405a 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -37,9 +37,6 @@ limitations under the License. #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -50,15 +47,6 @@ limitations under the License. namespace xla { namespace cpu { -/* static */ CompilerFunctor::VectorIntrinsics -CompilerFunctor::AllIntrinsics() { - VectorIntrinsics intrinsics; - intrinsics.sse_intrinsics = true; - intrinsics.avx_intrinsics = true; - intrinsics.neon_intrinsics = true; - return intrinsics; -} - /* Create filtered versions of the LLVM Pass Managers to filter out some of the expensive passes. Profiling: @@ -192,89 +180,28 @@ operator()(llvm::Module& module) const { std::move(object_file), std::move(memory_buffer)); } -namespace { -// Returns the set of vectorized library functions supported for the target. -std::vector VectorFunctionsForTargetLibraryInfoImpl( - llvm::Triple::ArchType arch, llvm::StringRef feature_string, - CompilerFunctor::VectorIntrinsics const& available_intrinsics) { - std::vector vector_functions; - - const llvm::VecDesc four_wide_vector_functions_neon[] = { - {"expf", runtime::kExpV4F32NEONSymbolName, 4}, - {"llvm.exp.f32", runtime::kExpV4F32NEONSymbolName, 4}, - - {"logf", runtime::kLogV4F32NEONSymbolName, 4}, - {"llvm.log.f32", runtime::kLogV4F32NEONSymbolName, 4}, - }; - - const llvm::VecDesc four_wide_vector_functions_sse[] = { - {"expf", runtime::kExpV4F32SSESymbolName, 4}, - {"llvm.exp.f32", runtime::kExpV4F32SSESymbolName, 4}, - - {"logf", runtime::kLogV4F32SSESymbolName, 4}, - {"llvm.log.f32", runtime::kLogV4F32SSESymbolName, 4}, - }; - - const llvm::VecDesc eight_wide_vector_functions_avx[] = { - {"expf", runtime::kExpV8F32AVXSymbolName, 8}, - {"llvm.exp.f32", runtime::kExpV8F32AVXSymbolName, 8}, - - {"logf", runtime::kLogV8F32AVXSymbolName, 8}, - {"llvm.log.f32", runtime::kLogV8F32AVXSymbolName, 8}, - }; - - // These functions are generated by XLA as LLVM IR, so they're always - // available. - const llvm::VecDesc ir_vector_functions[] = { +static std::vector VectorFunctionsForTargetLibraryInfoImpl() { + std::vector result = { {"tanhf", runtime::kTanhV4F32SymbolName, 4}, {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4}, {"tanhf", runtime::kTanhV8F32SymbolName, 8}, {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName, 8}, - }; - llvm::SmallVector features; - feature_string.split(features, ',', -1, /*KeepEmpty=*/false); - auto has_feature = [&features](const llvm::StringRef feature) { - return std::find(features.begin(), features.end(), feature) != - features.end(); - }; + {"expf", runtime::kExpV4F32SymbolName, 4}, + {"llvm.exp.f32", runtime::kExpV4F32SymbolName, 4}, - switch (arch) { - case llvm::Triple::x86: - case llvm::Triple::x86_64: { - if (has_feature("+sse4.1") && available_intrinsics.sse_intrinsics) { - vector_functions.insert(vector_functions.end(), - std::begin(four_wide_vector_functions_sse), - std::end(four_wide_vector_functions_sse)); - } - if (has_feature("+avx") && available_intrinsics.avx_intrinsics) { - vector_functions.insert(vector_functions.end(), - std::begin(eight_wide_vector_functions_avx), - std::end(eight_wide_vector_functions_avx)); - } - break; - } - case llvm::Triple::arm: - case llvm::Triple::aarch64: { - if (has_feature("+neon") && available_intrinsics.neon_intrinsics) { - vector_functions.insert(vector_functions.end(), - std::begin(four_wide_vector_functions_neon), - std::end(four_wide_vector_functions_neon)); - } - break; - } - default: - break; - } + {"expf", runtime::kExpV8F32SymbolName, 8}, + {"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8}, - vector_functions.insert(vector_functions.end(), - std::begin(ir_vector_functions), - std::end(ir_vector_functions)); + {"logf", runtime::kLogV4F32SymbolName, 4}, + {"llvm.log.f32", runtime::kLogV4F32SymbolName, 4}, - return vector_functions; + {"logf", runtime::kLogV8F32SymbolName, 8}, + {"llvm.log.f32", runtime::kLogV8F32SymbolName, 8}, + }; + return result; } -} // namespace void CompilerFunctor::AddTargetInfoPasses( llvm::legacy::PassManagerBase* passes) const { @@ -282,9 +209,7 @@ void CompilerFunctor::AddTargetInfoPasses( auto target_library_info_impl = MakeUnique(target_triple); target_library_info_impl->addVectorizableFunctions( - VectorFunctionsForTargetLibraryInfoImpl( - target_triple.getArch(), target_machine_->getTargetFeatureString(), - available_intrinsics_)); + VectorFunctionsForTargetLibraryInfoImpl()); passes->add( new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl)); passes->add(createTargetTransformInfoWrapperPass( diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index 8cdd049e7b773bdc455db627ff1749997d621ee4..1a8283a702223a7414c1ffcd99c1ac42c04ac068 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -31,21 +31,10 @@ namespace cpu { // Orc JIT compile layer. class CompilerFunctor { public: - // Describes the set of vector intrinsics available to the generated code. - struct VectorIntrinsics { - bool sse_intrinsics; - bool avx_intrinsics; - bool neon_intrinsics; - }; - - // Returns a VectorIntrinsics where all intrinsics are available. - static VectorIntrinsics AllIntrinsics(); - explicit CompilerFunctor( llvm::TargetMachine* target_machine, const Disassembler* disassembler, int opt_level, bool optimize_for_size, bool enable_fast_math, bool disable_expensive_passes, - const VectorIntrinsics& available_intrinsics, LLVMCompiler::ModuleHook pre_optimization_hook = nullptr, LLVMCompiler::ModuleHook post_optimization_hook = nullptr) : target_machine_(target_machine), @@ -54,7 +43,6 @@ class CompilerFunctor { optimize_for_size_(optimize_for_size), enable_fast_math_(enable_fast_math), disable_expensive_passes_(disable_expensive_passes), - available_intrinsics_(available_intrinsics), pre_optimization_hook_(pre_optimization_hook), post_optimization_hook_(post_optimization_hook) {} @@ -78,7 +66,6 @@ class CompilerFunctor { const bool optimize_for_size_; const bool enable_fast_math_; const bool disable_expensive_passes_; - const VectorIntrinsics available_intrinsics_; LLVMCompiler::ModuleHook pre_optimization_hook_; LLVMCompiler::ModuleHook post_optimization_hook_; }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 33af77e1a81411ff5e1543d594b6078ed8e7fd1e..f9cc9651846cca7bd6ab7e9e61590cec4e2400da 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -437,7 +437,8 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) { StatusOr> CpuCompiler::RunHloPasses( std::unique_ptr module, - perftools::gputools::StreamExecutor* /*stream_exec*/) { + perftools::gputools::StreamExecutor* /*stream_exec*/, + DeviceMemoryAllocator* /*device_allocator*/) { VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -450,7 +451,8 @@ StatusOr> CpuCompiler::RunHloPasses( StatusOr> CpuCompiler::RunBackend( std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec) { + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* /*device_allocator*/) { const string timer_message = "Compiling [" + module->name() + "] for CPU using JIT"; XLA_SCOPED_LOGGING_TIMER(timer_message); @@ -517,8 +519,8 @@ StatusOr> CpuCompiler::RunBackend( // ownership is std::moved. const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - const string xla_dump_hlo_proto_to = - module->config().debug_options().xla_dump_hlo_proto_to(); + const string xla_dump_optimized_hlo_proto_to = + module->config().debug_options().xla_dump_optimized_hlo_proto_to(); if (options::CpuParallelBackendRequested(module->config())) { VLOG(1) << "Using parallel cpu backend"; @@ -538,10 +540,10 @@ StatusOr> CpuCompiler::RunBackend( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!xla_dump_hlo_proto_to.empty()) { + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } // If we are using the parallel CPU backend, we need to create map from @@ -647,10 +649,10 @@ StatusOr> CpuCompiler::RunBackend( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!xla_dump_hlo_proto_to.empty()) { + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } // Each computation is a single function. Emit all embedded computations @@ -826,12 +828,12 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - const string xla_dump_hlo_proto_to = - module->config().debug_options().xla_dump_hlo_proto_to(); - if (!xla_dump_hlo_proto_to.empty()) { + const string xla_dump_optimized_hlo_proto_to = + module->config().debug_options().xla_dump_optimized_hlo_proto_to(); + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } IrEmitter ir_emitter(*module, *assignment, &llvm_module, @@ -886,8 +888,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, options::OptimizeForSizeRequested(module->config()), module->config().debug_options().xla_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), - CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook, - post_optimization_ir_dump_hook); + pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); llvm::object::OwningBinary object_file = compiler_functor(llvm_module); llvm::StringRef object_file_data_ref = object_file.getBinary()->getData(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index ebed7058d8f7968c6e03ef90d0da6b2325037eb0..3498139ab95d21383c6dc008ae5614b7bfe91148 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -118,11 +118,13 @@ class CpuCompiler : public LLVMCompiler { StatusOr> RunHloPasses( std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::vector> modules, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 802d0a6fb46890b31d14b1fbf3b2e7d6520caccc..c053703c3524a47ee1de9681c1b986edbf109430 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -63,7 +63,7 @@ CpuExecutable::CpuExecutable( assignment_(std::move(assignment)) { // Resolve symbols in the constructor rather than at execution time to avoid // races because FindSymbol is not thread safe. - llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name); + llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry_function_name); // We expect to find the symbol provided with entry_function_name; otherwise // this is an internal error. CHECK(sym) << "Symbol " << entry_function_name << " not found."; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 1ef45dbec39a0880ebb123ba3fcd1fd6c89eb39a..40ace963270e8cead47cc731cc326351178dff7d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -35,6 +35,8 @@ extern const char* const kEigenMatMulF32SymbolName = "__xla_cpu_runtime_EigenMatMulF32"; extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; +extern const char* const kEigenConvF16SymbolName = + "__xla_cpu_runtime_EigenConvF16"; extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; @@ -42,6 +44,8 @@ extern const char* const kEigenSingleThreadedMatMulF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; extern const char* const kEigenSingleThreadedMatMulF64SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF64"; +extern const char* const kEigenSingleThreadedConvF16SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedConvF16"; extern const char* const kEigenSingleThreadedConvF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedConvF32"; extern const char* const kAcquireInfeedBufferForDequeueSymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 3e1f08071119c938619d02777513e5b834077118..2141dfe1cedd6f9674acc348152574b4fd30895b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -43,10 +43,12 @@ namespace runtime { // because it is a symbol in the cpu_runtime library. extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; +extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; +extern const char* const kEigenSingleThreadedConvF16SymbolName; extern const char* const kEigenSingleThreadedConvF32SymbolName; extern const char* const kAcquireInfeedBufferForDequeueSymbolName; extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc deleted file mode 100644 index b1c1142e8d988be2ca00809b4be505466071c72f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/Eigen/Core" - -#ifdef TF_XLA_HAS_AVX -xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX( - xla::cpu::runtime::V8F32AVX x) { - return Eigen::internal::pexp(x); -} - -xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX( - xla::cpu::runtime::V8F32AVX x) { - return Eigen::internal::plog(x); -} -#endif // TF_XLA_HAS_AVX - -namespace xla { -namespace cpu { -namespace runtime { - -const char *const kExpV8F32AVXSymbolName = "__xla_cpu_runtime_ExpV8F32AVX"; -const char *const kLogV8F32AVXSymbolName = "__xla_cpu_runtime_LogV8F32AVX"; - -} // namespace runtime -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h deleted file mode 100644 index e5c782f93f54dc9f8f76fce7e4735a60e8847583..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This header declares functions which may be called by the generated code on -// the CPU. Calls to these functions must be resolved explicitly in the JIT in -// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context -// which is used to cache expensive state and resources utilized by the -// aforementioned functions. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ - -#include "tensorflow/core/platform/macros.h" - -#if defined(__AVX__) -#include -#define TF_XLA_HAS_AVX -#endif - -namespace xla { -namespace cpu { -namespace runtime { - -extern const char *const kExpV8F32AVXSymbolName; -extern const char *const kLogV8F32AVXSymbolName; - -#ifdef TF_XLA_HAS_AVX -typedef __m256 V8F32AVX; -#endif -} // namespace runtime -} // namespace cpu -} // namespace xla - -extern "C" { - -#ifdef TF_XLA_HAS_AVX -// The following functions are vectorized versions of a selection of libm -// library functions. -// References to these functions are created by the LLVM vectorizer. -xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX( - xla::cpu::runtime::V8F32AVX x); - -xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX( - xla::cpu::runtime::V8F32AVX x); -#endif -} - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc deleted file mode 100644 index 8099b722f10ecb83f7cf6c58ba2abb783478b97f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/Eigen/Core" - -#ifdef TF_XLA_HAS_NEON - -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON( - xla::cpu::runtime::V4F32NEON x) { - return Eigen::internal::pexp(x); -} - -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON( - xla::cpu::runtime::V4F32NEON x) { - Eigen::internal::Packet4f p = x; - return Eigen::internal::plog(p); -} - -#endif // TF_XLA_HAS_NEON - -namespace xla { -namespace cpu { -namespace runtime { - -const char *const kExpV4F32NEONSymbolName = "__xla_cpu_runtime_ExpV4F32NEON"; -const char *const kLogV4F32NEONSymbolName = "__xla_cpu_runtime_LogV4F32NEON"; - -} // namespace runtime -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h deleted file mode 100644 index 2f5d1a872aaf3868d6d27f88a4f05c778d45660f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ - -// This header declares functions which may be called by the generated code on -// the CPU. Calls to these functions must be resolved explicitly in the JIT in -// xla::cpu::SimpleResolver. - -#include "tensorflow/core/platform/macros.h" - -#ifdef __ARM_NEON__ -// For the other runtimes (AVX, SSE4.1) we define the vector type directly using -// __attribute__((__vector_size__(*))). Unfortunately, the typedef for the ARM -// NEON SIMD types is not portable, so the type has to come from -#include -#define TF_XLA_HAS_NEON -#endif // __ARM_NEON__ - -namespace xla { -namespace cpu { -namespace runtime { - -extern const char *const kExpV4F32NEONSymbolName; -extern const char *const kLogV4F32NEONSymbolName; - -#ifdef TF_XLA_HAS_NEON -typedef float32x4_t V4F32NEON; -#endif // TF_XLA_HAS_NEON - -} // namespace runtime -} // namespace cpu -} // namespace xla - -extern "C" { - -#ifdef TF_XLA_HAS_NEON -// The following functions are vectorized versions of a selection of libm -// library functions. -// References to these functions are created by the LLVM vectorizer. -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON( - xla::cpu::runtime::V4F32NEON x); - -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON( - xla::cpu::runtime::V4F32NEON x); -#endif // TF_XLA_HAS_NEON -} - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc deleted file mode 100644 index d8ecf231cc8c859ac88e1ef1478f7107cd86a052..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/Eigen/Core" - -#ifdef TF_XLA_HAS_SSE4_1 - -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE( - xla::cpu::runtime::V4F32SSE x) { - Eigen::internal::Packet4f p = x; - return Eigen::internal::pexp(p); -} - -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( - xla::cpu::runtime::V4F32SSE x) { - Eigen::internal::Packet4f p = x; - return Eigen::internal::plog(p); -} - -#endif // TF_XLA_HAS_SSE4_1 - -namespace xla { -namespace cpu { -namespace runtime { - -const char *const kExpV4F32SSESymbolName = "__xla_cpu_runtime_ExpV4F32SSE"; -const char *const kLogV4F32SSESymbolName = "__xla_cpu_runtime_LogV4F32SSE"; - -} // namespace runtime -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h deleted file mode 100644 index aeb1eda23f76a6b5cb520b6673e0a011fa1130c7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This header declares functions which may be called by the generated code on -// the CPU. Calls to these functions must be resolved explicitly in the JIT in -// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context -// which is used to cache expensive state and resources utilized by the -// aforementioned functions. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ - -#include "tensorflow/core/platform/macros.h" - -// MSVC does not have __SSE4_1__ macro. Eigen enables EIGEN_VECTORIZE_SSE4_1 -// when __AVX__ is defined, we should do the same. -#if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__)) -#include -#define TF_XLA_HAS_SSE4_1 -#endif - -namespace xla { -namespace cpu { -namespace runtime { - -extern const char *const kExpV4F32SSESymbolName; -extern const char *const kLogV4F32SSESymbolName; - -#ifdef TF_XLA_HAS_SSE4_1 -typedef __m128 V4F32SSE; -#endif - -} // namespace runtime -} // namespace cpu -} // namespace xla - -extern "C" { - -#ifdef TF_XLA_HAS_SSE4_1 -// The following functions are vectorized versions of a selection of libm -// library functions. -// References to these functions are created by the LLVM vectorizer. -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE( - xla::cpu::runtime::V4F32SSE x); - -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( - xla::cpu::runtime::V4F32SSE x); -#endif -} - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index c9fc586b9a4c06eb9e1f111d8f9bd2f717990aab..cfe7c9c3af0be109ac8a86753e880e2bcbceba41 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -549,7 +549,7 @@ DotOpEmitter::DotOpEmitter( const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); - TF_RET_CHECK(F32 == type || F64 == type || C64 == type); + TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, lhs_array, rhs_array, addend_array, executable_run_options_value, ir_builder, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 71e81331897a8bb82438dd5160d2964cb88fd31f..4dffaee87f6b33933b58c8c58478eec918569197 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -479,7 +479,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { Status IrEmitter::HandleSort(HloInstruction* sort) { // TODO(b/26783907): Implement sort on CPU. - return Unimplemented("Sort is not supported on CPU (b/26783907)."); + return Unimplemented("Sort is not implemented on CPU."); } Status IrEmitter::HandleTuple(HloInstruction* tuple) { @@ -522,7 +522,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { return Unimplemented( - "Dilation for reduce-window not implemented on CPU. See b/31410564."); + "Dilation for ReduceWindow is not implemented on CPU."); } // The called computation should have been emitted previously. @@ -625,8 +625,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // TODO(b/31410564): Implement dilation for select-and-scatter. if (window_util::HasDilation(window)) { return Unimplemented( - "Dilation for select-and-scatter not implemented on CPU. " - "See b/31410564."); + "Dilation for SelectAndScatter is not implemented on CPU. "); } // The select and scatter computations should have been emitted previously. @@ -802,7 +801,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { auto rhs = dot->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32, F64, C64})); + /*supported_types=*/{F16, F32, F64, C64})); const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); if (dnums.lhs_batch_dimensions_size() > 0 || dnums.rhs_batch_dimensions_size() > 0) { @@ -850,7 +849,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { const auto& window = convolution->window(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32, C64})); + /*supported_types=*/{F16, F32, C64})); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); @@ -929,25 +928,30 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { int64 rhs_col_dilation = one_dim_convolution ? 1 : window.dimensions(1).window_dilation(); - // Args have been computed, make the call. - llvm::Type* float_ptr_type = ir_builder_.getFloatTy()->getPointerTo(); + PrimitiveType primitive_type = lhs->shape().element_type(); + llvm::Type* ir_ptr_type = primitive_type == F16 + ? ir_builder_.getHalfTy()->getPointerTo() + : ir_builder_.getFloatTy()->getPointerTo(); llvm::Type* int64_type = ir_builder_.getInt64Ty(); llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); llvm::FunctionType* conv_type = llvm::FunctionType::get( ir_builder_.getVoidTy(), - {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type, - int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type}, + {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type, + int64_type, int64_type, int64_type, int64_type, int64_type, + int64_type, int64_type, int64_type, int64_type, int64_type, + int64_type, int64_type, int64_type, int64_type, int64_type, + int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); bool multi_threaded_eigen = hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); const char* fn_name = - (multi_threaded_eigen - ? runtime::kEigenConvF32SymbolName - : runtime::kEigenSingleThreadedConvF32SymbolName); + primitive_type == F16 + ? (multi_threaded_eigen + ? runtime::kEigenConvF16SymbolName + : runtime::kEigenSingleThreadedConvF16SymbolName) + : (multi_threaded_eigen + ? runtime::kEigenConvF32SymbolName + : runtime::kEigenSingleThreadedConvF32SymbolName); llvm::Function* conv_func = llvm::cast( module_->getOrInsertFunction(fn_name, conv_type)); conv_func->setCallingConv(llvm::CallingConv::C); @@ -957,9 +961,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { conv_func, { GetExecutableRunOptionsArgument(), ir_builder_.CreateBitCast( - GetEmittedValueFor(convolution), float_ptr_type), - ir_builder_.CreateBitCast(lhs_address, float_ptr_type), - ir_builder_.CreateBitCast(rhs_address, float_ptr_type), + GetEmittedValueFor(convolution), ir_ptr_type), + ir_builder_.CreateBitCast(lhs_address, ir_ptr_type), + ir_builder_.CreateBitCast(rhs_address, ir_ptr_type), ir_builder_.getInt64(input_batch), ir_builder_.getInt64(input_rows), ir_builder_.getInt64(input_cols), @@ -1196,8 +1200,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { } // TODO(b/33011107): Support cross replica sum on CPU. - return Unimplemented( - "Cross replica sum is not implemented on CPU. See b/33011107."); + return Unimplemented("CrossReplicaSum is not implemented on CPU."); } // Fills up the free variables in 'index_with_free_var' with values from @@ -1334,7 +1337,7 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( if (ShapeUtil::ElementIsComplex(root_shape)) { // TODO(b/65408531): Complex add could by done via bitcast to // Complex multiply would be more challenging. We could perhaps use a - // strided load to get all reals in a vector, all imags in a vector, or use + // strided load to get all reals in a vector, all images in a vector, or use // CreateShuffleVector on a bitcast to float x [2N]. *failure_reason = "complex values not supported"; return nullptr; @@ -1811,12 +1814,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { Status IrEmitter::HandleSend(HloInstruction* send) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Send is not implemented on CPU. See b/33942983."); + return Unimplemented("Send is not implemented on CPU."); } Status IrEmitter::HandleSendDone(HloInstruction* send_done) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Send-done is not implemented on CPU. See b/33942983."); + return Unimplemented("Send-done is not implemented on CPU."); } Status IrEmitter::HandleSlice(HloInstruction* slice) { @@ -1981,12 +1984,12 @@ Status IrEmitter::HandleDynamicUpdateSlice( Status IrEmitter::HandleRecv(HloInstruction* recv) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Recv is not implemented on CPU. See b/33942983."); + return Unimplemented("Recv is not implemented on CPU."); } Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Recv-done is not implemented on CPU. See b/33942983."); + return Unimplemented("Recv-done is not implemented on CPU."); } Status IrEmitter::HandlePad(HloInstruction* pad) { @@ -1995,10 +1998,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { for (auto& padding_dimension : pad->padding_config().dimensions()) { if (padding_dimension.edge_padding_low() < 0 || padding_dimension.edge_padding_high() < 0) { - return Unimplemented( - "Negative padding not supported in the CPU backend (b/34628603); " - "this should have been eliminated at the HLO level: %s", - pad->padding_config().ShortDebugString().c_str()); + return InternalErrorStrCat( + "Encountered negative padding in IrEmitter on CPU. " + "This should have been eliminated at the HLO level. ", + pad->ToString()); } } diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index ca8c290dd1c4959e42026c3917d37f8fc95a1011..2d6f2f3818a7bd4424aaa7d918ca86abef15c0e9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -209,9 +209,9 @@ std::vector GetArrayFunctionCallArguments( parameter_addresses[i], ir_builder->getInt8PtrTy(), AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, "_address_as_i8ptr"))); - llvm::Value* slot_in_param_adresses = ir_builder->CreateInBoundsGEP( + llvm::Value* slot_in_param_addresses = ir_builder->CreateInBoundsGEP( parameter_addresses_buffer, {ir_builder->getInt64(i)}); - ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_adresses); + ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); } const auto to_int8_ptr = [=](llvm::Value* ptr) { diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 0336fa61312e5cd626ae38ddd29875bff256212a..2e5cc96098241415b82f225afc81981f3e1069e0 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -20,6 +20,8 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -28,6 +30,10 @@ namespace runtime { const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32"; const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32"; +const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32"; +const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32"; +const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX"; +const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX"; namespace { llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, @@ -42,27 +48,23 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, } llvm::LLVMContext* context = &module->getContext(); - llvm::Type* float_type = llvm::Type::getFloatTy(*context); - llvm::VectorType* vector_type = - llvm::VectorType::get(float_type, vector_width); llvm::BasicBlock* vector_tanh_body = llvm::BasicBlock::Create(*context, "body", vector_tanh_function); llvm::IRBuilder<> ir_builder(vector_tanh_body); - llvm::FastMathFlags fast_math_flags; fast_math_flags.setFast(); ir_builder.setFastMathFlags(fast_math_flags); + VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "tanh_f32"); + llvm::Value* input = &*vector_tanh_function->arg_begin(); - CHECK_EQ(input->getType(), vector_type); + CHECK_EQ(input->getType(), vsl.vector_type()); // This implements the same rational interpolant as implemented in Eigen3. - llvm::Value* input_clamped = llvm_ir::EmitFloatMin( - llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(vector_type, -9.0), - &ir_builder), - llvm::ConstantFP::get(vector_type, 9.0), &ir_builder); + llvm::Value* input_clamped = + vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0)); std::array numerator_coeffs{ -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, @@ -73,31 +75,230 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, 4.89352518554385e-03f}; - llvm::Value* input_squared = - ir_builder.CreateFMul(input_clamped, input_clamped); - llvm::Value* numerator = - llvm::ConstantFP::get(vector_type, numerator_coeffs[0]); + llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped); + llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0])); for (int i = 1; i < numerator_coeffs.size(); i++) { - numerator = ir_builder.CreateFAdd( - ir_builder.CreateFMul(input_squared, numerator), - llvm::ConstantFP::get(vector_type, numerator_coeffs[i])); + numerator = + vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i])); } - numerator = ir_builder.CreateFMul(input_clamped, numerator); - llvm::Value* denominator = - llvm::ConstantFP::get(vector_type, denominator_coeffs[0]); + numerator = vsl.Mul(input_clamped, numerator); + + llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0])); for (int i = 1; i < denominator_coeffs.size(); i++) { - denominator = ir_builder.CreateFAdd( - ir_builder.CreateFMul(input_squared, denominator), - llvm::ConstantFP::get(vector_type, denominator_coeffs[i])); + denominator = vsl.MulAdd(input_squared, denominator, + GetIeeeF32(denominator_coeffs[i])); } - llvm::Value* result = ir_builder.CreateFDiv(numerator, denominator); + llvm::Value* result = vsl.Div(numerator, denominator); ir_builder.CreateRet(result); DCHECK(!llvm::verifyFunction(*vector_tanh_function)); return vector_tanh_function; } + +llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, + llvm::StringRef function_name, + int vector_width, + bool enable_fast_math) { + llvm::Function* vector_exp_function = module->getFunction(function_name); + if (vector_exp_function == nullptr) { + // If the function declaration is not present in the module, there can't be + // any calls to resolve. Don't emit the function in this case. + return nullptr; + } + + llvm::LLVMContext* context = &module->getContext(); + + llvm::BasicBlock* vector_exp_body = + llvm::BasicBlock::Create(*context, "body", vector_exp_function); + + llvm::IRBuilder<> ir_builder(vector_exp_body); + llvm::FastMathFlags fast_math_flags; + fast_math_flags.setFast(); + ir_builder.setFastMathFlags(fast_math_flags); + + VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "exp_f32"); + + // This implements the same polynomial approximation as implemented in Eigen3. + + const llvm::APFloat half = GetIeeeF32(0.5); + const llvm::APFloat one = GetIeeeF32(1.0); + + const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950); + const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949); + + const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341); + const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375); + const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4); + + const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4); + const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3); + const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3); + const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2); + const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1); + const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1); + + llvm::Value* input = &*vector_exp_function->arg_begin(); + llvm::Value* input_clamped = + vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi); + llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half)); + llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx); + llvm::Value* z = vsl.Mul(cephes_exp_C2, fx); + llvm::Value* x = vsl.Sub(input_clamped, tmp); + x = vsl.Sub(x, z); + z = vsl.Mul(x, x); + + llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1); + y = vsl.MulAdd(y, x, cephes_exp_p2); + y = vsl.MulAdd(y, x, cephes_exp_p3); + y = vsl.MulAdd(y, x, cephes_exp_p4); + y = vsl.MulAdd(y, x, cephes_exp_p5); + y = vsl.MulAdd(y, z, x); + y = vsl.Add(one, y); + + // VectorSupportLibrary (intentionally) can't juggle more than one type at a + // time so drop down to IRBuilder for this bit. + llvm::Value* vector_constant_0x7f = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); + llvm::Value* vector_constant_23 = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); + llvm::Type* i32_vector_type = + llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); + // fx is clamped so we don't have to worry about it being out of range for + // i32. + llvm::Value* emm0 = ir_builder.CreateFPToSI(fx, i32_vector_type); + emm0 = ir_builder.CreateAdd(emm0, vector_constant_0x7f); + emm0 = ir_builder.CreateShl(emm0, vector_constant_23); + llvm::Value* emm0_f32 = ir_builder.CreateBitCast(emm0, vsl.vector_type()); + + llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input); + + ir_builder.CreateRet(result); + + DCHECK(!llvm::verifyFunction(*vector_exp_function)); + return vector_exp_function; +} + +llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, + llvm::StringRef function_name, + int vector_width, + bool enable_fast_math) { + llvm::Function* vector_log_function = module->getFunction(function_name); + if (vector_log_function == nullptr) { + // If the function declaration is not present in the module, there can't be + // any calls to resolve. Don't emit the function in this case. + return nullptr; + } + + llvm::LLVMContext* context = &module->getContext(); + + llvm::BasicBlock* vector_log_body = + llvm::BasicBlock::Create(*context, "body", vector_log_function); + + llvm::IRBuilder<> ir_builder(vector_log_body); + llvm::FastMathFlags fast_math_flags; + fast_math_flags.setFast(); + ir_builder.setFastMathFlags(fast_math_flags); + + llvm::Value* input = &*vector_log_function->arg_begin(); + VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32"); + + const llvm::APFloat half = GetIeeeF32(0.5); + const llvm::APFloat one = GetIeeeF32(1.0); + + // This implements the same polynomial approximation as implemented in Eigen3. + // Returns NaN for x < 0, -INF for x = 0 + const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524); + const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2); + const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1); + const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1); + const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1); + const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1); + const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1); + const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1); + const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1); + const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1); + const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4); + const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375); + + // The smallest non denormalized float number. + const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000); + const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000); + const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000); + + // invalid_mask is set if x is negative or NaN (and therefore output + // must be NaN). + llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector()); + llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); + + // Cut off denormalized stuff. + input = vsl.Max(min_norm_pos, input); + + // VectorSupportLibrary (intentionally) can't juggle more than one type at a + // time so drop down to IRBuilder for this bit. + llvm::Value* vector_constant_0x7f = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); + llvm::Value* vector_constant_23 = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); + llvm::Type* i32_vector_type = + llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); + + llvm::Value* emm0 = ir_builder.CreateLShr( + ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23); + + // Keep only the fractional part. + input = vsl.FloatAnd(input, inv_mant_mask); + input = vsl.FloatOr(input, half); + + emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f); + llvm::Value* e = + vsl.Add(one, ir_builder.CreateSIToFP(emm0, vsl.vector_type())); + + // part2: + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF); + llvm::Value* tmp = vsl.FloatAnd(input, mask); + input = vsl.Sub(input, one); + e = vsl.Sub(e, vsl.FloatAnd(mask, one)); + input = vsl.Add(input, tmp); + + llvm::Value* x2 = vsl.Mul(input, input); + llvm::Value* x3 = vsl.Mul(x2, input); + + llvm::Value *y, *y1, *y2; + y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1); + y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4); + y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7); + y = vsl.MulAdd(y, input, cephes_log_p2); + y1 = vsl.MulAdd(y1, input, cephes_log_p5); + y2 = vsl.MulAdd(y2, input, cephes_log_p8); + y = vsl.MulAdd(y, x3, y1); + y = vsl.MulAdd(y, x3, y2); + y = vsl.Mul(y, x3); + + y1 = vsl.Mul(cephes_log_q1, e); + tmp = vsl.Mul(half, x2); + y = vsl.Add(y, y1); + input = vsl.Sub(input, tmp); + y2 = vsl.Mul(cephes_log_q2, e); + input = vsl.Add(input, y); + input = vsl.Add(input, y2); + + // Negative arg will be NAN, 0 will be -INF. + llvm::Value* or_lhs = + vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask)); + llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf); + llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs); + + ir_builder.CreateRet(result); + + DCHECK(!llvm::verifyFunction(*vector_log_function)); + return vector_log_function; +} } // namespace void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { @@ -108,11 +309,28 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName, /*vector_width=*/8, enable_fast_math); + auto* exp_v4f32 = + EmitVectorF32ExpIfNeeded(module, kExpV4F32SymbolName, + /*vector_width=*/4, enable_fast_math); + auto* exp_v8f32 = + EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName, + /*vector_width=*/8, enable_fast_math); + + auto* log_v4f32 = + EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName, + /*vector_width=*/4, enable_fast_math); + auto* log_v8f32 = + EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName, + /*vector_width=*/8, enable_fast_math); + // Gather all the call sites, force inline them and then delete the vector // function bodies. + // + // TODO(b/73081976): Should we avoid inlining these intrinsics in some cases? std::vector calls_to_inline; - for (auto* function : {tanh_v4f32, tanh_v8f32}) { + for (auto* function : + {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { if (function != nullptr) { for (auto* user : function->users()) { calls_to_inline.push_back(llvm::cast(user)); @@ -125,7 +343,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); } - for (auto* function : {tanh_v4f32, tanh_v8f32}) { + for (auto* function : + {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { if (function != nullptr) { function->eraseFromParent(); } diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h index 7f31fb98b0d03c16ef40bff9822227e01f6be46b..5553972677512617ccb6ac4f57a4d33400b664e3 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTINE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTINE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTIME_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTIME_H_ #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -25,6 +25,10 @@ namespace runtime { extern const char* const kTanhV4F32SymbolName; extern const char* const kTanhV8F32SymbolName; +extern const char* const kExpV4F32SymbolName; +extern const char* const kExpV8F32SymbolName; +extern const char* const kLogV4F32SymbolName; +extern const char* const kLogV8F32SymbolName; // The following CPU runtime functions have LLVM-IR only implementations: // @@ -40,4 +44,4 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math); } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTINE_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTIME_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index cd997f07890cdc1d9a546ede58cc1d992b6416ae..07a9f0efcb64db4b2ff0c6518d4b48eee9a505e0 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -394,7 +394,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( for (auto& entry : *function_names_) { tensorflow::mutex_lock lock(jit_mutex_); HloInstruction* instruction = entry.first; - llvm::JITSymbol sym = jit_->FindSymbol(entry.second); + llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry.second); TF_RET_CHECK(sym); InsertOrDie( &functions, instruction, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc index c2f64eb27a554d17ebe2a94dba334fe378bd7254..3905e7ff2a14d25813e345399e692f9e0f4bd0af 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc @@ -34,7 +34,26 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF32( int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); - tensorflow::xla::EigenConvF32Impl( + tensorflow::xla::EigenConvImpl( + *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, + input_rows, input_cols, input_channels, kernel_rows, kernel_cols, + kernel_channels, kernel_filters, output_rows, output_cols, row_stride, + col_stride, padding_top, padding_bottom, padding_left, padding_right, + lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols, + int64 input_channels, int64 kernel_rows, int64 kernel_cols, + int64 kernel_channels, int64 kernel_filters, int64 output_rows, + int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top, + int64 padding_bottom, int64 padding_left, int64 padding_right, + int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation, + int64 rhs_col_dilation) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + tensorflow::xla::EigenConvImpl( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h index 05ae094691fd9a7ca83b902145c0750fafdc529a..39e20ed45639040110b99ddb52eb6f6dab26dfaa 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h @@ -34,6 +34,20 @@ extern void __xla_cpu_runtime_EigenConvF32( tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation); +extern void __xla_cpu_runtime_EigenConvF16( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, + tensorflow::int64 input_batch, tensorflow::int64 input_rows, + tensorflow::int64 input_cols, tensorflow::int64 input_channels, + tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols, + tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters, + tensorflow::int64 output_rows, tensorflow::int64 output_cols, + tensorflow::int64 row_stride, tensorflow::int64 col_stride, + tensorflow::int64 padding_top, tensorflow::int64 padding_bottom, + tensorflow::int64 padding_left, tensorflow::int64 padding_right, + tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation, + tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation); + } // extern "C" #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h index 02f45fee0f1b8cd1125ec6a97f01e0028137bb69..85af63bb032ce33bdd188d6e5bcd78a726d5d9fa 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h @@ -24,26 +24,27 @@ limitations under the License. namespace tensorflow { namespace xla { -template -void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs, - float* rhs, int64 input_batch, int64 input_rows, - int64 input_cols, int64 input_channels, int64 kernel_rows, - int64 kernel_cols, int64 kernel_channels, - int64 kernel_filters, int64 output_rows, - int64 output_cols, int64 row_stride, int64 col_stride, - int64 padding_top, int64 padding_bottom, - int64 padding_left, int64 padding_right, - int64 lhs_row_dilation, int64 lhs_col_dilation, - int64 rhs_row_dilation, int64 rhs_col_dilation) { - const Eigen::TensorMap, +template +void EigenConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs, + ScalarType* rhs, int64 input_batch, int64 input_rows, + int64 input_cols, int64 input_channels, int64 kernel_rows, + int64 kernel_cols, int64 kernel_channels, + int64 kernel_filters, int64 output_rows, int64 output_cols, + int64 row_stride, int64 col_stride, int64 padding_top, + int64 padding_bottom, int64 padding_left, + int64 padding_right, int64 lhs_row_dilation, + int64 lhs_col_dilation, int64 rhs_row_dilation, + int64 rhs_col_dilation) { + const Eigen::TensorMap, Eigen::Aligned> input(lhs, input_batch, input_rows, input_cols, input_channels); - const Eigen::TensorMap, + const Eigen::TensorMap, Eigen::Aligned> kernel(rhs, kernel_rows, kernel_cols, kernel_channels, kernel_filters); - Eigen::TensorMap, Eigen::Aligned> + Eigen::TensorMap, + Eigen::Aligned> output(out, input_batch, output_rows, output_cols, kernel_filters); Eigen::array, 1> contract_dims; @@ -75,7 +76,7 @@ void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs, row_stride, rhs_col_dilation, rhs_row_dilation, lhs_col_dilation, lhs_row_dilation, padding_left, padding_right, padding_top, - padding_bottom, 0.0f) + padding_bottom, static_cast(0.0f)) .reshape(pre_contract_dims) .contract(kernel.reshape(kernel_dims), contract_dims) .reshape(post_contract_dims); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc index d0b0e11ac0f9fd06e384c2bb5e6296edd0825f5c..5afccc6a86e2df468e3e3e874cf0f4d4e1342a88 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc @@ -21,6 +21,24 @@ limitations under the License. using tensorflow::int64; +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedConvF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols, + int64 input_channels, int64 kernel_rows, int64 kernel_cols, + int64 kernel_channels, int64 kernel_filters, int64 output_rows, + int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top, + int64 padding_bottom, int64 padding_left, int64 padding_right, + int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation, + int64 rhs_col_dilation) { + tensorflow::xla::EigenConvImpl( + Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, + input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, + kernel_filters, output_rows, output_cols, row_stride, col_stride, + padding_top, padding_bottom, padding_left, padding_right, + lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation); +} + TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedConvF32( const void* run_options_ptr, float* out, float* lhs, float* rhs, @@ -30,7 +48,7 @@ __xla_cpu_runtime_EigenSingleThreadedConvF32( int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom, int64 padding_left, int64 padding_right, int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { - tensorflow::xla::EigenConvF32Impl( + tensorflow::xla::EigenConvImpl( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, col_stride, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h index 8ae1a42149bde26ca2f510ad47e76ae47f34a977..f216bd0152aa93b8753d881938c63a9cabea899b 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h @@ -20,6 +20,20 @@ limitations under the License. extern "C" { +extern void __xla_cpu_runtime_EigenSingleThreadedConvF16( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, + tensorflow::int64 input_batch, tensorflow::int64 input_rows, + tensorflow::int64 input_cols, tensorflow::int64 input_channels, + tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols, + tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters, + tensorflow::int64 output_rows, tensorflow::int64 output_cols, + tensorflow::int64 row_stride, tensorflow::int64 col_stride, + tensorflow::int64 padding_top, tensorflow::int64 padding_bottom, + tensorflow::int64 padding_left, tensorflow::int64 padding_right, + tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation, + tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation); + extern void __xla_cpu_runtime_EigenSingleThreadedConvF32( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, float* lhs, float* rhs, tensorflow::int64 input_batch, diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index de5e9b411905a37a7db7d05f51cca2802c1526ed..aa8d4ad9dc51b2c1f500898f8bbd2c548f710643 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -21,15 +21,13 @@ limitations under the License. #include #include "llvm/ExecutionEngine/ExecutionEngine.h" +#include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Mangler.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Host.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" @@ -46,36 +44,6 @@ namespace xla { namespace cpu { namespace { -// A simple SymbolResolver that delegates to the host dynamic linker. -class SimpleResolver : public llvm::LegacyJITSymbolResolver { - public: - explicit SimpleResolver(ExternalConstantPool* external_constant_pool) - : external_constant_pool_(external_constant_pool) {} - - llvm::JITSymbol findSymbol(const std::string& name) override { - if (const uint8* from_constant_pool = - external_constant_pool_->Find(string(name))) { - return llvm::JITEvaluatedSymbol( - reinterpret_cast(from_constant_pool), - llvm::JITSymbolFlags::None); - } - - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); - if (func_addr == nullptr) { - return nullptr; - } - llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), - llvm::JITSymbolFlags::None); - return symbol_info; - } - llvm::JITSymbol findSymbolInLogicalDylib(const std::string& name) override { - return nullptr; - } - - private: - ExternalConstantPool* external_constant_pool_; -}; - llvm::SmallVector DetectMachineAttributes() { llvm::SmallVector result; llvm::StringMap host_features; @@ -100,27 +68,6 @@ llvm::StringRef GetHostCpuName() { cpu_name.consume_back("-avx512"); return cpu_name; } - -CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { - CompilerFunctor::VectorIntrinsics intrinsics; -#ifdef TF_XLA_HAS_SSE4_1 - intrinsics.sse_intrinsics = true; -#else - intrinsics.sse_intrinsics = false; -#endif -#ifdef TF_XLA_HAS_AVX - intrinsics.avx_intrinsics = true; -#else - intrinsics.avx_intrinsics = false; -#endif -#ifdef TF_XLA_HAS_NEON - intrinsics.neon_intrinsics = true; -#else - intrinsics.neon_intrinsics = false; -#endif - return intrinsics; -} - } // namespace SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, @@ -139,49 +86,71 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), - object_layer_([] { - return std::make_shared( - orc_jit_memory_mapper::GetInstance()); - }), - compile_layer_( - object_layer_, - CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, - optimize_for_size, enable_fast_math, - disable_expensive_passes, GetAvailableIntrinsics(), - std::move(pre_optimization_hook), - std::move(post_optimization_hook))) { + execution_session_(string_pool_), + symbol_resolver_(llvm::orc::createLegacyLookupResolver( + [this](const std::string& name) -> llvm::JITSymbol { + return this->ResolveRuntimeSymbol(name); + }, + [](llvm::Error Err) { + cantFail(std::move(Err), "lookupFlags failed"); + })), + object_layer_(execution_session_, + [this](llvm::orc::VModuleKey) { + llvm::orc::RTDyldObjectLinkingLayer::Resources result; + result.MemMgr = + std::make_shared( + orc_jit_memory_mapper::GetInstance()); + result.Resolver = symbol_resolver_; + return result; + }), + compile_layer_(object_layer_, + CompilerFunctor(target_machine_.get(), &disassembler_, + opt_level, optimize_for_size, + enable_fast_math, disable_expensive_passes, + std::move(pre_optimization_hook), + std::move(post_optimization_hook))) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); } -SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule( - std::unique_ptr module) { - auto handle = cantFail(compile_layer_.addModule( - std::move(module), MakeUnique(external_constant_pool()))); - module_handles_.push_back(handle); - return handle; +llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { + if (const uint8* from_constant_pool = + external_constant_pool_.Find(string(name))) { + return llvm::JITEvaluatedSymbol( + reinterpret_cast(from_constant_pool), + llvm::JITSymbolFlags::None); + } + + void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + if (func_addr == nullptr) { + return nullptr; + } + llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), + llvm::JITSymbolFlags::None); + return symbol_info; } -void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::ModuleHandleT handle) { - module_handles_.erase( - std::remove(module_handles_.begin(), module_handles_.end(), handle), - module_handles_.end()); - cantFail(compile_layer_.removeModule(handle)); +SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( + std::unique_ptr module) { + auto key = execution_session_.allocateVModule(); + cantFail(compile_layer_.addModule(key, std::move(module))); + module_keys_.push_back(key); + return key; } -llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string& name) { - std::string mangled_name; - { - llvm::raw_string_ostream mangled_name_stream(mangled_name); - llvm::Mangler::getNameWithPrefix(mangled_name_stream, name, data_layout_); - } +void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) { + module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key), + module_keys_.end()); + cantFail(compile_layer_.removeModule(key)); +} +llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) { // Resolve symbol from last module to first, allowing later redefinitions of // symbols shadow earlier ones. - for (auto& handle : - llvm::make_range(module_handles_.rbegin(), module_handles_.rend())) { + for (auto& key : + llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) { if (auto symbol = - compile_layer_.findSymbolIn(handle, mangled_name, + compile_layer_.findSymbolIn(key, name, /*ExportedSymbolsOnly=*/true)) { return symbol; } @@ -208,25 +177,15 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); -#ifdef TF_XLA_HAS_NEON - REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON); - REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON); -#endif -#ifdef TF_XLA_HAS_SSE4_1 - REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE); - REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE); -#endif -#ifdef TF_XLA_HAS_AVX - REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX); - REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX); -#endif REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); @@ -273,15 +232,15 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(ilogb, int (*)(double)); REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int)); REGISTER_LIBM_SYMBOL(lgamma, double (*)(double)); - REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); - REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); + REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); // NOLINT(runtime/int) + REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); // NOLINT(runtime/int) REGISTER_LIBM_SYMBOL(log, double (*)(double)); REGISTER_LIBM_SYMBOL(log10, double (*)(double)); REGISTER_LIBM_SYMBOL(log1p, double (*)(double)); REGISTER_LIBM_SYMBOL(log2, double (*)(double)); REGISTER_LIBM_SYMBOL(logb, double (*)(double)); - REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); - REGISTER_LIBM_SYMBOL(lround, long (*)(double)); + REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); // NOLINT(runtime/int) + REGISTER_LIBM_SYMBOL(lround, long (*)(double)); // NOLINT(runtime/int) REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*)); REGISTER_LIBM_SYMBOL(nan, double (*)(const char*)); REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double)); @@ -292,7 +251,8 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*)); REGISTER_LIBM_SYMBOL(rint, double (*)(double)); REGISTER_LIBM_SYMBOL(round, double (*)(double)); - REGISTER_LIBM_SYMBOL(scalbln, double (*)(double, long)); + REGISTER_LIBM_SYMBOL(scalbln, + double (*)(double, long)); // NOLINT(runtime/int) REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int)); REGISTER_LIBM_SYMBOL(sin, double (*)(double)); #ifdef __APPLE__ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index ded01e9e4d7442296f7406dd035e6ab385458238..d0011e0a185cd0284d2f9334594f6e06d9284be7 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -21,8 +21,10 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" +#include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/SymbolStringPool.h" #include "llvm/IR/Module.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" @@ -48,7 +50,7 @@ class SimpleOrcJIT { std::function( llvm::Module&)>; using CompileLayerT = llvm::orc::IRCompileLayer; - using ModuleHandleT = CompileLayerT::ModuleHandleT; + using VModuleKeyT = llvm::orc::VModuleKey; // Create a new JIT, targeting the host architecture. // The |target_options| parameter allows customization of certain code @@ -78,16 +80,16 @@ class SimpleOrcJIT { return target_machine_->getTargetTriple(); } - // Add a module to the JIT. Returns an opaque handle that can be used to later + // Add a module to the JIT. Returns an opaque key that can be used to later // remove this module. - ModuleHandleT AddModule(std::unique_ptr module); + VModuleKeyT AddModule(std::unique_ptr module); // Remove a module from the JIT and free the memory associated with it. - void RemoveModule(ModuleHandleT handle); + void RemoveModule(VModuleKeyT key); // Get the runtime address of the compiled symbol whose name is given. Returns // nullptr if the symbol cannot be found. - llvm::JITSymbol FindSymbol(const std::string& name); + llvm::JITSymbol FindCompiledSymbol(const std::string& name); llvm::TargetMachine* target_machine() const { return target_machine_.get(); } @@ -96,10 +98,15 @@ class SimpleOrcJIT { } private: - std::vector module_handles_; + llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); + + std::vector module_keys_; std::unique_ptr target_machine_; const Disassembler disassembler_; const llvm::DataLayout data_layout_; + llvm::orc::SymbolStringPool string_pool_; + llvm::orc::ExecutionSession execution_session_; + std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; CompileLayerT compile_layer_; ExternalConstantPool external_constant_pool_; diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 128b465be239130918687d8e2ba0458684086ee1..150db1cb6edec1af6724a8bca6a5f6272f1a7416 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -35,8 +36,27 @@ VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); } +static string TypeToString(llvm::Type* type) { + std::string o; + llvm::raw_string_ostream ostream(o); + type->print(ostream); + return ostream.str(); +} + +void VectorSupportLibrary::AssertCorrectTypes( + std::initializer_list values) { + for (llvm::Value* v : values) { + llvm::Type* type = v->getType(); + if (type != scalar_type() && type != vector_type()) { + LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or " + << TypeToString(vector_type()) << " but got " + << TypeToString(type); + } + } +} + llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { - CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + AssertCorrectTypes({lhs, rhs}); return MulInternal(lhs, rhs); } @@ -50,10 +70,128 @@ llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, } llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { - CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + AssertCorrectTypes({lhs, rhs}); return AddInternal(lhs, rhs); } +llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return ir_builder()->CreateFSub(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + if (scalar_type_->isFloatingPointTy()) { + return llvm_ir::EmitFloatMax(lhs, rhs, ir_builder_); + } else { + LOG(FATAL) << "Max for integers is unimplemented"; + } +} + +llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) { + AssertCorrectTypes({a}); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a}, + {a->getType()}, ir_builder()); +} + +llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFDiv(lhs, rhs, name()); + } else { + LOG(FATAL) << "Division for integers is unimplemented"; + } +} + +llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, + const llvm::APFloat& low, + const llvm::APFloat& high) { + AssertCorrectTypes({a}); + llvm::Type* type = a->getType(); + CHECK(low.compare(high) == llvm::APFloat::cmpLessThan); + CHECK(scalar_type_->isFloatingPointTy()); + return llvm_ir::EmitFloatMin( + llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_), + GetConstantFloat(type, high), ir_builder_); +} + +llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name())); +} + +llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name())); +} + +llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name())); +} + +llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) { + bool is_vector = llvm::isa(i1->getType()); + llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector); + return ir_builder()->CreateBitCast( + ir_builder()->CreateSExt(i1, integer_type, name()), + is_vector ? vector_type() : scalar_type(), name()); +} + +llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { + CHECK(scalar_type()->isFloatingPointTy()); + const llvm::DataLayout& data_layout = + ir_builder()->GetInsertBlock()->getModule()->getDataLayout(); + int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type()); + llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits); + if (vector) { + return llvm::VectorType::get(scalar_int_type, vector_size()); + } else { + return scalar_int_type; + } +} + +llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) { + CHECK_EQ(x->getType(), scalar_type()); + return ir_builder()->CreateVectorSplat(vector_size(), x, name()); +} + +llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + llvm::Type* int_type = + IntegerTypeForFloatSize(lhs->getType() == vector_type()); + return ir_builder()->CreateBitCast( + ir_builder()->CreateAnd( + ir_builder()->CreateBitCast(lhs, int_type, name()), + ir_builder()->CreateBitCast(rhs, int_type, name()), name()), + vector_type()); +} + +llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) { + AssertCorrectTypes({lhs}); + llvm::Type* int_type = + IntegerTypeForFloatSize(lhs->getType() == vector_type()); + return ir_builder()->CreateBitCast( + ir_builder()->CreateNot( + ir_builder()->CreateBitCast(lhs, int_type, name()), name()), + vector_type()); +} + +llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + llvm::Type* int_type = + IntegerTypeForFloatSize(lhs->getType() == vector_type()); + return ir_builder()->CreateBitCast( + ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()), + ir_builder()->CreateBitCast(rhs, int_type, name()), + name()), + vector_type(), name()); +} + llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, llvm::Value* rhs) { if (scalar_type_->isFloatingPointTy()) { @@ -93,6 +231,7 @@ llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { void VectorSupportLibrary::StoreVector(llvm::Value* value, llvm::Value* pointer) { + AssertCorrectTypes({value}); if (pointer->getType() != vector_pointer_type()) { pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type()); } @@ -102,6 +241,7 @@ void VectorSupportLibrary::StoreVector(llvm::Value* value, void VectorSupportLibrary::StoreScalar(llvm::Value* value, llvm::Value* pointer) { + AssertCorrectTypes({value}); if (pointer->getType() != scalar_pointer_type()) { pointer = ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 8fbac2a6670f8ef18c00877a1566bd4ab896a7c8..6479bf76aab581ae3ec2923d98dab53720cab203 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -26,6 +26,16 @@ limitations under the License. namespace xla { namespace cpu { + +// Simple wrappers around llvm::APFloat::APFloat to make the calling code more +// obvious. + +inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); } +inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) { + return llvm::APFloat(llvm::APFloat::IEEEsingle(), + llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value)); +} + // A thin wrapper around llvm_util.h to make code generating vector math flow // more readable. class VectorSupportLibrary { @@ -41,16 +51,96 @@ class VectorSupportLibrary { llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { return Mul(ir_builder()->getInt64(lhs), rhs); } + llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) { + return Mul(GetConstantFloat(rhs->getType(), lhs), rhs); + } + + // If your call resolved to these then you probably wanted the versions taking + // APFloat. + llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete; + llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete; llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* Add(int64 lhs, llvm::Value* rhs) { return Add(ir_builder()->getInt64(lhs), rhs); } + llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) { + return Add(GetConstantFloat(rhs->getType(), lhs), rhs); + } + + // If your call resolved to these then you probably wanted the versions taking + // APFloat. + llvm::Value* Add(double lhs, llvm::Value* rhs) = delete; + llvm::Value* Add(float lhs, llvm::Value* rhs) = delete; + + llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) { + return Sub(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) { + return Max(GetConstantFloat(rhs->getType(), lhs), rhs); + } + llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) { return Add(c, Mul(a, b)); } + llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) { + return Add(GetConstantFloat(vector_type(), c), Mul(a, b)); + } + + llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b, + const llvm::APFloat& c) { + return Add(GetConstantFloat(a->getType(), c), + Mul(a, GetConstantFloat(a->getType(), b))); + } + + llvm::Value* Floor(llvm::Value* a); + + llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low, + const llvm::APFloat& high); + llvm::Value* SplatFloat(const llvm::APFloat& d) { + return GetConstantFloat(vector_type(), d); + } + + // These compare instructions return a floating point typed mask instead of an + // i1. For instance, on a vector typed input, lanes where the predicate is + // true get a float with all ones and other lanes get a float with all zeros. + // This is slightly odd from the perspective of LLVM's type system, but it + // makes kernel IR generation code written using VectorSupportLibrary (its + // raison d'etre) less cluttered. + + llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + + // These boolean operations operate on the bitwise values of the floating + // point inputs. They return a (vector of) float(s) but like in the mask + // generating predicates above this type system oddity makes the kernel IR + // generation code less cluttered. + llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + llvm::Value* FloatNot(llvm::Value* lhs); + llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) { + return FloatAnd(FloatNot(lhs), rhs); + } + + llvm::Value* BroadcastScalar(llvm::Value* x); + llvm::Value* BroadcastScalar(const llvm::APFloat& d) { + return BroadcastScalar(GetConstantFloat(scalar_type(), d)); + } + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, llvm::Value* offset_elements); llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, @@ -144,6 +234,11 @@ class VectorSupportLibrary { llvm::Value* AddReduce(llvm::Value* vector); + // Checks that each value in `values` is either of type scalar_type() or + // vector_type(). This LOG(FATAL)'s so it should only be called in cases + // where a mismatching type is a programmer bug. + void AssertCorrectTypes(std::initializer_list values); + // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The // resulting IR for an 8-float wide vector is expected to lower to a single // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in @@ -163,6 +258,16 @@ class VectorSupportLibrary { std::vector ComputeAvxOptimizedHorizontalSums( std::vector vectors, llvm::Value* init_values); + llvm::Type* IntegerTypeForFloatSize(bool vector); + llvm::Value* I1ToFloat(llvm::Value* i1); + llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) { + llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f); + if (llvm::isa(type)) { + return llvm::ConstantVector::getSplat(vector_size(), scalar_value); + } + return scalar_value; + } + int64 vector_size_; PrimitiveType primitive_type_; llvm::IRBuilder<>* ir_builder_; diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index 2e4b0a5230516b5308aeed892de9a49565a09f2e..78e7aa48accdbb51a8477455f5f9c004828c068f 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -24,7 +24,7 @@ limitations under the License. namespace xla { StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( - perftools::gputools::Platform* platform, + const perftools::gputools::Platform* platform, tensorflow::gtl::ArraySlice stream_executors) : DeviceMemoryAllocator(platform), diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index 00caefab667cba6abfef200050ca18f229fc0320..39dfad84c1c1c1c461c24de555ecd919cea47d83 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -33,7 +33,7 @@ class DeviceMemoryAllocator { public: // Parameter platform indicates which platform the allocator allocates memory // on. Must be non-null. - explicit DeviceMemoryAllocator(perftools::gputools::Platform* platform) + explicit DeviceMemoryAllocator(const perftools::gputools::Platform* platform) : platform_(platform) {} virtual ~DeviceMemoryAllocator() {} @@ -49,14 +49,14 @@ class DeviceMemoryAllocator { int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0; // Return the platform that the allocator allocates memory on. - perftools::gputools::Platform* platform() const { return platform_; } + const perftools::gputools::Platform* platform() const { return platform_; } // Can we call Deallocate() as soon as a computation has been scheduled on // a stream, or do we have to wait for the computation to complete first? virtual bool AllowsAsynchronousDeallocation() const = 0; protected: - perftools::gputools::Platform* platform_; + const perftools::gputools::Platform* platform_; }; // Default memory allocator for a platform which uses @@ -64,7 +64,7 @@ class DeviceMemoryAllocator { class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { public: StreamExecutorMemoryAllocator( - perftools::gputools::Platform* platform, + const perftools::gputools::Platform* platform, tensorflow::gtl::ArraySlice stream_executors); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index a803b3171f9afa6297553c5507c4f9aa45e420ab..56723e765048698baedc50ae7b189d0287ee56b8 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -190,6 +190,7 @@ class DfsHloVisitorBase { virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; + virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; virtual Status HandleRng(HloInstructionPtr hlo) = 0; virtual Status HandleReverse(HloInstructionPtr hlo) = 0; virtual Status HandleSort(HloInstructionPtr hlo) = 0; @@ -213,6 +214,7 @@ class DfsHloVisitorBase { virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; virtual Status HandleWhile(HloInstructionPtr hlo) = 0; virtual Status HandleConditional(HloInstructionPtr hlo) = 0; + virtual Status HandleGather(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 170adb3d241b3648bc53f96dde9866f0b794f80a..ecda5288ee17a3856ce95f0caa327c3524fd180b 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -103,6 +103,9 @@ class DfsHloVisitorWithDefaultBase Status HandleOutfeed(HloInstructionPtr outfeed) override { return DefaultAction(outfeed); } + Status HandleHostCompute(HloInstructionPtr host_compute) override { + return DefaultAction(host_compute); + } Status HandleReverse(HloInstructionPtr reverse) override { return DefaultAction(reverse); } @@ -185,6 +188,9 @@ class DfsHloVisitorWithDefaultBase Status HandleSendDone(HloInstructionPtr send_done) override { return DefaultAction(send_done); } + Status HandleGather(HloInstructionPtr gather) override { + return DefaultAction(gather); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 9780bac16ec17eed2c1df64f01bcb753e26b46f0..4468adbadbf823f1420a8b665a26f66cb7d36b43 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -428,7 +428,7 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm::Intrinsic::round, {operand_value}, {operand_value->getType()}, ir_builder_); case HloOpcode::kSign: { - // TODO(b/32151903): Ensure consistent sign behavior for -0.0 + // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero); @@ -870,7 +870,10 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Value* x) const { if (prim_type != F32) { - return Unimplemented("inverse erf only implemented for F32 (b/34339814)"); + // TODO(b/34339814): Implement inverse erf for F64. + return Unimplemented( + "Inverse erf is only implemented for element " + "type F32."); } auto getFloat = [&](const float f) { return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f); @@ -1040,17 +1043,9 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, lhs_value, rhs_value, ir_builder_); case HloOpcode::kMinimum: - return ir_builder_->CreateSelect( - ir_builder_->CreateICmp( - is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, - lhs_value, rhs_value), - lhs_value, rhs_value); + return EmitIntegralMin(lhs_value, rhs_value, is_signed); case HloOpcode::kMaximum: - return ir_builder_->CreateSelect( - ir_builder_->CreateICmp( - is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, - lhs_value, rhs_value), - lhs_value, rhs_value); + return EmitIntegralMax(lhs_value, rhs_value, is_signed); case HloOpcode::kAnd: return ir_builder_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: @@ -1067,6 +1062,26 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( } } +llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, + llvm::Value* rhs_value, + bool is_signed) const { + return ir_builder_->CreateSelect( + ir_builder_->CreateICmp( + is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, + lhs_value, rhs_value), + lhs_value, rhs_value); +} + +llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, + llvm::Value* rhs_value, + bool is_signed) const { + return ir_builder_->CreateSelect( + ir_builder_->CreateICmp( + is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, + lhs_value, rhs_value), + lhs_value, rhs_value); +} + llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, int64 operand_no) const { @@ -1363,7 +1378,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * max_value, operand_to_generator.at(hlo->operand(2))( ElementwiseSourceIndex(index, *hlo, 2))); - return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + PrimitiveType prim_type = hlo->shape().element_type(); + if (primitive_util::IsFloatingPointType(prim_type)) { + return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + } else if (primitive_util::IsIntegralType(prim_type)) { + bool is_signed = primitive_util::IsSignedIntegralType(prim_type); + return EmitIntegralMin( + max_value, EmitIntegralMax(min_value, arg_value, is_signed), + is_signed); + } else { + return Unimplemented("Clamp unimplemented for %s", + PrimitiveType_Name(prim_type).c_str()); + } }; case HloOpcode::kReducePrecision: return [this, hlo, &operand_to_generator]( diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 1a48eb5fcb960b60d524ea56a43e15269576db76..c516a826d9e382bc738e54635426db639d17108c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -86,6 +86,12 @@ class ElementalIrEmitter { virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value) const; + llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, + bool is_signed) const; + + llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, + bool is_signed) const; + virtual StatusOr EmitErfInv(PrimitiveType prim_type, llvm::Value* value) const; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index df5e2e35f802b476f4d9fef2cd4816089663686f..9da4fb97fa27a238fead74985cb481a9be1f4a65 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -129,8 +129,11 @@ cc_library( hdrs = [ "ir_emitter.h", "ir_emitter_context.h", + "ir_emitter_nested.h", + "ir_emitter_unnested.h", ], deps = [ + ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", ":gpu_executable", @@ -228,6 +231,7 @@ cc_library( cc_library( name = "gpu_executable", srcs = [ + "conditional_thunk.cc", "convolution_thunk.cc", "copy_thunk.cc", "cudnn_batchnorm_thunk.cc", @@ -243,6 +247,7 @@ cc_library( "while_thunk.cc", ], hdrs = [ + "conditional_thunk.h", "convolution_thunk.h", "copy_thunk.h", "cudnn_batchnorm_thunk.h", @@ -260,6 +265,7 @@ cc_library( ], deps = [ ":buffer_allocations", + ":cudnn_convolution_runner", ":infeed_manager", ":ir_emission_utils", ":partition_assignment", @@ -307,9 +313,41 @@ cc_library( ) cc_library( - name = "convolution_folding", - srcs = ["convolution_folding.cc"], - hdrs = ["convolution_folding.h"], + name = "cudnn_convolution_algorithm_picker", + srcs = ["cudnn_convolution_algorithm_picker.cc"], + hdrs = ["cudnn_convolution_algorithm_picker.h"], + deps = [ + ":cudnn_convolution_runner", + ":gpu_executable", + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "cudnn_convolution_runner", + srcs = ["cudnn_convolution_runner.cc"], + hdrs = ["cudnn_convolution_runner.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "cudnn_convolution_rewriter", + srcs = ["cudnn_convolution_rewriter.cc"], + hdrs = ["cudnn_convolution_rewriter.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", @@ -323,15 +361,18 @@ cc_library( ) tf_cc_test( - name = "convolution_folding_test", - srcs = ["convolution_folding_test.cc"], + name = "cudnn_convolution_rewriter_test", + srcs = ["cudnn_convolution_rewriter_test.cc"], deps = [ - ":convolution_folding", + ":cudnn_convolution_rewriter", + ":ir_emission_utils", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], ) @@ -444,7 +485,8 @@ cc_library( srcs = ["gpu_compiler.cc"], hdrs = ["gpu_compiler.h"], deps = [ - ":convolution_folding", + ":cudnn_convolution_algorithm_picker", + ":cudnn_convolution_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", @@ -512,7 +554,6 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", - "@llvm//:core", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index ed78fef4113bd9f7048ca3c8c2d4e38c5ec4762a..2029c303d47e9a62135b003c3bd9be6f8b3438d4 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -98,6 +98,14 @@ StatusOr> BufferAllocations::Builder::Build( } } + if (VLOG_IS_ON(2)) { + for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { + const auto& buf = buffer_allocations->buffers_[i]; + VLOG(2) << "Buffer " << i << " -> " << buf.opaque() << " (" << buf.size() + << "B)"; + } + } + return std::move(buffer_allocations); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..790ca535b11ee47724ef6227de40726d940d6153 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace gpu { + +ConditionalThunk::ConditionalThunk( + const BufferAllocation::Slice& predicate_buffer_index, + const BufferAllocation::Slice& true_operand_buffer_index, + const BufferAllocation::Slice& false_operand_buffer_index, + ThunkSequence true_thunk_sequence, ThunkSequence false_thunk_sequence, + const HloInstruction* hlo) + : Thunk(Kind::kConditional, hlo), + predicate_buffer_index_(predicate_buffer_index), + true_operand_buffer_index_(true_operand_buffer_index), + false_operand_buffer_index_(false_operand_buffer_index), + true_thunk_(std::move(true_thunk_sequence), hlo), + false_thunk_(std::move(false_thunk_sequence), hlo) {} + +Status ConditionalThunk::Initialize(const GpuExecutable& executable) { + TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable)); + TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable)); + return Status::OK(); +} + +Status ConditionalThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { + // Copy the predicate value from device. + bool predicate; + perftools::gputools::DeviceMemoryBase predicate_address = + buffer_allocations.GetDeviceAddress(predicate_buffer_index_); + stream->ThenMemcpy(&predicate, predicate_address, sizeof(bool)); + + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("Failed to retrieve predicate value on stream %p: %s.", + stream, block_status.error_message().c_str()); + } + + // Execute the true or the false computation depending on the value of the + // predicate. + if (predicate) { + TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream)); + } else { + TF_RETURN_IF_ERROR( + false_thunk_.ExecuteOnStream(buffer_allocations, stream)); + } + + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..7725c46a3b4b51af34a4dd977885353ff32c21f6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ + +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// ConditionalThunk implements the conditional instruction on GPU by reading the +// predicate of the conditional and executing the true or the false computation +// depending on the value of the predicate. +// +// ConditionalThunk assumes that the buffers of the conditional result and the +// result of the true and false computations share the same allocation. Also, +// the buffers of the true operand of the conditional and that of the parameter +// instruction of the true computation share the same allocation. Similarly, the +// buffers of the false operand and that of the parameter instruction of the +// false computation share the same allocation. +class ConditionalThunk : public Thunk { + public: + ConditionalThunk(const BufferAllocation::Slice& predicate_buffer_index, + const BufferAllocation::Slice& true_operand_buffer_index, + const BufferAllocation::Slice& false_operand_buffer_index, + ThunkSequence true_thunk_sequence, + ThunkSequence false_thunk_sequence, + const HloInstruction* hlo); + + ConditionalThunk(const ConditionalThunk&) = delete; + ConditionalThunk& operator=(const ConditionalThunk&) = delete; + + Status Initialize(const GpuExecutable& executable) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + BufferAllocation::Slice predicate_buffer_index_; + BufferAllocation::Slice true_operand_buffer_index_; + BufferAllocation::Slice false_operand_buffer_index_; + SequentialThunk true_thunk_; + SequentialThunk false_thunk_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 899cc5c83b99f1bb6154f883ca17871863e1f457..461747b699b542ae0c8735aea34cc9e57c1fb387 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -36,366 +37,70 @@ using se::dnn::DataLayout; using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; -ConvolveScratchAllocator::ConvolveScratchAllocator( - int device_ordinal, DeviceMemoryAllocator* memory_allocator) - : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} - -ConvolveScratchAllocator::~ConvolveScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - -int64 ConvolveScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { - constexpr int64 kConvolveScratchSize = 1LL << 32; // 4GB by default. - return kConvolveScratchSize; -} - -se::port::StatusOr> -ConvolveScratchAllocator::AllocateBytes(se::Stream* stream, int64 byte_size) { - CHECK_GE(byte_size, 0) << "byte_size must be positive."; - if (byte_size > GetMemoryLimitInBytes(stream)) { - return se::port::Status( - se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", - byte_size, GetMemoryLimitInBytes(stream))); - } - - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Failed to allocate %lld bytes on device %d.", - byte_size, device_ordinal_)); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); - total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); -} - -string ConvolutionKindToString( - ConvolutionThunk::ConvolutionKind convolution_kind) { - switch (convolution_kind) { - case ConvolutionThunk::ConvolutionKind::kForward: - return "forward"; - case ConvolutionThunk::ConvolutionKind::kBackwardFilter: - return "backward_filter"; - case ConvolutionThunk::ConvolutionKind::kBackwardInput: - return "backward_input"; - } - return "unknown convolution kind"; -} - ConvolutionThunk::ConvolutionThunk( - ConvolutionKind convolution_kind, - const BufferAllocation::Slice& input_buffer, + CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, const Shape& input_shape, + const BufferAllocation::Slice& output_buffer, + const BufferAllocation::Slice& tuple_result_buffer, + const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, const HloInstruction* hlo) + const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, + bool tensor_ops_enabled, const HloInstruction* hlo) : Thunk(Kind::kConvolution, hlo), convolution_kind_(convolution_kind), input_buffer_(input_buffer), filter_buffer_(filter_buffer), output_buffer_(output_buffer), + tuple_result_buffer_(tuple_result_buffer), + scratch_buffer_(scratch_buffer), input_shape_(input_shape), filter_shape_(filter_shape), output_shape_(output_shape), window_(window), - dim_nums_(dim_nums) {} - -tensorflow::Status ConvolutionThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { - VLOG(3) << "Convolution kind: " << ConvolutionKindToString(convolution_kind_); - VLOG(3) << "input shape: { " << input_shape_.ShortDebugString() << " }"; - VLOG(3) << "filter shape: { " << filter_shape_.ShortDebugString() << " }"; - VLOG(3) << "Output shape: { " << output_shape_.ShortDebugString() << " }"; - VLOG(3) << "Dim nums: { " << dim_nums_.ShortDebugString() << " }"; - VLOG(3) << "Window: { " << window_.ShortDebugString() << " }"; - - const int num_dimensions = window_.dimensions_size(); - CHECK_LE(num_dimensions, 3); - // cuDNN does not support 1D convolutions. We therefore express 1D - // convolutions as 2D convolutions where the first spatial dimension is 1. - // This matches the behavior of TF (see definition of conv1d in - // tensorflow/python/ops/nn_ops.py). - const int effective_num_dimensions = std::max(2, num_dimensions); - - CHECK_EQ(F32, output_shape_.element_type()); - CHECK_EQ(num_dimensions, dim_nums_.input_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dim_nums_.kernel_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dim_nums_.output_spatial_dimensions_size()); - for (const WindowDimension& dim : window_.dimensions()) { - CHECK_EQ(dim.padding_low(), dim.padding_high()); - } - - // cuDNN's convolution APIs support the BDYX layout for activations/output and - // the OIYX layout for weights. - BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(DataLayout::kBatchDepthYX) - .set_feature_map_count( - input_shape_.dimensions(dim_nums_.input_feature_dimension())) - .set_count(input_shape_.dimensions(dim_nums_.input_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - // Note that the dimensions are reversed. The same holds below. - input_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - input_shape_.dimensions(dim_nums_.input_spatial_dimensions(dim))); - } - - FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(FilterLayout::kOutputInputYX) - .set_input_feature_map_count( - filter_shape_.dimensions(dim_nums_.kernel_input_feature_dimension())) - .set_output_feature_map_count(filter_shape_.dimensions( - dim_nums_.kernel_output_feature_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - filter_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(dim))); - } - - ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); - for (int dim = 0; dim < num_dimensions; ++dim) { - convolution_descriptor - .set_zero_padding( - static_cast(effective_num_dimensions - dim - 1), - window_.dimensions(dim).padding_low()) - .set_filter_stride( - static_cast(effective_num_dimensions - dim - 1), - window_.dimensions(dim).stride()); - } - - BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(DataLayout::kBatchDepthYX) - .set_feature_map_count( - output_shape_.dimensions(dim_nums_.output_feature_dimension())) - .set_count(output_shape_.dimensions(dim_nums_.output_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - output_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - output_shape_.dimensions(dim_nums_.output_spatial_dimensions(dim))); - } - - // Add a singleton dimension in the 1D convolution case. - if (num_dimensions == 1) { - input_descriptor.set_spatial_dim(static_cast(0), 1); - output_descriptor.set_spatial_dim(static_cast(0), 1); - filter_descriptor.set_spatial_dim(static_cast(0), 1); - convolution_descriptor - .set_zero_padding(static_cast(0), 0) - .set_filter_stride(static_cast(0), 1); - } - - se::DeviceMemory input_data( - buffer_allocations.GetDeviceAddress(input_buffer_)); - se::DeviceMemory filter_data( - buffer_allocations.GetDeviceAddress(filter_buffer_)); - se::DeviceMemory output_data( - buffer_allocations.GetDeviceAddress(output_buffer_)); - return ConvolveWithTune(input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, buffer_allocations, stream); -} - -tensorflow::Status ConvolutionThunk::Convolve( - const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, - const FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const ConvolutionDescriptor& convolution_descriptor, - const se::dnn::AlgorithmConfig& algorithm_config, se::Stream* stream, - ConvolveScratchAllocator* scratch_allocator, - se::dnn::ProfileResult* profile_result) { - bool launch_ok; - switch (convolution_kind_) { - case ConvolutionKind::kBackwardFilter: - launch_ok = - stream - ->ThenConvolveBackwardFilterWithAlgorithm( - input_descriptor, input_data, output_descriptor, output_data, - convolution_descriptor, filter_descriptor, &filter_data, - scratch_allocator, algorithm_config, profile_result) - .ok(); - break; - case ConvolutionKind::kBackwardInput: - launch_ok = stream - ->ThenConvolveBackwardDataWithAlgorithm( - filter_descriptor, filter_data, output_descriptor, - output_data, convolution_descriptor, input_descriptor, - &input_data, scratch_allocator, algorithm_config, - profile_result) - .ok(); - break; - case ConvolutionKind::kForward: - launch_ok = - stream - ->ThenConvolveWithAlgorithm( - input_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, &output_data, - scratch_allocator, algorithm_config, profile_result) - .ok(); - break; - } - if (launch_ok) { - return tensorflow::Status::OK(); - } - return InternalError( - "Unable to launch convolution for thunk %p with type %s and algorithm " - "(%lld, %lld)", - this, ConvolutionKindToString(convolution_kind_).c_str(), - algorithm_config.algorithm().algo_id(), - algorithm_config.algorithm_no_scratch().algo_id()); -} + dim_nums_(dim_nums), + algorithm_(algorithm), + tensor_ops_enabled_(tensor_ops_enabled) {} -std::vector ConvolutionThunk::GetAlgorithms( - bool with_winograd_nonfused, se::StreamExecutor* stream_exec) const { - std::vector algorithms; - switch (convolution_kind_) { - case ConvolutionKind::kBackwardFilter: - CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - with_winograd_nonfused, &algorithms)); - break; - case ConvolutionKind::kBackwardInput: - CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - with_winograd_nonfused, &algorithms)); - break; - case ConvolutionKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, - &algorithms)); - break; - } - return algorithms; -} - -static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) { - if (algo.tensor_ops_enabled()) { - return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); - } - return tensorflow::strings::StrCat(algo.algo_id()); -} - -// Determines whether we can safely perform a winograd non-fused convolution for -// the given input and output descriptors. This works around b/68264959, an -// integer overflow in cuDNNv5 and cuDNNv6. -static bool ShouldIncludeWinogradNonfusedAlgo( - const BatchDescriptor& input_descriptor, - const BatchDescriptor& output_descriptor) { - int64 batch = input_descriptor.count(); - int64 in_depths = input_descriptor.feature_map_count(); - int64 in_rows = input_descriptor.height(); - int64 in_cols = input_descriptor.width(); - int64 out_depths = output_descriptor.feature_map_count(); - - int64 total_size = 16 * std::ceil(batch / 16.0) * - std::max(in_depths, out_depths) * in_cols * in_rows * - sizeof(float); - int64 threshold = 1L << 31; - - return total_size < threshold; -} - -tensorflow::Status ConvolutionThunk::ConvolveWithTune( - const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, - const FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const ConvolutionDescriptor& convolution_descriptor, +Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { - // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. - if (!best_algorithm_.has_value()) { - best_algorithm_.emplace(); - - // Auto-tuning either is disabled or only happens in the first run of this - // function. - VLOG(2) << "Profiling for best convolution algorithm used for " - "ConvolutionThunk: " - << this; - - bool with_winograd_nonfused = - ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor); - - se::dnn::ProfileResult best_result; - se::dnn::ProfileResult best_result_without_scratch; - std::vector algorithms = - GetAlgorithms(with_winograd_nonfused, stream->parent()); - for (auto algorithm : algorithms) { - ConvolveScratchAllocator scratch_allocator( - buffer_allocations.device_ordinal(), - buffer_allocations.memory_allocator()); - se::dnn::ProfileResult profile_result; - VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk: " << this; - bool launch_ok = - Convolve(input_descriptor, input_data, filter_descriptor, filter_data, - output_descriptor, output_data, convolution_descriptor, - se::dnn::AlgorithmConfig(algorithm, algorithm), stream, - &scratch_allocator, &profile_result) - .ok(); - if (launch_ok && profile_result.is_valid()) { - VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk " << this << " succeeded, taking " - << profile_result.elapsed_time_in_ms() - << "ms. (Best result: " << best_result.elapsed_time_in_ms() - << "ms)"; - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalAllocatedBytes() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_without_scratch.elapsed_time_in_ms()) { - best_result_without_scratch = profile_result; - } - } else { - VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk " << this << " failed."; - } - } - - if (best_result.is_valid()) { - best_algorithm_->set_algorithm(best_result.algorithm()); - } else { - LOG(ERROR) << "No convolution algorithm works with profiling. Fall back " - "to the default algorithm."; - best_algorithm_->set_algorithm(AlgorithmDesc()); + se::DeviceMemoryBase input_data = + buffer_allocations.GetDeviceAddress(input_buffer_); + se::DeviceMemoryBase filter_data = + buffer_allocations.GetDeviceAddress(filter_buffer_); + se::DeviceMemoryBase output_data = + buffer_allocations.GetDeviceAddress(output_buffer_); + se::DeviceMemoryBase scratch = + buffer_allocations.GetDeviceAddress(scratch_buffer_); + + se::dnn::AlgorithmConfig algorithm_config( + se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); + + TF_RETURN_IF_ERROR(RunCudnnConvolution( + convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, + filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, + stream)); + + // Figure out which of output/input/filter is the result produced by + // this op, and write the result tuple. + void* result_ptr = [&] { + switch (convolution_kind_) { + case CudnnConvKind::kForward: + return output_data.opaque(); + case CudnnConvKind::kBackwardInput: + return input_data.opaque(); + case CudnnConvKind::kBackwardFilter: + return filter_data.opaque(); } + }(); + void* ptrs[] = {result_ptr, scratch.opaque()}; + se::DeviceMemory tuple_addr( + buffer_allocations.GetDeviceAddress(tuple_result_buffer_)); + stream->ThenMemcpyH2D(ptrs, &tuple_addr); - if (best_result_without_scratch.is_valid()) { - best_algorithm_->set_algorithm_no_scratch( - best_result_without_scratch.algorithm()); - } else { - LOG(ERROR) << "No convolution algorithm without scratch works with " - "profiling. Fall back " - "to the default algorithm."; - best_algorithm_->set_algorithm_no_scratch(AlgorithmDesc()); - } - } - - { - VLOG(2) << "Using convolution algorithm (" - << AlgorithmToString(best_algorithm_->algorithm()) << ", " - << AlgorithmToString(best_algorithm_->algorithm_no_scratch()) - << ") for ConvolutionThunk: " << this; - ConvolveScratchAllocator scratch_allocator( - buffer_allocations.device_ordinal(), - buffer_allocations.memory_allocator()); - return Convolve(input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, *best_algorithm_, stream, - &scratch_allocator, nullptr); + if (!stream->ok()) { + return InternalError("ConvolutionThunk::ExecuteOnStream failed."); } + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 46c94d0bf1e486fb91e63109efb8e4ba778c4120..900d9cb6243088b56a1825fb3ab8c06cf8d74726 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,106 +31,47 @@ limitations under the License. namespace xla { namespace gpu { -// A one-time scratch allocator for forward and backward convolution. The -// scratch buffers allocated are released on destruction. -// -// Not thread-safe. -class ConvolveScratchAllocator : public perftools::gputools::ScratchAllocator { - public: - ConvolveScratchAllocator(int device_ordinal, - DeviceMemoryAllocator* memory_allocator); - - ~ConvolveScratchAllocator() override; - - int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override; - - int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - - perftools::gputools::port::StatusOr> - AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override; - - private: - const int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; - int64 total_allocated_bytes_ = 0; -}; - // This class stores everything that StreamExecutor needs to launch a BNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. class ConvolutionThunk : public Thunk { public: - // ConvolutionThunk performs one of the following types of convolution. - enum class ConvolutionKind { - kBackwardFilter, // Backward convolution for filter. - kBackwardInput, // Backward convolution for input. - kForward, // Forward convolution. - }; - - // Constructs a thunk for launching a DNN convolution. + // Constructs a thunk for launching a DNN convolution. When run, it will + // write a tuple (result, scratch_memory) into `tuple_result_buffer`. + // + // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that + // we should use the default (i.e. baseline) cudnn algorithm. + // + // Note that "output" here doesn't refer to the output from running this + // thunk, but rather to the "output" of a hypothetical forward convolution + // that corresponds to this input+filter+output triple. That is, the result + // generated by this thunk is "output" for forward convs, "input" for + // backward-input convs, and "filter" for backward-filter convs. + // // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(ConvolutionKind convolution_kind, + ConvolutionThunk(CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& filter_buffer, const BufferAllocation::Slice& output_buffer, + const BufferAllocation::Slice& tuple_result_buffer, + const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, - const HloInstruction* hlo); + const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, + bool tensor_ops_enabled, const HloInstruction* hlo); ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; - // Does the convolution for the thunk on "stream". Auto-tuning happens on the - // first run of this function. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; - - // Returns true if the next run of ExecuteOnStream will do autotuning. If so, - // we want the GPU to be quiescent during autotuning, so as not to introduce - // noise in our results. - bool ShouldHaltAllActivityBeforeRunning( - perftools::gputools::Stream*) override { - return !best_algorithm_.has_value(); - } - - // Return true if scratch memory is needed to execute the thunk, that is - // either the best algorithm hasn't been chosen or the best algorithm is not - // the same as the no-scratch algorithm. This is because that the execution - // of the thunk is asynchronous, and the scratch allocator goes out of - // scope before the thunk finishes execution. Returning true tells the stream - // executor to make future thunks wait for this thunk to avoid reusing the - // deallocated scratch memory until this thunk is done with it. - bool ShouldBlockFutureThunks() { - if (!best_algorithm_.has_value()) { - return true; - } - - const perftools::gputools::dnn::AlgorithmDesc& best_alg = - best_algorithm_->algorithm(); - const perftools::gputools::dnn::AlgorithmDesc& no_scratch_best_alg = - best_algorithm_->algorithm_no_scratch(); - return (!best_alg.is_default() || !no_scratch_best_alg.is_default() || - !(best_alg == no_scratch_best_alg)); - } + // Does the convolution for the thunk on "stream". + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: - tensorflow::Status ConvolveWithTune( - const perftools::gputools::dnn::BatchDescriptor& input_descriptor, - perftools::gputools::DeviceMemory input_data, - const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, - perftools::gputools::DeviceMemory filter_data, - const perftools::gputools::dnn::BatchDescriptor& output_descriptor, - perftools::gputools::DeviceMemory output_data, - const perftools::gputools::dnn::ConvolutionDescriptor& - convolution_descriptor, - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream); + class ScratchAllocator; - tensorflow::Status Convolve( + Status Convolve( const perftools::gputools::dnn::BatchDescriptor& input_descriptor, perftools::gputools::DeviceMemory input_data, const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, @@ -139,40 +81,27 @@ class ConvolutionThunk : public Thunk { const perftools::gputools::dnn::ConvolutionDescriptor& convolution_descriptor, const perftools::gputools::dnn::AlgorithmConfig& algorithm_config, - perftools::gputools::Stream* stream, - ConvolveScratchAllocator* scratch_allocator, + perftools::gputools::Stream* stream, ScratchAllocator* scratch_allocator, perftools::gputools::dnn::ProfileResult* profile_result); - // Returns the convolve algorithms that can be used for this ConvolutionThunk. - std::vector GetAlgorithms( - bool with_winograd_nonfused, - perftools::gputools::StreamExecutor* stream_exec) const; - - // Fastest cuDNN convolution algorithm for this thunk learned from - // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set - // to the default value, indicating cuDNN's convolution will choose the best - // algorithm from some heuristics based on its parameters. - tensorflow::gtl::optional - best_algorithm_; - - const ConvolutionKind convolution_kind_; + const CudnnConvKind convolution_kind_; const BufferAllocation::Slice input_buffer_; const BufferAllocation::Slice filter_buffer_; const BufferAllocation::Slice output_buffer_; + const BufferAllocation::Slice tuple_result_buffer_; + const BufferAllocation::Slice scratch_buffer_; const Shape input_shape_; const Shape filter_shape_; const Shape output_shape_; const Window window_; - const ConvolutionDimensionNumbers dim_nums_; + int64 algorithm_; + bool tensor_ops_enabled_; }; -string ConvolutionKindToString( - ConvolutionThunk::ConvolutionKind convolution_kind); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc new file mode 100644 index 0000000000000000000000000000000000000000..1792893ae401bf16d2dd9e861607e8f3821a505e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -0,0 +1,369 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gpu { +namespace { + +namespace se = perftools::gputools; + +using se::DeviceMemoryBase; +using se::dnn::AlgorithmConfig; +using se::dnn::AlgorithmDesc; +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +class ScratchAllocator : public se::ScratchAllocator { + public: + ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + + ~ScratchAllocator() override; + + int64 GetMemoryLimitInBytes(se::Stream* stream) override { + return 1LL << 32; // 4GB. TODO(jlebar): Tune this? + } + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override; + + private: + const int device_ordinal_; + DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +ScratchAllocator::~ScratchAllocator() { + for (auto& allocated_buffer : allocated_buffers_) { + if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) + .ok()) { + // The program can still continue with failed deallocation. + LOG(ERROR) << "Failed to deallocate the allocated buffer: " + << allocated_buffer.opaque(); + } + } +} + +se::port::StatusOr> ScratchAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + auto status_or_memory = + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false); + if (!status_or_memory.ok()) { + return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Failed to allocate %lld bytes on device %d.", + byte_size, device_ordinal_)); + } + se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); + allocated_buffers_.push_back(allocated_buffer); + total_allocated_bytes_ += byte_size; + return se::DeviceMemory(allocated_buffer); +} + +// Determines whether we can safely perform a winograd non-fused convolution for +// the given input and output shapes. This works around b/68264959, an integer +// overflow in cuDNNv5 and cuDNNv6. +// +// TODO(jlebar): We shouldn't need this check for cuDNNv7. +bool ShouldIncludeWinogradNonfusedAlgo( + const Shape& input_shape, const Shape& output_shape, + const ConvolutionDimensionNumbers& dnums) { + int64 batch = input_shape.dimensions(dnums.input_batch_dimension()); + int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension()); + int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0)); + int64 in_cols = + dnums.input_spatial_dimensions_size() == 1 + ? 1 + : input_shape.dimensions(dnums.input_spatial_dimensions(1)); + int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension()); + + int64 total_size = CeilOfRatio(batch, int64{16}) * + std::max(in_depths, out_depths) * in_cols * in_rows * + sizeof(float); + + const int64 threshold = 1L << 31; + return total_size < threshold; +} + +std::vector GetAlgorithms(CudnnConvKind kind, + bool with_winograd_nonfused, + se::StreamExecutor* stream_exec_) { + std::vector algorithms; + switch (kind) { + case CudnnConvKind::kBackwardFilter: + CHECK(stream_exec_->GetConvolveBackwardFilterAlgorithms( + with_winograd_nonfused, &algorithms)); + break; + case CudnnConvKind::kBackwardInput: + CHECK(stream_exec_->GetConvolveBackwardDataAlgorithms( + with_winograd_nonfused, &algorithms)); + break; + case CudnnConvKind::kForward: + CHECK(stream_exec_->GetConvolveAlgorithms(with_winograd_nonfused, + &algorithms)); + break; + } + + return algorithms; +} + +string AlgorithmToString(const AlgorithmDesc& algo) { + if (algo.tensor_ops_enabled()) { + return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + } + return tensorflow::strings::StrCat(algo.algo_id()); +} + +string NumBytesToString(int64 bytes) { + return tensorflow::strings::StrCat( + tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); +} + +} // anonymous namespace + +// We could have caching here so that we don't redo this work for two identical +// convolutions. Unfortunately our cache key would have to be a tuple +// containing the protos passed to this function, and we have no utility for +// hashing protos. We could write our own hash functions, but they'd silently +// break if we ever added a field to one of the protos. Perhaps we could hack +// using the binary-encoded proto as the hash key, on the assumption that two +// protos being binary-equal is a sufficient, if not necessary, condition for +// proper equality. But that would still leave us open to having unnecessary +// cache misses and doing extra work. Overall, caching doesn't seem worth the +// trouble, but we may want to revisit this if we ever find a model where +// caching would speed up compilation a lot. +optional> +CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, const Window& window, + const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + // Create a stream for us to do our work on. + se::Stream stream{stream_exec_}; + stream.Init(); + const auto device_ordinal = stream_exec_->device_ordinal(); + + // allocator either points to this->allocator_ or, if that's null, to a + // StreamExecutorMemoryAllocator for stream_exec_. + DeviceMemoryAllocator* allocator; + optional se_allocator; + if (allocator_ != nullptr) { + allocator = allocator_; + } else { + se_allocator.emplace( + stream_exec_->platform(), + tensorflow::gtl::ArraySlice({stream_exec_})); + allocator = &*se_allocator; + } + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + // + // We don't put any data in these buffers, because (in theory, anyway) the + // speed of a conv isn't affected by the data being convolved. + ScratchAllocator input_output_allocator(device_ordinal, allocator); + se::port::StatusOr input_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(input_shape)); + se::port::StatusOr filter_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(filter_shape)); + se::port::StatusOr output_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(output_shape)); + if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) { + LOG(WARNING) + << "Couldn't allocate space for input/filter/output of convolution " + << instr->ToString() << ". Falling back to default algorithm."; + return nullopt; + } + + const bool use_winograd_nonfused = + ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums); + se::dnn::ProfileResult best_result; + int64 best_result_bytes_used = 0; + + for (const AlgorithmDesc& alg : + GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { + ScratchAllocator scratch_allocator(device_ordinal, allocator); + se::dnn::ProfileResult profile_result; + VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " + << instr->ToString(); + + bool launch_ok = RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + input_buf.ValueOrDie(), filter_buf.ValueOrDie(), + output_buf.ValueOrDie(), &scratch_allocator, window, + dnums, AlgorithmConfig(alg), &stream, &profile_result) + .ok(); + + if (launch_ok && profile_result.is_valid()) { + int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); + VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) + << " succeeded, taking " << profile_result.elapsed_time_in_ms() + << "ms and using " << NumBytesToString(scratch_bytes_used) + << " of scratch (Best result: " + << best_result.elapsed_time_in_ms() << "ms, " + << NumBytesToString(best_result_bytes_used) << " of scratch)"; + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + best_result_bytes_used = scratch_bytes_used; + } + } else { + VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " failed."; + } + } + if (best_result.is_valid()) { + VLOG(2) << "Best algorithm for " << instr->ToString() << ": " + << AlgorithmToString(best_result.algorithm()) << ", takes " + << best_result.elapsed_time_in_ms() << "ms, and uses " + << best_result_bytes_used << "B of scratch memory."; + return std::make_tuple(best_result.algorithm().algo_id(), + best_result.algorithm().tensor_ops_enabled(), + best_result_bytes_used); + } + + LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString() + << " failed. Falling back to default algorithm."; + return nullopt; +} + +StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( + HloInstruction* instr) { + CHECK(IsCustomCallToDnnConvolution(*instr)); + + const auto& call_target = instr->custom_call_target(); + const auto& lhs_shape = instr->operand(0)->shape(); + const auto& rhs_shape = instr->operand(1)->shape(); + const auto& conv_result_shape = instr->shape().tuple_shapes(0); + optional> alg_scratch_and_tc; + if (call_target == kCudnnConvForwardCallTarget) { + alg_scratch_and_tc = PickBestAlgorithm( + CudnnConvKind::kForward, /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, + instr->window(), instr->convolution_dimension_numbers(), instr); + } else if (call_target == kCudnnConvBackwardInputCallTarget) { + alg_scratch_and_tc = PickBestAlgorithm( + CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, + /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), + instr->convolution_dimension_numbers(), instr); + } else if (call_target == kCudnnConvBackwardFilterCallTarget) { + alg_scratch_and_tc = PickBestAlgorithm( + CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, + /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, + instr->window(), instr->convolution_dimension_numbers(), instr); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instr->ToString(); + } + + if (!alg_scratch_and_tc.has_value()) { + return false; + } + + int64 algorithm; + bool tensor_ops_enabled; + int64 scratch_bytes; + + std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc; + + VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " + << NumBytesToString(scratch_bytes) + << " of scratch memory: " << instr->ToString() + << " tensor_ops_enabled: " << tensor_ops_enabled; + + // Replace instr with a new CustomCall which has the correct algorithm, and + // whose output shape has the appropriate amount of scratch memory. + HloComputation* computation = instr->parent(); + Shape new_call_shape = + ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), + ShapeUtil::MakeShape(U8, {scratch_bytes})}); + HloInstruction* algorithm_hlo = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); + HloInstruction* tensor_ops_enabled_hlo = + computation->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(tensor_ops_enabled))); + + HloInstruction* new_call = + computation->AddInstruction(HloInstruction::CreateCustomCall( + new_call_shape, + {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo, + tensor_ops_enabled_hlo}, + instr->custom_call_target())); + new_call->set_window(instr->window()); + new_call->set_convolution_dimension_numbers( + instr->convolution_dimension_numbers()); + + // Repackage new_call so it has the same shape as the original call, namely + // (conv_result, u8[0]). + HloInstruction* new_tuple = + computation->AddInstruction(HloInstruction::CreateTuple( + {computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_call_shape.tuple_shapes(0), new_call, 0)), + computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({})))})); + + TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple)); + return true; +} + +StatusOr CudnnConvolutionAlgorithmPicker::RunOnComputation( + HloComputation* computation) { + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + + bool changed = false; + for (auto* instr : convs) { + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr)); + changed |= result; + } + return changed; +} + +StatusOr CudnnConvolutionAlgorithmPicker::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h new file mode 100644 index 0000000000000000000000000000000000000000..516210ec2e500cf03774d27408300ac3346e7b4f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for +// each and adding explicit scratch space to the CustomCalls. +class CudnnConvolutionAlgorithmPicker : public HloPassInterface { + public: + // If the `allocator` parameter is not null, we will use it to allocate temp + // memory while timing the various convolution algorithms. If it's null, + // we'll use the default allocator on the StreamExecutor. + CudnnConvolutionAlgorithmPicker( + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator) + : stream_exec_(stream_exec), allocator_(allocator) {} + + tensorflow::StringPiece name() const override { + return "cudnn-convolution-algorithm-picker"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr RunOnComputation(HloComputation* computation); + StatusOr RunOnInstruction(HloInstruction* instr); + tensorflow::gtl::optional> PickBestAlgorithm( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, const Window& window, + const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); + + perftools::gputools::StreamExecutor* stream_exec_; // never null + DeviceMemoryAllocator* allocator_; // may be null +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc similarity index 83% rename from tensorflow/compiler/xla/service/gpu/convolution_folding.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index b0626ca3bc9f843e513d4727932f0e2d5fa37748..e0c73aa73acb7f3313eb54fb07390cb76590433e 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" #include #include @@ -33,14 +33,32 @@ namespace xla { namespace gpu { namespace { + +bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { + const ConvolutionDimensionNumbers& dnums = + conv->convolution_dimension_numbers(); + if (dnums.input_spatial_dimensions_size() > 3) { + return false; + } + + // CuDNN does not accept zero-element arguments + if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) || + ShapeUtil::HasZeroElements(conv->operand(1)->shape())) { + return false; + } + + if (window_util::HasWindowReversal(conv->window())) { + return false; + } + return true; +} + // Try to match a backward filter pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple, Window, - ConvolutionDimensionNumbers> -MatchBackwardFilter(HloInstruction* conv) { +std::tuple MatchBackwardFilter( + HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, std::vector(), Window(), - ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); // Step 1: match the instruction pattern without considering the paddings and // dimension numbers just yet. We may need some generic pattern matcher // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h @@ -190,18 +208,15 @@ MatchBackwardFilter(HloInstruction* conv) { backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); } - return std::make_tuple(true, std::vector({conv}), - backward_conv_window, backward_conv_dnums); + return std::make_tuple(true, backward_conv_window, backward_conv_dnums); } // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple, Window, - ConvolutionDimensionNumbers> -MatchBackwardInput(HloInstruction* conv) { +std::tuple MatchBackwardInput( + HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, std::vector(), Window(), - ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); @@ -374,58 +389,82 @@ MatchBackwardInput(HloInstruction* conv) { dnums.set_kernel_output_feature_dimension( conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, - std::vector({conv, reverse_filter}), - new_window, dnums); + return std::make_tuple(true, new_window, dnums); } -} // namespace -StatusOr ConvolutionFolding::Run(HloModule* module) { - HloComputation* entry_computation = module->entry_computation(); - std::vector convs; - for (auto* hlo : entry_computation->instructions()) { - if (hlo->opcode() == HloOpcode::kConvolution) { - convs.push_back(hlo); - } - } +// Tries to rewrite a single convolution into a call to cudnn. +StatusOr RunOnInstruction(HloInstruction* conv) { + CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - bool changed = false; - for (HloInstruction* conv : convs) { + HloInstruction* custom_call = [&]() -> HloInstruction* { bool match; - std::vector hlos_to_fuse; Window window; ConvolutionDimensionNumbers dnums; - std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardFilter(conv); + + std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { - VLOG(2) << "Fuse instructions"; - for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { - VLOG(2) << " " << hlo_to_fuse->ToString(); - } - HloInstruction* backward_convolution = - entry_computation->CreateFusionInstructionForBackwardConvolution( - hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardFilter, - window, dnums); - VLOG(2) << "to backward filter convolution"; - VLOG(2) << " " << backward_convolution->ToString(); - changed = true; - continue; + return CreateCudnnConvBackwardFilter( + conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), + window, dnums); } - std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums) = MatchBackwardInput(conv); if (match) { - VLOG(2) << "Fuse instructions"; - for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { - VLOG(2) << " " << hlo_to_fuse->ToString(); - } - HloInstruction* backward_convolution = - entry_computation->CreateFusionInstructionForBackwardConvolution( - hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardInput, - window, dnums); - VLOG(2) << "to backward input convolution"; - VLOG(2) << " " << backward_convolution->ToString(); - changed = true; - continue; + // Backward input conv subsumes the conv plus the reverse in operand 1. + HloInstruction* reverse = conv->mutable_operand(1); + CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); + HloInstruction* rhs = reverse->mutable_operand(0); + + return CreateCudnnConvBackwardInput( + conv->shape(), conv->mutable_operand(0), rhs, window, dnums); } + + // If all else fails, try a forward convolution. + if (CanImplementAsCudnnForwardConv(conv)) { + return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), + conv->mutable_operand(1), conv->window(), + conv->convolution_dimension_numbers()); + } + + return nullptr; + }(); + + if (custom_call == nullptr) { + return false; + } + + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out + // the conv result and replace `conv` with it. + TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( + conv, + HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0))); + return true; +} + +// Rewrites the convolutions in the given computation into calls to cudnn. +// Returns true if it made any changes. +StatusOr RunOnComputation(HloComputation* computation) { + std::vector convs; + for (auto* hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kConvolution) { + convs.push_back(hlo); + } + } + + bool changed = false; + for (HloInstruction* conv : convs) { + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv)); + changed |= result; + } + return changed; +} +} // namespace + +StatusOr CudnnConvolutionRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; } return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h similarity index 63% rename from tensorflow/compiler/xla/service/gpu/convolution_folding.h rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index f9c898721f8dd6b8b7e74c82bb2085cc437eaad5..0c0578d88840fed1d77f7456c9acef27dec380f5 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -22,10 +22,12 @@ limitations under the License. namespace xla { namespace gpu { -class ConvolutionFolding : public HloPassInterface { +// Rewrites plain convolutions, backwards-filter convolutions, and +// backwards-input convolutions into CustomCall HLOs that call into cuDNN. +class CudnnConvolutionRewriter : public HloPassInterface { public: tensorflow::StringPiece name() const override { - return "convolution-folding"; + return "cudnn-convolution-rewriter"; } StatusOr Run(HloModule* module) override; @@ -34,4 +36,4 @@ class ConvolutionFolding : public HloPassInterface { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc similarity index 82% rename from tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 34e6bdb117d47a3d7e1eb3bae5806e130e94ea79..65588b6aaf24da628ea586eb52c462b78b8daaa7 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,23 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace gpu { +namespace { -class ConvolutionFoldingTest : public HloTestBase { +namespace op = xla::testing::opcode_matchers; + +class CudnnConvolutionRewriterTest : public HloTestBase { public: - ConvolutionFoldingTest() { + CudnnConvolutionRewriterTest() { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -44,7 +50,8 @@ class ConvolutionFoldingTest : public HloTestBase { // the batch and feature dimension in the activations, and treat the batch // dimension in gradients as the input feature dimension in the filter. // - // TODO(jingyue): Add more tests on NCHW input order which TF also supports. + // TODO(jingyue): Add more tests on NCHW input order, which TF also + // supports. tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); @@ -74,9 +81,8 @@ class ConvolutionFoldingTest : public HloTestBase { } protected: - bool FoldConvolution(HloModule* module) { - ConvolutionFolding convolution_folding; - return convolution_folding.Run(module).ValueOrDie(); + bool RunPass(HloModule* module) { + return CudnnConvolutionRewriter().Run(module).ValueOrDie(); } // A convolution window with stride 1 and zero padding. The size fields are @@ -86,7 +92,7 @@ class ConvolutionFoldingTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -108,14 +114,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) { HloComputation::Builder builder(TestName()); HloInstruction* activations = @@ -135,12 +140,17 @@ TEST_F(ConvolutionFoldingTest, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } // Extracted from block35 training. -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardFilterConvolveWithPaddedActivations) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -162,15 +172,15 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } // Extracted from inception v3 training. -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardFilterConvolveWithPaddedGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -192,14 +202,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -221,14 +230,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -272,14 +280,15 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); for (int i = 0; i < 2; ++i) { - const WindowDimension& window_dim = - entry_computation->root_instruction()->window().dimensions(i); + const WindowDimension& window_dim = custom_call->window().dimensions(i); // Low padding of the backward input convolution // = kernel_size - 1 - low padding on gradients. EXPECT_EQ(3, window_dim.padding_low()); @@ -291,7 +300,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { // Convolve([abc], [x], base_dilation=2) // = Convolve([abc], Reverse([x]), base_dilation=2) // = BackwardInputConvolve([abc], [x], stride=2) -TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. HloInstruction* output = @@ -316,17 +325,16 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); } // BackwardInputConvolve([abc], [x], stride=1) is equivalent to // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input // convolution. -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. @@ -347,8 +355,12 @@ TEST_F(ConvolutionFoldingTest, tf_default_dnums_for_backward_input_)); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } // Extracted from Inception V3 training. @@ -365,7 +377,8 @@ TEST_F(ConvolutionFoldingTest, // 20x10x10x192 // // Gradients are padded unevenly. -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardInputConvolveUnevenPaddingOnGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -397,14 +410,14 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); for (int i = 0; i < 2; ++i) { - const WindowDimension& window_dim = - entry_computation->root_instruction()->window().dimensions(i); + const WindowDimension& window_dim = custom_call->window().dimensions(i); EXPECT_EQ(0, window_dim.padding_low()); EXPECT_EQ(0, window_dim.padding_high()); EXPECT_EQ(2, window_dim.stride()); @@ -413,7 +426,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -442,8 +455,12 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { .ValueOrDie())); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } // Extracted from //learning/brain/google/xla/benchmarks/resnet.py @@ -460,7 +477,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { // // We should fuse BC even though padding on activations is uneven, because // PadInsertion will canonicalize the fusion HLO. -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -493,13 +510,12 @@ TEST_F(ConvolutionFoldingTest, auto module = CreateNewModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - const HloInstruction* backward_conv = entry_computation->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, backward_conv->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - backward_conv->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); const WindowDimension& backward_conv_col_dim = - backward_conv->window().dimensions(1); + entry_computation->root_instruction()->operand(0)->window().dimensions(1); EXPECT_EQ(0, backward_conv_col_dim.padding_low()); EXPECT_EQ(1, backward_conv_col_dim.padding_high()); } @@ -515,7 +531,7 @@ TEST_F(ConvolutionFoldingTest, // // We currently don't fuse BC because PadInsertion doesn't support negative // padding on the gradients of backward convolution (b/32744257). -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveNegativePaddingHighOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -544,9 +560,14 @@ TEST_F(ConvolutionFoldingTest, .ValueOrDie())); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } +} // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..e4ae839e1dd4cb3a744a3f6a3329cabdaeb3f38d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -0,0 +1,262 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace se = ::perftools::gputools; + +using se::DeviceMemory; +using se::DeviceMemoryBase; +using se::Stream; +using se::dnn::AlgorithmConfig; +using se::dnn::BatchDescriptor; +using se::dnn::ConvolutionDescriptor; +using se::dnn::DataLayout; +using se::dnn::DimIndex; +using se::dnn::FilterDescriptor; +using se::dnn::FilterLayout; +using se::dnn::ProfileResult; + +// A StreamExecutor ScratchAllocator that wraps a single XLA allocation, +// returning it (in its entirety) the first time Allocate() is called. +class ScratchBufAllocator : public se::ScratchAllocator { + public: + explicit ScratchBufAllocator(se::DeviceMemoryBase scratch) + : scratch_(scratch) {} + + ~ScratchBufAllocator() override = default; + + int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override { + return scratch_.size(); + } + + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override { + if (allocated_) { + return se::port::InternalError( + "Can't allocate twice from a ScratchBufAllocator."); + } + if (byte_size > scratch_.size()) { + return se::port::InternalError(tensorflow::strings::StrCat( + "Can't allocate ", byte_size, + " bytes from a ScratchBufAllocator of size ", scratch_.size())); + } + + allocated_ = true; + return se::DeviceMemory(scratch_); + } + + private: + se::DeviceMemoryBase scratch_; + bool allocated_ = false; +}; + +template +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, DeviceMemory input_buf, + DeviceMemory filter_buf, DeviceMemory output_buf, + se::ScratchAllocator* scratch_allocator, const Window& window, + const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm, + Stream* stream, ProfileResult* profile_result /*= nullptr*/) { + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); + VLOG(3) << "tensor_ops_enabled: " + << algorithm.algorithm().tensor_ops_enabled(); + VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); + VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }"; + VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }"; + VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }"; + VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; + VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; + + const int num_dimensions = window.dimensions_size(); + CHECK_LE(num_dimensions, 3); + // cuDNN does not support 1D convolutions. We therefore express 1D + // convolutions as 2D convolutions where the first spatial dimension is 1. + // This matches the behavior of TF (see definition of conv1d in + // tensorflow/python/ops/nn_ops.py). + const int effective_num_dimensions = std::max(2, num_dimensions); + + if (std::is_same::value) { + CHECK_EQ(F32, output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); + } else if (std::is_same::value) { + CHECK_EQ(F16, output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); + } else { + LOG(FATAL) << ShapeUtil::HumanString(output_shape); + } + + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); + for (const WindowDimension& dim : window.dimensions()) { + CHECK_EQ(dim.padding_low(), dim.padding_high()); + } + + // cuDNN's convolution APIs support the BDYX layout for activations/output and + // the OIYX layout for weights. + BatchDescriptor input_descriptor(effective_num_dimensions); + input_descriptor.set_layout(DataLayout::kBatchDepthYX) + .set_feature_map_count( + input_shape.dimensions(dnums.input_feature_dimension())) + .set_count(input_shape.dimensions(dnums.input_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + // Note that the dimensions are reversed. The same holds below. + input_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + input_shape.dimensions(dnums.input_spatial_dimensions(dim))); + } + + FilterDescriptor filter_descriptor(effective_num_dimensions); + filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + .set_input_feature_map_count( + filter_shape.dimensions(dnums.kernel_input_feature_dimension())) + .set_output_feature_map_count( + filter_shape.dimensions(dnums.kernel_output_feature_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + filter_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim))); + } + + ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + for (int dim = 0; dim < num_dimensions; ++dim) { + convolution_descriptor + .set_zero_padding( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).padding_low()) + .set_filter_stride( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).stride()); + } + + BatchDescriptor output_descriptor(effective_num_dimensions); + output_descriptor.set_layout(DataLayout::kBatchDepthYX) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_count(output_shape.dimensions(dnums.output_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + output_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + output_shape.dimensions(dnums.output_spatial_dimensions(dim))); + } + + // Add a singleton dimension in the 1D convolution case. + if (num_dimensions == 1) { + input_descriptor.set_spatial_dim(static_cast(0), 1); + output_descriptor.set_spatial_dim(static_cast(0), 1); + filter_descriptor.set_spatial_dim(static_cast(0), 1); + convolution_descriptor.set_zero_padding(static_cast(0), 0) + .set_filter_stride(static_cast(0), 1); + } + + switch (kind) { + case CudnnConvKind::kForward: + stream->ThenConvolveWithAlgorithm( + input_descriptor, input_buf, filter_descriptor, filter_buf, + convolution_descriptor, output_descriptor, &output_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardInput: + stream->ThenConvolveBackwardDataWithAlgorithm( + filter_descriptor, filter_buf, output_descriptor, output_buf, + convolution_descriptor, input_descriptor, &input_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardFilter: + stream->ThenConvolveBackwardFilterWithAlgorithm( + input_descriptor, input_buf, output_descriptor, output_buf, + convolution_descriptor, filter_descriptor, &filter_buf, + scratch_allocator, algorithm, profile_result); + break; + } + + if (!stream->ok()) { + return InternalError( + "Unable to launch convolution with type %s and algorithm (%lld, %lld)", + CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(), + algorithm.algorithm_no_scratch().algo_id()); + } + return Status::OK(); +} + +} // anonymous namespace + +string CudnnConvKindToString(CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kForward: + return "forward"; + case CudnnConvKind::kBackwardFilter: + return "backward_filter"; + case CudnnConvKind::kBackwardInput: + return "backward_input"; + } +} + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, + perftools::gputools::DeviceMemoryBase filter_buf, + perftools::gputools::DeviceMemoryBase output_buf, + perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, + const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result) { + ScratchBufAllocator scratch_allocator(scratch_buf); + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + input_buf, filter_buf, output_buf, + &scratch_allocator, window, dnums, algorithm, + stream, profile_result); +} + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, + perftools::gputools::DeviceMemoryBase filter_buf, + perftools::gputools::DeviceMemoryBase output_buf, + perftools::gputools::ScratchAllocator* scratch_allocator, + const Window& window, const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result) { + PrimitiveType output_primitive_type = output_shape.element_type(); + CHECK(output_primitive_type == F32 || output_primitive_type == F16) + << ShapeUtil::HumanString(output_shape); + if (output_primitive_type == F32) { + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf), se::DeviceMemory(filter_buf), + se::DeviceMemory(output_buf), scratch_allocator, window, dnums, + algorithm, stream, profile_result); + } + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf), + se::DeviceMemory(filter_buf), + se::DeviceMemory(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..3dbfa2730da359d3c7937140508017c4a7b02d6c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ + +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// This file contains low-level routines for running cudnn convolutions. + +// Different types of convolutions supported by cudnn. +// +// A way to think about these is that a convolution is defined by three arrays +// -- the "input", the "filter", and the "output" -- and given any two of these, +// we can compute the third. For example, a backward-input convolution takes as +// input a filter and an "output" and produces an "input" such that if one were +// to do a forward convolution of "input" using filter, the result would be +// something with the same shape as "output". +// +// This way of thinking is not correct if you look at the values produced. For +// example, a backward-input convolution is not actually the mathematical +// inverse of a forward convolution. But it's right as far as the shapes and +// "connectivity" (i.e. which elements of the input affect which elements of +// the output) are concerned. +enum class CudnnConvKind { + kForward, // input + filter => output + kBackwardInput, // filter + output => input + kBackwardFilter, // input + output => filter +}; + +// Converts a CudnnConvKind value to a string. +string CudnnConvKindToString(CudnnConvKind kind); + +// Calls into cudnn to run the specified convolution. +// +// Note that depending on the value of CudnnConvKind, the result of this call +// may be written into input_buf, filter_buf, or output_buf! +// +// At the moment we only support cudnn convolutions over float and half, and +// convolution with half data type is implemented with cudnn PSEUDO_HALF +// configuration, that is, the input values are half and the internal +// computation type is float. +// +// We provide one overload which takes a scratch buffer, and another which takes +// an allocator which is responsible for allocating the scratch space. In +// theory the second one shouldn't be necessary -- users of this function could +// just ask cudnn how much scratch space it needs for a particular convolution. +// But in practice, StreamExecutor does not expose such an API, and in the name +// of parsimony, perhaps it's better not to add it. Instead, the first time you +// call a convolution, you should call the version that takes a scratch +// allocator and take note of how much memory is used. The next time you call +// the same conv, you can provide an explicitly preallocated scratch buffer of +// that size, if you like. +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, + perftools::gputools::DeviceMemoryBase filter_buf, + perftools::gputools::DeviceMemoryBase output_buf, + perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, + const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, + perftools::gputools::DeviceMemoryBase filter_buf, + perftools::gputools::DeviceMemoryBase output_buf, + perftools::gputools::ScratchAllocator* scratch_allocator, + const Window& window, const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 8e3aebbc12b5e6d746700956b9743bc94db50167..ba482793e7632f0f423cc9da0dd9620bdf29c642 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -137,9 +137,9 @@ StatusOr DoGemmAutotune( // for all algorithms if we're targeting < sm_50. But because we pass a // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. - DCHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, - computation_type, algorithm, stream, - &profile_result)); + CHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, + computation_type, algorithm, stream, + &profile_result)); if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < best_result.elapsed_time_in_ms()) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 0cca3ca0926ad1f9fe21803a771d66ac8b1affaf..28ebd034ee0c89137f4e6eb417d8a37f4a00af7a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -35,8 +35,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" @@ -46,8 +47,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" @@ -127,7 +128,9 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { +tensorflow::Status OptimizeHloModule(HloModule* hlo_module, + se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -143,6 +146,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { // most ops. pipeline.AddPass(BF16, F32); pipeline.AddPass(); + { auto& pass = pipeline.AddPass>("simplification"); @@ -173,7 +177,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { pass.AddPass(); pass.AddPass(); } - pipeline.AddPass(); + pipeline.AddPass( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { @@ -185,6 +189,58 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } + + { + // Convert convolutions into CustomCalls to cudnn, then canonicalize them + // (PadInsertion). + HloPassPipeline pipeline("conv_canonicalization"); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); + pipeline.AddPass(); + + // Choose the fastest algorithm for each conv. + // + // In theory doing this here is way too early: It needs to happen after + // layout assignment, because the layout of the inputs/outputs affects the + // speed of the conv. But currently we only allow only one input/output + // layout when calling cudnn, so there's no ambiguity. + // + // We pick the algorithm at this early stage so we can generate better HLO. + // After CudnnConvolutionRewriter, our convolutions are CustomCalls which + // return a tuple (conv_result, scratch_memory), and the each conv uses 0 + // bytes of scratch: + // + // customcall = (f32[...], f32[0]) + // return gte(customcall, 0) + // + // The algorithm picker then chooses the best algorithm, and potentially + // increases the scratch space. It replaces customcall with new_tuple, + // giving us the following: + // + // new_customcall = (f32[...], f32[N]) + // new_tuple = tuple(gte(new_customcall, 0), constant f32[0]) + // return gte(new_tuple, 0) + // + // The new tuple and gte instructions then be simplified away, because + // nobody is expected to use the scratch value. + // + // However, if we were to run CudnnConvolutionAlgorithmPicker after layout + // assignment, fusion would already have run, and the gte(customcall, 0) + // would probably already be into a fusion node. We can't simplify across + // HloComputation boundaries, so in this case we wouldn't be able to + // simplify away the new_tuple bits. + // + // We'll need to revisit this if we ever allow multiple layouts for the + // inputs/outputs of a cudnn convolution. + pipeline.AddPass(stream_exec, + device_allocator); + // Clean up new_tuple described above. + pipeline.AddPass(); + pipeline.AddPass(); + + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + { HloPassFix fusion("fusion"); fusion.AddInvariantChecker(); @@ -220,9 +276,10 @@ tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); pipeline.AddInvariantChecker(); - pipeline.AddPass(); + pipeline.AddPass( hlo_module->mutable_entry_computation_layout()); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -410,16 +467,19 @@ GpuCompiler::GpuCompiler() .getPointerSize(0 /* default address space */)) {} StatusOr> GpuCompiler::RunHloPasses( - std::unique_ptr module, se::StreamExecutor* /*stream_exec*/) { + std::unique_ptr module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); Tracing::TraceMe annotation("HLO Transforms", module->name(), /*is_expensive=*/true); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get())); + TF_RETURN_IF_ERROR( + OptimizeHloModule(module.get(), stream_exec, device_allocator)); return std::move(module); } StatusOr> GpuCompiler::RunBackend( - std::unique_ptr module, se::StreamExecutor* stream_exec) { + std::unique_ptr module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); TF_RET_CHECK(stream_exec != nullptr); @@ -459,16 +519,17 @@ StatusOr> GpuCompiler::RunBackend( /*color_alignment=*/[](LogicalBuffer::Color) { return kCudaMallocAlignBytes; })); - // BufferAssignment::ToString() includes a header, so no need for us to - // print one ourselves. + // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() + // include headers, so no need for us to print them ourselves. + XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); XLA_VLOG_LINES(2, module->ToString()); - const string xla_dump_hlo_proto_to = - module->config().debug_options().xla_dump_hlo_proto_to(); - if (!xla_dump_hlo_proto_to.empty()) { + const string xla_dump_optimized_hlo_proto_to = + module->config().debug_options().xla_dump_optimized_hlo_proto_to(); + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *buffer_assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index 18e34340205b6f51497e26c45520799d21c55a46..c352d4d8462fadb266c55ad437de998e86a6528e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -51,11 +51,13 @@ class GpuCompiler : public LLVMCompiler { StatusOr> RunHloPasses( std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::vector> module, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index e67087d822e2f3367c48b08be66f5f60791be638..9db85bc788bde46c890a46ce9b0902ddce3f5675 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -36,7 +36,7 @@ namespace gpu { StatusOr GpuCopyInsertion::FindOrInsertCopy( HloInstruction* hlo) { - HloInstruction*& copy = inserted_copies_[hlo]; + HloInstruction*& copy = hlo_to_copy_map_[hlo]; if (copy == nullptr) { TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo)); } @@ -49,7 +49,7 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(module)); + HloDataflowAnalysis::Run(*module)); // Make sure all operands of a library call are in memory instead of constants // in IR. @@ -78,6 +78,12 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } + } else if (IsCustomCallToDnnConvolution(*hlo)) { + // The last two arguments to a CUDNN convolution are two HLO constants for + // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied. + for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } } else if (ImplementedAsLibraryCall(*hlo)) { // For all other library calls, materialize all the operands into memory. for (int64 i = 0; i < hlo->operand_count(); ++i) { @@ -86,27 +92,34 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { } } - // Init values of a while node cannot be constants. Insert copies for any - // constants found at the operand of a while. - tensorflow::gtl::FlatSet copied_constants; + // Init values of while and conditional nodes cannot be constants. Insert + // copies for any constants found at the operands of these nodes. + tensorflow::gtl::FlatSet inserted_copies; for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kWhile) { + if (instruction->opcode() != HloOpcode::kWhile && + instruction->opcode() != HloOpcode::kConditional) { continue; } - for (auto& pair : - dataflow->GetInstructionValueSet(instruction->operand(0))) { - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (value->defining_instruction()->opcode() == - HloOpcode::kConstant && - !ContainsKey(copied_constants, value->defining_instruction())) { - HloInstruction* constant = value->defining_instruction(); - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - FindOrInsertCopy(constant)); - TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); - copied_constants.insert(constant); - changed = true; + for (auto operand : instruction->operands()) { + // Skip the operands that have already been replaced with a copy in a + // previous iteration (which is possible when a constant is used as an + // operand in multiple places). + if (ContainsKey(inserted_copies, operand)) { + continue; + } + for (auto& pair : dataflow->GetInstructionValueSet(operand)) { + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (value->defining_instruction()->IsConstant() && + !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) { + HloInstruction* constant = value->defining_instruction(); + TF_ASSIGN_OR_RETURN(HloInstruction * copy, + FindOrInsertCopy(constant)); + TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); + inserted_copies.insert(copy); + changed = true; + } } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 4d77f337e6eb20f7d79acc0829fde26bbe443f25..0c6f9b511f3aac5f62182273b827adcd068cd633 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -32,13 +32,13 @@ class GpuCopyInsertion : public HloPassInterface { StatusOr Run(HloModule* module) override; protected: - // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making + // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making // duplicate copies. StatusOr FindOrInsertCopy(HloInstruction* hlo); // A map containing all copies inserted to materialize operands of library // calls. The key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap inserted_copies_; + tensorflow::gtl::FlatMap hlo_to_copy_map_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index f5d67b9ea9498df3f023ea9a694a63b468c5be18..623d6714de501000e38b7698620925f66425f157 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -262,9 +262,16 @@ StatusOr> GpuExecutable::ExecuteOnStream( ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); if (allocation.is_entry_computation_parameter()) { - auto param_no = allocation.parameter_number(); - buffer_allocations_builder.RegisterBuffer( - i, arguments[param_no]->root_buffer()); + // The caller must give us a buffer for ShapeIndex {} of every parameter. + // It can optionally give us a buffer for other ShapeIndices, but we + // ignore them: Because we can't rely on these sub-buffers' addresses + // being available, our generated code can't use them. Instead, it must + // chase pointers starting at the tuple root. + if (allocation.param_shape_index().empty()) { + auto param_no = allocation.parameter_number(); + buffer_allocations_builder.RegisterBuffer( + i, arguments[param_no]->root_buffer()); + } } } se::StreamExecutor* executor = run_options->stream()->parent(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 58915f1f62f0c0f320443058a798333c498ffe47..89f1e625884568bf7370b3801d851ef4846c2a98 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -28,122 +28,114 @@ limitations under the License. namespace xla { namespace gpu { +// cuDNN convolutions are called with specific layouts on the input, output, +// and filter: +// +// input: DataLayout::kBatchDepthYX +// output: DataLayout::kBatchDepthYX +// filter: FilterLayout::kOutputInputYX +// +// The order dimensions in the constant name is major-to-minor (eg, the +// most-major dimension of the input is batch, most-minor is X). The +// specific dimension numbers these named dimensions correspond to is +// determined by the ConvolutionDimensionNumbers argument. Y is spatial +// dimension 0, and X is spatial dimension 1. +// +// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. +static Status AddBackendConstraintsToDnnConvCustomCall( + HloInstruction* instr, LayoutConstraints* constraints) { + CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); + Shape input_shape; + Shape filter_shape; + Shape output_shape; + const auto& target = instr->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + input_shape = instr->operand(0)->shape(); + filter_shape = instr->operand(1)->shape(); + output_shape = instr->shape().tuple_shapes(0); + } else if (target == kCudnnConvBackwardInputCallTarget) { + input_shape = instr->shape().tuple_shapes(0); + filter_shape = instr->operand(1)->shape(); + output_shape = instr->operand(0)->shape(); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + input_shape = instr->operand(0)->shape(); + filter_shape = instr->shape().tuple_shapes(0); + output_shape = instr->operand(1)->shape(); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << instr->custom_call_target(); + } + + // Construct minor-to-major dimension orders for operands and result. + // cuDNN's convolution APIs support the BDYX layout for activations/output + // and the OIYX layout for weights. + // TODO(b/29399649): Be more flexible about handling layouts of cuDNN + // calls after we switch to cuDNN v5. + const ConvolutionDimensionNumbers& dimension_numbers = + instr->convolution_dimension_numbers(); + std::vector input_layout; + for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0; + --i) { + input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); + } + input_layout.push_back(dimension_numbers.input_feature_dimension()); + input_layout.push_back(dimension_numbers.input_batch_dimension()); + *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); + + std::vector filter_layout; + for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0; + --i) { + filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); + } + filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension()); + filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension()); + *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); + + std::vector output_layout; + for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0; + --i) { + output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); + } + output_layout.push_back(dimension_numbers.output_feature_dimension()); + output_layout.push_back(dimension_numbers.output_batch_dimension()); + *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); + + // The custom call returns a tuple of (actual_result, scratch_buffer); + // call_result_buf is the logical buffer for actual_result, the thing that + // contains the result of the conv call. + TF_ASSIGN_OR_RETURN(const LogicalBuffer* call_result_buf, + constraints->points_to_analysis().GetBufferDefinedAt( + instr, /*index=*/{0})); + + // Set layouts of the instructions' shapes. + if (target == kCudnnConvForwardCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(output_shape.layout(), *call_result_buf)); + } else if (target == kCudnnConvBackwardInputCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(input_shape.layout(), *call_result_buf)); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf)); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << instr->custom_call_target(); + } + return Status::OK(); +} + Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { for (auto* instruction : constraints->computation()->instructions()) { - // cuDNN is called with specific layouts on the input, output, and filter: - // - // input: DataLayout::kBatchDepthYX - // output: DataLayout::kBatchDepthYX - // filter: FilterLayout::kOutputInputYX - // - // The order dimensions in the constant name is major-to-minor (eg, the - // most-major dimension of the input is batch, most-minor is X). The - // specific dimension numbers these named dimensions correspond to is - // determined by the ConvolutionDimensionNumbers argument. Y is spatial - // dimension 0, and X is spatial dimension 1. - // - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. - if (ImplementedAsDnnConvolution(*instruction)) { - HloInstruction* input = nullptr; - HloInstruction* filter = nullptr; - HloInstruction* output = nullptr; - if (instruction->opcode() == HloOpcode::kConvolution) { - input = instruction->mutable_operand(0); - filter = instruction->mutable_operand(1); - output = instruction; - } else { - CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - // filter = BackwardFilterConvolve(input, output) - input = instruction->mutable_operand(0); - filter = instruction; - output = instruction->mutable_operand(1); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - // input = BackwardInputConvolve(output, filter) - input = instruction; - filter = instruction->mutable_operand(1); - output = instruction->mutable_operand(0); - break; - default: - LOG(FATAL) << "Not a convolution-fusion"; - } - } - - // Construct minor-to-major dimension orders for operands and result. - // cuDNN's convolution APIs support the BDYX layout for activations/output - // and the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN - // calls after we switch to cuDNN v5. - const ConvolutionDimensionNumbers& dimension_numbers = - instruction->convolution_dimension_numbers(); - std::vector input_layout; - for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; - i >= 0; --i) { - input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); - } - input_layout.push_back(dimension_numbers.input_feature_dimension()); - input_layout.push_back(dimension_numbers.input_batch_dimension()); - Shape input_shape(input->shape()); - *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); - - std::vector filter_layout; - for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; - i >= 0; --i) { - filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); - } - filter_layout.push_back( - dimension_numbers.kernel_input_feature_dimension()); - filter_layout.push_back( - dimension_numbers.kernel_output_feature_dimension()); - Shape filter_shape(filter->shape()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); - - std::vector output_layout; - for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; - i >= 0; --i) { - output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); - } - output_layout.push_back(dimension_numbers.output_feature_dimension()); - output_layout.push_back(dimension_numbers.output_batch_dimension()); - Shape output_shape(output->shape()); - *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); - - // Set layouts of the instructions' shapes. - if (instruction->opcode() == HloOpcode::kConvolution) { - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, output, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, output, 1)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(output_shape, output)); - } else { - CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - // filter = BackwardFilterConvolve(input, output) - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, filter, 0)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(filter_shape, filter)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(output_shape, filter, 1)); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - // input = BackwardInputConvolve(output, filter) - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(input_shape, input)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(output_shape, input, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, input, 1)); - break; - default: - LOG(FATAL) << "Not a convolution-fusion"; - } - } + if (IsCustomCallToDnnConvolution(*instruction)) { + TF_RETURN_IF_ERROR( + AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); } } return Status::OK(); @@ -151,9 +143,12 @@ Status GpuLayoutAssignment::AddBackendConstraints( bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout( const HloInstruction* instruction) { - // Inputs to cudnn batchnorm custom calls don't need the major-first layout - // (i.e. {n, n-1, ...0}) -- we can handle any layout. - return !IsCustomCallToDnnBatchNorm(*instruction); + // - Inputs to cudnn batchnorm custom calls don't need the major-first layout + // (i.e. {n, n-1, ...0}) -- we can handle any layout. + // - Inputs to cudnn convolution require custom layouts handled in + // AddBackendConstraints. + return !IsCustomCallToDnnBatchNorm(*instruction) && + !IsCustomCallToDnnConvolution(*instruction); } Status GpuLayoutAssignment::PropagateOperandConstraint( diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index c2115c49993ef71c4b6dd584e7e0498807666613..061210352cf12e6802d066d311fd2cb481673f15 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -22,12 +22,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + void HloToIrBindings::EmitBasePointersForHlos( tensorflow::gtl::ArraySlice io_hlos, tensorflow::gtl::ArraySlice non_io_hlos) { @@ -191,7 +196,11 @@ static bool BuffersInvariantWithinConsumer( llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, const HloInstruction& consumer, const ShapeIndex& shape_index) { - llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index), + llvm::Value* base_ptr = GetBasePointer(hlo, shape_index); + CHECK_NE(base_ptr, nullptr) + << "Buffer not assigned for shape_index " << shape_index.ToString() + << " of " << hlo.ToString(); + llvm_ir::IrArray ir_array(base_ptr, ShapeUtil::GetSubshape(hlo.shape(), shape_index)); alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); @@ -223,5 +232,54 @@ void HloToIrBindings::UnbindAllLocalIrValues() { } } +string HloToIrBindings::ToString() const { + string s = StrCat("** HloToIrBindings **\n"); + StrAppend(&s, " is_nested_=", is_nested_, "\n"); + StrAppend(&s, + " temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_), + "\n"); + + if (base_ptrs_.empty()) { + return s; + } + + // Iterate over all computations in the module in topological order, and print + // out the base pointers we have in each computation in topological order. + for (const HloComputation* computation : + base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) { + bool is_first = true; + for (const HloInstruction* instr : + computation->MakeInstructionPostOrder()) { + auto it = base_ptrs_.find(instr); + if (it == base_ptrs_.end()) { + continue; + } + if (is_first) { + StrAppend(&s, " Base pointers for computation ", computation->name(), + ":\n"); + is_first = false; + } + StrAppend(&s, " ", instr->ToString()); + + const ShapeTree& shape_tree = it->second; + if (!ShapeUtil::IsTuple(instr->shape())) { + const llvm::Value* val = shape_tree.begin()->second; + StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n"); + continue; + } + + StrAppend(&s, "\n"); + for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end(); + ++shape_it) { + llvm::Value* val = shape_it->second; + StrAppend(&s, " ", shape_it->first.ToString(), " -> ", + (val != nullptr ? llvm_ir::DumpToString(*val) : "null"), + "\n"); + } + } + } + return s; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 62ae1769a1f2fb3b9acaf35bdf18a793232500b0..3d34311b4368d17cb074aaf33c71fc865e96387e 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -66,13 +66,14 @@ class HloToIrBindings { } llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } + void SetTempBufferBase(llvm::Value* v) { temp_buffer_base_ = v; } // A helper method that returns the base pointer of the IrArray containing the // output of "inst".at the given ShapeIndex. llvm::Value* GetBasePointer(const HloInstruction& hlo, const ShapeIndex& shape_index = {}) const { auto it = base_ptrs_.find(&hlo); - CHECK(it != base_ptrs_.end()); + CHECK(it != base_ptrs_.end()) << hlo.ToString(); return it->second.element(shape_index); } @@ -87,6 +88,8 @@ class HloToIrBindings { const HloInstruction& consumer, const ShapeIndex& shape_index = {}); + string ToString() const; + private: // Emits IR to resolve (possibly) recursive GetTupleElement instructions. llvm::Value* EmitGetTupleElement(const HloInstruction* gte, @@ -111,7 +114,7 @@ class HloToIrBindings { std::unordered_map> base_ptrs_; // The address of the memory block that contains all temporary buffers. - llvm::Value* temp_buffer_base_; + llvm::Value* temp_buffer_base_ = nullptr; llvm_ir::AliasAnalysis alias_analysis_; }; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 1d47ffde4331868cbc8a8afb2d01b11e77a7fab0..2d6dad27a59978da6e4719afc50ebee5e641dde0 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -137,49 +137,6 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { - HloComputation::Builder builder(TestName()); - auto input = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), "input")); - auto filter = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 2}), "filter")); - - Window conv_window; - WindowDimension* conv_window_row = conv_window.add_dimensions(); - conv_window_row->set_size(1); - WindowDimension* conv_window_col = conv_window.add_dimensions(); - conv_window_col->set_size(2); - conv_window_col->set_padding_high(1); - - ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_input_batch_dimension(0); - conv_dnums.set_output_batch_dimension(0); - conv_dnums.set_input_feature_dimension(1); - conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_input_spatial_dimensions(2); - conv_dnums.add_output_spatial_dimensions(2); - conv_dnums.add_input_spatial_dimensions(3); - conv_dnums.add_output_spatial_dimensions(3); - conv_dnums.set_kernel_output_feature_dimension(0); - conv_dnums.set_kernel_input_feature_dimension(1); - conv_dnums.add_kernel_spatial_dimensions(2); - conv_dnums.add_kernel_spatial_dimensions(3); - - auto conv = builder.AddInstruction( - HloInstruction::CreateConvolve(ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), - input, filter, conv_window, conv_dnums)); - auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 1, 1, 1}), conv, {3, 2, 1, 0})); - builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); -} - TEST_F(InstructionFusionTest, GetTupleElementFused) { HloComputation::Builder builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {8}); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 76566a9e3dbbc936ff90fe3f440ede14bf4e5233..2f65edffea81db7dba1f8545f92b27ea622044e7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -90,43 +90,6 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { return false; } -bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { - // We can only do this if the HLO is unnested. - if (hlo.parent() != hlo.GetModule()->entry_computation()) { - return false; - } - - // Forward convolution. - if (hlo.opcode() == HloOpcode::kConvolution) { - const ConvolutionDimensionNumbers& dnums = - hlo.convolution_dimension_numbers(); - if (dnums.input_spatial_dimensions_size() > 3) { - return false; - } - - // CuDNN does not accept zero-element arguments - if (ShapeUtil::HasZeroElements(hlo.operand(0)->shape()) || - ShapeUtil::HasZeroElements(hlo.operand(1)->shape())) { - return false; - } - - if (window_util::HasWindowReversal(hlo.window())) { - return false; - } - - return true; - } - - // Backward convolution. - if (hlo.opcode() == HloOpcode::kFusion && - (hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardFilter || - hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardInput)) { - return true; - } - - return false; -} - const char* const kCudnnBatchNormForwardInferenceCallTarget = "__cudnn$batchNormalizationForwardInference"; const char* const kCudnnBatchNormForwardTrainingCallTarget = @@ -144,9 +107,76 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) { target == kCudnnBatchNormBackwardCallTarget; } +const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward"; +const char* const kCudnnConvBackwardInputCallTarget = + "__cudnn$convBackwardInput"; +const char* const kCudnnConvBackwardFilterCallTarget = + "__cudnn$convBackwardFilter"; + +bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { + return false; + } + const auto& target = hlo.custom_call_target(); + return target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardInputCallTarget || + target == kCudnnConvBackwardFilterCallTarget; +} + bool ImplementedAsLibraryCall(const HloInstruction& hlo) { - return ImplementedAsGemm(hlo) || ImplementedAsDnnConvolution(hlo) || - IsCustomCallToDnnBatchNorm(hlo); + return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) || + IsCustomCallToDnnConvolution(hlo); +} + +static HloInstruction* CreateCudnnConv( + const char* call_target, const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, const Window& window, + const ConvolutionDimensionNumbers& dnums) { + HloComputation* computation = lhs->parent(); + + // This call returns a tuple of (conv_result, scratch_memory), where + // conv_result is the actual result of the convolution, and scratch_memory is + // temporary memory used by cudnn. + // + // At the moment, we don't know how much scratch memory this conv is going to + // use, so we put u8[0] in this place. Later on another pass will choose + // which conv algorithm to use, and at that point we'll modify the shape of + // this second tuple element. + Shape call_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); + + // Our CustomCall takes three arguments: The conv lhs and rhs, and the cudnn + // algorithm to use. It's up to a later pass to choose the algorithm, so to + // indicate that we haven't yet made a choice, we speicfy -1 for that arg. + HloInstruction* negative_one = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-1))); + HloInstruction* custom_call = + computation->AddInstruction(HloInstruction::CreateCustomCall( + call_shape, {lhs, rhs, negative_one}, call_target)); + custom_call->set_window(window); + custom_call->set_convolution_dimension_numbers(dnums); + return custom_call; +} + +HloInstruction* CreateCudnnConvForward( + const Shape& shape, HloInstruction* input, HloInstruction* kernel, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, + window, dnums); +} + +HloInstruction* CreateCudnnConvBackwardInput( + const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, + reverse_filter, window, dnums); +} + +HloInstruction* CreateCudnnConvBackwardFilter( + const Shape& shape, HloInstruction* input, HloInstruction* output, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, + output, window, dnums); } bool IsReductionToVector(const HloInstruction& reduce) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index d24ed9879d084e96862885efaae2f79a256cd71d..59455f389e733fee2d6cace7486f919a0c5e834e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -22,6 +22,9 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they +// don't belong in "ir_emission_utils". + namespace xla { namespace gpu { @@ -30,9 +33,6 @@ constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. bool ImplementedAsGemm(const HloInstruction& hlo); -// Returns true if `hlo` will be implemented as a call to cuDNN convolution. -bool ImplementedAsDnnConvolution(const HloInstruction& hlo); - // A call to cuDNN for batch normalization is represented as CustomCall HLO with // a call target equal to one of these strings. // @@ -58,6 +58,61 @@ extern const char* const kCudnnBatchNormBackwardCallTarget; // sequence of generic HLOs or to a cuDNN CustomCall. bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); +// A call to cuDNN for convolution (forward, backward filter, or backward input) +// is represented as a CustomCall HLO with a call target equal to one of these +// strings. +// +// These CustomCalls have window() and convolution_dimension_numbers() set like +// regular convolution ops. They have the same LHS and RHS operands, plus two +// additional constant operands: an int64 operand for the cudnn algorithm and +// a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn +// algorithm means that the implementation is free to choose the best algorithm +// it can. +// +// These calls output a tuple (conv_result, scratch_memory), where conv_result +// is the actual result of the convolution, and scratch_memory is temporary +// memory used by cudnn. Callers shouldn't inspect scratch_memory, as its value +// is not well-defined. +// +// CudnnConvolutionRewriter lowers kConvolution HLOs to these custom calls. +// When it does so, it chooses algorithm -1 and 0 bytes of scratch space. Later +// on in the pipeline, CudnnConvolutionAlgorithmChooser chooses an explicit +// algorithm for each conv and sets the amount of scratch space needed. +// +// (Representing the scratch memory as an output may seem strange at first, but +// it's quite sensible, from a certain point of view. The scratch buffer is a +// location in memory that the conv can write into, but which it can't legally +// read from, at least until it's written something first. But that's exactly +// the definition of an output buffer.) +extern const char* const kCudnnConvForwardCallTarget; +extern const char* const kCudnnConvBackwardInputCallTarget; +extern const char* const kCudnnConvBackwardFilterCallTarget; + +// Returns true if `hlo` will be implemented as a call to a cuDNN convolution +// routine. +// +// This returns true if `hlo` is a CustomCall HLO with a call target equal to +// one of the kCudnnConvFoo constants above, but returns *false* for HLOs with a +// kConvolution opcode. +bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); + +// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv. +// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If +// you want just the conv result, you'll need to get-tuple-element the value +// returned by this function. +// +// The created cudnn call will use the default cudnn algorithm and no scratch +// space. +HloInstruction* CreateCudnnConvForward( + const Shape& shape, HloInstruction* input, HloInstruction* kernel, + const Window& window, const ConvolutionDimensionNumbers& dnums); +HloInstruction* CreateCudnnConvBackwardInput( + const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, + const Window& window, const ConvolutionDimensionNumbers& dnums); +HloInstruction* CreateCudnnConvBackwardFilter( + const Shape& shape, HloInstruction* input, HloInstruction* output, + const Window& window, const ConvolutionDimensionNumbers& dnums); + // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. bool ImplementedAsLibraryCall(const HloInstruction& hlo); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 095c3df3bfc75cae999edc7fdd800f6e399546dd..a3df67a87344d6ece2ea9047321ad9542c13f8cf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" @@ -615,8 +617,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { // TODO(b/33011107): Support cross replica sum on GPU. - return Unimplemented( - "Cross replica sum not implemented on GPU. See b/33011107."); + return Unimplemented("CrossReplicaSum is not implemented on GPU."); } Status IrEmitter::HandleParameter(HloInstruction* parameter) { @@ -710,11 +711,13 @@ Status IrEmitter::HandleCustomCall(HloInstruction*) { } Status IrEmitter::HandleInfeed(HloInstruction*) { - return Unimplemented("Infeed is not supported on GPU (b/30467474)."); + // TODO(b/30467474): Implement infeed on GPU. + return Unimplemented("Infeed is not supported on GPU."); } Status IrEmitter::HandleOutfeed(HloInstruction*) { - return Unimplemented("Outfeed is not supported on GPU (b/34359662)."); + // TODO(b/34359662): Implement outfeed on GPU. + return Unimplemented("Outfeed is not supported on GPU."); } Status IrEmitter::HandleRng(HloInstruction* random) { @@ -758,37 +761,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -Status IrEmitter::HandleConditional(HloInstruction* conditional) { - auto pred = conditional->operand(0); - auto true_arg = conditional->operand(1); - auto false_arg = conditional->operand(2); - - llvm::Value* conditional_result = GetBasePointer(*conditional); - - llvm::LoadInst* pred_value = ir_builder_.CreateLoad( - GetBasePointer(*pred), - llvm_ir::AsStringRef(IrName(conditional, "load_predicate_value"))); - llvm::Value* pred_cond = ir_builder_.CreateICmpNE( - pred_value, - llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), - llvm_ir::AsStringRef(IrName(conditional, "boolean_predicate"))); - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - pred_cond, IrName(conditional, "if_then_else"), &ir_builder_); - - SetToFirstInsertPoint(if_data.true_block, &ir_builder_); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *conditional->true_computation(), {GetBasePointer(*true_arg)}, - conditional_result)); - - SetToFirstInsertPoint(if_data.false_block, &ir_builder_); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *conditional->false_computation(), {GetBasePointer(*false_arg)}, - conditional_result)); - - SetToFirstInsertPoint(if_data.after_block, &ir_builder_); - return Status::OK(); -} - llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 39bafaa34656a35f24444dc7f3665c1250833921..b0accc08d479258d65a18202122e4c9e90ff78d0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -13,19 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// An XLA HLO graph may contain multiple computations. These computations -// fall into two types, nested and unnested. We translate each nested -// computation (e.g. the computation operand of a Map operator) to a device -// function. For each unnested computation composed of top-level -// HloInstructions, we generate a CUDA kernel for each HloInstruction. -// -// This file declares classes that translate an XLA HLO graph to LLVM IR for -// GPUs. IrEmitterNested emits LLVM IR for nested computations, and -// IrEmitterUnnested for unnested computations. The logic of emitting LLVM IR -// for each individual HloInstruction is largely the same between these two -// classes. Therefore, we implement the common logic in the Handle* functions in -// the superclass IrEmitter. - #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ @@ -60,19 +47,28 @@ limitations under the License. namespace xla { namespace gpu { -// This class is the top-level API for the XLA HLO --> LLVM IR compiler. -// It implements the DfsHloVisitor interface and emits an LLVM IR program that -// implements the input HLO graph. +// Abstract base class for translating HLO graphs to LLVM IR for a GPU. +// +// There are two concrete subclasses of IrEmitter: IrEmitterNested and +// IrEmitterUnnested. In the unnested variety, each HLO gets its own kernel +// function, whereas in the nested version the whole computation is emitted as +// one *non-kernel* function. +// +// In XLA, kernel functions never call other kernel functions. This means that +// if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use +// an HLO computation as a "subroutine" -- e.g. the HLO computation that +// specifies how to reduce two elements -- then the subroutine computation must +// be emitted using IrEmitterNested. // -// Note: if `T` is a subclass of `IrEmitter` and a handler is not overridden in -// either `IrEmitter` or `T`, the handler in `DfsHloVisitorWithDefault` -// calls `T::DefaultAction`. +// Fusion nodes are a special case. A fusion node is emitted using +// IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is +// not a subclass of gpu::IrEmitter, and in fact is better understood as an IR +// generator generator. See comments on that class. class IrEmitter : public DfsHloVisitorWithDefault { public: IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; - // The following methods implement the DfsHloVisitorWithDefault interface. Status DefaultAction(HloInstruction* hlo) override; Status HandleConstant(HloInstruction* constant) override; Status HandleBitcast(HloInstruction* bitcast) override; @@ -96,7 +92,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleRng(HloInstruction* random) override; - Status HandleConditional(HloInstruction* conditional) override; Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; @@ -218,197 +213,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { std::map computation_to_ir_function_; }; -// Emits LLVM IR for unnested computations. Each HloInstruction is translated to -// a separate CUDA kernel. These kernels are inserted into the resultant module -// sorted in reverse postorder of the XLA HLO graph. -class IrEmitterUnnested : public IrEmitter { - public: - IrEmitterUnnested(const HloModuleConfig& hlo_module_config, - const HloComputation* hlo_computation, - IrEmitterContext* ir_emitter_context); - IrEmitterUnnested(const IrEmitterUnnested&) = delete; - IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; - - // Transfers the ownship of thunk_sequence_ out. - std::unique_ptr ConsumeThunkSequence() { - return std::move(thunk_sequence_); - } - - Status DefaultAction(HloInstruction* hlo) override; - - // IrEmitterUnnested handles the following instructions differently from - // IrEmitter. - Status HandleCopy(HloInstruction* copy) override; - Status HandleConditional(HloInstruction* conditional) override; - Status HandleConvolution(HloInstruction* convolution) override; - Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleDot(HloInstruction* dot) override; - Status HandleFft(HloInstruction* fft) override; - Status HandleFusion(HloInstruction* fusion) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; - Status HandleReduce(HloInstruction* reduce) override; - Status HandleSelectAndScatter(HloInstruction* instruction) override; - Status HandleTuple(HloInstruction* tuple) override; - Status HandleWhile(HloInstruction* xla_while) override; - Status HandleInfeed(HloInstruction* xla_infeed) override; - Status HandleRng(HloInstruction* random) override; - Status HandleSelect(HloInstruction* select) override; - - Status EmitTargetElementLoop( - const HloInstruction& hlo, - const llvm_ir::ElementGenerator& body_emitter) override; - - // Same as `EmitTargetElementLoop`, but in given `thunk` rather than - // `LastThunk()`. - Status EmitTargetElementLoopInThunk( - const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, - KernelThunk* thunk); - - private: - // Builds the appropriate thunk for the instruction hlo and returns the owning - // pointer to it. The caller needs to make sure `inst` outlives the lifetime - // of the returned Thunk object. - std::unique_ptr BuildThunk(const HloInstruction* hlo); - - // Builds the prototype of the IR kernel for `inst` and adds it to the module. - llvm::Function* BuildKernelPrototype( - const HloInstruction& inst, - tensorflow::gtl::ArraySlice escaped_hlos); - - // Emits the base pointers for `hlo` and its operands. `io_hlos` will store - // all input/output HLOs among `hlo` and its operands. - llvm::Function* EmitBasePointersForHloAndItsOperands( - const HloInstruction& hlo, std::vector* io_hlos); - - // EmitColumnReduction and EmitRowReduction emit code for column and row - // reduction of a matrix and/or 3D tensor. Row and column reduction have - // different memory access pattern, so for performance their implementations - // are significantly different. - // - // Emits code that reduces a matrix of shape [height x width] to a vector of - // [width]. Other parameters have the same meaning as those of - // `EmitReductionToVector`. Note that input shape might not be - // [height x width], but can be bitcast to [height x weight] with "height" - // being the major dimension. - Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Emits code that reduces a 3D tensor of shape [depth x height x width] to a - // vector of shape [height]. Other parameters have the same meaning as those - // of `EmitReductionToVector`. Note that input shape might not be - // [depth x height x width], but can be bitcast to [depth x height x weight] - // with "depth" being the most major dimension. - Status EmitRowReduction(int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Figures out whether `reduce` is a row or column reduction, and which - // dimensions to reduce, and calls either `EmitRowReduction` or - // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the - // input array, which is the operand of the Reduce instruction if unfused or - // of the Fusion instruction if fused. `input_gen` and `init_value_gen` - // generate elements of the input and the initial value. Other parameters mean - // the same as for `HandleReduce`. - // - // Prerequisite: `IsReductionToVector(*reduce)` - Status EmitReductionToVector( - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer); - - // Emits code to initialize buffer of `inst` in given `thunk`. - Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); - - // Returns a KernelThunk that invokes the kernel emitted for `inst`. The - // caller needs to make sure `inst` outlives the lifetime of the returned - // Thunk object. - std::unique_ptr BuildKernelThunk(const HloInstruction* inst); - - // Returns a ConvolutionThunk that calls DNN to implement `inst`. - std::unique_ptr BuildConvolutionThunk(const HloInstruction* inst); - - // Returns a FftThunk that calls cuFFT to implement `inst`. - std::unique_ptr BuildFftThunk(const HloInstruction* inst); - - // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs - // to make sure `inst` outlives the lifetime of the returned Thunk object. - std::unique_ptr BuildGemmThunk(const HloInstruction* inst); - - // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); - - // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildDeviceToDeviceCopyThunk( - const HloInstruction* inst); - - // Returns an InfeedThunk that performs device-to-device memcpy to implement - // `inst`. - std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); - - // Returns a WhileThunk that invokes thunk sequences for 'condition' and - // 'body' sub-computations of while instruction 'hlo'. - std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); - - // Returns a ForThunk which executes 'loop_limit' invocations of a thunk - // sequence from the 'body' sub-computation of the while instruction 'hlo'. - std::unique_ptr BuildForThunk(const HloInstruction* hlo, - const int64 loop_limit); - - Status Postprocess(HloInstruction* hlo) override; - - // Returns the last generated thunk. - Thunk* LastThunk() const { return thunk_sequence_->back().get(); } - - // The thunk sequence this IrEmitter generates for the input computation. - std::unique_ptr thunk_sequence_; - - // The HloComputation that this IrEmitter emits code for. - const HloComputation* hlo_computation_; -}; - -// Emits LLVM IR for a nested computation to the resultant function. -class IrEmitterNested : public IrEmitter { - public: - // Constructs an LLVM IR emitter for a nested HLO computation. `function` is - // the containing IR function this emitter produces IR to. See - // IrEmitter::IrEmitter for the meanings of other arguments. - IrEmitterNested(const HloModuleConfig& hlo_module_config, - const HloComputation& nested_computation, - IrEmitterContext* ir_emitter_context); - IrEmitterNested(const IrEmitterNested&) = delete; - IrEmitterNested& operator=(const IrEmitterNested&) = delete; - - // Overrides the default empty implementation. Binds the given instruction - // "parameter" with the parameter of the IR function. - Status HandleParameter(HloInstruction* parameter) override; - - llvm::Function* GetEmittedFunction() const { return emitted_function_; } - - Status EmitTargetElementLoop( - const HloInstruction& hlo, - const llvm_ir::ElementGenerator& body_emitter) override; - - private: - llvm::Function* EmitBasePointersForNestedComputation( - const HloComputation& nested_computation, - std::vector* io_hlos); - - llvm::Function* emitted_function_; -}; - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5225ff36ff3a8a1b049479c34aa301de8724f73e..71aada080ae8df70bffce3e1854b5fbd833efd23 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -16,12 +16,13 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" + #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h new file mode 100644 index 0000000000000000000000000000000000000000..ca11cf2c182b0600b931b19d2d7fb3983e36441a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ + +#include "llvm/IR/Function.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" + +namespace xla { +namespace gpu { + +// Emits LLVM IR for a "nested computation" into a non-kernel device function. +// +// This is used to emit code for HloComputations that don't require a separate +// kernel call. For example, IrEmitterNested is used to emit code for a kReduce +// HLO's elementwise reduction computation. Notably, IrEmitterNested is *not* +// used to emit code for fusion nodes -- fusion nodes use FusedIrEmitter, which +// is a different beast altogether. +// +// IrEmitterNested generates a non-kernel function with the following +// parameters: +// +// - N pointers to the buffers of each of the N parameters to the computation, +// - a pointer to the output buffer of the computation, and +// - a pointer to the top-level temp buffer. +// +class IrEmitterNested : public IrEmitter { + public: + // Constructs an LLVM IR emitter for a nested HLO computation. `function` is + // the containing IR function this emitter produces IR to. See + // IrEmitter::IrEmitter for the meanings of other arguments. + IrEmitterNested(const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterNested(const IrEmitterNested&) = delete; + IrEmitterNested& operator=(const IrEmitterNested&) = delete; + + // Overrides the default empty implementation. Binds the given instruction + // "parameter" with the parameter of the IR function. + Status HandleParameter(HloInstruction* parameter) override; + + llvm::Function* GetEmittedFunction() const { return emitted_function_; } + + Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) override; + + private: + llvm::Function* EmitBasePointersForNestedComputation( + const HloComputation& nested_computation, + std::vector* io_hlos); + + llvm::Function* emitted_function_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index be35351e8727ce15998460e41f21a53ebe427c3b..30c88c0a5d38f6ea3f94d3b47b7b69c7122bf6ac 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" + #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -28,9 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" @@ -38,7 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" @@ -72,6 +75,10 @@ namespace gpu { namespace { using llvm_ir::IrName; +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; +using tensorflow::strings::StrCat; // If a dimensions is smaller than this, untiled transposition may be more // efficient. @@ -134,6 +141,38 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::MDString::get(llvm_context, "reqntidx"), llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } + +// Tries to get a Slice for the given instruction at the given index, but +// returns nullopt if we might not know the slice's address at runtime without +// dereferencing a containing tuple. +// +// In particular, when XLA accepts a parameter of tuple type, the caller has the +// option of telling XLA what are the values inside of the tuple, or just giving +// XLA a pointer to the top-level tuple and letting us chase the pointers on the +// GPU. We therefore cannot rely having these pointers to parameter sub-buffers +// being present when we run the program. +optional GetKnownAtRuntimeSlice( + const HloInstruction* instr, const ShapeIndex& index, + const BufferAssignment& buffer_assn) { + auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index); + if (!maybe_slice.ok()) { + return nullopt; + } + // BufferAllocation gives a slice and alloc to every buffer accessed by XLA, + // but we don't necessarily know the runtime address of sub-buffers of input + // parameters. + const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie(); + const BufferAllocation* alloc = slice.allocation(); + if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() && + !alloc->param_shape_index().empty()) { + return nullopt; + } + + // Otherwise, we will know the address of this slice at runtime without having + // to dereference a tuple. + return slice; +} + } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, @@ -151,16 +190,20 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { } namespace { -bool ImplementedAsHostToDeviceMemcpy(const HloInstruction& hlo) { - // `hlo` needs to satisfy three conditions to be implemented as a +bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, + const HloInstruction& hlo) { + // `hlo` needs to satisfy the following conditions to be implemented as a // host-to-device cuMemcpy. // // 1. `hlo` is a kCopy instruction. // 2. `hlo`'s only operand is a kConstant instruction. // 3. `hlo` and its operand have the same shape (thus the same layout too). + // 4. The address of `hlo`'s buffer is known at runtime (without dereferencing + // pointers in a tuple). return hlo.opcode() == HloOpcode::kCopy && hlo.operand(0)->opcode() == HloOpcode::kConstant && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()); + ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && + GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value(); } bool ImplementedAsDeviceToDeviceMemcpy( @@ -174,13 +217,15 @@ bool ImplementedAsDeviceToDeviceMemcpy( // instance) which means the source buffer also resides on the device. return hlo.opcode() == HloOpcode::kCopy && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.HasTopLevelAllocation(hlo.operand(0)); + GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() && + GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment) + .has_value(); } } // namespace llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice escaped_hlos) { + tensorflow::gtl::ArraySlice args) { // Compute the kernel name. The opcode string may contain "-" which cannot be // in a PTX function name, so sanitize the name before uniquifying it. string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( @@ -189,43 +234,32 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( // Create the kernel and add it to the module. llvm::Module* module = ir_emitter_context_->llvm_module(); llvm::LLVMContext& context = module->getContext(); - int num_escaped_hlos = escaped_hlos.size(); llvm::FunctionType* kernel_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(context), - std::vector(num_escaped_hlos + 1, - ir_builder_.getInt8PtrTy()), + std::vector(args.size(), ir_builder_.getInt8PtrTy()), /*isVarArg=*/false); llvm::Function* kernel = llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, kernel_name.c_str(), module); - // Add dereferenceable information to each of the escaped HLO parameters. - for (size_t arg_no = 0; arg_no < escaped_hlos.size(); ++arg_no) { - const HloInstruction* escaped_hlo = escaped_hlos[arg_no]; - const Shape& escaped_hlo_shape = escaped_hlo->shape(); - int64 escaped_hlo_size = llvm_ir::ByteSizeOf( - escaped_hlo_shape, ir_emitter_context_->llvm_module()->getDataLayout()); - kernel->addDereferenceableAttr(arg_no + 1, escaped_hlo_size); - } - - // The last argument is a pointer to the temporary buffer memory block. - // We know that it doesn't alias any of the escaped arguments (the inputs + - // the result). We also know how many bytes can be dereferenced in it. - const llvm::Argument& temp_buffer = *std::prev(kernel->arg_end()); - int64 temp_buffer_arg_no = temp_buffer.getArgNo(); - int64 temp_allocation_total_size = - ir_emitter_context_->buffer_assignment().temp_allocation_total_size(); - if (temp_allocation_total_size != 0) { - kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, - temp_allocation_total_size); - } - kernel->addParamAttr(temp_buffer_arg_no, llvm::Attribute::NoAlias); + // Add dereferenceable and alignment information to each of the kernel's + // parameters. + auto arg_it = kernel->arg_begin(); + for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) { + const BufferAllocation* alloc = args[arg_no]; + llvm::Argument* fn_arg = &*arg_it; + ++arg_it; - // All arguments to a kernel must be aligned to kCudaMallocAlignBytes. - for (int64 i = 0; i < kernel->arg_size(); ++i) { + kernel->addDereferenceableAttr(arg_no + 1, alloc->size()); kernel->addParamAttr( - i, llvm::Attribute::get(context, llvm::Attribute::Alignment, - kCudaMallocAlignBytes)); + arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment, + kCudaMallocAlignBytes)); + + if (alloc->IsPreallocatedTempBuffer()) { + fn_arg->setName("temp_buf"); + } else { + fn_arg->setName(llvm_ir::AsStringRef(StrCat("alloc", alloc->index()))); + } } // TODO(b/65380986): Investigate if adding fast math flags for generated @@ -242,10 +276,9 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( // Update the insert point to the entry basic block. llvm::BasicBlock* entry_bb = - llvm::BasicBlock::Create(context, - "entry", // The name of the basic block. - kernel); // The parent/owner of "entry_bb". - // Emit a "return void" at entry_bb's end, and sets the insert point before + llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel); + + // Emit a "return void" at entry_bb's end, and set the insert point before // that return instruction. ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); @@ -272,15 +305,11 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { } Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { - thunk_sequence_->push_back(BuildKernelThunk(conditional)); - return IrEmitter::HandleConditional(conditional); + thunk_sequence_->emplace_back(BuildConditionalThunk(conditional)); + return Status::OK(); } Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { - if (ImplementedAsDnnConvolution(*convolution)) { - thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution)); - return Status::OK(); - } thunk_sequence_->emplace_back(BuildKernelThunk(convolution)); return IrEmitter::HandleConvolution(convolution); } @@ -379,6 +408,76 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { return Status::OK(); } + if (IsCustomCallToDnnConvolution(*custom_call)) { + const auto& assn = ir_emitter_context_->buffer_assignment(); + const auto& lhs_shape = custom_call->operand(0)->shape(); + const auto& rhs_shape = custom_call->operand(1)->shape(); + const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); + auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); + auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); + auto tuple_result_slice = GetAllocationSlice(*custom_call); + auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); + auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); + + const HloInstruction* algorithm_inst = custom_call->operand(2); + CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); + int64 algorithm = algorithm_inst->literal().Get({}); + + const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3); + CHECK(tensor_ops_enabled_inst->IsConstant()) + << tensor_ops_enabled_inst->ToString(); + bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get({}); + + const auto& target = custom_call->custom_call_target(); + std::unique_ptr thunk; + if (target == kCudnnConvForwardCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kForward, + /*input_buffer=*/lhs_slice, + /*filter_buffer=*/rhs_slice, + /*output_buffer=*/conv_result_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/conv_result_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, tensor_ops_enabled, custom_call); + } else if (target == kCudnnConvBackwardInputCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kBackwardInput, + /*input_buffer=*/conv_result_slice, + /*filter_buffer=*/rhs_slice, + /*output_buffer=*/lhs_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/conv_result_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/lhs_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, tensor_ops_enabled, custom_call); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kBackwardFilter, + /*input_buffer=*/lhs_slice, + /*filter_buffer=*/conv_result_slice, + /*output_buffer=*/rhs_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/lhs_shape, + /*filter_shape=*/conv_result_shape, + /*output_shape=*/rhs_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, tensor_ops_enabled, custom_call); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << custom_call->custom_call_target(); + } + + thunk_sequence_->emplace_back(std::move(thunk)); + return Status::OK(); + } + return IrEmitter::HandleCustomCall(custom_call); } @@ -499,10 +598,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); return Status::OK(); } - if (ImplementedAsDnnConvolution(*fusion)) { - thunk_sequence_->emplace_back(BuildConvolutionThunk(fusion)); - return Status::OK(); - } thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); return IrEmitter::HandleFusion(fusion); } @@ -804,7 +899,8 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, } // namespace Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - if (ImplementedAsHostToDeviceMemcpy(*copy)) { + if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(), + *copy)) { thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy)); return Status::OK(); } @@ -1598,24 +1694,24 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); - bool all_tuple_elements_have_buffer = std::all_of( - operands.begin(), operands.end(), [this](HloInstruction* tuple_element) { + bool all_tuple_elements_have_buffer = + c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation( tuple_element); }); - // Tuples (especially output tuples) can take too many tuple elements, - // causing the kernel emitted exceeds the parameter space limit - // (b/31336476). As an optimization, if all tuple elements have a buffer, we - // collect their buffer addresses in a host array, and then copy that array - // to the tuple's buffer. + // Tuples (especially tuples that are the final result of a computation) can + // be so huge that if we were to emit a kernel that took each tuple element as + // a parameter, we would exceed the max allowable number of parameters to a + // GPU kernel, b/31336476. As an optimization, if all tuple elements have a + // buffer, we collect their buffer addresses in a host array, and then copy + // that array to the tuple's buffer. // // Some tuple elements (e.g. const or bitcast of const) might not have a - // buffer -- their contents are stored in code. In that case, we fall back - // to emitting kernels which have access to their buffer addresses in code. + // buffer -- their contents are stored in code. In that case, we fall back to + // emitting kernels which have access to their buffer addresses in code. if (all_tuple_elements_have_buffer) { std::vector tuple_element_buffers; - for (const HloInstruction* tuple_element : operands) { + for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } thunk_sequence_->emplace_back(MakeUnique( @@ -1657,8 +1753,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { return Unimplemented( - "Dilation for select-and-scatter not implemented on GPU. " - "See b/31410564."); + "Dilation for SelectAndScatter not implemented on GPU."); } // kSelectAndScatter is implemented as two kernel launches: the first launch @@ -1867,62 +1962,207 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { return Status::OK(); } -llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands( - const HloInstruction& hlo, std::vector* io_hlos) { - const BufferAssignment& buffer_assignment = - ir_emitter_context_->buffer_assignment(); - // GetTupleElement instructions are implemented by emitting IR that indexes - // and loads the target tuple element pointer from its operand (possibly - // recursively). For this reason, GetTupleElement instructions are associated - // with their operand buffer in 'io_hlos' and 'non_io_hlos' below. - std::vector non_io_hlos; - for (const HloInstruction* operand : hlo.operands()) { - const HloInstruction* to_lookup = operand->LatestNonGteAncestor(); - if (buffer_assignment.HasTopLevelAllocation(to_lookup) && - buffer_assignment.GetUniqueTopLevelSlice(to_lookup) - .ConsumeValueOrDie() - .allocation() - ->IsInputOrOutput()) { - io_hlos->push_back(operand); - } else { - non_io_hlos.push_back(operand); +// Figures out how to access the buffers for all subshapes of hlo's operands and +// for hlo itself (i.e. all the buffers produced by HLO). +// +// Returns a map keyed on the pair {HloInstruction, ShapeIndex}. The value for +// this key is a pair {Slice, ShapeIndex}, where the slice tells you the root +// buffer to look in, and the ShapeIndex describes how to dereference starting +// at that buffer to get to the buffer in question. +// +// For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for +// hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) +// is found at slice[3][4]. That is, slice is a void***, which we dereference +// twice -- first at index 3, and then at index 4 -- to get the address of our +// buffer. +// +// This function conservatively assumes that we'll touch all sub-buffers of +// every operand and of the output. +static std::map, + std::pair> +GetHloBufferSlices(const HloInstruction* hlo, + const BufferAssignment& buffer_assn) { + std::map, + std::pair> + slices; + + // Tries to find a slice plus an array of indices i1, ..., iN such that the + // sub-buffer for instr at index can be found at slice[i1]...[iN]. + auto find_slice_for = [&](const HloInstruction* instr, + const ShapeIndex& index) + -> optional> { + // Simple, common case: Is the buffer for instr known at runtime? If so, + // we're done. + auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn); + if (slice.has_value()) { + return {{*slice, ShapeIndex()}}; } - } - CHECK_NE(HloOpcode::kGetTupleElement, hlo.opcode()); - if (buffer_assignment.HasTopLevelAllocation(&hlo) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo) - .ConsumeValueOrDie() - .allocation() - ->IsInputOrOutput()) { - io_hlos->push_back(&hlo); - } else { - non_io_hlos.push_back(&hlo); + // If we don't know the buffer for instr at index, see if we know the buffer + // for instr at index without its last element. If so, we can dynamically + // find the buffer for instr by dereferencing a pointer in that buffer. + // Continue looking this way until we run out of elements in 'index'. + ShapeIndex new_index = index; + ShapeIndex gte_indices; + while (!new_index.empty()) { + gte_indices.push_front(new_index.back()); + new_index.pop_back(); + auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn); + if (slice.has_value()) { + return {{*slice, gte_indices}}; + } + } + + // If *that* didn't work, check whether instr is a GTE instruction. If it + // is, see if we can get a buffer for its parent, and continue walking up + // parents until we find a defined buffer or we hit something that's not a + // GTE. + const HloInstruction* parent = instr; + while (parent->opcode() == HloOpcode::kGetTupleElement) { + gte_indices.push_front(parent->tuple_index()); + parent = parent->operand(0); + + auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn); + if (slice.has_value()) { + return {{*slice, gte_indices}}; + } + } + + return nullopt; + }; + + // Adds entries for all subshapes of instr to `slices`. + auto add_slices_for = [&](const HloInstruction* instr) { + // GPU constants don't have buffers; don't bother looking for one. + if (instr->IsConstant()) { + return; + } + + ShapeUtil::ForEachSubshape( + instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { + if (slices.count({instr, index})) { + // HLOs can have duplicate operands; don't bother redoing work. + return; + } + auto maybe_slice = find_slice_for(instr, index); + if (maybe_slice.has_value()) { + slices[{instr, index}] = *maybe_slice; + } else { + VLOG(1) << "Couldn't find buffer for " << instr->ToString() + << " at index " << index.ToString(); + } + }); + }; + + add_slices_for(hlo); + for (const HloInstruction* operand : hlo->operands()) { + // Conservatively assume we'll need the buffers for all subshapes of the + // operand. + add_slices_for(operand); } - llvm::Function* kernel = BuildKernelPrototype(hlo, *io_hlos); - // bindings_ is reused because the bindings of kConstant to their underlying - // llvm::Constant can be shared for all HLOs in this computation. - bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos); - return kernel; + return slices; +} + +Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { + // TODO(b/72710576): Gather is not implemented on GPUs + return Unimplemented("Gather is not implemented on GPUs."); } std::unique_ptr IrEmitterUnnested::BuildKernelThunk( const HloInstruction* inst) { - std::vector io_hlos; - llvm::Function* kernel = - EmitBasePointersForHloAndItsOperands(*inst, &io_hlos); + const BufferAssignment& buffer_assn = + ir_emitter_context_->buffer_assignment(); - // Compute the input buffer indices. - std::vector io_buffers; - io_buffers.reserve(io_hlos.size()); - for (const HloInstruction* io_hlo : io_hlos) { - io_buffers.push_back(GetAllocationSlice(*io_hlo->LatestNonGteAncestor())); + std::map, + std::pair> + hlo_slices = GetHloBufferSlices(inst, buffer_assn); + + // Figure out which buffer allocations need to be passed as arguments to our + // kernel. This is simply all of the allocations referenced in hlo_slices, + // plus the XLA temp buffer (if we have it). We always include the temp + // buffer because even if the kernel itself doesn't use it, a nested + // subcomputation within the kernel (e.g. a kMap's computation) might. + std::unordered_set buffers_needed; + for (const auto& kv : hlo_slices) { + buffers_needed.insert(kv.second.first.allocation()); } + tensorflow::gtl::optional temp_buffer; + for (const BufferAllocation& alloc : buffer_assn.Allocations()) { + if (alloc.IsPreallocatedTempBuffer()) { + if (!temp_buffer.has_value()) { + temp_buffer = &alloc; + } else { + LOG(FATAL) << "Multiple temp buffers found, but only one is allowed!"; + } + } + } + if (temp_buffer.has_value()) { + buffers_needed.insert(*temp_buffer); + } + + // We'll pass a pointer to each of the elements of `buffers` to our kernel, in + // this order. + std::vector buffers(buffers_needed.begin(), + buffers_needed.end()); + std::sort(buffers.begin(), buffers.end(), + [](const BufferAllocation* a, const BufferAllocation* b) { + return a->index() < b->index(); + }); + + llvm::Function* kernel = BuildKernelPrototype(*inst, buffers); - // Create a KernelThunk that launches the kernel that implements "inst". - return MakeUnique(io_buffers, - llvm_ir::AsString(kernel->getName()), inst); + // Build a map from a BufferAllocation to the corresponding argument in our + // kernel. + std::unordered_map kernel_args; + { + auto arg_it = kernel->arg_begin(); + auto buffers_it = buffers.begin(); + for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { + kernel_args[*buffers_it] = arg_it; + } + } + + // For each buffer our kernel might want to touch, bind it to a value derived + // from our kernel args. + for (const auto& kv : hlo_slices) { + const HloInstruction* instr = kv.first.first; + const ShapeIndex& index = kv.first.second; + const BufferAllocation::Slice& slice = kv.second.first; + const ShapeIndex& gte_index = kv.second.second; + + VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString() + << " is found in slice " << slice.ToString() << " at GTE index " + << gte_index.ToString(); + + llvm::Value* loc = + ir_builder_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), + {ir_builder_.getInt64(slice.offset())}); + + // If gte_index is nonempty, we have to dereference `loc` to get to the + // value we're ultimately interested in. + llvm::Type* int8_double_pointer = + llvm::PointerType::get(ir_builder_.getInt8PtrTy(), /*AddressSpace=*/0); + for (int64 idx : gte_index) { + loc = ir_builder_.CreateBitCast(loc, int8_double_pointer); + loc = ir_builder_.CreateLoad( + ir_builder_.CreateInBoundsGEP(loc, {ir_builder_.getInt64(idx)})); + } + + bindings_.BindHloToIrValue(*instr, loc, index); + } + + // Bind the temp buffer so that nested subcomputations can find it if they + // need. + if (temp_buffer.has_value()) { + bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer)); + } else { + bindings_.SetTempBufferBase( + llvm::ConstantPointerNull::get(ir_builder_.getInt8PtrTy())); + } + + return MakeUnique(buffers, llvm_ir::AsString(kernel->getName()), + inst); } std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( @@ -2011,52 +2251,6 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); } -std::unique_ptr IrEmitterUnnested::BuildConvolutionThunk( - const HloInstruction* inst) { - const HloInstruction* lhs = inst->operand(0); - const HloInstruction* rhs = inst->operand(1); - if (inst->opcode() == HloOpcode::kConvolution) { - // Forward covolution. - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kForward, - /*input_buffer=*/GetAllocationSlice(*lhs), - /*filter_buffer=*/GetAllocationSlice(*rhs), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/lhs->shape(), - /*filter_shape=*/rhs->shape(), - /*output_shape=*/inst->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - } - - // Backward filter convolution, which takes the input (activations) and the - // gradients, and computes the filter. - CHECK_EQ(HloOpcode::kFusion, inst->opcode()); - switch (inst->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kBackwardFilter, - /*input_buffer=*/GetAllocationSlice(*lhs), - /*filter_buffer=*/GetAllocationSlice(*inst), - /*output_buffer=*/GetAllocationSlice(*rhs), - /*input_shape=*/lhs->shape(), - /*filter_shape=*/inst->shape(), - /*output_shape=*/rhs->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - case HloInstruction::FusionKind::kConvBackwardInput: - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kBackwardInput, - /*input_buffer=*/GetAllocationSlice(*inst), - /*filter_buffer=*/GetAllocationSlice(*rhs), - /*output_buffer=*/GetAllocationSlice(*lhs), - /*input_shape=*/inst->shape(), - /*filter_shape=*/rhs->shape(), - /*output_shape=*/lhs->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - default: - LOG(FATAL) << "Not a convolution-fusion"; - } -} - std::unique_ptr IrEmitterUnnested::BuildFftThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); @@ -2102,6 +2296,24 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, namespace { +// Checks that the buffers corresponding to the given two HLOs share the same +// allocation. +Status CheckHloBuffersShareAllocation( + const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index, + const BufferAssignment& buffer_assignment) { + const BufferAllocation::Slice slice_a = + buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie(); + const BufferAllocation::Slice slice_b = + buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie(); + if (slice_a != slice_b) { + return InternalError( + "instruction %s %s does not share allocation with instruction %s %s", + a->ToString().c_str(), slice_a.ToString().c_str(), + b->ToString().c_str(), slice_b.ToString().c_str()); + } + return Status::OK(); +} + // Checks that all buffers used during while loop iteration share the same // buffer allocation. This includes buffers for while result, while init // operand, condition parameter, body parameter and body result. @@ -2111,37 +2323,65 @@ Status CheckWhileBuffersShareAllocation( const BufferAssignment& buffer_assignment) { return ShapeUtil::ForEachSubshapeWithStatus( xla_while->shape(), - [&buffer_assignment, &xla_while](const Shape& /*subshape*/, - const ShapeIndex& index) -> Status { - auto check = [&buffer_assignment](const HloInstruction* a, - const HloInstruction* b, - const ShapeIndex& index) -> Status { - const BufferAllocation::Slice slice_a = - buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie(); - const BufferAllocation::Slice slice_b = - buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie(); - if (slice_a != slice_b) { - return InternalError( - "instruction %s %s does not share allocation with " - "instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); - } - return Status::OK(); - }; + [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { const HloInstruction* condition_parameter = xla_while->while_condition()->parameter_instruction(0); const HloComputation* body = xla_while->while_body(); const HloInstruction* body_parameter = body->parameter_instruction(0); const HloInstruction* body_result = body->root_instruction(); - TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index)); - TF_RETURN_IF_ERROR(check(xla_while, condition_parameter, index)); - TF_RETURN_IF_ERROR(check(xla_while, body_parameter, index)); - TF_RETURN_IF_ERROR(check(xla_while, body_result, index)); + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + xla_while, xla_while->operand(0), index, buffer_assignment)); + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + xla_while, condition_parameter, index, buffer_assignment)); + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + xla_while, body_parameter, index, buffer_assignment)); + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + xla_while, body_result, index, buffer_assignment)); return Status::OK(); }); } +// Checks that the buffers used in a conditional instruction are shared with the +// operands and result as follows: +// * The result buffer of the conditional should share the allocation with the +// result buffers of the true and false computations. +// * The buffer of operand 1 should share the allocation with the buffer of +// the parameter 0 instruction of the true computation. +// * The buffer of operand 2 should share the allocation with the buffer of +// the parameter 0 instruction of the false computation. +Status CheckConditionalBuffersShareAllocation( + const HloInstruction* conditional, + const BufferAssignment& buffer_assignment) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + conditional->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + conditional, conditional->true_computation()->root_instruction(), + index, buffer_assignment)); + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + conditional, conditional->false_computation()->root_instruction(), + index, buffer_assignment)); + return Status::OK(); + })); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + conditional->operand(1)->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { + return CheckHloBuffersShareAllocation( + conditional->operand(1), + conditional->true_computation()->parameter_instruction(0), index, + buffer_assignment); + })); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + conditional->operand(2)->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { + return CheckHloBuffersShareAllocation( + conditional->operand(2), + conditional->false_computation()->parameter_instruction(0), index, + buffer_assignment); + })); + return Status::OK(); +} + } // namespace std::unique_ptr IrEmitterUnnested::BuildWhileThunk( @@ -2184,9 +2424,36 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( ir_emitter_body.ConsumeThunkSequence(), hlo); } +std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( + const HloInstruction* hlo) { + // Check that the buffers used in conditional are shared with the operands and + // result appropriately. + TF_CHECK_OK(CheckConditionalBuffersShareAllocation( + hlo, ir_emitter_context_->buffer_assignment())); + + HloComputation* true_computation = hlo->true_computation(); + IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation, + ir_emitter_context_); + TF_CHECK_OK(true_computation->root_instruction()->Accept(&ir_emitter_true)); + + HloComputation* false_computation = hlo->false_computation(); + IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation, + ir_emitter_context_); + TF_CHECK_OK(false_computation->root_instruction()->Accept(&ir_emitter_false)); + + return MakeUnique( + GetAllocationSlice(*hlo->operand(0)), + GetAllocationSlice(*hlo->operand(1)), + GetAllocationSlice(*hlo->operand(2)), + std::move(*ir_emitter_true.ConsumeThunkSequence()), + std::move(*ir_emitter_false.ConsumeThunkSequence()), hlo); +} + Status IrEmitterUnnested::EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) { + VLOG(3) << bindings_.ToString(); + const Shape& element_shape = hlo.IsMultiOutputFusion() ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h new file mode 100644 index 0000000000000000000000000000000000000000..b83a2337e2decd9d4fba3d40fcf33f131fca8a3c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -0,0 +1,206 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ + +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" + +namespace xla { +namespace gpu { + +// Emits LLVM IR for an "unnested computation". +// +// An unnested computation is an HloComputation which you run by executing one +// or more kernels for each HloInstruction it contains. Examples of unnested +// computations: +// +// - An HloModule's root computation, +// - The body of an HLO while loop, +// - The true/false computation of an HLO conditional. +// +// Note the opportunity for confusion -- the while loop's computation is nested +// within the root computation, but it's emitted using IrEmitterUnnested! Don't +// think about it too hard. +// +// Examples of things that are not unnested computations: +// +// - The reducer of a kReduce HLO. This is emited using IrEmitterNested. +// - The body of a fusion node. IrEmitterUnenested emits the relevant code +// within a kernel function using FusedIrEmitter. (FusedIrEmitter is not +// really an IrEmitter, but is more an "IR generator generator".) +// +class IrEmitterUnnested : public IrEmitter { + public: + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterUnnested(const IrEmitterUnnested&) = delete; + IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; + + // Transfers the ownship of thunk_sequence_ out. + std::unique_ptr ConsumeThunkSequence() { + return std::move(thunk_sequence_); + } + + Status DefaultAction(HloInstruction* hlo) override; + + // IrEmitterUnnested handles the following instructions differently from + // IrEmitter. + Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleFft(HloInstruction* fft) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleGather(HloInstruction* gather) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleInfeed(HloInstruction* xla_infeed) override; + Status HandleRng(HloInstruction* random) override; + Status HandleSelect(HloInstruction* select) override; + + Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) override; + + // Same as `EmitTargetElementLoop`, but in given `thunk` rather than + // `LastThunk()`. + Status EmitTargetElementLoopInThunk( + const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, + KernelThunk* thunk); + + private: + // Builds the appropriate thunk for the instruction hlo and returns the owning + // pointer to it. The caller needs to make sure `inst` outlives the lifetime + // of the returned Thunk object. + std::unique_ptr BuildThunk(const HloInstruction* hlo); + + // Builds the prototype of the IR kernel for `inst` and adds it to the module. + // This kernel takes as arguments pointers to the given buffer allocations. + llvm::Function* BuildKernelPrototype( + const HloInstruction& inst, + tensorflow::gtl::ArraySlice args); + + // EmitColumnReduction and EmitRowReduction emit code for column and row + // reduction of a matrix and/or 3D tensor. Row and column reduction have + // different memory access pattern, so for performance their implementations + // are significantly different. + // + // Emits code that reduces a matrix of shape [height x width] to a vector of + // [width]. Other parameters have the same meaning as those of + // `EmitReductionToVector`. Note that input shape might not be + // [height x width], but can be bitcast to [height x weight] with "height" + // being the major dimension. + Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Emits code that reduces a 3D tensor of shape [depth x height x width] to a + // vector of shape [height]. Other parameters have the same meaning as those + // of `EmitReductionToVector`. Note that input shape might not be + // [depth x height x width], but can be bitcast to [depth x height x weight] + // with "depth" being the most major dimension. + Status EmitRowReduction(int64 depth, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Emits code that reduces a tensor of arbitrary rank to a scalar. + Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Figures out whether `reduce` is a row or column reduction, and which + // dimensions to reduce, and calls either `EmitRowReduction` or + // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the + // input array, which is the operand of the Reduce instruction if unfused or + // of the Fusion instruction if fused. `input_gen` and `init_value_gen` + // generate elements of the input and the initial value. Other parameters mean + // the same as for `HandleReduce`. + // + // Prerequisite: `IsReductionToVector(*reduce)` + Status EmitReductionToVector( + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reducer); + + // Emits code to initialize buffer of `inst` in given `thunk`. + Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); + + // Returns a KernelThunk that invokes the kernel emitted for `inst`. The + // caller needs to make sure `inst` outlives the lifetime of the returned + // Thunk object. + std::unique_ptr BuildKernelThunk(const HloInstruction* inst); + + // Returns a FftThunk that calls cuFFT to implement `inst`. + std::unique_ptr BuildFftThunk(const HloInstruction* inst); + + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs + // to make sure `inst` outlives the lifetime of the returned Thunk object. + std::unique_ptr BuildGemmThunk(const HloInstruction* inst); + + // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); + + // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildDeviceToDeviceCopyThunk( + const HloInstruction* inst); + + // Returns an InfeedThunk that performs device-to-device memcpy to implement + // `inst`. + std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); + + // Returns a WhileThunk that invokes thunk sequences for 'condition' and + // 'body' sub-computations of while instruction 'hlo'. + std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); + + // Returns a ForThunk which executes 'loop_limit' invocations of a thunk + // sequence from the 'body' sub-computation of the while instruction 'hlo'. + std::unique_ptr BuildForThunk(const HloInstruction* hlo, + const int64 loop_limit); + + // Returns a ConditionalThunk that executes the thunk sequence for + // 'true_computation' or 'false_computation' depending on the value of the + // predicate in the given conditional instruction. + std::unique_ptr BuildConditionalThunk(const HloInstruction* hlo); + + Status Postprocess(HloInstruction* hlo) override; + + // Returns the last generated thunk. + Thunk* LastThunk() const { return thunk_sequence_->back().get(); } + + // The thunk sequence this IrEmitter generates for the input computation. + std::unique_ptr thunk_sequence_; + + // The HloComputation that this IrEmitter emits code for. + const HloComputation* hlo_computation_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 96606993696354f36e143b3b994bbe6afb902df3..c20a781a33fe89af4740ed31dd5bfb1a64473057 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -29,10 +29,10 @@ namespace xla { namespace gpu { KernelThunk::KernelThunk( - tensorflow::gtl::ArraySlice io_buffers, + tensorflow::gtl::ArraySlice args, const string& kernel_name, const HloInstruction* hlo_instruction) : Thunk(Kind::kKernel, hlo_instruction), - io_buffers_(io_buffers.begin(), io_buffers.end()), + args_(args.begin(), args.end()), kernel_name_(kernel_name) {} tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { @@ -42,7 +42,7 @@ tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { return tensorflow::Status::OK(); } - loader_spec_.reset(new se::MultiKernelLoaderSpec(io_buffers_.size() + 1)); + loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); tensorflow::StringPiece ptx = executable.ptx(); // Convert tensorflow::StringPiece to se::port::StringPiece because // StreamExecutor uses the latter. @@ -81,15 +81,16 @@ tensorflow::Status KernelThunk::ExecuteOnStream( kernel = &it->second; } + VLOG(3) << "Launching " << kernel->name(); // Launch the kernel with potentially multiple blocks and threads. static constexpr int kKernelArgsLimit = 1024; auto kernel_args = MakeUnique>(); - for (const BufferAllocation::Slice io_buffer : io_buffers_) { - kernel_args->add_device_memory_argument( - buffer_allocations.GetDeviceAddress(io_buffer)); + for (const BufferAllocation* arg : args_) { + const auto& buf = buffer_allocations.GetDeviceAddress(arg->index()); + kernel_args->add_device_memory_argument(buf); + VLOG(3) << " Arg: alloc #" << arg->index() << ": " << buf.opaque() << " (" + << buf.size() << "B)"; } - kernel_args->add_device_memory_argument( - buffer_allocations.GetTempBufferBase()); if (!stream->parent()->Launch( stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 350b5aaf360b0dad7f7b04d73f4c32bad55d3ce9..9ae455e2fcc253a7a08ff95764721048a16b0bf7 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -46,7 +46,7 @@ class KernelThunk : public Thunk { // Constructs a thunk for the given kernel. // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. - KernelThunk(tensorflow::gtl::ArraySlice io_buffers, + KernelThunk(tensorflow::gtl::ArraySlice args, const string& kernel_name, const HloInstruction* hlo_instruction); KernelThunk(const KernelThunk&) = delete; KernelThunk& operator=(const KernelThunk&) = delete; @@ -63,8 +63,8 @@ class KernelThunk : public Thunk { perftools::gputools::Stream* stream) override; private: - // The indices of the input/output buffers. - const std::vector io_buffers_; + // Buffers passed to the kernel as arguments. + const std::vector args_; // Entry kernel name for the computation. const string kernel_name_; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index cfabae791d26d0eb49826085ad7ad166a19109a1..defd281d74bd38f7da3f268e0f55970fc1af8263 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -252,7 +252,7 @@ void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) { LOG(FATAL) << "opening bitcode file for writing: " << error_code.message(); } - llvm::WriteBitcodeToFile(&module, outfile.os()); + llvm::WriteBitcodeToFile(module, outfile.os()); outfile.keep(); } diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 2923a79af0a559b08a2126162130a83801d024f8..25846dc6cd4633c7becb6e62d6bc9585348a6eac 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -27,7 +27,7 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK_EQ(HloOpcode::kConvolution, conv.opcode()); + CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); @@ -47,6 +47,12 @@ HloInstruction* MaybePaddedAndSlicedInput( window_util::HasBaseDilation(conv_window)) { // If padding is uneven or has dilation, we insert a kPad instruction that // applies positive padding and dilation. + // + // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of + // moving all the padding into an explicit pad op, we should keep as much + // padding inside of cudnn as possible, on the assumption that padding + // within cudnn is basically free, whereas a kPad's cost increases as the + // amount of padding increases. PaddingConfig padding_config = MakeNoPaddingConfig(input->shape().dimensions_size()); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { @@ -167,14 +173,17 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { dim->set_window_dilation(1); } + // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract + // out the shape of conv_result. + Shape old_conv_shape = conv->shape().tuple_shapes(0); + VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = HloInstruction::CreateConvolve( - conv->shape(), new_input, new_kernel, new_conv_window, - conv->convolution_dimension_numbers()); + auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel, + new_conv_window, + conv->convolution_dimension_numbers()); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); - TF_CHECK_OK( - conv->parent()->ReplaceWithNewInstruction(conv, std::move(new_conv))); + TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); return true; } @@ -190,6 +199,8 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { bool PadInsertion::CanonicalizeBackwardFilterConvolution( HloInstruction* backward_conv) { + CHECK_EQ(backward_conv->custom_call_target(), + kCudnnConvBackwardFilterCallTarget); if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; } @@ -202,15 +213,11 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* forward_conv = backward_conv->fused_expression_root(); HloInstruction* input = backward_conv->mutable_operand(0); - Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); // input_padding_config is the config of the kPad to be inserted. PaddingConfig input_padding_config = MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); - ConvolutionDimensionNumbers forward_conv_dnums = - forward_conv->convolution_dimension_numbers(); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { @@ -222,11 +229,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // cuDNN convolution (which doesn't support negative padding) to fail. return false; } - // If the backward convolution has uneven padding on the activations, we - // move some padding on the larger end to "internal" padding, so that the - // backward convolution produces larger weight gradients which get sliced - // later. Therefore, the amount of new padding (low or high) is the minimum - // of the amount of old padding low and old padding high. + // Compute the new, even padding for the backward conv operation. int64 new_conv_padding = std::min(padding_low, padding_high); int64 dim = backward_conv_dnums.input_spatial_dimensions(i); input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -237,14 +240,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Since we move some padding from the backward convolution to the kPad, we // need to accordingly reduce the padding amount of the backward convolution // and its inner forward convolution. - IncreasePaddingLowBy(-(padding_low - new_conv_padding), - new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(-(padding_high - new_conv_padding), - new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingLowBy(-(padding_low - new_conv_padding), - new_forward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(-(padding_high - new_conv_padding), - new_forward_conv_window.mutable_dimensions(i)); + auto* new_dim = new_backward_conv_window.mutable_dimensions(i); + new_dim->set_padding_low(new_conv_padding); + new_dim->set_padding_high(new_conv_padding); } // Create a new backward convolution replacing the old one. @@ -260,19 +258,12 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), input, padding, input_padding_config)); - HloInstruction* new_forward_conv = - computation->AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - padded_input->shape(), output->shape(), new_forward_conv_window, - forward_conv_dnums) - .ConsumeValueOrDie(), - padded_input, output, new_forward_conv_window, forward_conv_dnums)); - - // Fuse the new forward convolution to the new backward convolution. - HloInstruction* new_backward_conv = - computation->CreateFusionInstructionForBackwardConvolution( - {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, - new_backward_conv_window, backward_conv_dnums); + // The shape of the backward_conv CustomCall is a tuple (conv_result, + // scratch_buffer). Extract out the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( + backward_conv_shape, padded_input, output, new_backward_conv_window, + backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -289,14 +280,15 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return false; } - HloInstruction* forward_conv = backward_conv->fused_expression_root(); - HloInstruction* reverse_filter = forward_conv->mutable_operand(1); - Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); - ConvolutionDimensionNumbers forward_conv_dnums = - forward_conv->convolution_dimension_numbers(); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); + + // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory). + // Get the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + + Shape new_backward_conv_shape = backward_conv_shape; for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); @@ -315,41 +307,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( // where the amount of padding low is larger, we can canonicalize it to // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1)) // [A] = Slice([B A]) - // For consistency, we need to increase the low padding of the inner - // convolution by 1 as well because the input is larger now. if (padding_low > padding_high) { IncreasePaddingLowBy(padding_high - padding_low, new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingLowBy(padding_low - padding_high, - new_forward_conv_window.mutable_dimensions(i)); } else if (padding_low < padding_high) { IncreasePaddingHighBy(padding_low - padding_high, new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(padding_high - padding_low, - new_forward_conv_window.mutable_dimensions(i)); } + // Decreasing the padding by X *increases* the size of our output by X. + int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + new_backward_conv_shape.set_dimensions( + dim, new_backward_conv_shape.dimensions(dim) + + std::abs(padding_low - padding_high)); } // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(0); HloInstruction* filter = backward_conv->mutable_operand(1); - HloInstruction* new_reverse_filter = - computation->AddInstruction(HloInstruction::CreateReverse( - filter->shape(), filter, reverse_filter->dimensions())); - HloInstruction* new_forward_conv = - computation->AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - output->shape(), new_reverse_filter->shape(), - new_forward_conv_window, forward_conv_dnums) - .ConsumeValueOrDie(), - output, new_reverse_filter, new_forward_conv_window, - forward_conv_dnums)); + + HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( + new_backward_conv_shape, output, filter, new_backward_conv_window, + backward_conv_dnums); + + // The CustomCall created above returns a tuple (conv_result, scratch_memory). + // Extract out the two elements. HloInstruction* new_backward_conv = - computation->CreateFusionInstructionForBackwardConvolution( - {new_forward_conv, new_reverse_filter}, - HloInstruction::FusionKind::kConvBackwardInput, - new_backward_conv_window, backward_conv_dnums); + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_shape, new_backward_conv_call, 0)); + HloInstruction* new_backward_conv_scratch = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_call->shape().tuple_shapes(1), + new_backward_conv_call, 1)); // Slice the new backward convolution. // @@ -377,22 +366,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( } // Replace the old backward convolution with the slice. - CHECK(ShapeUtil::Compatible( + Shape slice_shape = ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices, limit_indices, strides) - .ConsumeValueOrDie(), - backward_conv->shape())); + .ConsumeValueOrDie(); + CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape)) + << ShapeUtil::HumanString(slice_shape) << " vs " + << ShapeUtil::HumanString(backward_conv_shape); - auto slice = - HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv, - start_indices, limit_indices, strides); + HloInstruction* slice = computation->AddInstruction( + HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv, + start_indices, limit_indices, strides)); + HloInstruction* new_tuple = computation->AddInstruction( + HloInstruction::CreateTuple({slice, new_backward_conv_scratch})); VLOG(1) << "Canonicalizing backward input conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " - << slice->ToString(); + << new_tuple->ToString(); - TF_CHECK_OK( - computation->ReplaceWithNewInstruction(backward_conv, std::move(slice))); + TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple)); return true; } @@ -400,18 +392,17 @@ StatusOr PadInsertion::Run(HloModule* module) { bool changed = false; for (HloInstruction* instruction : module->entry_computation()->MakeInstructionPostOrder()) { - if (instruction->opcode() == HloOpcode::kConvolution) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (instruction->opcode() == HloOpcode::kFusion) { - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - changed |= CanonicalizeBackwardFilterConvolution(instruction); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - changed |= CanonicalizeBackwardInputConvolution(instruction); - break; - default: - break; + if (IsCustomCallToDnnConvolution(*instruction)) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); } } } diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 934e7e1919f08a16daf09ec634e2f9dc0c7cc723..8ed63a854a74fc06c3c389f40fe1f5970885deac 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -42,6 +42,11 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder); + // Constructs a loop emitter for a loop that generates on element of each of N + // arrays on each iteration. + // + // This is used in multi-output fusion. target_element_generator should + // produce a struct with N elements, one for each of target_arrays. ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 625c3f8bea418b7942145a05ba42b9ea9b14543b..2c3032d79be221e8cacb178ffb1817459b603cc0 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -41,6 +41,7 @@ class GpuExecutable; class Thunk { public: enum class Kind { + kConditional, kConvolution, kCopy, kCudnnBatchNormBackward, diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 34e2f7ee206c6a74073d8f4e867e862feb4aff49..a2d13c013c56059148ccd04dba2137a5b2badc42 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -64,10 +64,8 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, - const FlatSet* buffers_to_assign) { - HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign, - &module_sequence); + const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); const HloComputation* entry_computation = module.entry_computation(); const std::vector& instruction_sequence = FindOrDie(module_sequence, entry_computation); @@ -81,9 +79,8 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, - const FlatSet* buffers_to_assign) { - HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign, + const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + HeapSimulator heap(std::move(algorithm), size_fn, options, /*module_sequence=*/nullptr); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); @@ -199,15 +196,17 @@ Status HeapSimulator::RunComputation( // We can only share with the operand buffer if it is about to be freed; // we must be the last user of the buffer. bool shared = false; - for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { - if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && - buffer->instruction()->opcode() != HloOpcode::kCopy && - CanShareOperandBufferWithUser( - operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { - ShareBuffer(buffer, operand_buffer, instruction); - shared = true; - break; + if (options_.may_reuse_operand_buffers) { + for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { + if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && + buffer->instruction()->opcode() != HloOpcode::kCopy && + CanShareOperandBufferWithUser( + operand_buffer->instruction(), operand_buffer->index(), + buffer->instruction(), buffer->index(), points_to_analysis)) { + ShareBuffer(buffer, operand_buffer, instruction); + shared = true; + break; + } } } @@ -226,6 +225,7 @@ Status HeapSimulator::RunComputation( // sub-computations will never be run concurrently. if (module_sequence_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { @@ -266,13 +266,12 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, - const FlatSet* buffers_to_assign, + const LogicalBuffer::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), - buffers_to_assign_(buffers_to_assign), + options_(options), module_sequence_(module_sequence) { debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); } @@ -280,13 +279,16 @@ HeapSimulator::HeapSimulator( HeapSimulator::~HeapSimulator() {} bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { - // Buffers for constants are ignored, as with BufferAssigner. Also ignore - // buffers that we're not meant to assign. + // Buffers for constants are ignored unless the alloc_constants option is + // set. Also ignore buffers that we're not meant to assign. // // TODO(b/32248867): For consistency, constants should get allocations. - return buffer->instruction()->opcode() == HloOpcode::kConstant || - (buffers_to_assign_ != nullptr && - buffers_to_assign_->count(buffer) == 0); + if (!options_.alloc_constants && + buffer->instruction()->opcode() == HloOpcode::kConstant) { + return true; + } + return options_.buffers_to_assign != nullptr && + options_.buffers_to_assign->count(buffer) == 0; } // Alloc always calls the underlying heap algorithm. @@ -400,8 +402,8 @@ HeapSimulator::Result HeapSimulator::Finish() { } // If we were told to assign specific buffers, make sure we've assigned // exactly that many buffers. - if (buffers_to_assign_ != nullptr) { - CHECK_EQ(buffers_to_assign_->size(), result.chunk_map.size()); + if (options_.buffers_to_assign != nullptr) { + CHECK_EQ(options_.buffers_to_assign->size(), result.chunk_map.size()); } } diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 88a8698d16132372fc8f4e87eba3b99125aab876..636f19dd39f09721bd82fc4b44785f196f281ad7 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -67,6 +67,23 @@ class HeapSimulator { HeapSimulatorTrace debug_trace; }; + // The different options to be passed to the Run() APIs. + struct Options { + Options() + : may_reuse_operand_buffers(true), + alloc_constants(false), + buffers_to_assign(nullptr) {} + + // Whether a buffer about to be Free()-ed, can be recycled for a new born + // one, hence collapsing Free()+Alloc() calls (default true). + bool may_reuse_operand_buffers; + // Whether to issue Alloc() and Free() calls for constants (default false). + bool alloc_constants; + // If 'buffers_to_assign' is provided, only those buffers are assigned + // offsets, otherwise all buffers defined by the instructions are assigned. + const tensorflow::gtl::FlatSet* buffers_to_assign; + }; + // Run the heap simulation with the given algorithm, assuming the given // module_sequence, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid @@ -76,15 +93,12 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - // If 'buffers_to_assign' is provided, only those buffers are assigned - // offsets, otherwise all buffers defined by the instructions are assigned. static StatusOr Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, - const tensorflow::gtl::FlatSet* buffers_to_assign = - nullptr); + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions @@ -96,8 +110,7 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, - const tensorflow::gtl::FlatSet* buffers_to_assign = - nullptr); + const Options& options = Options()); private: // If 'module_sequence' is non-null, it is used to find kCall and kWhile @@ -105,8 +118,7 @@ class HeapSimulator { // be run recursively. I.e. the simulation is run over the whole module. HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, - const tensorflow::gtl::FlatSet* buffers_to_assign, + const LogicalBuffer::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence); ~HeapSimulator(); @@ -130,7 +142,7 @@ class HeapSimulator { const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; const LogicalBuffer::SizeFunction size_fn_; - const tensorflow::gtl::FlatSet* buffers_to_assign_; + const Options options_; const SequentialHloOrdering::HloModuleSequence* module_sequence_; // In addition to Alloc and Free, the heap simulator exposes a concept of diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 0e9a852788e978f79fa6f6c802f855a4c476583f..a43785b4a9701369ae315f67d4d64d03dc6c081d 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -129,6 +129,10 @@ message HloInstructionProto { // FFT length. repeated int64 fft_length = 32; + + // Gather dimension numbers. + xla.GatherDimensionNumbers gather_dimension_numbers = 33; + repeated int64 gather_window_bounds = 34; } // Serialization of HloComputation. @@ -200,6 +204,7 @@ message BufferAllocationProto { bool is_reusable = 4; bool is_entry_computation_parameter = 5; int64 parameter_number = 6; + repeated int64 parameter_shape_index = 10; bool maybe_live_out = 7; int64 color = 8; repeated Assigned assigned = 9; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 6d2a3aa5b531650a658502531e050702ffbd3760..30e32a46d7dd0923f738939c33407ac7484b5bbe 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -419,7 +419,7 @@ StatusOr> HloAliasAnalysis::Run( auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); TF_ASSIGN_OR_RETURN( alias_analysis->dataflow_analysis_, - HloDataflowAnalysis::Run(module, /*ssa_form=*/true, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, /*bitcast_defines_value=*/false)); BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index a63affa06caf75f1ccab084bd114e39ba7c91a38..21e6b2ca730f6347af902097e6496826b861e8a3 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -461,20 +461,6 @@ HloInstruction* HloComputation::CreateFusionInstruction( return fusion_instruction; } -HloInstruction* HloComputation::CreateFusionInstructionForBackwardConvolution( - tensorflow::gtl::ArraySlice instructions_to_fuse, - HloInstruction::FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums) { - CHECK(HloInstruction::FusionKind::kConvBackwardFilter == fusion_kind || - HloInstruction::FusionKind::kConvBackwardInput == fusion_kind); - HloInstruction* root = instructions_to_fuse.front(); - HloInstruction* fusion_instruction = - AddInstruction(HloInstruction::CreateFusionForBackwardConvolution( - root->shape(), fusion_kind, window, conv_dnums, root)); - FuseInstructionsInto(instructions_to_fuse, fusion_instruction); - return fusion_instruction; -} - StatusOr HloComputation::DeepCopyHelper( HloInstruction* instruction, const ShapeTree* indices_to_copy, ShapeTree* copies_added, ShapeIndex* index) { @@ -523,13 +509,14 @@ StatusOr HloComputation::DeepCopyInstruction( "Can't deep copy instruction %s: instruction is not in computation %s", instruction->name().c_str(), name().c_str()); } - if (indices_to_copy != nullptr && !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { return FailedPrecondition( "Can't deep copy instruction %s: given shape tree of indices to copy " - "has incompatible shape", - instruction->name().c_str()); + "has incompatible shapes: %s vs. %s", + instruction->name().c_str(), + ShapeUtil::HumanString(instruction->shape()).c_str(), + ShapeUtil::HumanString(indices_to_copy->shape()).c_str()); } ShapeIndex index; @@ -577,8 +564,11 @@ Status HloComputation::ReplaceWithNewInstruction( Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, HloInstruction* new_instruction) { - TF_RET_CHECK(ShapeUtil::Compatible(old_instruction->shape(), - new_instruction->shape())); + TF_RET_CHECK( + ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape())) + << ShapeUtil::HumanString(old_instruction->shape()) << " vs " + << ShapeUtil::HumanString(new_instruction->shape()); + VLOG(10) << "transformed " << old_instruction->ToString() << " to " << new_instruction->ToString(); // Try to add metadata for HLO instructions that are created to replace diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 6436815f910405477ec21a33dec75ef71df08602..39d864efcb70382b6f8e631d7e6e452ea6410104 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -77,6 +77,14 @@ class HloComputation { return last_added_instruction_; } + Status ForEachInstruction( + const std::function& func) const { + for (const auto& instruction : instructions_) { + TF_RETURN_IF_ERROR(func(instruction.get())); + } + return Status::OK(); + } + private: const string name_; HloInstruction* last_added_instruction_; @@ -224,15 +232,6 @@ class HloComputation { tensorflow::gtl::ArraySlice instructions_to_fuse, HloInstruction::FusionKind fusion_kind); - // Creates a fusion instruction that represents a backward convolution. This - // is similar to CreateFusionInstruction but takes window and conv_dnums which - // indicate the window and convolution dimension numbers of the backward - // convolution. - HloInstruction* CreateFusionInstructionForBackwardConvolution( - tensorflow::gtl::ArraySlice instructions_to_fuse, - HloInstruction::FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums); - // Create a deep copy of the given instruction and return the instruction // producing the copied result. All instructions performing the copy are added // to the computation. For array-shaped values, this method trivially returns diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index cd54eb74d18d0be714b5b56fc8ae0dfa55ff31a0..4ec2ef27bf59b0c877ec38e55ef5c12debeec227 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -229,6 +229,10 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, @@ -469,7 +473,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) { } Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { - return Unimplemented("Custom-call is not implemented for HLO cost analysis."); + // We can't do anything sane with CustomCalls, since we don't know what they + // do, and returning an error status will stop iteration over this + // computation, which is probably also not what we want. So just punt and + // return OK. This will cause all of the properties to be reported as 0, + // which is fine. + current_should_compute_bottleneck_time_ = false; + return Status::OK(); } Status HloCostAnalysis::HandleSort(const HloInstruction* sort) { @@ -523,6 +533,11 @@ Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { return Status::OK(); } +Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { + // Gather does not issue any flops. + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index e5783539e5436f09fa58bf7889118380ee90fea0..d17678d20f2a23fd98d18b77d5fb25853901a789 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; + Status HandleHostCompute(const HloInstruction* host_compute) override; Status HandleRng(const HloInstruction* random) override; Status HandleReverse(const HloInstruction* reverse) override; Status HandleSort(const HloInstruction* sort) override; @@ -99,6 +100,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; + Status HandleGather(const HloInstruction* gather) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 7feda2b3b040de1f0a14303ce1adcd21c6624c8b..279edd4ba8772a9c576f76f554de8ec68631b953 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -119,9 +119,8 @@ StatusOr HloCSE::Run(HloModule* module) { equivalent_instructions; for (HloInstruction* user : operand->users()) { if (user != instruction && - user->Identical(*instruction, eq_instructions, eq_computations) && - (!is_layout_sensitive_ || - ShapeUtil::Equal(user->shape(), instruction->shape()))) { + user->Identical(*instruction, eq_instructions, eq_computations, + is_layout_sensitive_)) { equivalent_instructions.push_back(user); } } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index d25fc5d7418ae40c7167f88d6172906482a58925..934e43ba4879628362009267c671ec4cb0d79c52 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -38,12 +38,12 @@ namespace xla { using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form, +HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value) : module_(module), ssa_form_(ssa_form), bitcast_defines_value_(bitcast_defines_value), - call_graph_(CallGraph::Build(module)) {} + call_graph_(CallGraph::Build(&module)) {} bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { @@ -115,9 +115,9 @@ void HloDataflowAnalysis::DeleteMarkedValues() { } string HloDataflowAnalysis::ToString() const { - string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n"); + string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n"); StrAppend(&out, " Instruction value sets:\n"); - for (const HloComputation* computation : module_->computations()) { + for (const HloComputation* computation : module_.computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { @@ -585,16 +585,23 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( void HloDataflowAnalysis::Propagate() { std::queue worklist; + tensorflow::gtl::FlatSet workset; + auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { + if (workset.insert(instruction).second) { + worklist.push(instruction); + } + }; - for (HloComputation* computation : module_->computations()) { + for (HloComputation* computation : module_.computations()) { for (HloInstruction* instruction : computation->instructions()) { - worklist.push(instruction); + add_to_worklist(instruction); } } while (!worklist.empty()) { HloInstruction* instruction = worklist.front(); worklist.pop(); + workset.erase(workset.find(instruction)); VLOG(3) << "Worklist top: " << instruction->name(); VLOG(3) << ToString(); @@ -608,9 +615,10 @@ void HloDataflowAnalysis::Propagate() { VLOG(4) << "New value set for " << instruction->name() << ": " << GetInstructionValueSet(instruction); - // Instruction value was updated. Add users to work list. + // Instruction value was updated. Add users to work list if we haven't + // already. for (HloInstruction* user : instruction->users()) { - worklist.push(user); + add_to_worklist(user); // If user sequentially calls a computation, then the respective // parameter(s) of the computation need to be updated. @@ -625,10 +633,10 @@ void HloDataflowAnalysis::Propagate() { // Note that the same instruction can be used in both operand 1 and // operand 2. if (user->operand(1) == instruction) { - worklist.push(user->true_computation()->parameter_instruction(0)); + add_to_worklist(user->true_computation()->parameter_instruction(0)); } if (user->operand(2) == instruction) { - worklist.push(user->false_computation()->parameter_instruction(0)); + add_to_worklist(user->false_computation()->parameter_instruction(0)); } } else { for (HloComputation* called_computation : user->called_computations()) { @@ -636,7 +644,7 @@ void HloDataflowAnalysis::Propagate() { call_graph_->GetNode(called_computation); if (call_graph_node.context() == CallContext::kSequential) { for (int64 operand_number : user->OperandIndices(instruction)) { - worklist.push( + add_to_worklist( called_computation->parameter_instruction(operand_number)); } } @@ -652,13 +660,13 @@ void HloDataflowAnalysis::Propagate() { for (const CallSite& callsite : call_graph_node.caller_callsites()) { if ((callsite.instruction()->opcode() == HloOpcode::kCall) || (callsite.instruction()->opcode() == HloOpcode::kConditional)) { - worklist.push(callsite.instruction()); + add_to_worklist(callsite.instruction()); } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { // Add the while itself, and the body and condition parameters. - worklist.push(callsite.instruction()); - worklist.push( + add_to_worklist(callsite.instruction()); + add_to_worklist( callsite.instruction()->while_body()->parameter_instruction(0)); - worklist.push( + add_to_worklist( callsite.instruction()->while_condition()->parameter_instruction( 0)); } @@ -678,7 +686,7 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( } Status HloDataflowAnalysis::InitializeInstructionValueSets() { - for (const HloComputation* computation : module_->computations()) { + for (const HloComputation* computation : module_.computations()) { const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); for (HloInstruction* instruction : computation->instructions()) { // Create an empty shape tree. @@ -779,9 +787,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { /* static */ StatusOr> HloDataflowAnalysis::Run( - HloModule* module, bool ssa_form, bool bitcast_defines_value) { - VLOG(1) << "HloDataflowAnalysis::Run on module " << module->name(); - XLA_VLOG_LINES(2, module->ToString()); + const HloModule& module, bool ssa_form, bool bitcast_defines_value) { + VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); + XLA_VLOG_LINES(2, module.ToString()); auto dataflow_analysis = WrapUnique( new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); @@ -798,7 +806,7 @@ StatusOr> HloDataflowAnalysis::Run( // lookup is faster. std::vector> value_positions( dataflow_analysis->next_value_id_); - for (const HloComputation* computation : module->computations()) { + for (const HloComputation* computation : module.computations()) { for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : dataflow_analysis->GetInstructionValueSet(instruction)) { @@ -850,7 +858,7 @@ Status HloDataflowAnalysis::Verify() const { // For each value in each value set, verify that the value set's position // appears in the value's positions(). - for (const auto& computation : module_->computations()) { + for (const auto& computation : module_.computations()) { for (const auto& instruction : computation->instructions()) { for (const auto& pair : GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 89d318188f0855c7924836a51cfe98d531e08cb4..7b8a74b096ff48733717e78ada5bb56a28caed72 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -60,7 +60,7 @@ class HloDataflowAnalysis { // a new HLO value in the analysis. If false then Bitcast forwards the // value of its operand. static StatusOr> Run( - HloModule* module, bool ssa_form = false, + const HloModule& module, bool ssa_form = false, bool bitcast_defines_value = false); // Returns true if 'instruction' defines an HLO value at the given shape index @@ -119,7 +119,7 @@ class HloDataflowAnalysis { string ToString() const; protected: - HloDataflowAnalysis(HloModule* module, bool ssa_form, + HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false); // Returns a new HloValue defined at the given instruction and shape index. @@ -180,7 +180,7 @@ class HloDataflowAnalysis { // Verify various invariants of the dataflow analysis. Status Verify() const; - HloModule* const module_; + const HloModule& module_; const bool ssa_form_; const bool bitcast_defines_value_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index e714b2567fd1b3eab607a19f0bb7e3288150dc64..7bf3a1a06045c79621d75b653bf42220705a69d4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -50,7 +50,7 @@ class HloDataflowAnalysisTest : public HloTestBase, bool bitcast_defines_value = false) { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis"); analysis_ = - HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value) + HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); return *analysis_; } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e3f5c17e35f5294e204993af9396dec326a779cd..296f010a920a801ef0a4dc5e40bf0dbc07898196 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -34,12 +34,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -793,6 +792,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[shr], ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + // If shift amount is greater than the number of bits, then return 0. + if (rhs_elem >= sizeof(UnsignedT) * CHAR_BIT) { + return static_cast(0); + } return static_cast(static_cast(lhs_elem) >> rhs_elem); })); @@ -1027,55 +1030,119 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { CHECK(ShapeUtil::IsArray(lhs->shape())); CHECK(ShapeUtil::IsArray(rhs->shape())); - // Dot only supports operands of rank 1 and 2. - const auto dot_rank = ShapeUtil::Rank(dot->shape()); + const auto& dnums = dot->dot_dimension_numbers(); + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); - CHECK(lhs_rank > 0 && lhs_rank <= 2); - CHECK(rhs_rank > 0 && rhs_rank <= 2); - CHECK_EQ(dot_rank, lhs_rank + rhs_rank - 2); CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); - // Check contracted dimensions are the same. - // - // Determine the index of the contracted dimensions for input tensors. - // dimensions -1 of lhs and dimension 0 of rhs are contracted. - const int64 lhs_contracted_dimension = - ShapeUtil::GetDimensionNumber(lhs->shape(), -1); - const int64 rhs_contracted_dimension = 0; - CHECK_EQ(lhs->shape().dimensions(lhs_contracted_dimension), - rhs->shape().dimensions(rhs_contracted_dimension)) + // There must be 1 and only 1 Contracting dimension for lhs and rhs. + CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); + const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + // Contracted dimension sizes must be the same. + CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), + rhs->shape().dimensions(rhs_contracting_dimension)) << "lhs contracted dimension: " - << lhs->shape().dimensions(lhs_contracted_dimension) + << lhs->shape().dimensions(lhs_contracting_dimension) << " rhs contracted dimension: " - << rhs->shape().dimensions(rhs_contracted_dimension); + << rhs->shape().dimensions(rhs_contracting_dimension); const int64 contracted_dimension_size = - lhs->shape().dimensions(lhs_contracted_dimension); + lhs->shape().dimensions(lhs_contracting_dimension); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); auto result = Literal::CreateFromShape(dot->shape()); + + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + + std::vector lhs_non_contracting_dims; + for (int64 i = 0; i < lhs_rank; i++) { + if (i != lhs_contracting_dimension) { + lhs_non_contracting_dims.push_back(i); + } + } + + std::vector rhs_non_batch_non_contracting_dims; + tensorflow::gtl::FlatSet batch_dims_set( + dnums.rhs_batch_dimensions().begin(), + dnums.rhs_batch_dimensions().end()); + for (int64 i = 0; i < rhs_rank; i++) { + if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { + rhs_non_batch_non_contracting_dims.push_back(i); + } + } + + const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); + const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); + + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + [&](tensorflow::gtl::ArraySlice result_index) { ElementwiseT result_val = static_cast(0); - std::vector lhs_index(lhs_rank, 0); - std::vector rhs_index(rhs_rank, 0); - // Set index for non-contracted dimension for lhs and rhs. - if (lhs_rank > 1) { - lhs_index[0] = multi_index[0]; + // Find the corresponding non-contracting indices for lhs and rhs. + // + // For `result_index`, its batch dimension, if exists, will be at the + // same dimension as the batch dimension of lhs and rhs. More + // specifically: + // - For lhs, the non-contracting dimensions, including the batch + // dimension have the same index as the `result_index`. + // - For rhs, the batch dimension is set seperately from other + // non-contracting dimensions, since these other non-contracting + // dimensions in rhs follow the non-contracting dimensions of lhs in + // the resulting index. + // + // As an example, for a resulting index: + // result_index [result_batch, result_x, result_y] + // the effecting lhs and rhs indices are: + // lhs [result_batch, lhs_non_contracting_dim, contracting_dim + // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] + // `result_x` is only affected by the lhs_non_contracting_dim and + // likewise `result_y` only depends on rhs_non_contracting_dim. + // + // so we can look up the lhs and rhs indices by: + // + // lhs: + // batch index is the same as `result_batch`. + // non-contracting dimension is the same as + // result_index[lhs_non_contracting_dim] + // rhs: + // batch index: the same as `result_batch`. + // non-contracting dimension index: *not* the same as + // result_index[rhs_non_contractng_dim], since the + // non-contracting dimensions of lhs are included in the + // result_index first. Instead, the non_contracting_dim of rhs must + // be calculated as following: + // lhs_non_contracting_dimensions_size + + // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 + // + // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is + // the index offset to the result_index that only depends on + // the non_batch and non-contracting dimensions of rhs. -1 at the + // end translates size to index. + for (auto i : lhs_non_contracting_dims) { + lhs_index[i] = result_index[i]; + } + for (auto i : dnums.rhs_batch_dimensions()) { + rhs_index[i] = result_index[i]; } - if (rhs_rank > 1) { - rhs_index[1] = multi_index[multi_index.size() - 1]; + for (auto i : rhs_non_batch_non_contracting_dims) { + const int64 rhs_non_batch_non_contracting_dim = + lhs_non_contracting_size + (i - batch_dim_size) - 1; + rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; } // Accumulates resulting product along the contracted dimension. for (int64 i = 0; i < contracted_dimension_size; ++i) { - lhs_index[lhs_contracted_dimension] = i; - rhs_index[rhs_contracted_dimension] = i; + lhs_index[lhs_contracting_dimension] = i; + rhs_index[rhs_contracting_dimension] = i; result_val += static_cast(lhs_literal.Get(lhs_index)) * @@ -1338,6 +1405,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break; } + case F16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], + MapImpl(map)); + break; + } case F32: { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break; @@ -1706,6 +1778,115 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleCos(cos); } + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[reduce_precision], + ElementWiseUnaryOp(reduce_precision, [reduce_precision]( + ElementwiseT elem) { + uint32_t value_as_int = tensorflow::bit_cast(elem); + const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); + const uint32_t exponent_bits = reduce_precision->exponent_bits(); + + // Code is based on the CPU/GPU implementation in LLVM-emitting code. + // + // Bits in float type: + // mantissa : bits [0:22] + // exponent : bits [23:30] + // sign : bits [31] + if (mantissa_bits < 23) { + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. + // This is equal to a base value of 0111... plus one bit if the last + // remaining mantissa bit is 1. + const uint32_t base_rounding_bias = + (last_mantissa_bit_mask >> 1) - 1; + const uint32_t x_last_mantissa_bit = + (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); + const uint32_t x_rounding_bias = + x_last_mantissa_bit + base_rounding_bias; + + // Add rounding bias, and mask out truncated bits. Note that the + // case where adding the rounding bias overflows into the exponent + // bits is correct; the non-masked mantissa bits will all be zero, + // and the exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + value_as_int = value_as_int + x_rounding_bias; + value_as_int = value_as_int & truncation_mask; + } + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the + // most- significant bit -- is equal to 1.0f for all exponent sizes. + // Adding 2^(n-1)-1 to this gives us the highest non-infinite + // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from + // this gives us the lowest' exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n + // is (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = + (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; + const bool x_overflows = x_exponent > (reduced_max_exponent << 23); + const bool x_underflows = + x_exponent <= (reduced_min_exponent << 23); + + // Compute appropriately-signed values of zero and infinity. + const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; + const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; + + // Force to zero or infinity if overflow or underflow. (Note that + // this truncates all denormal values to zero, rather than rounding + // them.) + value_as_int = x_overflows ? x_signed_inf : value_as_int; + value_as_int = x_underflows ? x_signed_zero : value_as_int; + } + + float reduced_result = tensorflow::bit_cast(value_as_int); + if (std::isnan(elem)) { + reduced_result = mantissa_bits > 0 + ? elem + : std::numeric_limits::infinity(); + } + return reduced_result; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Double not supported for reduce precision"); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Unsupported type for reduce precision"); + } + + Status HandleReducePrecision(HloInstruction* reduce_precision) override { + return HandleReducePrecision(reduce_precision); + } + private: template StatusOr> DynamicSlice( @@ -1867,9 +2048,7 @@ HloEvaluator::HloEvaluator() { }); typed_visitors_[S32] = MakeUnique>(this); typed_visitors_[S64] = MakeUnique>(this); - typed_visitors_[F16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: F16."); - }); + typed_visitors_[F16] = MakeUnique>(this); typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); typed_visitors_[C64] = MakeUnique>(this); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index f7c6435002d278d93cc0814041a7e055e5573e3e..2861fec39ef0c92fdfbcee04584f9bd36d3cb4d8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -940,6 +940,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kConcatenate: case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: @@ -988,6 +989,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kHostCompute: case HloOpcode::kWhile: return kDarkGreen; case HloOpcode::kConstant: @@ -1063,14 +1065,19 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { // node -- there the shape and layout is present in the output node. if (instr->opcode() != HloOpcode::kFusion || !ShouldShowFusionSubcomputation(instr)) { - string instr_shape = ShapeUtil::HumanString(instr->shape()); - - // Show layout of non-tuple shapes with more than one dimension. - if (LayoutUtil::HasLayout(instr->shape()) && - instr->shape().dimensions_size() > 1 && - !ShapeUtil::IsTuple(instr->shape())) { - StrAppend(&instr_shape, "{", - Join(LayoutUtil::MinorToMajor(instr->shape()), ","), "}"); + // Show layout of instructions with more than one dimension. Don't show + // layout on tuples or tensors with just one dimension (which only have one + // possible layout) to avoid visual noise. + bool shape_is_multidim = false; + ShapeUtil::ForEachSubshape(instr->shape(), + [&](const Shape& s, const ShapeIndex&) { + shape_is_multidim |= s.dimensions_size() > 1; + }); + string instr_shape; + if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) { + instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape()); + } else { + instr_shape = ShapeUtil::HumanString(instr->shape()); } // Some instructions have giant tuples as their shapes, so truncate the @@ -1421,9 +1428,11 @@ void DumpText(const HloModule& module, const string& label, string MaybeDumpHloModule(const HloModule& module, const string& label, const HloExecutionProfile* profile) { - VLOG(2) << "MaybeDumpHloModule called on module " << module.name(); - string graph_url; const DebugOptions& debug_options = module.config().debug_options(); + VLOG(2) << "MaybeDumpHloModule called on module " << module.name() + << " with generate_hlo_graph regex \"" + << debug_options.xla_generate_hlo_graph() << "\""; + string graph_url; if (!debug_options.xla_generate_hlo_graph().empty() && RE2::PartialMatch(module.name(), debug_options.xla_generate_hlo_graph())) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a889c35aeb297bd118c40ced2dd9539957dce67a..b7dd055d7cd78eb759a2b24bcbbbc948159f9425 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -763,16 +763,13 @@ HloInstruction::CreateBroadcastSequence( return instruction; } -// We put the fusion kind into the instruction's name for transpose-dot and -// backward-conv fusions, since those fusions are really just describing a type -// of dot/conv rather than generating a novel computation. +// We put the fusion kind into the instruction's name for transpose-dot fusions, +// since those fusions are really just describing a type of dot rather than +// generating a novel computation. static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { switch (fusion_kind) { case HloInstruction::FusionKind::kTransposeDot: return "dot_fusion"; - case HloInstruction::FusionKind::kConvBackwardInput: - case HloInstruction::FusionKind::kConvBackwardFilter: - return "conv_fusion"; default: return "fusion"; } @@ -804,16 +801,20 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { return instruction; } -/* static */ std::unique_ptr -HloInstruction::CreateFusionForBackwardConvolution( - const Shape& shape, FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* fused_root) { - std::unique_ptr fusion = - CreateFusion(shape, fusion_kind, fused_root); - fusion->window_ = MakeUnique(window); - fusion->convolution_dimension_numbers_ = - MakeUnique(conv_dnums); - return fusion; +HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { + CHECK_EQ(opcode(), HloOpcode::kFusion); + CHECK_EQ(operand_count(), + fused_instructions_computation()->parameter_instructions().size()); + const int64 param_no = operand_count(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. + string param_name = StrCat(new_operand->name(), ".param_", param_no); + HloInstruction* fused_parameter = + fused_instructions_computation()->AddParameter( + HloInstruction::CreateParameter(param_no, new_operand->shape(), + param_name)); + AppendOperand(new_operand); + return fused_parameter; } void HloInstruction::MergeFusionInstruction( @@ -1008,13 +1009,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // Clone's operand was not already an operand of the fusion // instruction. Add it as an operand and add a corresponding fused // parameter instruction. - int64 param_no = fused_parameters.size(); - // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. - string param_name = StrCat(operand->name(), ".param_", param_no); - fused_param = fused_instructions_computation()->AddParameter( - CreateParameter(param_no, operand->shape(), param_name)); - AppendOperand(operand); + fused_param = AddFusionOperand(operand); } TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); } @@ -1099,6 +1094,7 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: + case HloOpcode::kHostCompute: return true; default: { // Check if any of the called computations has a side effect. @@ -1136,6 +1132,19 @@ bool HloInstruction::HasSideEffect() const { return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateHostCompute( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { + std::unique_ptr instruction = + WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->channel_name_ = channel_name.ToString(); + instruction->cost_estimate_ns_ = cost_estimate_ns; + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateTuple( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; @@ -1146,6 +1155,38 @@ bool HloInstruction::HasSideEffect() const { return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements); } +/* static */ std::unique_ptr HloInstruction::CreateGather( + const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + std::unique_ptr instruction = + WrapUnique(new HloInstruction(HloOpcode::kGather, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(gather_indices); + instruction->gather_dimension_numbers_ = + MakeUnique(gather_dim_numbers); + c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_)); + return instruction; +} + +/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice output_window_dims, + tensorflow::gtl::ArraySlice elided_window_dims, + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims) { + GatherDimensionNumbers gather_dim_numbers; + for (int64 output_window_dim : output_window_dims) { + gather_dim_numbers.add_output_window_dims(output_window_dim); + } + for (int64 elided_window_dim : elided_window_dims) { + gather_dim_numbers.add_elided_window_dims(elided_window_dim); + } + for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { + gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + } + + return gather_dim_numbers; +} + std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, @@ -1227,6 +1268,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCustomCall: clone = CreateCustomCall(shape, new_operands, custom_call_target_); break; + case HloOpcode::kHostCompute: + clone = CreateHostCompute(shape, new_operands, channel_name_, + cost_estimate_ns_); + break; case HloOpcode::kConcatenate: clone = CreateConcatenate(shape, new_operands, dimensions(0)); break; @@ -1376,12 +1421,19 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kRecv: CHECK_EQ(new_operands.size(), 0); - clone = CreateRecv(shape, channel_id()); + // The shape is a tuple, but CreateRecv() wants the raw data shape. + clone = + CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); break; case HloOpcode::kRecvDone: CHECK_EQ(new_operands.size(), 1); clone = CreateRecvDone(new_operands[0]); break; + case HloOpcode::kGather: + CHECK_EQ(new_operands.size(), 2); + clone = CreateGather(shape, new_operands[0], new_operands[1], + *gather_dimension_numbers_, gather_window_bounds_); + break; case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } @@ -1627,7 +1679,8 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations) const { + eq_computations, + const std::function& eq_shapes) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and @@ -1675,8 +1728,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kTuple: return true; - // These opcodes have complex or special behavior so just return false. case HloOpcode::kFusion: + return fusion_kind() == other.fusion_kind() && + eq_computations(fused_instructions_computation(), + other.fused_instructions_computation()); + + // These opcodes have complex or special behavior so just return false. case HloOpcode::kRng: case HloOpcode::kTrace: case HloOpcode::kWhile: @@ -1686,7 +1743,7 @@ bool HloInstruction::IdenticalSlowPath( return parameter_number() == other.parameter_number() && // Check the shape too because `this` and `other` may be in // different HloComputations. - ShapeUtil::Compatible(shape(), other.shape()); + eq_shapes(shape(), other.shape()); case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: @@ -1720,6 +1777,11 @@ bool HloInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(dot_dimension_numbers(), other.dot_dimension_numbers()); + case HloOpcode::kGather: + return protobuf_util::ProtobufEquals(gather_dimension_numbers(), + other.gather_dimension_numbers()) && + gather_window_bounds() == other.gather_window_bounds(); + // FFT has various types & lengths. case HloOpcode::kFft: return fft_type() == other.fft_type() && @@ -1742,18 +1804,18 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals(window(), other.window()); case HloOpcode::kReshape: - return ShapeUtil::Compatible(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); // Transpose result is determined by the final shape and the permutation. case HloOpcode::kTranspose: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dimensions() == other.dimensions(); // Remaining instructions with special values. case HloOpcode::kBitcast: - return ShapeUtil::Equal(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); case HloOpcode::kBroadcast: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dimensions() == other.dimensions(); case HloOpcode::kConcatenate: return dimensions() == other.dimensions(); @@ -1767,10 +1829,10 @@ bool HloInstruction::IdenticalSlowPath( slice_limits_ == other.slice_limits_ && slice_strides_ == other.slice_strides_; case HloOpcode::kDynamicSlice: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dynamic_slice_sizes_ == other.dynamic_slice_sizes_; case HloOpcode::kDynamicUpdateSlice: - return ShapeUtil::Compatible(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); @@ -1790,6 +1852,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kRecvDone: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kHostCompute: return false; } } @@ -1815,7 +1878,8 @@ void HloInstruction::RemoveUser(HloInstruction* user) { Status HloInstruction::ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer) { - TF_RET_CHECK(ShapeUtil::Compatible(shape(), new_producer->shape())) + TF_RET_CHECK( + ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape())) << "this shape: " << ShapeUtil::HumanString(shape()) << ", replacement shape: " << ShapeUtil::HumanString(new_producer->shape()); @@ -1838,8 +1902,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); - TF_RET_CHECK( - ShapeUtil::Compatible(old_operand->shape(), new_operand->shape())) + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), + new_operand->shape())) << old_operand->shape().ShortDebugString() << " is not compatible with " << new_operand->shape().ShortDebugString(); operands_[operand_num] = new_operand; @@ -2149,6 +2213,11 @@ std::vector HloInstruction::ExtraAttributesToString( if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } + if (gather_dimension_numbers_ != nullptr) { + extra.push_back(GatherDimensionNumbersToString()); + extra.push_back( + StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); + } if (opcode() == HloOpcode::kFft) { extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); @@ -2280,6 +2349,14 @@ HloInstructionProto HloInstruction::ToProto() const { if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } + if (gather_dimension_numbers_ != nullptr) { + *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; + } + if (opcode() == HloOpcode::kGather) { + for (int64 bound : gather_window_bounds()) { + proto.add_gather_window_bounds(bound); + } + } for (int i = 0; i < slice_starts_.size(); ++i) { auto* slice_dimension = proto.add_slice_dimensions(); slice_dimension->set_start(slice_starts_[i]); @@ -2318,7 +2395,7 @@ string HloInstruction::ToCategory() const { return "data formatting"; } - auto conv_category = [&] { + if (opcode() == HloOpcode::kConvolution) { string category = "convolution"; if (window_util::HasBaseDilation(window())) { category += " base-dilated"; @@ -2327,10 +2404,6 @@ string HloInstruction::ToCategory() const { category += " window-dilated"; } return category; - }; - - if (opcode() == HloOpcode::kConvolution) { - return conv_category(); } // Give transpose-dot and backwards-conv fusions the categories "dot" and @@ -2348,9 +2421,6 @@ string HloInstruction::ToCategory() const { return "output fusion"; case FusionKind::kTransposeDot: return "dot"; - case FusionKind::kConvBackwardFilter: - case FusionKind::kConvBackwardInput: - return conv_category(); case FusionKind::kCustom: return "custom fusion"; } @@ -2581,6 +2651,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleInfeed(this); case HloOpcode::kOutfeed: return visitor->HandleOutfeed(this); + case HloOpcode::kHostCompute: + return visitor->HandleHostCompute(this); case HloOpcode::kRng: return visitor->HandleRng(this); case HloOpcode::kWhile: @@ -2601,6 +2673,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSend(this); case HloOpcode::kSendDone: return visitor->HandleSendDone(this); + case HloOpcode::kGather: + return visitor->HandleGather(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -3125,10 +3199,6 @@ string ToString(HloInstruction::FusionKind kind) { return "kOutput"; case HloInstruction::FusionKind::kTransposeDot: return "kTransposeDot"; - case HloInstruction::FusionKind::kConvBackwardFilter: - return "kConvBackwardFilter"; - case HloInstruction::FusionKind::kConvBackwardInput: - return "kConvBackwardInput"; case HloInstruction::FusionKind::kCustom: return "kCustom"; } @@ -3148,12 +3218,6 @@ StatusOr StringToFusionKind( if (kind_name == "kTransposeDot") { return HloInstruction::FusionKind::kTransposeDot; } - if (kind_name == "kConvBackwardFilter") { - return HloInstruction::FusionKind::kConvBackwardFilter; - } - if (kind_name == "kConvBackwardInput") { - return HloInstruction::FusionKind::kConvBackwardInput; - } if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } @@ -3261,7 +3325,13 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { result += "_"; append_dims(rhs_dims, operand(1)->shape()); result += "->"; - append_dims(output_dims, shape()); + + // A convolution can be represented as a kConvolution HLO or as a CustomCall + // that returns a tuple, the first element of which is the result of the + // convolution. + Shape this_shape = + ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); + append_dims(output_dims, this_shape); return result; } @@ -3288,6 +3358,23 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +string HloInstruction::GatherDimensionNumbersToString() const { + CHECK_NE(gather_dimension_numbers_.get(), nullptr); + string output_window_dims = + StrCat("output_window_dims={", + Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); + string elided_window_dims = + StrCat("elided_window_dims={", + Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); + string gather_dims_to_operand_dims = StrCat( + "gather_dims_to_operand_dims={", + Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + + return Join>( + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims}, + ", "); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5e89dc79bea81e650331e320f7836fdde90b2a53..c4fe132d1d52d6071914869cd50a035ace3389b2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -162,17 +162,14 @@ class HloPrintOptions { class HloInstruction { public: enum class FusionKind { - kLoop, // Fused into a loop. - kInput, // Op's input is fused into the op itself. - kOutput, // Op's output is fused into the op itself. - // REQUIRES: At least one operand buffer must be able - // to alias the output buffer. - kTransposeDot, // Fused into a dot with transposed operands. - kConvBackwardFilter, // Fused into a backward filter convolution. - kConvBackwardInput, // Fused into a backward input convolution. - - kCustom, // Custom category for backend-specific fusions that - // do not match any of the more specific ones. + kLoop, // Fused into a loop. + kInput, // Op's input is fused into the op itself. + kOutput, // Op's output is fused into the op itself. + // REQUIRES: At least one operand buffer must be able + // to alias the output buffer. + kTransposeDot, // Fused into a dot with transposed operands. + kCustom, // Custom category for backend-specific fusions that + // do not match any of the more specific ones. }; ~HloInstruction(); @@ -454,6 +451,12 @@ class HloInstruction { HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation); + static std::unique_ptr CreateGather( + const Shape& shape, HloInstruction* operand, + HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -466,14 +469,6 @@ class HloInstruction { tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation); - // Creates a fusion instruction that represents backward convolution. This is - // similar to CreateFusion, but with extra arguments indicating the window and - // dimemsion mapping of the backward convolution. - static std::unique_ptr CreateFusionForBackwardConvolution( - const Shape& shape, FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums, - HloInstruction* fused_root); - // Creates a call instruction that applies the given computation on the given // operands. "shape" is the resultant shape. static std::unique_ptr CreateCall( @@ -486,6 +481,12 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target); + // Creates a HostCompute instruction, which records host-side control and + // data dependencies for use in instruction scheduling. + static std::unique_ptr CreateHostCompute( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( @@ -497,6 +498,12 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates an instance of GatherDimensionNumbers. + static GatherDimensionNumbers MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice output_window_dims, + tensorflow::gtl::ArraySlice elided_window_dims, + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims); + // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -565,27 +572,33 @@ class HloInstruction { } // Returns true if "other" performs the same computation as this instruction. - // Layout of the instructions' output array is not considered. bool Identical( const HloInstruction& other, const std::function& eq_operands = std::equal_to(), const std::function& - eq_computations = std::equal_to()) const { + eq_computations = std::equal_to(), + bool layout_sensitive = true) const { // An instruction is always identical to itself. if (this == &other) { return true; } - // Identical instruction must have the same opcode and identical operands. - // In general, there is no need to check shape because shape is inferred - // from the shape of the operands. + // Identical instruction must have the same opcode, shape, and identical + // operands. if (opcode() != other.opcode()) { return false; } + using EqShapeFuncType = bool (*)(const Shape&, const Shape&); + EqShapeFuncType eq_shapes = + layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; + if (!eq_shapes(shape(), other.shape())) { + return false; + } if (operands().size() != other.operands().size()) { return false; } + // Use an explicit loop rather than ContainerEquals, because copying around // std::functions may be too expensive in some cases. for (size_t i = 0; i < operands().size(); ++i) { @@ -594,7 +607,7 @@ class HloInstruction { } } - return IdenticalSlowPath(other, eq_computations); + return IdenticalSlowPath(other, eq_computations, eq_shapes); } // Returns whether the instruction has a constant operand. @@ -772,6 +785,10 @@ class HloInstruction { // // (We express the default options using an overload rather than a default // param because gdb ignores default params, but does resolve overloads.) + // + // TODO(b/73348663): Make ToString() adaptive to the size of the string by + // default, backing off on providing full information for very large strings, + // or provide a different name for a ToString-like function that does that. string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; @@ -885,8 +902,8 @@ class HloInstruction { // Returns true if this instruction is a fusion instruction that generates // multiple outputs. const bool IsMultiOutputFusion() const { - return (opcode() == HloOpcode::kFusion && - fused_expression_root()->opcode() == HloOpcode::kTuple); + return opcode() == HloOpcode::kFusion && + fused_expression_root()->opcode() == HloOpcode::kTuple; } FusionKind fusion_kind() const { @@ -919,6 +936,9 @@ class HloInstruction { // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } + // Adds a new operand the fusion instruction. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + // Merges the fused instructions from 'instruction_to_merge' into the // fused instruction set of 'this', updating operands as necessary. // @@ -1052,13 +1072,23 @@ class HloInstruction { return *padding_config_; } - // Returns data on the dimension numbers used for a convolution - // operation. + // Returns data on the dimension numbers used for a convolution operation, + // which may be a kConvolution instruction or a kCustomCall that implements a + // convolution. const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { CHECK(convolution_dimension_numbers_ != nullptr); return *convolution_dimension_numbers_; } + // Sets the convolution dimension numbers on this instruction. In general you + // shouldn't need to call this; instead, specify the convolution dimension + // numbers when you create the instruction. + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = + MakeUnique(dnums); + } + FftType fft_type() const { CHECK_EQ(HloOpcode::kFft, opcode_); return fft_type_; @@ -1081,6 +1111,19 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; + const GatherDimensionNumbers& gather_dimension_numbers() const { + CHECK(gather_dimension_numbers_ != nullptr); + return *gather_dimension_numbers_; + } + + tensorflow::gtl::ArraySlice gather_window_bounds() const { + CHECK_EQ(opcode(), HloOpcode::kGather); + return gather_window_bounds_; + } + + // Returns the dump string of the gather dimension numbers. + string GatherDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -1233,10 +1276,14 @@ class HloInstruction { class FusionReusesParamElements; // See comments on Identical(). + // eq_shapes() is used to check shapes for equality, and would normally be + // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on + // whether we want a layout-sensitive check or not. bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations) const; + eq_computations, + const std::function& eq_shapes) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( @@ -1341,6 +1388,9 @@ class HloInstruction { // Describes the dimension numbers used for a dot. std::unique_ptr dot_dimension_numbers_; + std::unique_ptr gather_dimension_numbers_; + std::vector gather_window_bounds_; + // Describes FFT type for an FFT instruction. FftType fft_type_ = FftType::FFT; @@ -1379,6 +1429,12 @@ class HloInstruction { // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; + // Name to use for host send/recv channels, only present for kHostCompute. + string channel_name_; + + // Estimate of the duration of a host computation in nanoseconds. + int64 cost_estimate_ns_; + // Computations called by this instruction. std::vector called_computations_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 1038ab555567aa654342d59e02efaf844f2b95ba..32d3ed272bd6b239918076999ecae6c1b3ded2fd 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -825,17 +825,42 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { EXPECT_THAT(c1->users(), ElementsAre(fusion)); } -// Convenience function for comparing two HloInstructions inside of -// std::unique_ptrs. -static bool Identical(std::unique_ptr instruction1, - std::unique_ptr instruction2) { +// Convenience function for comparing two HloInstructions. +static bool Identical(const HloInstruction& instruction1, + const HloInstruction& instruction2) { // Verify Identical is reflexive for both instructions. - EXPECT_TRUE(instruction1->Identical(*instruction1)); - EXPECT_TRUE(instruction2->Identical(*instruction2)); + EXPECT_TRUE(instruction1.Identical(instruction1)); + EXPECT_TRUE(instruction2.Identical(instruction2)); - bool is_equal = instruction1->Identical(*instruction2); + bool is_equal = instruction1.Identical(instruction2); // Verify Identical is symmetric. - EXPECT_EQ(is_equal, instruction2->Identical(*instruction1)); + EXPECT_EQ(is_equal, instruction2.Identical(instruction1)); + return is_equal; +} + +// Convenience function for comparing two HloInstructions for structural +// equality. +static bool StructuralEqual(const HloInstruction& instruction1, + const HloInstruction& instruction2) { + auto eq_operand_shapes = [](const HloInstruction* a, + const HloInstruction* b) { + return ShapeUtil::Equal(a->shape(), b->shape()); + }; + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + + // Verify Identical is reflexive for both instructions. + EXPECT_TRUE( + instruction1.Identical(instruction1, eq_operand_shapes, eq_computations)); + EXPECT_TRUE( + instruction2.Identical(instruction2, eq_operand_shapes, eq_computations)); + + bool is_equal = + instruction1.Identical(instruction2, eq_operand_shapes, eq_computations); + // Verify Identical is symmetric. + EXPECT_EQ(is_equal, instruction2.Identical(instruction1, eq_operand_shapes, + eq_computations)); return is_equal; } @@ -858,42 +883,42 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { // Operations which only depend on their operands and opcode. EXPECT_TRUE( - Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), - HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1))); + Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), + *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1))); EXPECT_FALSE( - Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), - HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2))); + Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), + *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2))); EXPECT_FALSE( - Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), - HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1))); + Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), + *HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1))); // Tuples. - EXPECT_TRUE(Identical(HloInstruction::CreateTuple({op1, op2}), - HloInstruction::CreateTuple({op1, op2}))); - EXPECT_FALSE(Identical(HloInstruction::CreateTuple({op1, op2}), - HloInstruction::CreateTuple({op2, op1}))); + EXPECT_TRUE(Identical(*HloInstruction::CreateTuple({op1, op2}), + *HloInstruction::CreateTuple({op1, op2}))); + EXPECT_FALSE(Identical(*HloInstruction::CreateTuple({op1, op2}), + *HloInstruction::CreateTuple({op2, op1}))); // Broadcasts. - EXPECT_TRUE(Identical(HloInstruction::CreateBroadcast(shape, op1, {0, 1}), - HloInstruction::CreateBroadcast(shape, op1, {0, 1}))); - EXPECT_FALSE(Identical(HloInstruction::CreateBroadcast(shape, op1, {0, 1}), - HloInstruction::CreateBroadcast(shape, op1, {1, 0}))); + EXPECT_TRUE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}), + *HloInstruction::CreateBroadcast(shape, op1, {0, 1}))); + EXPECT_FALSE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}), + *HloInstruction::CreateBroadcast(shape, op1, {1, 0}))); Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42}); Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123}); EXPECT_FALSE( - Identical(HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}), - HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1}))); + Identical(*HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}), + *HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1}))); // Binary operands. EXPECT_TRUE(Identical( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2))); + *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), + *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2))); EXPECT_FALSE(Identical( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), - HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1))); + *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), + *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1))); EXPECT_FALSE(Identical( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), - HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2))); + *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), + *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2))); } TEST_F(HloInstructionTest, FunctionVisitor) { @@ -1089,6 +1114,70 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape())); EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(), root2->operand(1)->operand(0)->shape())); + EXPECT_TRUE(StructuralEqual(*fusion, *fusion2)); +} + +TEST_F(HloInstructionTest, FusionEquality) { + HloModule module(TestName()); + HloComputation::Builder builder(TestName()); + + // Create two fusion instructions containing a single unary operation. + auto parameter = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter)); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter)); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {exp}, HloInstruction::FusionKind::kLoop); + auto* fusion2 = computation->CreateFusionInstruction( + {neg}, HloInstruction::FusionKind::kLoop); + EXPECT_FALSE(StructuralEqual(*fusion, *fusion2)); + + auto clone = fusion->Clone(); + EXPECT_TRUE(StructuralEqual(*fusion, *clone)); +} + +TEST_F(HloInstructionTest, NestedFusionEquality) { + HloModule module(TestName()); + HloComputation::Builder builder(TestName()); + + // Build a nested fusion computation. + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + auto b_t = builder.AddInstruction( + HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kSubtract, dot, add_operand)); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub)); + auto computation = module.AddEntryComputation(builder.Build()); + + auto nested_fusion = computation->CreateFusionInstruction( + {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); + + auto fusion = computation->CreateFusionInstruction( + {add, nested_fusion}, HloInstruction::FusionKind::kOutput); + auto fusion2 = computation->CreateFusionInstruction( + {sub, nested_fusion}, HloInstruction::FusionKind::kOutput); + auto clone = fusion->Clone(); + EXPECT_TRUE(StructuralEqual(*fusion, *clone)); + EXPECT_FALSE(StructuralEqual(*fusion, *fusion2)); } TEST_F(HloInstructionTest, CloneSuffixNames) { @@ -1182,5 +1271,40 @@ TEST_F(HloInstructionTest, Stringification) { "true_computation=%TransposeDot, false_computation=%TransposeDot"); } +TEST_F(HloInstructionTest, StringifyGather) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape gather_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + Shape gather_result_shape = + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Gather"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* gather_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, gather_indices_tensor_shape, "gather_indices")); + + HloInstruction* gather_instruction = + builder.AddInstruction(HloInstruction::CreateGather( + gather_result_shape, input, gather_indices, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(gather_instruction->ToString(), + "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " + "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " + "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " + "gather_dims_to_operand_dims={0,1,2,3,4}, " + "window_bounds={30,29,28,27,26}"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 4255d6086625dfb9a045e4431e968a5ee0106ac7..bc74c4bc10cad20eab20b5caf8550b17048a5276 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -102,6 +102,36 @@ bool HloGetTupleElementMatcher::MatchAndExplain( return true; } +void HloCustomCallMatcher::DescribeTo(std::ostream* os) const { + HloMatcher::DescribeTo(os); + *os << " with call target that "; + call_target_matcher_.DescribeTo(os); +} + +bool HloCustomCallMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + ::testing::StringMatchResultListener sub_listener; + bool result = ExplainMatchResult( + call_target_matcher_, instruction->custom_call_target(), &sub_listener); + if (sub_listener.str().empty()) { + sub_listener << " that "; + + std::stringstream desc_stream; + if (result) { + call_target_matcher_.DescribeTo(&desc_stream); + } else { + call_target_matcher_.DescribeNegationTo(&desc_stream); + } + sub_listener << desc_stream.str(); + } + *listener << "custom-call with call target" << sub_listener.str(); + return result; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 9206cdac05fbc1d6051617ab4b0f3016f19e3c90..103f04a2cb7a1a5ae877d8bf259692f7cbed3408 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -56,8 +56,8 @@ class HloParameterMatcher : public HloMatcher { // index to match. class HloGetTupleElementMatcher : public HloMatcher { public: - explicit HloGetTupleElementMatcher( - ::testing::Matcher operand, int64 tuple_index) + HloGetTupleElementMatcher(::testing::Matcher operand, + int64 tuple_index) : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}), tuple_index_(tuple_index) {} @@ -68,6 +68,24 @@ class HloGetTupleElementMatcher : public HloMatcher { int64 tuple_index_; }; +// Custom matcher for custom-call instructions, which accepts a matcher for its +// call target. +class HloCustomCallMatcher : public HloMatcher { + public: + HloCustomCallMatcher( + ::testing::Matcher call_target_matcher, + std::vector<::testing::Matcher> operands) + : HloMatcher(HloOpcode::kCustomCall, operands), + call_target_matcher_(call_target_matcher) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + ::testing::Matcher call_target_matcher_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -94,7 +112,6 @@ HLO_MATCHER(Convert); HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); -HLO_MATCHER(CustomCall); HLO_MATCHER(Divide); HLO_MATCHER(Dot); HLO_MATCHER(DynamicSlice); @@ -184,6 +201,36 @@ inline ::testing::Matcher GetTupleElement() { new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {})); } +// - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call +// target T and the given operands. +// +// - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the +// given operands. +// +// - CustomCall() matches any CustomCall HLO at all. +template +inline ::testing::Matcher CustomCall( + ::testing::Matcher call_target_matcher, M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher( + call_target_matcher, {operands...})); +} +// This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to +// ::testing::Matcher. In that case, we want to prefer the overload +// above. +template >::value, + void>::type*> +inline ::testing::Matcher CustomCall( + FirstM operands_first, M... operands_rest) { + return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( + HloOpcode::kCustomCall, {operands_first, operands_rest...})); +} +inline ::testing::Matcher CustomCall() { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {})); +} + #undef HLO_MATCHER } // namespace opcode_matchers diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 1465d1cacdc971a04c620bc48bed33239a67a955..1c21703a45e11914854153bc14fabd85e9ea57f2 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -23,6 +23,12 @@ using ::testing::Eq; namespace xla { namespace { +string DescribeHloMatcher(const ::testing::Matcher& m) { + std::stringstream ss; + m.DescribeTo(&ss); + return ss.str(); +} + template string Explain(const T& t, const M& m) { ::testing::StringMatchResultListener listener; @@ -67,5 +73,32 @@ TEST(HloMatchersTest, Test) { "add")); } +TEST(HloMatchersTest, CustomCallMatcher) { + auto c1 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); + auto c2 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); + auto call = HloInstruction::CreateCustomCall( + ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target"); + + EXPECT_THAT(call.get(), op::CustomCall()); + EXPECT_THAT(call.get(), op::CustomCall(c1.get(), c2.get())); + EXPECT_THAT(call.get(), op::CustomCall("foo_target")); + EXPECT_THAT(call.get(), op::CustomCall("foo_target", c1.get(), c2.get())); + EXPECT_THAT(call.get(), op::CustomCall(::testing::StartsWith("foo"))); + EXPECT_THAT(call.get(), + op::CustomCall(::testing::Not(::testing::StartsWith("bar")))); + + // Wrong number of operands. + EXPECT_THAT(call.get(), ::testing::Not(op::CustomCall(c1.get()))); + + // Call target does not match. + EXPECT_THAT(call.get(), + ::testing::Not(op::CustomCall(::testing::StartsWith("bar")))); + + EXPECT_THAT(Explain(call.get(), op::CustomCall("bar")), + R"(custom-call with call target that isn't equal to "bar")"); + EXPECT_THAT(DescribeHloMatcher(op::CustomCall("foo_target")), + R"(custom-call with call target that is equal to "foo_target")"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 99d8dd04e5279e0e8a977370beedc4448dc6dc4b..cb2fe9f874012a51e1e6cbd1dd086dbb26994bde 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -38,12 +38,16 @@ HloModule::HloModule(const string& name, : name_(NameUniquer::GetSanitizedName(name)), config_(config), has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle) {} + entry_computation_handle_(entry_computation_handle), + unique_id_(next_unique_module_id_++) {} HloModule::HloModule(const string& name) - : name_(NameUniquer::GetSanitizedName(name)) {} + : name_(NameUniquer::GetSanitizedName(name)), + unique_id_(next_unique_module_id_++) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) - : name_(NameUniquer::GetSanitizedName(name)), config_(config) {} + : name_(NameUniquer::GetSanitizedName(name)), + config_(config), + unique_id_(next_unique_module_id_++) {} HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -141,6 +145,21 @@ void HloModule::ReplaceComputations( } break; } + case HloOpcode::kConditional: { + HloComputation* new_true_computation = + tensorflow::gtl::FindWithDefault( + replacements, instruction->true_computation(), nullptr); + if (new_true_computation != nullptr) { + instruction->set_true_computation(new_true_computation); + } + HloComputation* new_false_computation = + tensorflow::gtl::FindWithDefault( + replacements, instruction->false_computation(), nullptr); + if (new_false_computation != nullptr) { + instruction->set_false_computation(new_false_computation); + } + break; + } case HloOpcode::kSelectAndScatter: { HloComputation* new_select = tensorflow::gtl::FindWithDefault( replacements, instruction->select(), nullptr); @@ -559,9 +578,23 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { return module; } +HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { + HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this)); + TF_CHECK_OK( + clone->root_instruction()->Accept([this](HloInstruction* instruction) { + instruction->ReplaceCalledComputations([this](HloComputation* callee) { + return DeepCloneComputation(callee); + }); + return Status::OK(); + })); + return clone; +} + uint64 HloModule::RandomNew64() const { tensorflow::mutex_lock l(rng_mutex_); return rng_(); } +/* static */ std::atomic HloModule::next_unique_module_id_(0); + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index e377654d024819d00f73f43a70d363bd902dc981..06d92f94fd6f62162b22575e9cc341f2906cd0db 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ +#include #include #include #include @@ -84,6 +85,10 @@ class HloModule { // Returns a deep copy of this module including all computations. std::unique_ptr Clone(const string& suffix = "clone") const; + // Performs a deep clone of the computation, by recursively cloning all + // the called computations as well. + HloComputation* DeepCloneComputation(HloComputation* computation); + // Return a pointer to the entry computation of the module.. const HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); @@ -201,6 +206,10 @@ class HloModule { // this point are guaranteed to be in the range [0..NumUniqueInstructionIds()) int NumUniqueInstructionIds() const { return next_unique_id_; } + // Returns an id that is unique to this module across all modules created over + // the lifetime of this process. + int unique_id() const { return unique_id_; } + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -227,6 +236,11 @@ class HloModule { NameUniquer computation_name_uniquer_{/*separator=*/"."}; NameUniquer instruction_name_uniquer_{/*separator=*/"."}; int next_unique_id_ = 0; + + // Used to keep track of the next unique module id that should be assigned. + static std::atomic next_unique_module_id_; + // A unique id to label modules with. + int unique_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index cd51fa4e8549daba3e953eece50cb3538f627b89..7f28a804bfec9c2f1bbb5fa08f7dd4e68be14d35 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -188,6 +188,12 @@ TEST_F(HloModuleTest, LargeConstantToString) { module->ToString(HloPrintOptions().set_print_large_constants(true))); } +TEST_F(HloModuleTest, UniqueModuleId) { + auto module_a = CreateNewModule(); + auto module_b = CreateNewModule(); + EXPECT_NE(module_a->unique_id(), module_b->unique_id()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 3d64523a79fc50638fdf378b5d521a5cd4482b90..af24604c39b554f146793594958f373999844b4c 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -76,9 +76,11 @@ namespace xla { V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ + V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ + V(kHostCompute, "host-compute") \ V(kImag, "imag") \ V(kInfeed, "infeed") \ V(kIsFinite, "is-finite") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 68e3c9618c1fe9daacb0aee3ee98862c8b9e4bc4..1b24d8da9e832e6847cb6f405e15af3c455f695a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -186,6 +186,22 @@ bool HloOrdering::UseIsBeforeValueDefinition( } } + if (use.instruction->opcode() == HloOpcode::kConditional) { + const HloInstruction* conditional = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + conditional->true_computation())) { + VLOG(4) << " use is conditional " << use.instruction->name() + << " and def is in TRUE computation"; + return true; + } + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + conditional->false_computation())) { + VLOG(4) << " use is conditional " << use.instruction->name() + << " and def is in FALSE computation"; + return true; + } + } + VLOG(4) << " use is not before value"; return false; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index aba66114de649ce7667ae77174e9c4073b010b90..a989fce63234cb860d08c48b02462e96bec879bc 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -262,8 +262,8 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { scalar_shape, HloOpcode::kAdd, constant, xla_while)); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); // Init value is defined before the while, but live range is not before the diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 53bd46a641afcba1b9551895955742e74a9f374b..5120775737bfa32bbb656421216f2b3fbef590ea 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -32,12 +33,28 @@ using ::tensorflow::strings::StrCat; namespace xla { namespace { -void DumpModule(const HloModule& module, - const string& message) { +void DumpModuleGraph(const HloModule& module, const string& message) { hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; XLA_VLOG_LINES(3, module.ToString()); } + +void DumpModuleProto(const HloModule& module, const string& dump_to, + const string& pipeline_name, const string& pass_name) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static auto* const module_id_to_pass_number = + new tensorflow::gtl::FlatMap(); + + tensorflow::mutex_lock lock(mu); + const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; + + const string mod_name = SanitizeFileName(tensorflow::strings::Printf( + "module_%04d.%04lld.%s.after_%s", module.unique_id(), pass_number, + pipeline_name.c_str(), pass_name.c_str())); + + TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), + dump_to, mod_name)); +} } // namespace StatusOr HloPassPipeline::Run(HloModule* module) { @@ -78,6 +95,13 @@ StatusOr HloPassPipeline::Run(HloModule* module) { string message; TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("before running pipeline: ", name()))); + const string xla_dump_per_pass_hlo_proto_to = + module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); + if (!xla_dump_per_pass_hlo_proto_to.empty()) { + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, name().ToString(), + "pipeline_start"); + } + for (auto& pass : passes_) { if (disabled_passes.count(pass->name().ToString()) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() @@ -90,17 +114,21 @@ StatusOr HloPassPipeline::Run(HloModule* module) { // Emit label containing: "after foo-pass, before bar-pass". message.clear(); StrAppend(&message, prefix, ", before ", pass->name()); - DumpModule(*module, message); + DumpModuleGraph(*module, message); TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("after running pass: ", pass->name()))); + if (!xla_dump_per_pass_hlo_proto_to.empty()) { + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, + name().ToString(), pass->name().ToString()); + } changed |= changed_this_pass; prefix.clear(); StrAppend(&prefix, name(), ": after ", pass->name()); } - DumpModule(*module, prefix + ", pipeline end"); + DumpModuleGraph(*module, prefix + ", pipeline end"); return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index c6b4dc0368d92fd477decdfb38045f74f8696803..98b8d34be1f331aaeac94e952deeae1e76379861 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -60,6 +60,7 @@ bool IsRematerializable(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kConstant: + case HloOpcode::kConditional: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kParameter: diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 204a8bf748685af71ac82be0d102cf7f76c7b38f..41b079eb799d06321a31f7d7ae0630dc8d58c46b 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -47,22 +47,11 @@ HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, return tools::Parse(hlo_string, config); } -/*static*/ StatusOr> -HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, - const DebugOptions& debug_options) { - HloProto proto; - - const Status s = - tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &proto); - - if (!s.ok()) { - const Status s2 = - tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto); - if (!s2.ok()) { - return Status(s2.code(), s.error_message() + "\n" + s2.error_message()); - } - } +namespace { +// Creates an HloModule from the given proto. +StatusOr> HloProtoToModule( + const HloProto& proto, const DebugOptions& debug_options) { TF_ASSIGN_OR_RETURN( HloModuleConfig config, HloModule::CreateModuleConfigFromProto(proto.hlo_module())); @@ -72,9 +61,29 @@ HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, return std::move(module); } +} // namespace + /*static*/ StatusOr> -HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename, +HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename, const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + filename, &proto)); + return HloProtoToModule(proto, debug_options); +} + +/*static*/ StatusOr> +HloRunner::ReadModuleFromTextProtoFile(const std::string& filename, + const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR( + tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto)); + return HloProtoToModule(proto, debug_options); +} + +/*static*/ StatusOr> +HloRunner::ReadModuleFromHloTextFile(const std::string& filename, + const DebugOptions& debug_options) { string hlo_string; TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), filename, &hlo_string)); @@ -83,19 +92,6 @@ HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename, return tools::Parse(hlo_string, config); } -/*static*/ StatusOr> HloRunner::ReadModule( - const std::string& filename, const DebugOptions& debug_options) { - auto module = HloRunner::ReadModuleFromHloProtoFile(filename, debug_options); - if (module.ok()) { - return module; - } - const std::string e = module.status().error_message(); - module = HloRunner::ReadModuleFromHloTextDumpFile(filename, debug_options); - return module.ok() ? std::move(module) - : Status(module.status().code(), - e + "\n" + module.status().error_message()); -} - // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct HloRunner::EigenThreadPoolWrapper { @@ -121,12 +117,14 @@ StatusOr> HloRunner::ExecuteInternal( if (run_hlo_passes) { TF_ASSIGN_OR_RETURN( module, backend().compiler()->RunHloPasses( - std::move(module), backend().default_stream_executor())); + std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr)); } TF_ASSIGN_OR_RETURN( std::unique_ptr executable, backend().compiler()->RunBackend(std::move(module), - backend().default_stream_executor())); + backend().default_stream_executor(), + /*device_allocator=*/nullptr)); se::Stream stream(backend().default_stream_executor()); stream.Init(); diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index d4b221fb52dff64dda264a931df6fd19b86e5260..cbaebc68bee708090b8ccb2eae19b556c4d6d453 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -52,21 +52,15 @@ class HloRunner { const DebugOptions& debug_options); // Reads the proto file in xla.HloProto format, creates and returns the - // HloModule. Will try to parse the filename as binary proto, then try as - // text proto if that fails. - static StatusOr> ReadModuleFromHloProtoFile( + // HloModule. + static StatusOr> ReadModuleFromBinaryProtoFile( + const std::string& filename, const DebugOptions& debug_options); + static StatusOr> ReadModuleFromTextProtoFile( const std::string& filename, const DebugOptions& debug_options); // Reads the hlo text dump file in HloModule::ToString format, creates and // returns the HloModule. - static StatusOr> ReadModuleFromHloTextDumpFile( - const std::string& filename, const DebugOptions& debug_options); - - // Tries to parse the filename specified first as binary proto format, then - // as a textual proto format, then textual IR, then gives up if both fail. - // ReadModuleFromHloProtoFile or ReadModuleFromHloTextDumpFile should be used - // explicitly when you know the format, this if you don't. - static StatusOr> ReadModule( + static StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); // Executes the given module with given literals as input and returns the diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 2594c29efd717b3bead34d326c28c7efdf093c50..f6e33403f538bd8492b04c34d46a458f7f06cc06 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include #include #include @@ -100,7 +101,7 @@ class ListScheduler { // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. for (auto* instruction : computation.instructions()) { - std::unordered_set instr_uses; + tensorflow::gtl::FlatSet instr_uses; for (auto* operand : instruction->operands()) { for (const LogicalBuffer* buffer : points_to_analysis.GetBuffersDefinedByInstruction(operand)) { @@ -150,7 +151,7 @@ class ListScheduler { int64 bytes_defined; // For each buffer B used by this instruction, we keep a pair (B, U), where - // U is the number of uses of B that have not yet been scheduled. This pair + // U is the number of uses of B that have not yet been scheduled. This pair // is a pointer into the unscheduled_use_count_ map, so it gets updated for // free when we update counts in the map. std::vector*> @@ -205,7 +206,8 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. - std::unordered_map unscheduled_pred_count; + tensorflow::gtl::FlatMap + unscheduled_pred_count; for (auto* instruction : computation_.instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. @@ -217,39 +219,48 @@ class ListScheduler { } } - std::list ready_list; + // Use a multimap to sort ReadyListEntry according to their priority. + std::multimap ready_queue; + + // Map of ready instructions to their iterators in ready_queue. + tensorflow::gtl::FlatMap::iterator> + ready_instructions; + + auto add_to_ready_queue = [&](HloInstruction* inst) { + auto entry = MakeReadyListEntry(inst); + auto it = ready_queue.emplace(GetPriority(entry), std::move(entry)); + ready_instructions[inst] = it; + }; + for (auto* instruction : computation_.instructions()) { // Instruction with no operands or control predecessors will // not be in the map. if (unscheduled_pred_count.count(instruction) == 0) { - ready_list.push_back(MakeReadyListEntry(instruction)); + add_to_ready_queue(instruction); } } - while (!ready_list.empty()) { - // Select the highest priority HLO instruction from the ready list. - auto best_it = ready_list.begin(); - Priority best_priority = GetPriority(*best_it); - for (auto ready_it = std::next(ready_list.begin()); - ready_it != ready_list.end(); ++ready_it) { - Priority priority = GetPriority(*ready_it); - if (priority > best_priority) { - best_it = ready_it; - best_priority = priority; - } - } - + while (!ready_queue.empty()) { // Remove the selected instruction from the ready list and add it to the // schedule. - const HloInstruction* best = best_it->instruction; - ready_list.erase(best_it); + auto best_it = ready_queue.end(); + --best_it; + const HloInstruction* best = best_it->second.instruction; + ready_queue.erase(best_it); + ready_instructions.erase(best); schedule.push_back(best); scheduled_instructions_.insert(best); + bool adjust_ready_queue = false; // Update the unscheduled uses of the logical buffers. for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { - CHECK_GT(unscheduled_use_count_.at(buffer), 0); - --unscheduled_use_count_[buffer]; + int64& count = unscheduled_use_count_[buffer]; + CHECK_GT(count, 0); + --count; + if (count == 1) { + adjust_ready_queue = true; + } } // Add new instructions to ready list. @@ -257,7 +268,7 @@ class ListScheduler { int64 pred_count = --unscheduled_pred_count.at(inst); CHECK_GE(pred_count, 0); if (pred_count == 0) { - ready_list.push_back(MakeReadyListEntry(inst)); + add_to_ready_queue(inst); } }; // TODO(b/34466113): Replace this and above with successors() or @@ -268,6 +279,31 @@ class ListScheduler { for (HloInstruction* succ : best->control_successors()) { update_pred_count(succ); } + // The unscheduled use count for a buffer has changed to 1, so the + // priorities of some ready instructions may go up. We update them in the + // ready queue, so that they can appear earlier. + if (adjust_ready_queue) { + for (HloInstruction* operand : best->operands()) { + for (HloInstruction* operand_user : operand->users()) { + auto ready_instructions_it = ready_instructions.find(operand_user); + if (ready_instructions_it == ready_instructions.end()) { + continue; + } + auto ready_queue_it = ready_instructions_it->second; + auto& entry = ready_queue_it->second; + Priority new_priority = GetPriority(entry); + if (new_priority == ready_queue_it->first) { + continue; + } + // Create a new entry in ready_queue, then update + // ready_instructions[operand_user] to refer to the new entry. + ready_instructions_it->second = + ready_queue.emplace(new_priority, std::move(entry)); + // Remove the old entry in ready_queue. + ready_queue.erase(ready_queue_it); + } + } + } } CHECK_EQ(schedule.size(), computation_.instruction_count()); CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); @@ -280,15 +316,17 @@ class ListScheduler { const LogicalBuffer::SizeFunction& size_function_; // A map containing the LogicalBuffers that each instruction uses. - std::unordered_map> + tensorflow::gtl::FlatMap> buffer_uses_; // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. We rely on iterator stability in this map. + // LogicalBuffer. We rely on iterator stability in this map, and that the map + // entries are std::pair's. std::unordered_map unscheduled_use_count_; // Set of instructions which have been scheduled. - std::unordered_set scheduled_instructions_; + tensorflow::gtl::FlatSet scheduled_instructions_; }; int64 SumLogicalBufferSizes( diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 6e46f945e0a2d776ab557c10fedf9b5eb393f3c2..b1fd068115e1d104a11d880675ef84e07d6d5602 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -123,6 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { return CheckShape(outfeed, ShapeUtil::MakeNil()); } +Status ShapeVerifier::HandleHostCompute(HloInstruction*) { + return tensorflow::Status::OK(); +} + Status ShapeVerifier::HandleRng(HloInstruction*) { return tensorflow::Status::OK(); } @@ -164,6 +170,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { // HLO broadcast has no exact analog at the proto level so there is no // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); + // Check for mixed precision. + TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape())); TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == broadcast->dimensions().size()); for (int64 operand_dimension = 0; @@ -178,6 +186,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + // Check for mixed precision. + TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == ShapeUtil::ElementsIn(reshape->operand(0)->shape())); return tensorflow::Status::OK(); @@ -359,13 +369,130 @@ Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { batch_norm_grad->feature_index())); } +namespace { + +// Checks that the instruction does not have mixed precision floating point +// inputs. +Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { + switch (instruction->opcode()) { + // White list the following opcodes for mixed-precision check, because they + // involve data pass through or grouping via tuples, where the precisions + // of buffers can be different. + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kCustomCall: + case HloOpcode::kFusion: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReducePrecision: + case HloOpcode::kSelect: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + break; + default: { + PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID; + for (auto operand : instruction->operands()) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + operand->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::ElementIsFloating(subshape)) { + return Status::OK(); + } + if (fp_type == PRIMITIVE_TYPE_INVALID) { + fp_type = subshape.element_type(); + } else if (fp_type != subshape.element_type()) { + return FailedPrecondition( + "Seen floating point types of different precisions in " + "%s, but mixed precision is disallowed.", + instruction->ToString().c_str()); + } + return Status::OK(); + })); + } + } + } + return Status::OK(); +} + +} // namespace + +Status ShapeVerifier::HandleGather(HloInstruction* gather) { + return CheckShape( + gather, + ShapeInference::InferGatherShape( + gather->operand(0)->shape(), gather->operand(1)->shape(), + gather->gather_dimension_numbers(), gather->gather_window_bounds())); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, - const Shape& expected_shape) { - if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) { + const Shape& inferred_shape) { + // If allow_mixed_precision_ is false, check if there are operands with + // different precisions. We need this check because ShapeInference allows + // mixed precision inputs. + if (!allow_mixed_precision_) { + TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction)); + } + + // Check if the output shape matches the expected shape. + bool compatible; + // We treat BF16 and F32 as compatible types if mixed precision is allowed, + // but only when the instruction defines the BF16/F32 buffer. + switch (instruction->opcode()) { + case HloOpcode::kSelect: + if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) { + // Select only defines the top-level buffer, which in this case is the + // tuple, so we cannot allow mixed precision. + compatible = + ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } else { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision( + instruction->shape(), inferred_shape); + } + break; + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed + // precision is disallowed. + case HloOpcode::kConstant: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCustomCall: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kWhile: + // The above opcodes should match the expected shapes exactly. + compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); + break; + default: + if (allow_mixed_precision_) { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision( + instruction->shape(), inferred_shape); + } else { + compatible = + ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } + } + if (!compatible) { return InvalidArgument( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", - ShapeUtil::HumanString(expected_shape).c_str(), + ShapeUtil::HumanString(inferred_shape).c_str(), ShapeUtil::HumanString(instruction->shape()).c_str(), instruction->ToString().c_str()); } @@ -373,14 +500,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, - const StatusOr& expected_shape_status) { - if (!expected_shape_status.ok()) { - Status s = expected_shape_status.status(); + const StatusOr& inferred_shape_status) { + if (!inferred_shape_status.ok()) { + Status s = inferred_shape_status.status(); tensorflow::errors::AppendToMessage(&s, ", for instruction ", instruction->ToString()); return s; } - return CheckShape(instruction, expected_shape_status.ValueOrDie()); + return CheckShape(instruction, inferred_shape_status.ValueOrDie()); } Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { @@ -687,7 +814,8 @@ StatusOr HloVerifier::Run(HloModule* module) { instructions[instruction->name()] = instruction; } - TF_RETURN_IF_ERROR(computation->Accept(shape_verifier_.get())); + std::unique_ptr shape_verifier = shape_verifier_factory_(); + TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); } return false; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 5a1d864e03d436bb29f7c98b9a373a19abc28a7e..1dd7ec3c51e18dcfe89bd478de87798ba3858119 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -27,6 +27,10 @@ namespace xla { // TODO(b/26024837): Check output shape for all instruction types. class ShapeVerifier : public DfsHloVisitor { public: + explicit ShapeVerifier() : allow_mixed_precision_(false) {} + explicit ShapeVerifier(bool allow_mixed_precision) + : allow_mixed_precision_(allow_mixed_precision) {} + Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; Status HandleClamp(HloInstruction* clamp) override; @@ -56,6 +60,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFusion(HloInstruction*) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction*) override; + Status HandleHostCompute(HloInstruction*) override; Status HandleSlice(HloInstruction* slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( @@ -75,20 +80,21 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleBatchNormInference( HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + Status HandleGather(HloInstruction* gather) override; Status FinishVisit(HloInstruction*) override { return tensorflow::Status::OK(); } protected: - // Check the instruction's shape against the given expected shape and return - // an appropriate error if there is a mismatch. + // Check the instruction's shape against the shape given by ShapeInference + // and return an appropriate error if there is a mismatch. Status CheckShape(const HloInstruction* instruction, - const Shape& expected_shape); + const Shape& inferred_shape); // Overload which takes a StatusOr to reduce boilerplate in the caller. Status CheckShape(const HloInstruction* instruction, - const StatusOr& expected_shape_status); + const StatusOr& inferred_shape_status); // Check a unary (binary, etc) instruction's shape against the inferred shape. Status CheckUnaryShape(const HloInstruction* instruction); @@ -99,17 +105,34 @@ class ShapeVerifier : public DfsHloVisitor { // Checks if the given two instructions shares the same channel id. Status CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2); + + private: + // Whether the inputs and output of an instruction can contain both F32s and + // BF16s. Tuples that include both F32s and BF16s are allowed regardless of + // this flag. + bool allow_mixed_precision_; }; // HLO pass that verifies invariants of HLO instructions for each computation in // the module. class HloVerifier : public HloPassInterface { public: + using ShapeVerifierFactory = std::function()>; + // Uses standard shape inference. - explicit HloVerifier() : shape_verifier_(MakeUnique()) {} + explicit HloVerifier() + : shape_verifier_factory_( + [] { return MakeUnique(false); }) {} + + explicit HloVerifier(bool allow_mixed_precision) + : shape_verifier_factory_([allow_mixed_precision] { + return MakeUnique(allow_mixed_precision); + }) {} + // Uses custom shape verification. - explicit HloVerifier(std::unique_ptr shape_verifier) - : shape_verifier_(std::move(shape_verifier)) {} + explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) + : shape_verifier_factory_(std::move(shape_verifier_factory)) {} + ~HloVerifier() override = default; tensorflow::StringPiece name() const override { return "verifier"; } @@ -121,8 +144,11 @@ class HloVerifier : public HloPassInterface { // CHECKs various invariants of a fusion instruction. Status CheckFusionInstruction(HloInstruction* fusion) const; - // Verifies shapes match inferred expectations. - std::unique_ptr shape_verifier_; + // Creates a ShapeVerifier that checks that shapes match inferred + // expectations. This is a factory function because ShapeVerifier, Note that + // ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object + // for each run of the verifier. + ShapeVerifierFactory shape_verifier_factory_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 2a3b55decc5289e7e576d3c5897b333c0b1bc922..c92db0be14dceb32ea86521dcc99b8f63738e4a5 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -97,5 +97,31 @@ TEST_F(HloVerifierTest, DifferentOperandParents) { HasSubstr("is in a different computation")); } +TEST_F(HloVerifierTest, ResetsShapeVerifierState) { + HloComputation::Builder builder(TestName()); + Shape s1 = ShapeUtil::MakeShape(F32, {1}); + Shape s2 = ShapeUtil::MakeShape(F32, {2}); + + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param")); + + // Create an add instruction with the incorrect shape. + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param)); + + // In order to trigger the bug we're checking for, the instruction with the + // bad shape can't be the root of the computation. + builder.AddInstruction( + HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + // Run the verifier twice. It should fail both times, because it shouldn't + // carry state in its DFS visitor between runs. + EXPECT_FALSE(verifier().Run(module.get()).status().ok()); + EXPECT_FALSE(verifier().Run(module.get()).status().ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc new file mode 100644 index 0000000000000000000000000000000000000000..ada21345014dac70d61129aaf7bbc7466a7db914 --- /dev/null +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +namespace { + +// Visitor for removing implicit broadcasts. +class ImplicitBroadcastVisitor : public DfsHloVisitorWithDefault { + public: + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Status::OK(); + } + + Status HandleElementwiseBinary(HloInstruction* hlo) override { + return ReplaceImplicitBroadcastOperands(hlo); + } + + Status HandleClamp(HloInstruction* hlo) override { + // Clamp is the only element-wise ternary operation. + return ReplaceImplicitBroadcastOperands(hlo); + } + + // Returns whether any modification has been made to any visited instruction. + bool changed() const { return changed_; } + + private: + // Iterates through the operands of 'hlo' and replace any operands which are + // implicitly broadcast with the equivalent sequence of broadcast and reshape + // instructions. An operand is considered to be implicitly broadcast if the + // operand shape does have the same dimensions as the shape of 'hlo'. + Status ReplaceImplicitBroadcastOperands(HloInstruction* hlo) { + auto fadd = [hlo](std::unique_ptr x) { + return hlo->parent()->AddInstruction(std::move(x)); + }; + std::vector operands; + bool operands_changed = false; + for (int i = 0; i < hlo->operand_count(); ++i) { + HloInstruction* operand = hlo->mutable_operand(i); + if (!ShapeUtil::SameDimensions(hlo->shape(), operand->shape())) { + HloInstruction* new_operand = hlo->parent()->AddInstruction( + HloInstruction::CreateBroadcastSequence(hlo->shape(), operand, + fadd)); + operands.push_back(new_operand); + operands_changed = true; + } else { + operands.push_back(operand); + } + } + if (operands_changed) { + // Create a new HLO instruction because the HloInstruction::Replace* + // methods check that the shape does not change with the replacement. + HloInstruction* new_hlo = hlo->parent()->AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), operands)); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo)); + changed_ = true; + } + return Status::OK(); + } + + bool changed_ = false; +}; + +} // namespace + +StatusOr ImplicitBroadcastRemover::Run(HloModule* module) { + VLOG(1) << "Removing implicit broadcast from module " << module->name(); + XLA_VLOG_LINES(2, + "Before removing implicit broadcasts:\n" + module->ToString()); + + ImplicitBroadcastVisitor visitor; + for (HloComputation* computation : module->computations()) { + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + } + + if (visitor.changed()) { + // HLO instructions with implicitly broadcast operands are cloned and left + // for dead. Remove them. + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + + XLA_VLOG_LINES(2, + "After removing implicit broadcasts:\n" + module->ToString()); + + return visitor.changed(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h new file mode 100644 index 0000000000000000000000000000000000000000..aa325dc8a353c5bfbfded0c2774c66bfcc71c9cb --- /dev/null +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Pass which replaces all implicit broadcasts with their equivalent sequence of +// explicit broadcast and reshape instructions. +class ImplicitBroadcastRemover : public HloPassInterface { + public: + ImplicitBroadcastRemover() {} + ~ImplicitBroadcastRemover() override {} + + tensorflow::StringPiece name() const override { + return "implicit-broadcast-remover"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8c7b38dd1bf73e0be7b669d7215812aaef1cee17 --- /dev/null +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { + protected: + ImplicitBroadcastRemover remover_; +}; + +TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_FALSE(remover_.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Parameter(), op::Parameter())); +} + +TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "scalar_param")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kPower, param0, param1)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + + EXPECT_FALSE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); + + EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + + EXPECT_THAT(root, op::Power(op::Broadcast(op::Parameter()), op::Parameter())); + + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); +} + +TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 4, 1}), "p1")); + builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, param0, param1)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Subtract(op::Parameter(), + op::Broadcast(op::Reshape(op::Parameter())))); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); +} + +TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {1, 4, 1}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "scalar_param")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, param0, param1)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, + op::Subtract(op::Broadcast(op::Parameter()), op::Parameter())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); +} + +TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6, 8}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 4, 1, 8}), "p0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 6, 8}), "p1")); + auto param2 = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {2, 1, 6, 8}), "p2")); + builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, + param0, param1, param2)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Reshape(op::Parameter())), + op::Broadcast(op::Reshape(op::Parameter())), + op::Broadcast(op::Reshape(op::Parameter())))); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(2)->shape())); +} + +TEST_F(ImplicitBroadcastRemoverTest, + TernaryScalarAndDegenerateDimensionBroadcast) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 4, 6}), "p1")); + auto param2 = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "p2")); + builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, + param0, param1, param2)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Parameter()), + op::Broadcast(op::Reshape(op::Parameter())), + op::Parameter())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(2)->shape())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 90e1f0acdc4cdeda280dabaab2df66b181d0f407..f494748e17fc2d0de74dec67f7414d4791f76a07 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -102,6 +102,8 @@ namespace xla { case HloOpcode::kExp: case HloOpcode::kFft: case HloOpcode::kFusion: + case HloOpcode::kGather: + case HloOpcode::kHostCompute: case HloOpcode::kLog: case HloOpcode::kMap: case HloOpcode::kParameter: diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index dc63a2224d659fa427d4d1a30c5dc0f94d643b36..9171e859c6f84ceef9664aa1eb90a07c87dfab40 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -44,41 +44,26 @@ namespace interpreter { namespace se = ::perftools::gputools; namespace sep = ::perftools::gputools::interpreter; -/* - * Run optimization passes on the module. The graph is transformed by - * each pass in the optimization pipeline. The service subdirectory - * contains useful optimization passes. - */ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(false); - - pipeline.AddPass>( - false, [](const Shape&, const Shape&) { return false; }); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(true); + pipeline.AddPass( hlo_module->mutable_entry_computation_layout()); - pipeline.AddPass(); - pipeline.AddPass(); return pipeline.Run(hlo_module).status(); } StatusOr> InterpreterCompiler::RunHloPasses( - std::unique_ptr hlo_module, - se::StreamExecutor* /*stream_exec*/) { + std::unique_ptr hlo_module, se::StreamExecutor* /*stream_exec*/, + DeviceMemoryAllocator* /*device_allocator*/) { VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); return std::move(hlo_module); } StatusOr> InterpreterCompiler::RunBackend( - std::unique_ptr hlo_module, se::StreamExecutor* stream_exec) { + std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* /*device_allocator*/) { TF_RET_CHECK(stream_exec != nullptr); VLOG(1) << "Run backend " << hlo_module->name(); @@ -96,7 +81,8 @@ StatusOr> InterpreterCompiler::RunBackend( StatusOr>> InterpreterCompiler::Compile( std::vector> /*hlo_modules*/, - std::vector> /*stream_execs*/) { + std::vector> /*stream_execs*/, + DeviceMemoryAllocator* /*device_allocator*/) { return tensorflow::errors::Unimplemented( "Compilation of multiple HLO modules is not supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index 278cf5184227ae25518b1d46c0e16e4cce7bd1a8..c8660c04d86a82e7dfcfd1658310c2a0e4fa0083 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -45,16 +45,19 @@ class InterpreterCompiler : public Compiler { StatusOr> RunHloPasses( std::unique_ptr hlo_module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr hlo_module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> Compile( std::vector> hlo_modules, std::vector> - stream_exec) override; + stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::vector> hlo_modules, diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index bbea6bee5659c73cc71f45ed5e6bbd51df26c050..0668f66051ce96292c3c85bac7e649d89914106c 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -53,6 +53,83 @@ limitations under the License. namespace xla { +// For now moving only one API here, but we should have a single top level +// anonymous namespace, instead of three or four spread all over this file. +namespace { + +// Creates and returns a copy of the given instruction with a different +// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple +// instruction producing the copy is returned. +StatusOr CreateCopyWithNewLayout( + const Shape& shape_with_layout, HloInstruction* instruction) { + TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); + DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) + << ShapeUtil::HumanString(shape_with_layout) << " " + << ShapeUtil::HumanString(instruction->shape()) + << " instruction: " << instruction->ToString(); + + if (ShapeUtil::IsTuple(instruction->shape())) { + // Deep-copy tuples. + std::vector element_copies; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); + ++i) { + HloInstruction* gte = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, + i)); + + // Recurse to copy each elements. + TF_ASSIGN_OR_RETURN( + HloInstruction * element_copy, + CreateCopyWithNewLayout( + ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); + element_copies.push_back(element_copy); + } + // Gather element copies into a tuple with a new Tuple instruction. + HloInstruction* tuple_copy = instruction->parent()->AddInstruction( + HloInstruction::CreateTuple(element_copies)); + LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, tuple_copy->mutable_shape())); + return tuple_copy; + } else if (ShapeUtil::IsArray(instruction->shape())) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, copy->mutable_shape())); + + return copy; + } else { + return FailedPrecondition( + "Can only copy array and tuple shaped instructions"); + } +} + +// Creates a copy of the given operand if the operand's layout does not match +// the given layout. This copy replaces the use in the given instruction. Tuple +// operands will be deep-copied. +Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no) { + HloInstruction* operand = instruction->mutable_operand(operand_no); + TF_RET_CHECK(operand_layout.LayoutIsSet()); + TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); + + if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + // Operand layout already matches our constraint. Nothing to do. + return Status::OK(); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, + CreateCopyWithNewLayout(operand_layout.shape(), operand)); + + return instruction->ReplaceOperandWith(operand_no, operand_copy); +} + +} // namespace + std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint) { out << constraint.ToString(); @@ -61,8 +138,8 @@ std::ostream& operator<<(std::ostream& out, BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory) - : LayoutConstraint(mandatory), layout_(layout), buffer_(&buffer) { + bool mandatory, bool dfs) + : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) { CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok()); } @@ -74,8 +151,8 @@ string BufferLayoutConstraint::ToString() const { OperandLayoutConstraint::OperandLayoutConstraint( const ShapeLayout& shape_layout, const HloInstruction* instruction, - int64 operand_no, bool mandatory) - : LayoutConstraint(mandatory), + int64 operand_no, bool mandatory, bool dfs) + : LayoutConstraint(mandatory, dfs), shape_layout_(shape_layout), instruction_(instruction), operand_no_(operand_no) { @@ -134,7 +211,7 @@ bool LayoutConstraints::OperandBufferForwarded( Status LayoutConstraints::SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory) { + bool mandatory, bool dfs) { VLOG(3) << "SetBufferLayout : " << buffer << " : " << LayoutUtil::HumanString(layout); @@ -171,10 +248,11 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, if (!overwrite) { iter = buffer_constraints_ .insert(std::make_pair( - &buffer, BufferLayoutConstraint(layout, buffer, mandatory))) + &buffer, + BufferLayoutConstraint(layout, buffer, mandatory, dfs))) .first; } else { - iter->second = BufferLayoutConstraint(layout, buffer, /*mandatory=*/true); + iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } added_constraints_.push_back(&iter->second); @@ -188,7 +266,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, const HloInstruction* instruction, - int64 operand_no, bool mandatory) { + int64 operand_no, bool mandatory, + bool dfs) { VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand " << operand_no << " : " << ShapeUtil::HumanStringWithLayout(shape_with_layout); @@ -226,12 +305,12 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, if (iter == operand_constraints_.end()) { auto pair = std::make_pair( key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), - instruction, operand_no, mandatory)); + instruction, operand_no, mandatory, dfs)); iter = operand_constraints_.insert(pair).first; } else { iter->second = OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction, - operand_no, /*mandatory=*/true); + operand_no, mandatory, dfs); } added_constraints_.push_back(&iter->second); @@ -240,16 +319,17 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, Status LayoutConstraints::SetArrayOperandLayout( const Layout& layout, const HloInstruction* instruction, int64 operand_no, - bool mandatory) { + bool mandatory, bool dfs) { const HloInstruction* operand = instruction->operand(operand_no); TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); Shape shape(operand->shape()); *shape.mutable_layout() = layout; TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); - return SetOperandLayout(shape, instruction, operand_no, mandatory); + return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs); } -Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) { +Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, + bool dfs) { VLOG(3) << "SetResultLayout : " << ShapeUtil::HumanStringWithLayout(shape_with_layout); @@ -267,14 +347,15 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) { } result_constraint_.reset( - new ResultLayoutConstraint(ShapeLayout(shape_with_layout))); + new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs)); added_constraints_.push_back(result_constraint_.get()); return Status::OK(); } Status LayoutConstraints::SetInstructionLayout( - const Shape& shape_with_layout, const HloInstruction* instruction) { + const Shape& shape_with_layout, const HloInstruction* instruction, + bool mandatory, bool dfs) { VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", " << ShapeUtil::HumanStringWithLayout(shape_with_layout); @@ -290,8 +371,8 @@ Status LayoutConstraints::SetInstructionLayout( // instruction. return ShapeUtil::ForEachSubshapeWithStatus( shape_with_layout, - [this, instruction](const Shape& subshape, - const ShapeIndex& index) -> Status { + [this, instruction, mandatory](const Shape& subshape, + const ShapeIndex& index) -> Status { // The precondition for this method is that the instruction defines all // buffers in its output. auto buffers = @@ -300,7 +381,7 @@ Status LayoutConstraints::SetInstructionLayout( CHECK_EQ(buffers[0]->instruction(), instruction); if (ShapeUtil::IsArray(subshape)) { - return SetBufferLayout(subshape.layout(), *buffers[0]); + return SetBufferLayout(subshape.layout(), *buffers[0], mandatory); } else { return Status::OK(); } @@ -394,8 +475,7 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain the input to the Outfeed instruction to be the expected // layout of the Outfeed. TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - instruction->outfeed_shape(), instruction, 0, - /*mandatory=*/true)); + instruction->outfeed_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kParameter) { // Parameter layouts must match the respective layout in // ComputationLayout. @@ -434,8 +514,8 @@ Status LayoutAssignment::AddMandatoryConstraints( {0})); Shape new_shape = channel_constraints->LayoutShapeForChannel( recv_buffer_shape, instruction->channel_id()); - TF_RETURN_IF_ERROR(constraints->SetBufferLayout( - new_shape.layout(), *buffer, /*mandatory=*/true)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(new_shape.layout(), *buffer)); } } } @@ -457,7 +537,7 @@ Status LayoutAssignment::AddMandatoryConstraints( for (int64 i = 0; i < instruction->operand_count(); ++i) { TF_RETURN_IF_ERROR(constraints->SetOperandLayout( called_computation_layout.parameter_layout(i).shape(), instruction, - i, /*mandatory=*/true)); + i)); } } else if (instruction->opcode() == HloOpcode::kWhile) { // Layout of input and output of kWhile instruction must be equal and must @@ -508,7 +588,36 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( body_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - body_layout.result_shape(), instruction, 0, + body_layout.result_shape(), instruction, 0)); + } else if (instruction->opcode() == HloOpcode::kConditional) { + // The layout of the true and false computations must match, and must + // be the layout of the kConditional instruction. + TF_RET_CHECK(instruction->operand_count() == 3); + + HloComputation* true_computation = instruction->true_computation(); + HloComputation* false_computation = instruction->false_computation(); + const HloInstruction* true_operand = instruction->operand(1); + const HloInstruction* false_operand = instruction->operand(2); + + TF_RET_CHECK(true_computation->num_parameters() == 1); + TF_RET_CHECK(false_computation->num_parameters() == 1); + ComputationLayout& true_computation_layout = + FindOrDie(computation_layouts_, true_computation); + ComputationLayout& false_computation_layout = + FindOrDie(computation_layouts_, false_computation); + + DCHECK(ShapeUtil::Compatible(true_operand->shape(), + true_computation_layout.parameter_shape(0))); + DCHECK(ShapeUtil::Compatible( + false_operand->shape(), false_computation_layout.parameter_shape(0))); + + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + true_computation_layout.result_shape(), instruction)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + true_computation_layout.parameter_shape(0), instruction, 1, + /*mandatory=*/true)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + false_computation_layout.parameter_shape(0), instruction, 2, /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { if (!CustomCallRequiresMajorFirstLayout(instruction)) { @@ -533,7 +642,7 @@ Status LayoutAssignment::AddMandatoryConstraints( operand_shape.element_type(), AsInt64Slice(operand_shape.dimensions())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction, i, /*mandatory=*/true)); + row_major_operand_shape, instruction, i)); } } } @@ -596,6 +705,33 @@ Status CheckWhileLayout(HloInstruction* while_inst, return Status::OK(); } +Status CheckConditionalLayout( + HloInstruction* instruction, + const ComputationLayout& true_computation_layout, + const ComputationLayout& false_computation_layout) { + HloComputation* true_computation = instruction->true_computation(); + HloComputation* false_computation = instruction->false_computation(); + const HloInstruction* true_operand = instruction->operand(1); + const HloInstruction* false_operand = instruction->operand(2); + + TF_RET_CHECK(true_computation_layout.result_layout() == + false_computation_layout.result_layout()); + TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( + instruction->shape())); + TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( + true_computation->root_instruction()->shape())); + TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( + instruction->shape())); + TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( + false_computation->root_instruction()->shape())); + TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape( + true_operand->shape())); + TF_RET_CHECK( + false_computation_layout.parameter_layout(0).MatchesLayoutInShape( + false_operand->shape())); + return Status::OK(); +} + // Fusion parameters must match the layout of the fusion instructions operands, // and the root of the fusion expression must match the layout of the fusion // instruction. @@ -708,6 +844,13 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, instruction->while_condition()), FindOrDie(computation_layouts_, instruction->while_body()))); break; + case HloOpcode::kConditional: + TF_RETURN_IF_ERROR(CheckConditionalLayout( + instruction, + FindOrDie(computation_layouts_, instruction->true_computation()), + FindOrDie(computation_layouts_, + instruction->false_computation()))); + break; default: break; } @@ -907,7 +1050,11 @@ Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) { auto add_new_constraints_to_worklist = [constraints, &worklist]() { // Add constraints to the front of the deque for DFS ordering. for (auto* constraint : constraints->ConsumeAddedConstraints()) { - worklist.push_front(constraint); + if (constraint->dfs()) { + worklist.push_front(constraint); + } else { + worklist.push_back(constraint); + } } }; add_new_constraints_to_worklist(); @@ -1159,84 +1306,14 @@ StatusOr InferArrayLayout( return *first_buffer_layout; } -// Creates and returns a copy of the given instruction with a different -// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple -// instruction producing the copy is returned. -StatusOr CreateCopyWithNewLayout( - const Shape& shape_with_layout, HloInstruction* instruction) { - TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); - DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) - << ShapeUtil::HumanString(shape_with_layout) << " " - << ShapeUtil::HumanString(instruction->shape()) - << " instruction: " << instruction->ToString(); - - if (ShapeUtil::IsTuple(instruction->shape())) { - // Deep-copy tuples. - std::vector element_copies; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); - ++i) { - HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - - // Recurse to copy each elements. - TF_ASSIGN_OR_RETURN( - HloInstruction * element_copy, - CreateCopyWithNewLayout( - ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); - element_copies.push_back(element_copy); - } - // Gather element copies into a tuple with a new Tuple instruction. - HloInstruction* tuple_copy = instruction->parent()->AddInstruction( - HloInstruction::CreateTuple(element_copies)); - LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, tuple_copy->mutable_shape())); - return tuple_copy; - } else if (ShapeUtil::IsArray(instruction->shape())) { - HloInstruction* copy = - instruction->parent()->AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - LayoutUtil::ClearLayout(copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, copy->mutable_shape())); - - return copy; - } else { - return FailedPrecondition( - "Can only copy array and tuple shaped instructions"); - } -} - -// Creates a copy of the given operand if the operand's layout does not match -// the given layout. This copy replaces the use in the given instruction. Tuple -// operands will be deep-copied. -Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no) { - HloInstruction* operand = instruction->mutable_operand(operand_no); - TF_RET_CHECK(operand_layout.LayoutIsSet()); - TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); - - if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { - // Operand layout already matches our constraint. Nothing to do. - return Status::OK(); - } - - TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, - CreateCopyWithNewLayout(operand_layout.shape(), operand)); - - return instruction->ReplaceOperandWith(operand_no, operand_copy); -} - // For fusion instructions, set the layout of each fused parameter instruction // to match the layout of its corresponding fusion instruction operand. Also, // set the layout of the fused root to match the layout of the fusion // instruction itself. Status SetFusionLayouts(HloInstruction* fusion) { TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion); - for (auto* fused_instruction : fusion->fused_instructions()) { + for (auto* fused_instruction : + fusion->fused_instructions_computation()->MakeInstructionPostOrder()) { if (fused_instruction->opcode() == HloOpcode::kParameter) { const HloInstruction* fusion_operand = fusion->operand(fused_instruction->parameter_number()); @@ -1251,11 +1328,22 @@ Status SetFusionLayouts(HloInstruction* fusion) { ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape())); TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( fusion->shape(), fused_instruction->mutable_shape())); - } else if (fused_instruction->opcode() != HloOpcode::kConstant && - fused_instruction->opcode() != HloOpcode::kGetTupleElement && - fused_instruction->opcode() != HloOpcode::kInfeed) { - // Internal fused instructions with the exception of constants - // and infeed need no layout. + } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) { + // A GTE inherits its layout from its operand (which should ultimately be + // a parameter). + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + fused_instruction->operand(0)->shape().tuple_shapes( + fused_instruction->tuple_index()), + fused_instruction->mutable_shape())); + } else if (fused_instruction->opcode() == HloOpcode::kConstant) { + // Give constants the layout of their literal. + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + fused_instruction->literal().shape(), + fused_instruction->mutable_shape())); + } else if (fused_instruction->opcode() == HloOpcode::kInfeed) { + // Nop; leave the infeed layout alone. + } else { + // Other instructions don't have layouts inside of fusion nodes. LayoutUtil::ClearLayout(fused_instruction->mutable_shape()); } } @@ -1367,20 +1455,6 @@ Status LayoutAssignment::RunOnComputation( << ")"; VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); - // Clear existing layouts of the instructions. All layouts must be assigned by - // the LayoutAssignment pass, except for Infeed, Outfeed, Parameters and the - // computation result. The latter two are specified in computation_layout, so - // we only need to keep the existing layouts for Infeed and Outfeed. Clearing - // the layouts here avoids hiding potential bugs in the layout assignment pass - // that may accidently use the existing layout. - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kInfeed || - instruction->opcode() == HloOpcode::kOutfeed) { - continue; - } - LayoutUtil::ClearLayout(instruction->mutable_shape()); - } - // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); @@ -1392,7 +1466,7 @@ Status LayoutAssignment::RunOnComputation( // Add any backend-specific constraints. TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints)); - // Propagates layouts from an HLO to its neighbors. + // Propagates layouts from mandatory and backend constraints. TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); // While any unconstrained buffers remain, pick an arbitrary buffer, give it a @@ -1457,13 +1531,27 @@ StatusOr LayoutAssignment::Run(HloModule* module) { // Assign layouts to computations in an order such that a callee computation // is handled before its caller computation. This ensures that the layout of // all callers of a computation will agree. + std::list computation_post_order = + module->MakeComputationPostOrder(); for (auto* computation : module->MakeComputationPostOrder()) { + if (computation->IsFusionComputation()) { + continue; + } + // Clear existing layouts of the instructions. All layouts must be assigned + // by the LayoutAssignment pass, except for those on infeeds, parameters, + // and the computation result. The latter two are specified in + // computation_layout, so we only need to keep the existing layouts for + // infeeds. Clearing the layouts here avoids hiding potential bugs in the + // layout assignment pass that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kInfeed) { + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + } if (computation == module->entry_computation()) { TF_RETURN_IF_ERROR(RunOnComputation( *entry_computation_layout_, *points_to_analysis, module->entry_computation(), channel_layout_constraints_)); - } else if (computation->IsFusionComputation()) { - continue; } else { ComputationLayout computation_layout(computation->ComputeProgramShape()); // Setting all embedded computations to the default layout is potentially diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 6bfae2998609c0482b91368f1891ce1e8e43fa23..29018584487cabfd740d7914625c2a50f552d6ff 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -46,7 +46,8 @@ namespace xla { // gathered together in LayoutConstraints object. class LayoutConstraint { public: - LayoutConstraint(bool mandatory) : mandatory_(mandatory) {} + LayoutConstraint(bool mandatory, bool dfs) + : mandatory_(mandatory), dfs_(dfs) {} virtual ~LayoutConstraint() = default; virtual string ToString() const = 0; @@ -54,8 +55,12 @@ class LayoutConstraint { // True if this constraint cannot be overwritten by a different constraint. bool mandatory() const { return mandatory_; } + // When true, propagate in DFS. When false, constraint will propagate in BFS. + bool dfs() const { return dfs_; } + private: bool mandatory_; + bool dfs_; }; std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); @@ -65,7 +70,7 @@ std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); class BufferLayoutConstraint : public LayoutConstraint { public: BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory); + bool mandatory, bool dfs); const LogicalBuffer& buffer() const { return *buffer_; } const Layout& layout() const { return layout_; } @@ -86,7 +91,7 @@ class OperandLayoutConstraint : public LayoutConstraint { public: OperandLayoutConstraint(const ShapeLayout& shape_layout, const HloInstruction* instruction, int64 operand_no, - bool mandatory); + bool mandatory, bool dfs); const ShapeLayout& shape_layout() const { return shape_layout_; } const HloInstruction* instruction() const { return instruction_; } @@ -106,8 +111,10 @@ class OperandLayoutConstraint : public LayoutConstraint { // Constraint on the layout of the result of the entry computation. class ResultLayoutConstraint : public LayoutConstraint { public: - explicit ResultLayoutConstraint(const ShapeLayout& shape_layout) - : LayoutConstraint(/*mandatory=*/true), shape_layout_(shape_layout) {} + explicit ResultLayoutConstraint(const ShapeLayout& shape_layout, + bool dfs = false) + : LayoutConstraint(/*mandatory=*/true, dfs), + shape_layout_(shape_layout) {} const ShapeLayout& shape_layout() const { return shape_layout_; } string ToString() const override; @@ -157,23 +164,25 @@ class LayoutConstraints { // operand of the instruction, or the layout of the result of the computation, // respectively. Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory = true); + bool mandatory = true, bool dfs = true); Status SetOperandLayout(const Shape& shape_with_layout, const HloInstruction* instruction, int64 operand_no, - bool mandatory = true); - Status SetResultLayout(const Shape& shape_with_layout); + bool mandatory = true, bool dfs = true); + Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true); // Convenience wrapper around SetOperandLayout for setting the layout of a // operand using a Layout object. The operand must be array-shaped. Status SetArrayOperandLayout(const Layout& layout, const HloInstruction* instruction, - int64 operand_no, bool mandatory = true); + int64 operand_no, bool mandatory = true, + bool dfs = true); // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers // created by the instruction to the layouts in the given shape. The // instruction must define every logical buffer in its output. Status SetInstructionLayout(const Shape& shape_with_layout, - const HloInstruction* instruction); + const HloInstruction* instruction, + bool mandatory = true, bool dfs = true); // Returns true if any buffer in the given operand is forwarded to the output // of the given instruction. For example, the Tuple instruction forwards the diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index d51c0d1dfb727801d6d2a8328eba60838373479f..dd0fba2758f0d77e72bc55138df229b24c026677 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -35,9 +35,11 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -587,5 +589,137 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } + +// A GTE inside of a fusion node inherits the layout of its operand (which +// should, if we keep following operands, eventually be a parameter). +TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { + const char* module_str = R"( + HloModule test_module + + fused_computation { + fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0) + gte0 = f32[2,2,2] get-tuple-element(fparam), index=0 + gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1 + gte1a = f32[2,2,2] get-tuple-element(gte1), index=0 + gte1b = f32[2,2,2] get-tuple-element(gte1), index=1 + add = f32[2,2,2] add(gte1a, gte1b) + ROOT fresult = f32[2,2,2] add(gte0, add) + } + + ENTRY entry_computation { + param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0) + ROOT fusion = + f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation + } + )"; + + auto module = tools::Parse(module_str).ValueOrDie(); + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}), + })}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({2, 1, 0})); + AssignLayouts(module.get(), &computation_layout); + + HloComputation* fused_computation = *std::find_if( + module->computations().begin(), module->computations().end(), + [](const HloComputation* c) { return c->name() == "fused_computation"; }); + + auto fused_instr = [&](const string& name) { + auto it = std::find_if( + fused_computation->instructions().begin(), + fused_computation->instructions().end(), + [&](const HloInstruction* i) { return i->name() == name; }); + CHECK(it != fused_computation->instructions().end()); + return *it; + }; + + EXPECT_THAT(fused_instr("gte0")->shape().layout().minor_to_major(), + ElementsAre(0, 1, 2)); + EXPECT_THAT( + fused_instr("gte1")->shape().tuple_shapes(0).layout().minor_to_major(), + ElementsAre(1, 2, 0)); + EXPECT_THAT( + fused_instr("gte1")->shape().tuple_shapes(1).layout().minor_to_major(), + ElementsAre(2, 0, 1)); + EXPECT_THAT(fused_instr("gte1a")->shape().layout().minor_to_major(), + ElementsAre(1, 2, 0)); + EXPECT_THAT(fused_instr("gte1b")->shape().layout().minor_to_major(), + ElementsAre(2, 0, 1)); + EXPECT_THAT(fused_instr("fresult")->shape().layout().minor_to_major(), + ElementsAre(2, 1, 0)); +} + +TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { + auto builder = HloComputation::Builder(TestName()); + auto module = CreateNewModule(); + Shape shape = ShapeUtil::MakeShape(F32, {128, 8}); + Shape tshape = ShapeUtil::MakeTupleShape({shape, shape}); + Shape result_tshape = ShapeUtil::MakeTupleShape({shape}); + + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + auto pred = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(PRED, {}), "param2")); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); + + auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch"); + { + auto param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, tshape, "param")); + auto gte0 = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, param, 0)); + auto gte1 = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, param, 1)); + auto add = true_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1)); + true_builder.AddInstruction(HloInstruction::CreateTuple({add})); + } + HloComputation* true_computation = + module->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch"); + { + Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1}); + false_builder.AddInstruction( + HloInstruction::CreateParameter(0, tshape, "param")); + // Using infeed as layout assignment does not mess up with it. + auto infeed = + false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, "")); + false_builder.AddInstruction(HloInstruction::CreateTuple({infeed})); + } + HloComputation* false_computation = + module->AddEmbeddedComputation(false_builder.Build()); + builder.AddInstruction(HloInstruction::CreateConditional( + result_tshape, pred, tuple, true_computation, tuple, false_computation)); + + HloComputation* computation = module->AddEntryComputation(builder.Build()); + ComputationLayout computation_layout(computation->ComputeProgramShape()); + + AssignLayouts(module.get(), &computation_layout); + + const HloInstruction* true_root = true_computation->root_instruction(); + const HloInstruction* false_root = false_computation->root_instruction(); + EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple); + EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple); + + const HloInstruction* true_result = true_root->operand(0); + const HloInstruction* false_result = false_root->operand(0); + EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(), + false_result->shape().layout())); + EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index 2c2a02f6375343d67dfb155bbb03729ff6e490d2..f8b309488eeb5391b1cad5db760934ec1f7e3521 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -35,8 +35,7 @@ class PointsToAnalysisTestBase : public HloTestBase { CHECK_NOTNULL(module_.get()); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - dataflow_analysis_ = - HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie(); + dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); } void BuildModuleAndRunAnalysis(std::unique_ptr computation) { diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index 34f3419269abbc73cd0ddb13c723a8da38ab19ff..911b243fe28a5baf8a4b8ed752b892265f5388ac 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -14,12 +14,29 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/core/platform/denormal.h" + +#ifdef __FAST_MATH__ +#error "Don't build XLA with -ffast-math" +#endif namespace xla { StatusOr>> LLVMCompiler::Compile( std::vector> modules, - std::vector> - stream_execs) { + std::vector> stream_execs, + DeviceMemoryAllocator* device_allocator) { + // Tensorflow tries to enable the following behaviors in all its threads: + // + // - Denormals are zero (DAZ): roughly, operations treat denormal floats as + // zero. + // - Flush denormals to zero (FTZ): roughly, operations produce zero instead + // of denormal floats. + // + // In theory enabling these shouldn't matter since the compiler should ideally + // not leak its environment into generated code, but we turn off DAZ and FTZ + // to get some defense-in-depth. + tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals; + std::vector> result; for (size_t i = 0; i < modules.size(); i++) { if (stream_execs[i].size() != 1) { @@ -27,10 +44,12 @@ StatusOr>> LLVMCompiler::Compile( "Model partitioning not implemented for the CPU/GPU compilers!"); } - TF_ASSIGN_OR_RETURN( - modules[i], RunHloPasses(std::move(modules[i]), stream_execs[i][0])); + TF_ASSIGN_OR_RETURN(modules[i], + RunHloPasses(std::move(modules[i]), stream_execs[i][0], + device_allocator)); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - RunBackend(std::move(modules[i]), stream_execs[i][0])); + RunBackend(std::move(modules[i]), stream_execs[i][0], + device_allocator)); result.push_back(std::move(executable)); } diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index c5393cef4f961c5d04c32d0d4291732b8ec702f1..d74e81bb7f622ac5e89203a3d02ca5ad839da07e 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -60,17 +60,20 @@ class LLVMCompiler : public Compiler { // Bring in // StatusOr> RunBackend( // std::unique_ptr module, - // perftools::gputools::StreamExecutor* stream_exec) + // perftools::gputools::StreamExecutor* stream_exec, + // DeviceMemoryAllocator* device_allocator) // StatusOr> RunHloPasses( // std::unique_ptr module, - // perftools::gputools::StreamExecutor* stream_exec) + // perftools::gputools::StreamExecutor* stream_exec, + // DeviceMemoryAllocator* device_allocator) using Compiler::RunBackend; using Compiler::RunHloPasses; StatusOr>> Compile( std::vector> modules, std::vector> - stream_execs) override; + stream_execs, + DeviceMemoryAllocator* device_allocator) override; protected: ModuleHook user_pre_optimization_hook_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index ffc78bd5cfac3df1001d8125327607c85169ae92..37261ed1e665ebed9685751161a412ad114a9e96 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -54,6 +54,7 @@ cc_library( "@llvm//:core", "@llvm//:support", "@llvm//:target", + "@llvm//:transform_utils", ], ) diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 9ad7cd82cb8ca862fd7acec3dfb12c9fd61f6e27..b3b6026ef17daa184c0a015fdea618597ef068b3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -32,8 +32,23 @@ limitations under the License. namespace xla { -// Unlike IrEmitter, this creates host functions which emit IR to generate the -// output element at the given index. It is used to generate fused operations. +// FusedIrEmitter is used to generate code for fusion nodes. +// +// Unlike IrEmitter and its ilk, which directly create LLVM IR in an LLVM +// Module, FusedIrEmitter is better understood as "IR generator generator". +// FusedIrEmitter recursively creates a generator (a host function) which the +// compiler can invoke at a later time. Invoking the generator emits LLVM IR +// that, when run, produces the value at a particular index of the output. +// +// After building this generator, the compiler creates a loop (or its moral +// equivalent, e.g. a GPU kernel) and calls the generator from within the loop. +// This generates code that produces each element of the output. +// +// This class handles both vanilla fusion and multi-output fusion. In the MOF +// case, the fusion node ends with a kTuple instruction, and the generator +// created produces an LLVM struct with N elements, one for each element of the +// arrays in the tuple. It follows that the arrays in the tuple must have the +// same length. class FusedIrEmitter : public DfsHloVisitorWithDefault { public: using Generator = llvm_ir::ElementGenerator; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 8d1e6338e189a055ac20f09961a783b52600866d..5c1866311d1ae1e0c33ab061ee326d86d647a908 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -20,9 +20,11 @@ limitations under the License. #include #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/Target/TargetOptions.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -61,6 +63,16 @@ llvm::StringRef AsStringRef(tensorflow::StringPiece str) { return llvm::StringRef(str.data(), str.size()); } +std::unique_ptr DropConstantInitializers( + const llvm::Module& module) { + std::unique_ptr cloned_module = CloneModule(module); + for (llvm::GlobalVariable& global_var : cloned_module->globals()) { + global_var.setInitializer(nullptr); + global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage); + } + return cloned_module; +} + string DumpModuleToString(const llvm::Module& module) { std::string buffer_string; llvm::raw_string_ostream ostream(buffer_string); @@ -672,6 +684,19 @@ static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { return uniquer->GetUniqueName(prefix); } +static Status CreateAndWriteStringToFile(const string& directory_name, + const string& file_name, + const string& text) { + std::unique_ptr f; + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->RecursivelyCreateDir(directory_name)); + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->NewWritableFile(file_name, &f)); + TF_RETURN_IF_ERROR(f->Append(text)); + TF_RETURN_IF_ERROR(f->Close()); + return Status::OK(); +} + Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized) { @@ -686,13 +711,17 @@ Status DumpIRToDirectory(const string& directory_name, directory_name, tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); - std::unique_ptr f; - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->RecursivelyCreateDir(directory_name)); - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->NewWritableFile(ir_file_name, &f)); - TF_RETURN_IF_ERROR(f->Append(DumpModuleToString(llvm_module))); - return f->Close(); + // For some models the embedded constants can be huge, so also dump the module + // with the constants stripped to get IR that is easier to manipulate. + string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( + directory_name, + tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll")); + + TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( + directory_name, ir_file_name, DumpModuleToString(llvm_module))); + return CreateAndWriteStringToFile( + directory_name, ir_no_constant_initializers_file_name, + DumpModuleToString(*DropConstantInitializers(llvm_module))); } llvm::Function* CreateFunction(llvm::FunctionType* function_type, diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index a5f7c850c33757fe8d48567ade35544d81224e46..b6b918ec78a27b90325f72eea14b97f9aee43c54 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -51,37 +51,40 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, shape_(target_array.GetShape()), ir_builder_(ir_builder) {} +static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( + const ElementGenerator& target_element_generator, + const std::vector& target_arrays, llvm::IRBuilder<>* ir_builder) { + return [=](const llvm_ir::IrArray::Index array_index) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_element, + target_element_generator(array_index)); + CHECK(target_element->getType()->isStructTy()) + << "This BodyEmitter is for multi-output fusion, but target element " + "generator does not produce values of struct type."; + CHECK_EQ(target_element->getType()->getStructNumElements(), + target_arrays.size()); + + for (int64 i = 0; i < target_arrays.size(); ++i) { + target_arrays[i].EmitWriteArrayElement( + array_index, ir_builder->CreateExtractValue(target_element, i), + ir_builder); + } + return Status::OK(); + }; +} + LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, llvm::IRBuilder<>* ir_builder) - : body_emitter_([=](const llvm_ir::IrArray::Index array_index) - -> ::tensorflow::Status { - // Convert target_element_generator to a BodyEmitter. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - target_element_generator(array_index)); - if (target_arrays.size() == 1) { - target_arrays[0].EmitWriteArrayElement(array_index, target_element, - ir_builder); - return tensorflow::Status::OK(); - } - - for (int64 i = 0; i < target_arrays.size(); ++i) { - target_arrays[i].EmitWriteArrayElement( - array_index, ir_builder_->CreateExtractValue(target_element, i), - ir_builder); - } - return tensorflow::Status::OK(); - }), + : body_emitter_(MakeBodyEmitterForMultiOutputFusion( + target_element_generator, + std::vector(target_arrays.begin(), target_arrays.end()), + ir_builder)), + shape_(target_arrays[0].GetShape()), ir_builder_(ir_builder) { - if (target_arrays.size() > 1) { - // The sanity check for multiple outputs. - shape_ = target_arrays[0].GetShape(); - for (int64 i = 1; i < target_arrays.size(); ++i) { - const Shape& element_shape = target_arrays[i].GetShape(); - CHECK(ShapeUtil::SameDimensions(shape_, element_shape)); - } - } else { - shape_ = target_arrays[0].GetShape(); + // Sanity check: In multi-output fusion, all shapes produced must have the + // same dimensions. + for (const IrArray& array : target_arrays) { + CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 1ef1dc246442041698d96f6aff48794c8788f1d1..0fc528439a0d5bf8382dfcf2d8b3051f8900bf1d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -47,10 +47,16 @@ class LoopEmitter { // element of the given target array. LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder); - // Same as previous method except emits multiple targets in an array. + + // Constructs a LoopEmitter that emits one element into each of N separate + // arrays on each iteration of the loop. + // + // This is used for multi-output fusion. target_element_generator must + // produce an LLVM struct with N elements. LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, llvm::IRBuilder<>* ir_builder); + LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; virtual ~LoopEmitter() = default; diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 2194d24257d0ccd04f3c9625412116eba01acd8c..07f989d4faea199e812e54d2ae74d3ff9e7fa19a 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -71,7 +72,7 @@ LocalService::LocalService(const ServiceOptions& options, StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const Shape* result_layout, int device_ordinal) { + const ExecutableBuildOptions& build_options) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(computation)); VersionedComputationHandle versioned_handle = @@ -112,14 +113,19 @@ StatusOr> LocalService::CompileExecutable( ShapeUtil::HumanString(argument_shape).c_str()); } } - if (result_layout != nullptr) { - TF_RETURN_IF_ERROR( - ValidateResultShapeWithLayout(*result_layout, program_shape->result())); + if (build_options.result_layout() != nullptr) { + TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( + *build_options.result_layout(), program_shape->result())); } ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (result_layout != nullptr) { - *execution_options.mutable_shape_with_output_layout() = *result_layout; + if (build_options.generate_hlo_graph().has_value()) { + execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( + build_options.generate_hlo_graph().value()); + } + if (build_options.result_layout() != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *build_options.result_layout(); } else { *execution_options.mutable_shape_with_output_layout() = program_shape->result(); @@ -128,13 +134,16 @@ StatusOr> LocalService::CompileExecutable( } TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, &execution_options)); + CreateModuleConfig(*program_shape, argument_layouts, &execution_options, + *user_computation)); - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - execute_backend_->stream_executor(device_ordinal)); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(build_options.device_ordinal())); return BuildExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), executor); + execute_backend_.get(), executor, + build_options.device_allocator()); } StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index acbc7268252881958190f416ab936d64430166e1..15e120685e1be9190d49fdaf5ed6706bdf991a6c 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -41,11 +42,13 @@ class LocalService : public Service { // Builds an Executable with the given argument layouts and options. If // result_layout is non-null, then the executable is compiled to produce a - // result of the given layout. + // result of the given layout. If device_allocator is non-null, then the + // compiler may use it to allocate temp space on the device. The compiler is + // responsible for freeing any memory it allocates this way. StatusOr> CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const Shape* result_layout, int device_ordinal); + const ExecutableBuildOptions& options); // Returns the device ordinal that corresponds to the given replica number. // diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 926ebbe3140d631a3fb03f41c687ae72c58706f5..e278eab69088d3031b1d951734b7dcad6f8afc77 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -37,12 +37,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -56,6 +58,7 @@ namespace se = ::perftools::gputools; using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrCat; +using ::xla::source_map_util::InvalidParameterArgument; namespace xla { @@ -261,7 +264,8 @@ StatusOr> Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options) { + const ExecutionOptions* execution_options, + const UserComputation& user_computation) { auto config = MakeUnique(program_shape); auto* computation_layout = config->mutable_entry_computation_layout(); @@ -275,8 +279,10 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { - return InvalidArgument( - "computation expects parameter %d to have shape %s, given shape %s", + return InvalidParameterArgument( + *user_computation.ParameterMetadata(i).value(), + "Argument does not match shape of computation parameter %d: want %s, " + "got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } @@ -318,19 +324,22 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options) { + const ExecutionOptions& execution_options, + const UserComputation& user_computation) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - return CreateModuleConfig(program_shape, argument_shapes, &execution_options); + return CreateModuleConfig(program_shape, argument_shapes, &execution_options, + user_computation); } StatusOr>> Service::BuildExecutables( std::vector versioned_handles, std::vector> module_configs, Backend* backend, - std::vector> executors) { + std::vector> executors, + DeviceMemoryAllocator* device_allocator) { VLOG(1) << Printf("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. @@ -376,7 +385,8 @@ StatusOr>> Service::BuildExecutables( TF_ASSIGN_OR_RETURN( std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors))); + backend->compiler()->Compile(std::move(modules), std::move(executors), + device_allocator)); for (size_t i = 0; i < versioned_handles.size(); ++i) { if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { @@ -389,8 +399,8 @@ StatusOr>> Service::BuildExecutables( StatusOr> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, - Backend* backend, se::StreamExecutor* executor) { + std::unique_ptr module_config, Backend* backend, + se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, versioned_handle.ToString().c_str()); @@ -423,11 +433,12 @@ StatusOr> Service::BuildExecutable( TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); TF_ASSIGN_OR_RETURN( - module, backend->compiler()->RunHloPasses(std::move(module), executor)); + module, backend->compiler()->RunHloPasses(std::move(module), executor, + device_allocator)); - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - backend->compiler()->RunBackend(std::move(module), executor)); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + backend->compiler()->RunBackend( + std::move(module), executor, device_allocator)); if (!other_directory_path.empty()) { executable->set_session_module(std::move(session_module)); @@ -438,9 +449,9 @@ StatusOr> Service::BuildExecutable( StatusOr> Service::BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, - Backend* backend, perftools::gputools::StreamExecutor* executor, - ExecutionProfile* profile) { + std::unique_ptr module_config, Backend* backend, + perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile, + DeviceMemoryAllocator* device_allocator) { std::shared_ptr executable = compilation_cache_.LookUp(versioned_handle, *module_config); @@ -462,7 +473,7 @@ StatusOr> Service::BuildAndCacheExecutable( TF_ASSIGN_OR_RETURN( std::unique_ptr executable_unique_ptr, BuildExecutable(versioned_handle, std::move(module_config), backend, - executor)); + executor, device_allocator)); if (profile != nullptr) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -742,9 +753,10 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Create an HloModuleConfig object for the computation, given the shape of // the program and the argument allocations. - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, - request.execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arguments, + request.execution_options(), *user_computation)); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -763,10 +775,14 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Build the user computations into HloModules and compile to generate the // executables. + // + // TODO(jlebar): There's currently no way to pass a device allocator to + // ExecuteParallel, so we have to pass a null device_allocator below. TF_ASSIGN_OR_RETURN( std::vector> executables, BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), all_executors)); + execute_backend_.get(), all_executors, + /*device_allocator=*/nullptr)); std::vector executable_ptrs; executable_ptrs.reserve(executables.size()); for (const auto& executable : executables) { @@ -852,7 +868,8 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, arg->execution_options())); + CreateModuleConfig(*program_shape, arguments, arg->execution_options(), + *user_computation)); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -916,7 +933,8 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, arg->execution_options())); + CreateModuleConfig(*program_shape, arguments, arg->execution_options(), + *user_computation)); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -1236,7 +1254,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options)); + CreateModuleConfig(program_shape, {}, execution_options, + *user_computation)); // Exclude dead parameter instructions for the purpose of computing constants. TF_ASSIGN_OR_RETURN( @@ -1427,6 +1446,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { case OpRequest::kFftRequest: handle_status = computation->AddFftInstruction(arg->fft_request()); break; + case OpRequest::kGatherRequest: + handle_status = computation->AddGatherInstruction(arg->gather_request()); + break; case OpRequest::kGetTupleElementRequest: handle_status = computation->AddGetTupleElementInstruction( arg->get_tuple_element_request()); @@ -1435,9 +1457,13 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddInfeedInstruction(arg->infeed_request()); break; case OpRequest::kOutfeedRequest: - TF_RETURN_IF_ERROR( - computation->AddOutfeedInstruction(arg->outfeed_request())); - return tensorflow::Status::OK(); + handle_status = + computation->AddOutfeedInstruction(arg->outfeed_request()); + break; + case OpRequest::kHostComputeRequest: + handle_status = + computation->AddHostComputeInstruction(arg->host_compute_request()); + break; case OpRequest::kMapRequest: { TF_ASSIGN_OR_RETURN( UserComputation * to_apply, @@ -1601,14 +1627,14 @@ StatusOr> Service::Replicas( } Status Service::MaybeDumpHloModule(const HloModule& module) const { - const string xla_dump_prepass_hlo_proto_to = - module.config().debug_options().xla_dump_prepass_hlo_proto_to(); - if (xla_dump_prepass_hlo_proto_to.empty()) { + const string xla_dump_unoptimized_hlo_proto_to = + module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); + if (xla_dump_unoptimized_hlo_proto_to.empty()) { return Status::OK(); } HloProto proto = MakeHloProto(module); return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_prepass_hlo_proto_to, module.name()); + proto, xla_dump_unoptimized_hlo_proto_to, module.name()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 0a7d0b3a7d25a1b046852c87d8463d0169080a5e..6ce241971156599aaa25aea1b0caac0e1bd5379c 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -251,7 +251,8 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options); + const ExecutionOptions& execution_options, + const UserComputation& user_computation); protected: friend class LocalExecutable; @@ -275,13 +276,19 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options); + const ExecutionOptions* execution_options, + const UserComputation& user_computation); // Builds an Executable for the given parameters. + // + // If device_allocator is not null, the compiler may use it to allocate temp + // buffers, which the compiler is responsible for freeing. The allocator + // given here need not match the allocator used when running the executable. StatusOr> BuildExecutable( const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, - Backend* backend, perftools::gputools::StreamExecutor* executor); + std::unique_ptr module_config, Backend* backend, + perftools::gputools::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator = nullptr); // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. @@ -289,16 +296,17 @@ class Service : public ServiceInterface { std::vector versioned_handles, std::vector> module_configs, Backend* backend, - std::vector> executors); + std::vector> executors, + DeviceMemoryAllocator* device_allocator); // Similar to BuildExecutable, but look in the compilation cache for the // executable first. If the executable is not in the cache, it is built and // inserted into the cache. StatusOr> BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, - Backend* backend, perftools::gputools::StreamExecutor* executor, - ExecutionProfile* profile); + std::unique_ptr module_config, Backend* backend, + perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile, + DeviceMemoryAllocator* device_allocator = nullptr); // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index a6d6c8b27f81045a4bee09e056c5c8f8e8a330c7..c9692757b27980b10a5ca562223c3d0f6462d820 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -37,6 +37,9 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" +using tensorflow::str_util::Join; +using tensorflow::strings::Printf; + namespace xla { namespace { @@ -206,7 +209,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, } // Check that init_value's shape is suitable for reducer_shape. - if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, + init_value_shape)) { return InvalidArgument( "Reduction function's accumulator shape differs from the " "init_value shape: %s vs %s", @@ -217,8 +221,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, // Check that the inputs can be passed in as the second argument. const Shape& input_element_shape = ShapeUtil::MakeShape(input_element_type, {}); - if (!ShapeUtil::Compatible(input_element_shape, - reducer_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape, + reducer_shape.parameters(1))) { return InvalidArgument( "Reduction function's second parameter shape differs from the " "input type element type: %s vs %s", @@ -228,7 +232,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, // Currently the accumulator and inputs must be the same type, // though that restriction could be relaxed. - if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, + reducer_shape.parameters(1))) { return InvalidArgument( "Reduction function's second parameter shape currently must " "match the result shape. Got %s vs %s", @@ -391,11 +396,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, dimension); } const Shape* arg_shape = nullptr; + PrimitiveType element_type = PRIMITIVE_TYPE_INVALID; for (const Shape* shape : arg_shapes) { TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(*shape, "operand of concatenation")); if (!arg_shape) { arg_shape = shape; + element_type = arg_shape->element_type(); continue; } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { @@ -406,7 +413,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape).c_str()); } - if (arg_shape->element_type() != shape->element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( "cannot concatenate arrays with different element types: %s vs %s", PrimitiveType_Name(arg_shape->element_type()).c_str(), @@ -428,6 +435,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(*shape).c_str(), dimension); } } + element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape); } std::vector new_dimensions(arg_shape->dimensions().begin(), @@ -435,7 +443,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, for (size_t i = 1; i < arg_shapes.size(); ++i) { new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension); } - return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions); + return ShapeUtil::MakeShape(element_type, new_dimensions); } /* static */ StatusOr ShapeInference::InferConvertShape( @@ -533,7 +541,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), padding_config.ShortDebugString().c_str()); } - if (operand_shape.element_type() != padding_value_shape.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, + padding_value_shape)) { return InvalidArgument( "the element types of the operands to pad do not match"); } @@ -545,7 +554,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, std::max(operand_shape.dimensions(i) - 1, 0LL) * padding_config.dimensions(i).interior_padding(); } - return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions); + return ShapeUtil::MakeShape( + ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), + dimensions); } // Current DotDimensionNumbers Requirements: @@ -670,7 +681,7 @@ Status ValidateDotDimensionNumbers( }; // Check if both element types are the same. - if (lhs.element_type() != rhs.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return fail("element types do not match"); } @@ -733,7 +744,8 @@ Status ValidateDotDimensionNumbers( dimensions.push_back(rhs.dimensions(i)); } } - Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions); + Shape result = ShapeUtil::MakeShape( + ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -764,7 +776,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::HumanString(rhs).c_str()); } } - return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions); + return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), + output_dimensions); } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( @@ -826,6 +839,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // specified in broadcast_dimensions are then changed to match the // corresponding dimension size in smaller_shape. Shape output_shape(larger_shape); + output_shape.set_element_type( + ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape)); for (int i = 0; i < smaller_shape.dimensions_size(); ++i) { int64 dimension_to_match = broadcast_dimensions.at(i); @@ -875,7 +890,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); - if (!ShapeUtil::SameElementType(lhs, rhs)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "binary op %s with different element types: %s and %s", BinaryOperation_Name(operation).c_str(), @@ -894,10 +909,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } } - if (ShapeUtil::Compatible(lhs, rhs)) { + if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) { // If the shapes are the same other than layout, the output shape is the // same (elementwise op). - return lhs; + return ShapeUtil::ChangeElementType( + lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs)); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { @@ -934,7 +950,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", BinaryOperation_Name(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), - tensorflow::str_util::Join(broadcast_dimensions, ", ").c_str()); + Join(broadcast_dimensions, ", ").c_str()); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -970,7 +986,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions)); - if (lhs.element_type() == F32) { + if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); } else { return Unimplemented("complex component type not supported"); @@ -1075,12 +1091,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map")); - if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) { + if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { continue; } if (!ShapeUtil::IsTuple(*arg_shapes[i]) && !ShapeUtil::IsTuple(*arg_shape) && - ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) { + ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], + *arg_shape)) { if (ShapeUtil::IsScalar(*arg_shapes[i])) { continue; } @@ -1097,7 +1114,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Map operation requires all operands to have the same shape; got: " "%s", - tensorflow::str_util::Join(pieces, ", ").c_str()); + Join(pieces, ", ").c_str()); } // Check that dimensions.size == arg_shape.dimensions_size() (we currently @@ -1114,7 +1131,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (dimensions[i] != i) { return InvalidArgument( "Map requires monotonically increasing dimension numbers, found: %s ", - tensorflow::str_util::Join(dimensions, ", ").c_str()); + Join(dimensions, ", ").c_str()); } } @@ -1145,7 +1162,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( i, ShapeUtil::HumanString(parameter_shape).c_str()); } - if (parameter_shape.element_type() != arg_shape->element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, + *arg_shape)) { return InvalidArgument( "mapped computation's parameter type has to match argument element " "type; got parameter %d shape: %s, argument shape: %s", @@ -1218,7 +1236,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " @@ -1227,7 +1246,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " @@ -1326,7 +1346,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1336,7 +1357,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1346,7 +1368,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1356,7 +1379,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1478,7 +1502,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(output_grad_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(output_grad_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " @@ -1487,7 +1512,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " @@ -1496,7 +1522,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " @@ -1505,7 +1532,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(var_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " @@ -1566,7 +1594,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution")); - if (!ShapeUtil::SameElementType(lhs, rhs)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s", ShapeUtil::HumanString(lhs).c_str(), @@ -1711,8 +1739,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( dimensions[dnums.output_spatial_dimensions(i)] = window_output_shape.dimensions(i); } - - return ShapeUtil::MakeShape(lhs.element_type(), dimensions); + return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), + dimensions); } /* static */ StatusOr ShapeInference::InferFftShape( @@ -1874,16 +1902,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } const Shape& operand_element_shape = ShapeUtil::MakeShape(operand_shape.element_type(), {}); - if (!ShapeUtil::Compatible(operand_element_shape, - select_shape.parameters(0))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, + select_shape.parameters(0))) { return InvalidArgument( "select function's first parameter shape currently must " "match the operand element shape. Got %s vs %s", ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), ShapeUtil::HumanString(operand_element_shape).c_str()); } - if (!ShapeUtil::Compatible(operand_element_shape, - select_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, + select_shape.parameters(1))) { return InvalidArgument( "select function's second parameter shape currently must " "match the operand element shape. Got %s vs %s", @@ -1900,7 +1928,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( InferWindowOutputShape(operand_shape, window, operand_shape.element_type(), /*allow_negative_padding=*/false)); - if (!ShapeUtil::Compatible(source_shape, window_result_shape)) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape, + window_result_shape)) { return InvalidArgument( "source shape does not match the shape of window-reduced operand: " "source(%s), window-reduced operand(%s)", @@ -1914,21 +1943,28 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& arg, tensorflow::gtl::ArraySlice starts, tensorflow::gtl::ArraySlice limits, tensorflow::gtl::ArraySlice strides) { + auto error = [&](const string& message) { + return InvalidArgument( + "%s in slice operation; argument shape: %s; starts: {%s}; limits: " + "{%s}; strides: {%s}", + message.c_str(), ShapeUtil::HumanString(arg).c_str(), + Join(starts, ",").c_str(), Join(limits, ",").c_str(), + Join(strides, ",").c_str()); + }; TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s starts={%s} limits={%s}", - ShapeUtil::HumanString(arg).c_str(), - tensorflow::str_util::Join(starts, ", ").c_str(), - tensorflow::str_util::Join(limits, ", ").c_str()); + ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), + Join(limits, ", ").c_str()); if (starts.size() != limits.size()) { - return InvalidArgument("slice start and limit sizes differ: %zu vs %zu", - starts.size(), limits.size()); + return error(Printf("slice start and limit sizes differ: %zu vs %zu", + starts.size(), limits.size())); } if (starts.size() != strides.size()) { - return InvalidArgument("slice start and strides sizes differ: %zu vs %zu", - starts.size(), strides.size()); + return error(Printf("slice start and strides sizes differ: %zu vs %zu", + starts.size(), strides.size())); } if (starts.size() != ShapeUtil::Rank(arg)) { @@ -1947,20 +1983,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( start_index); } if (limit_index > arg.dimensions(dimension)) { - return InvalidArgument( - "limit index (%lld) must be less than or equal to dimension " - "size (%lld)", - limit_index, arg.dimensions(dimension)); + return error( + Printf("limit index (%lld) must be less than or equal to dimension " + "size (%lld)", + limit_index, arg.dimensions(dimension))); } VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension, start_index); VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, limit_index); if (start_index > limit_index) { - return InvalidArgument( - "limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index); + return error( + Printf("limit index (%lld) must be greater or equal to " + "start index (%lld) in slice with positive stride", + limit_index, start_index)); } if (stride <= 0) { return InvalidArgument("stride (%lld) must be positive", stride); @@ -1983,7 +2019,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", ShapeUtil::HumanString(operand_shape).c_str(), ShapeUtil::HumanString(start_indices_shape).c_str(), - tensorflow::str_util::Join(slice_sizes, ", ").c_str()); + Join(slice_sizes, ", ").c_str()); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( @@ -2076,7 +2112,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } - if (operand_shape.element_type() != update_shape.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, + update_shape)) { return InvalidArgument( "dynamic update slice update element type does not match argument. " "operand.element_type: %s vs update.element_type: %s", @@ -2280,8 +2317,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Reshape dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", - tensorflow::str_util::Join(dimensions, ",").c_str(), - ShapeUtil::HumanString(operand).c_str()); + Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str()); } return inferred_shape; @@ -2313,24 +2349,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); - if (!ShapeUtil::SameElementType(min, operand) || - !ShapeUtil::SameElementType(max, operand)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || + !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("clamp op with different operand types: %s, %s, %s", ShapeUtil::HumanString(min).c_str(), ShapeUtil::HumanString(operand).c_str(), ShapeUtil::HumanString(max).c_str()); } - if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) && - (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) { + if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || + ShapeUtil::IsScalar(min)) && + (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) || + ShapeUtil::IsScalar(max)))) { return operand; } if (ShapeUtil::IsScalar(operand)) { - if (ShapeUtil::Compatible(min, max)) { - return min; + if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) { + return ShapeUtil::ChangeElementType(min, operand.element_type()); } else if (ShapeUtil::IsScalar(min)) { - return max; + return ShapeUtil::ChangeElementType(max, operand.element_type()); } else if (ShapeUtil::IsScalar(max)) { - return min; + return ShapeUtil::ChangeElementType(min, operand.element_type()); } } return Unimplemented( @@ -2343,7 +2381,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { - if (!ShapeUtil::Compatible(on_true, on_false)) { + bool compatible; + if (ShapeUtil::IsTuple(on_true)) { + // Select only defines the top-level buffer, so if it's a tuple, the two + // input must match exactly. + compatible = ShapeUtil::Compatible(on_true, on_false); + } else { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false); + } + if (!compatible) { return InvalidArgument( "operands to select must be the same shape; got %s and %s", ShapeUtil::HumanString(on_true).c_str(), @@ -2358,7 +2404,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. - return on_true; + return ShapeUtil::ChangeElementType( + on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); } else { return Unimplemented( "select operation with non-scalar predicate with dimensionality " @@ -2373,8 +2420,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { string computation_signature = ShapeUtil::HumanString(to_apply); - string argument_shapes = tensorflow::str_util::Join( - arg_shapes, ", ", [](string* out, const Shape* shape) { + string argument_shapes = + Join(arg_shapes, ", ", [](string* out, const Shape* shape) { tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape)); }); return InvalidArgument( @@ -2401,4 +2448,197 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return to_apply.result(); } +static Status ValidateGatherDimensionNumbers( + const Shape& input_shape, + tensorflow::gtl::ArraySlice gather_indices_shape, + const GatherDimensionNumbers& dim_numbers) { + if (!c_is_sorted(dim_numbers.output_window_dims())) { + return InvalidArgument( + "Output window dimensions in gather op must be ascending; got: %s", + Join(dim_numbers.output_window_dims(), ", ").c_str()); + } + + if (c_adjacent_find(dim_numbers.output_window_dims()) != + dim_numbers.output_window_dims().end()) { + return InvalidArgument( + "Output window dimensions in gather op must not repeat; got: %s", + Join(dim_numbers.output_window_dims(), ", ").c_str()); + } + + const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); + const int64 output_shape_rank = + output_window_dim_count + gather_indices_shape.size(); + + for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { + int64 window_index = dim_numbers.output_window_dims(i); + if (window_index < 0 || window_index >= output_shape_rank) { + return InvalidArgument( + "Window index %d in gather op is out of bounds; got %lld, but should " + "have been in" + "[0,%lld)", + i, window_index, output_shape_rank); + } + } + + if (dim_numbers.gather_dims_to_operand_dims_size() != + gather_indices_shape.back()) { + return InvalidArgument( + "There must be exactly as many elements in gather_dims_to_operand_dims " + "as there are elements in the last dimension of %%gather_indices; got: " + "%d, expected %lld", + dim_numbers.gather_dims_to_operand_dims_size(), + gather_indices_shape.back()); + } + + for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { + int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i); + if (gather_dim_to_input_dim < 0 || + gather_dim_to_input_dim >= input_shape.dimensions_size()) { + return InvalidArgument( + "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " + "got: %d->%lld", + input_shape.dimensions_size(), i, gather_dim_to_input_dim); + } + } + + std::vector sorted_gather_dims_to_operand_dims( + dim_numbers.gather_dims_to_operand_dims().begin(), + dim_numbers.gather_dims_to_operand_dims().end()); + + c_sort(sorted_gather_dims_to_operand_dims); + + if (c_adjacent_find(sorted_gather_dims_to_operand_dims) != + sorted_gather_dims_to_operand_dims.end()) { + return InvalidArgument( + "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " + "got: %s", + Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); + } + + for (int64 elided_dim : dim_numbers.elided_window_dims()) { + if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { + return InvalidArgument( + "Invalid elided_window_dims set in gather op; valid range is [0, " + "%d), got: %lld", + input_shape.dimensions_size(), elided_dim); + } + } + + if (!c_is_sorted(dim_numbers.elided_window_dims())) { + return InvalidArgument( + "elided_window_dims in gather op must be sorted; got: %s", + Join(dim_numbers.elided_window_dims(), ", ").c_str()); + } + + if (c_adjacent_find(dim_numbers.elided_window_dims()) != + dim_numbers.elided_window_dims().end()) { + return InvalidArgument( + "Repeated dimensions not allowed in elided_window_dims in gather op; " + "got: %s", + Join(dim_numbers.elided_window_dims(), ", ").c_str()); + } + + return Status::OK(); +} + +/*static*/ StatusOr ShapeInference::InferGatherShape( + const Shape& input_shape, const Shape& gather_indices_shape, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + gather_indices_shape, "gather indices operand of gather op")); + + if (gather_indices_shape.dimensions_size() < 1) { + return InvalidArgument( + "Gather indices parameter must at least of rank 1; got %s", + ShapeUtil::HumanString(gather_indices_shape).c_str()); + } + + if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + return InvalidArgument( + "Gather indices parameter must be an integral tensor; got %s", + ShapeUtil::HumanString(gather_indices_shape).c_str()); + } + + std::vector expanded_gather_indices_shape; + // We implicitly reshape gather indices of shape P[N] to P[N,1]. + expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); + c_copy(gather_indices_shape.dimensions(), + std::back_inserter(expanded_gather_indices_shape)); + if (expanded_gather_indices_shape.size() == 1) { + expanded_gather_indices_shape.push_back(1); + } + + TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( + input_shape, expanded_gather_indices_shape, gather_dim_numbers)); + + if (window_bounds.size() != input_shape.dimensions_size()) { + return InvalidArgument( + "Gather op must have one window bound for every input dimension; got: " + "len(window_bounds)=%lu, input_shape.rank=%d", + window_bounds.size(), input_shape.dimensions_size()); + } + + if (window_bounds.size() != + gather_dim_numbers.output_window_dims_size() + + gather_dim_numbers.elided_window_dims_size()) { + return InvalidArgument( + "All components of the window index in a gather op must either be a " + "output window index or explicitly elided; got len(window_bounds)=%lu, " + "output_window_bounds=%s, elided_window_bounds=%s", + window_bounds.size(), + Join(gather_dim_numbers.output_window_dims(), ",").c_str(), + Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); + } + + for (int i = 0; i < window_bounds.size(); i++) { + int64 window_bound = window_bounds[i]; + int64 corresponding_input_bound = input_shape.dimensions(i); + if (window_bound < 0 || window_bound > corresponding_input_bound) { + return InvalidArgument( + "Window bound at index %d in gather op is out of range, must be " + "within " + "[0, %lld), got %lld", + i, corresponding_input_bound + 1, window_bound); + } + } + + for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) { + if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { + return InvalidArgument( + "Gather op can only elide window indices with bound 1, but bound is " + "%lld for index %lld at position %d", + window_bounds[gather_dim_numbers.elided_window_dims(i)], + gather_dim_numbers.elided_window_dims(i), i); + } + } + + int64 result_rank = gather_dim_numbers.output_window_dims_size() + + (expanded_gather_indices_shape.size() - 1); + int64 window_dims_seen = 0; + int64 gather_dims_seen = 0; + std::vector output_dim_bounds; + output_dim_bounds.reserve(result_rank); + for (int64 i = 0; i < result_rank; i++) { + int64 current_bound; + bool is_window_index = + c_binary_search(gather_dim_numbers.output_window_dims(), i); + if (is_window_index) { + while (c_binary_search(gather_dim_numbers.elided_window_dims(), + window_dims_seen)) { + window_dims_seen++; + } + current_bound = window_bounds[window_dims_seen++]; + } else { + current_bound = expanded_gather_indices_shape[gather_dims_seen++]; + } + + output_dim_bounds.push_back(current_bound); + } + + return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index b39151ebbc19f5d0b702a80da5069f58c8dfb07d..0d3045213db2230da3e18ffcb1a9923250560b64 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -37,6 +37,11 @@ namespace xla { // the expected result type for computations that are built up via the API -- // the shape that results from an operation is inferred. Some methods have // overloads for inferring shape at the HLO level. +// +// TODO(b/73352135): Shape inference does not issue very good error messages, in +// part because HloInstruction::ToString() is not available since shape +// inference runs before the HloInstruction object is created. We need a +// solution for this. class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the @@ -248,6 +253,14 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers); + // Helper that infers the shape of the tensor produced by a gather operation + // with the given input shape, gather indices shape and gather dimension + // numbers. + static StatusOr InferGatherShape( + const Shape& input_shape, const Shape& gather_indices_shape, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 99d87f3b550ae72befe254f23fad080dd210aaf4..7eb120843fd841d841048eeaefd895fde96d133c 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -18,15 +18,16 @@ limitations under the License. #include #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { +using ::tensorflow::gtl::ArraySlice; using ::testing::ContainsRegex; using ::testing::HasSubstr; @@ -1512,5 +1513,356 @@ TEST_F(ShapeInferenceTest, Conditional) { "must have the same shape")); } +TEST_F(ShapeInferenceTest, BadSlice) { + auto arg = ShapeUtil::MakeShape(F32, {4}); + StatusOr statusor = + ShapeInference::InferSliceShape(arg, {0}, {5}, {1}); + ASSERT_FALSE(statusor.ok()); + + LOG(INFO) << statusor.status(); + + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("less than or equal to dimension size")) + << statusor.status(); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape")) + << statusor.status(); +} + +class GatherShapeInferenceTest : public ShapeInferenceTest { + protected: + const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32}); + const Shape s64_4d_tensor_10_9_8_7_1_ = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}); + const Shape s64_4d_tensor_10_9_8_7_5_ = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + const Shape f32_5d_tensor_50_49_48_47_46_ = + ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); +}; + +TEST_F(GatherShapeInferenceTest, TensorFlowGather) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}), + /*window_bounds=*/{64, 1})); + EXPECT_TRUE( + ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{1}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}), + /*window_bounds=*/{1, 48})); + EXPECT_TRUE( + ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}), + /*window_bounds=*/{1, 48})); + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26})); + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { + StatusOr statusor = ShapeInference::InferGatherShape( + tuple_shape_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected non-tuple argument for input")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { + StatusOr statusor = ShapeInference::InferGatherShape( + s64_vector_32_, tuple_shape_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected non-tuple argument for gather indices")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) { + StatusOr statusor = ShapeInference::InferGatherShape( + s64_vector_32_, s32_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather indices parameter must at least of rank 1")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { + StatusOr statusor = ShapeInference::InferGatherShape( + s64_vector_32_, vector_32_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather indices parameter must be an integral tensor")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_NonAscendingWindowIndices) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 8, 7}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Output window dimensions in gather op must be ascending")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedWindowIndices) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 7}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Output window dimensions in gather op must not repeat")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowIndexOutOfBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 99, 100, 101}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window index 2 in gather op is out of bounds")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingElidedWindowDims) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{4}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("All components of the window index in a gather op must either " + "be a output window index or explicitly elided")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{0, 1, 2, 3, 19}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid elided_window_dims set in gather op; valid " + "range is [0, 5), got: 19")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{0, 1, 2, 3, 3}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions not allowed in elided_window_dims in gather op")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "There must be exactly as many elements in " + "gather_dims_to_operand_dims " + "as there are elements in the last dimension of %gather_indices")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is " + "[0, 5), got: 4->7")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions are not allowed in gather_dims_to_operand_dims")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{2, 1}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{1, 1, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("elided_window_dims in gather op must be sorted")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7}, + /*elided_window_dims=*/{2}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 1, 300, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window bound at index 3 in gather op is out of range, " + "must be within [0, 48), got 300")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Gather op must have one window bound for every input dimension")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 26, 20}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather op can only elide window indices with bound 1, " + "but bound is 29 for index 1 at position 0")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index c679d401c3691b14a43ce77cbe953cd4c64a9e92..6e9986165f7eaf71a964b42b734a5ae5db5e45d7 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -41,7 +41,32 @@ ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, on_device_shape_(on_device_shape), platform_(platform), device_ordinal_(device_ordinal), - buffers_(on_device_shape) {} + buffers_(&on_device_shape_) {} + +ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) + : on_host_shape_(std::move(s.on_host_shape_)), + on_device_shape_(std::move(s.on_device_shape_)), + platform_(s.platform_), + device_ordinal_(s.device_ordinal_), + buffers_(std::move(s.buffers_)) { + // s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_ + // into buffers_, we also need to update this pointer so that buffers_ doesn't + // point into s. + buffers_.replace_shape_ptr(&on_device_shape_); +} + +ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { + on_host_shape_ = std::move(s.on_host_shape_); + on_device_shape_ = std::move(s.on_device_shape_); + platform_ = s.platform_; + device_ordinal_ = s.device_ordinal_; + buffers_ = std::move(s.buffers_); + // buffers_ has a pointer to its on_device_shape_. When we move s.buffers_ + // into buffers_, we also need to update this pointer so that buffers_ doesn't + // point into s. + buffers_.replace_shape_ptr(&on_device_shape_); + return *this; +} void ShapedBuffer::clear() { for (auto& pair : buffers_) { @@ -99,6 +124,10 @@ ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape, device_ordinal), allocator_(allocator) {} +ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, + DeviceMemoryAllocator* allocator) + : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {} + ScopedShapedBuffer::~ScopedShapedBuffer() { // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what @@ -116,12 +145,8 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { } std::unique_ptr ScopedShapedBuffer::release() { - auto shaped_buffer = MakeUnique( - on_host_shape(), on_device_shape(), platform(), device_ordinal()); - - shaped_buffer->buffers() = buffers(); - clear(); - + auto shaped_buffer = MakeUnique(std::move(*this)); + buffers_ = ShapeTree(); return shaped_buffer; } diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index d397e47d2ca734458c7dc99baa5c81b16d0fd72b..b816df8385ef65b0b69ede1d6e65a1991b4bd7c6 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -87,18 +87,24 @@ class ShapedBuffer { string ToString() const; + ShapedBuffer(ShapedBuffer&& s); + ShapedBuffer& operator=(ShapedBuffer&&); + protected: + ShapedBuffer(const ShapedBuffer&) = delete; + ShapedBuffer& operator=(const ShapedBuffer&) = delete; + // The shape of the data when represented on the host. - const Shape on_host_shape_; + Shape on_host_shape_; // The shape of the data on the device. - const Shape on_device_shape_; + Shape on_device_shape_; // The platform the memory is allocated on. const perftools::gputools::Platform* platform_; // The device the memory is allocated on. - const int device_ordinal_; + int device_ordinal_; // The tree of device buffers. Its shape is on_device_shape(). ShapeTree buffers_; @@ -121,14 +127,20 @@ class ScopedShapedBuffer : public ShapedBuffer { ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, DeviceMemoryAllocator* allocator, int device_ordinal); + // Create a ScopedShapedBuffer by taking over the memory from the incoming + // ShapedBuffer. + ScopedShapedBuffer(ShapedBuffer shaped_buffer, + DeviceMemoryAllocator* allocator); + // Return the allocator used to allocate the device memory held in this // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } - // Release all device memory owned by this ScopedShapedBuffer and return the - // device memory pointers in the form of a ShapedBuffer. Device memory - // pointers in this ScopedShapedBuffer object are set to null. This method is - // analogous to std::unique_ptr::release(). + // Release all device memory owned by this ScopedShapedBuffer and + // return the device memory pointers in the form of a + // ShapedBuffer. The returned ShapedBuffer takes over the memory + // from the ScopedShapedBuffer. The resulting ScopedShapedBuffer can + // only be destroyed. std::unique_ptr release(); // All buffers in the shape are deallocated on destruction. diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cbaac7b3760717bcacb57adc8782a5755c0aa6d --- /dev/null +++ b/tensorflow/compiler/xla/service/source_map_util.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/source_map_util.h" + +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace source_map_util { +namespace { + +Status InvalidParameterArgumentV(const OpMetadata& op_metadata, + const char* format, va_list args) { + string message; + tensorflow::strings::Appendv(&message, format, args); + if (!op_metadata.source_file().empty()) { + tensorflow::strings::Appendf(&message, " (%s:%d)", + op_metadata.source_file().c_str(), + op_metadata.source_line()); + } + return InvalidArgument("%s", message.c_str()); +} + +} // namespace + +Status InvalidParameterArgument(const OpMetadata& op_metadata, + const char* format, ...) { + va_list args; + va_start(args, format); + Status result = InvalidParameterArgumentV(op_metadata, format, args); + va_end(args); + return result; +} + +Status InvalidParameterArgument(Executable* executable, int parameter_number, + const char* format, ...) { + va_list args; + va_start(args, format); + if (executable != nullptr && executable->has_module()) { + const HloModule& module = executable->module(); + const HloComputation& computation = *module.entry_computation(); + HloInstruction* param = computation.parameter_instruction(parameter_number); + const OpMetadata& metadata = param->metadata(); + Status result = InvalidParameterArgumentV(metadata, format, args); + va_end(args); + return result; + } + Status result = InvalidArgumentV(format, args); + va_end(args); + return result; +} + +} // namespace source_map_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h new file mode 100644 index 0000000000000000000000000000000000000000..a776d745f4e56ca4f3d2480740259832bbc85011 --- /dev/null +++ b/tensorflow/compiler/xla/service/source_map_util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ + +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace source_map_util { + +// Creates an INVALID_ARUGMENT status with the given format string. +// +// Also, attempts to extract the OpMetadata for parameter_number on executable +// and append it to the status message for source mapping to user code. +// +// executable may be nullptr, but parameter_number should not be out of bounds +// or a CHECK-failure may occur. +Status InvalidParameterArgument(Executable* executable, int parameter_number, + const char* format, ...) + TF_PRINTF_ATTRIBUTE(3, 4); + +// As above, but takes the parameter metadata directly instead of extracting it +// from the executable. +Status InvalidParameterArgument(const OpMetadata& op_metadata, + const char* format, ...) + TF_PRINTF_ATTRIBUTE(2, 3); + +} // namespace source_map_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 2ea6507900e712200ce43e9b63577a4967381fdf..4a55e4095aa92cbdcd1bcb585dc851b2c5e9a32c 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -315,6 +315,36 @@ StatusOr UserComputation::AddConstantInstruction( return handle; } +StatusOr UserComputation::AddGatherInstruction( + const GatherRequest& gather_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, + LookUpRequest(gather_request.input())); + TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, + LookUpRequest(gather_request.gather_indices())); + + TF_ASSIGN_OR_RETURN( + Shape shape, + ShapeInference::InferGatherShape( + input_request->output_shape(), gather_indices_request->output_shape(), + gather_request.dimension_numbers(), + AsInt64Slice(gather_request.window_bounds()))); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_gather_request() = gather_request; + + VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << gather_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddGetTupleElementInstruction( const GetTupleElementRequest& get_tuple_element_request) { tensorflow::mutex_lock lock(mutex_); @@ -1185,7 +1215,7 @@ StatusOr UserComputation::AddInfeedInstruction( return handle; } -Status UserComputation::AddOutfeedInstruction( +StatusOr UserComputation::AddOutfeedInstruction( const OutfeedRequest& outfeed_request) { tensorflow::mutex_lock lock(mutex_); @@ -1197,8 +1227,6 @@ Status UserComputation::AddOutfeedInstruction( // Verify that operand is valid. TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); - // No handle is returned, but a handle must be assigned to this instruction - // for computation versioning. ComputationDataHandle handle = CreateComputationDataHandle(); OperationRequest& request = (*session_computation_.mutable_requests())[handle.handle()]; @@ -1209,7 +1237,7 @@ Status UserComputation::AddOutfeedInstruction( VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() << "), data handle " << handle.handle() << ": " << outfeed_request.ShortDebugString(); - return Status::OK(); + return handle; } StatusOr UserComputation::AddCallInstruction( @@ -1278,6 +1306,28 @@ StatusOr UserComputation::AddCustomCallInstruction( return handle; } +StatusOr UserComputation::AddHostComputeInstruction( + const HostComputeRequest& host_compute_request) { + tensorflow::mutex_lock lock(mutex_); + + for (const ComputationDataHandle& handle : host_compute_request.operands()) { + TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); + } + + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = host_compute_request.shape(); + *request.mutable_request()->mutable_host_compute_request() = + host_compute_request; + + VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << host_compute_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddDotInstruction( const DotRequest& dot_request) { tensorflow::mutex_lock lock(mutex_); @@ -1715,6 +1765,11 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kHostComputeRequest: { + *is_functional = false; + break; + } + case OpRequest::kCallRequest: { const CallRequest& call_request = request.request().call_request(); for (const ComputationDataHandle& handle : call_request.operands()) { @@ -1993,12 +2048,25 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kGatherRequest: { + PureFunctionalVisitor(session_computation, + request.request().gather_request().input(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + request.request().gather_request().gather_indices(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; default: LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); } + if (!*is_functional) { + VLOG(1) << "Non-functional: " << request.request().DebugString(); + } visited->insert(handle.handle()); } @@ -2642,6 +2710,15 @@ static void ForEachOperand( break; } + case OpRequest::kHostComputeRequest: { + const HostComputeRequest& hc_request = + request.request().host_compute_request(); + for (const ComputationDataHandle& operand : hc_request.operands()) { + apply(operand); + } + break; + } + case OpRequest::kDotRequest: { const DotRequest& dot_request = request.request().dot_request(); apply(dot_request.rhs()); @@ -2683,6 +2760,13 @@ static void ForEachOperand( break; } + case OpRequest::kGatherRequest: { + const GatherRequest& gather_request = request.request().gather_request(); + apply(gather_request.input()); + apply(gather_request.gather_indices()); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; @@ -3298,6 +3382,22 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kHostComputeRequest: { + const HostComputeRequest& host_compute_request = + request.request().host_compute_request(); + std::vector operands; + for (const ComputationDataHandle& operand : + host_compute_request.operands()) { + operands.push_back(lookup_instruction(operand)); + } + auto output_shape = host_compute_request.shape(); + auto channel_name = host_compute_request.channel_name(); + auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); + hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( + output_shape, operands, channel_name, cost_estimate_ns)); + break; + } + case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); @@ -3400,6 +3500,20 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kGatherRequest: { + const GatherRequest& gather_request = request.request().gather_request(); + HloInstruction* input_operand = + lookup_instruction(gather_request.input()); + HloInstruction* gather_indices_operand = + lookup_instruction(gather_request.gather_indices()); + std::vector window_bounds; + c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); + hlo_instruction = add_instruction(HloInstruction::CreateGather( + request.output_shape(), input_operand, gather_indices_operand, + gather_request.dimension_numbers(), window_bounds)); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 4f92e58877a1d06728fdd250744ca2ce7b57d9ad..fd5a2ace9bacf66727dc91b6d96305424771a99b 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -146,7 +146,12 @@ class UserComputation { const InfeedRequest& infeed_request); // Enqueues an outfeed instruction onto this user computation. - Status AddOutfeedInstruction(const OutfeedRequest& outfeed_request); + StatusOr AddOutfeedInstruction( + const OutfeedRequest& outfeed_request); + + // Enqueues a host compute instruction onto this user computation. + StatusOr AddHostComputeInstruction( + const HostComputeRequest& host_compute_request); // Enqueues a call instruction onto this user computation. StatusOr AddCallInstruction( @@ -237,6 +242,10 @@ class UserComputation { StatusOr AddRecvInstruction( const RecvRequest& recv_request); + // Enqueues a Gather instruction onto this user computation. + StatusOr AddGatherInstruction( + const GatherRequest& gather_request); + // Returns the user-provided name of this user computation, which is provided // via the XLA computation-building API. const string& name() const { return name_; } diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index ca02115863e6906ef709ba63259024877e0dcef4..2fa163953f638c0038e9f6bb11ce2a3742e0558c 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -67,7 +67,8 @@ TEST_F(UserComputationTest, SimpleComputation) { *outfeed_request.mutable_operand() = constant_handle; *outfeed_request.mutable_shape() = kVectorShape; outfeed_request.set_outfeed_config("abc"); - TF_ASSERT_OK(computation.AddOutfeedInstruction(outfeed_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, + computation.AddOutfeedInstruction(outfeed_request)); auto hlo_resolver = [](const VersionedComputationHandle& handle) { return nullptr; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 87a7f86f4ec9844de3e350d7774093dd6248dd83..981de9b2200a9ae8938db21299580f510834d2f0 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -564,9 +564,11 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // // This is not a fundamental limitation. The control operands can be moved // onto the new HLOs after simplification, and any side-effecting ops inside - // the loop aren't removed, just cloned and added back to the loop. - // Nevertheless our infrastructure sees loop simplification as removal of - // these nodes and currently doesn't allow it. + // the loop aren't removed, just cloned and added back to the loop. But + // moving an op out of the loop also removes implicit control dependencies + // between the op and the ops outside the loop, so we'd have to add those back + // for things like infeed/outfeed. It gets complicated. So for now we just + // avoid it. if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Not attempting to remove while loop it is not removable: " << while_op->ToShortString(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index d752619bd65751779c24f061e44e206d66b01465..280f02e88675381bd75108bfae0dd22c462ba718 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -143,6 +143,18 @@ class ShapeTree { // Return the shape represented with this ShapeTree. const Shape& shape() const { return *shape_; } + // Replaces *only* the underlying shape of this ShapeTree. The caller must own + // the Shape object and hence shape_storage_ is not updated. + // + // Only safe to use this if the ShapeTree was constructed with 'explicit + // ShapeTree(const Shape* shape)' or is moved from one such ShapeTree. The + // caller must ensure that the input shape is consistent with the underlying + // tree. + void replace_shape_ptr(const Shape* shape) { + CHECK(shape_storage_.get() == nullptr); + shape_ = shape; + } + // Returns true if the node at the given index is a leaf node (an array // shape). bool IsLeaf(const ShapeIndex& index) const { diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index cba73322fa924785fbc73a4e931b5f27227d89b9..604e0173e789348923316174873f58058eaf2815 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -475,8 +475,6 @@ StatusOr StringToPrimitiveType(const string& name) { if (LayoutUtil::HasLayout(shape)) { tensorflow::strings::StrAppend(&result, LayoutUtil::HumanString(shape.layout())); - } else { - tensorflow::strings::StrAppend(&result, "{no layout}"); } } return result; @@ -632,6 +630,19 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return SameDimensions(lhs, rhs); } +/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, + const Shape& rhs) { + if (lhs.element_type() == TUPLE) { + return rhs.element_type() == TUPLE && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringFpPrecision); + } + if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { + return CompatibleIgnoringElementType(lhs, rhs); + } + return false; +} + /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, int64 dimension_number) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 453d4ec04726a4dd3851b8becb439bb7506e4ca9..19b1aa93bd373ebd5f502d0dca56c9b31ab4fd7f 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -62,6 +63,9 @@ class ShapeIndex { void push_back(int64 value) { indices_.push_back(value); } void pop_back() { indices_.pop_back(); } + // push_front is O(n^2), but shapes don't usually have a ton of dimensions. + void push_front(int64 value) { indices_.insert(indices_.begin(), value); } + std::vector::const_iterator begin() const { return indices_.begin(); } std::vector::const_iterator end() const { return indices_.end(); } std::vector::iterator begin() { return indices_.begin(); } @@ -211,6 +215,31 @@ class ShapeUtil { return lhs.element_type() == rhs.element_type(); } + // As SameElementType, but allows floating point types to have different + // precisions. + static bool SameElementTypeIgnoringFpPrecision(const Shape& a, + const Shape& b) { + if (ElementIsFloating(a) && ElementIsFloating(b)) { + return true; + } + return ShapeUtil::SameElementType(a, b); + } + + // Returns the higher-precision element type if a and b are both floating + // point types; otherwise, checks that that they have the same element type + // and returns it. + static PrimitiveType HigherPrecisionElementType(const Shape& a, + const Shape& b) { + if (SameElementType(a, b)) { + return a.element_type(); + } + CHECK(SameElementTypeIgnoringFpPrecision(a, b)); + return primitive_util::BitWidth(a.element_type()) < + primitive_util::BitWidth(b.element_type()) + ? b.element_type() + : a.element_type(); + } + // Returns true if the rank, dimension sizes, and element type are // identical. Layout is ignored. Tuple elements are compared recursively for // compatibility. @@ -221,6 +250,10 @@ class ShapeUtil { // compatibility. static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); + // As Compatible, but allow one of lhs and rhs to be BF16 while the other + // being F32. Tuple elements are compared recursively for compatibility. + static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 81ba7afb95265398e830e26122cd0056a32daee3..4db97d45b20b86dc60531845c6e28a223203ff7f 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -170,6 +170,18 @@ TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) { EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2)); } +TEST(ShapeUtilTest, CompatibleIgnoringFpPrecision) { + Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2}); + Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); + ASSERT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); +} + +TEST(ShapeUtilTest, IncompatibleIgnoringFpPrecision) { + Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2}); + Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2}); + ASSERT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); +} + TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2}); @@ -184,6 +196,14 @@ TEST(ShapeUtilTest, CompatibleTuples) { EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2)); } +TEST(ShapeUtilTest, CompatibleTuplesIgnoringFpPrecision) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(BF16, {3, 2}), ShapeUtil::MakeShape(F32, {4, 5})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})}); + EXPECT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2)); +} + TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); @@ -193,6 +213,14 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } +TEST(ShapeUtilTest, IncompatibleTuplesIgnoringFpPrecision) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(BF16, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})}); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2)); +} + TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 3afd52b6b2573aaecb125ad6e5bd05b41a1fbc68..8339d08ef4d7455f9739b80074ab0405a404e8e8 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -351,6 +351,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:platform_util", @@ -574,9 +575,31 @@ xla_test( ], ) +xla_test( + name = "exhaustive_f32_elementwise_op_test", + srcs = ["exhaustive_f32_elementwise_op_test.cc"], + backends = [ + "cpu", + "gpu", + ], + shard_count = 48, + tags = [ + "enormous", + "manual", + "notap", + ], + deps = [ + ":client_library_test_base", + ":literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + xla_test( name = "reduce_precision_test", srcs = ["reduce_precision_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -599,6 +622,9 @@ xla_test( xla_test( name = "dot_operation_test", srcs = ["dot_operation_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -621,6 +647,9 @@ xla_test( xla_test( name = "dot_operation_runtime_test", srcs = ["dot_operation_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -848,7 +877,8 @@ xla_test( name = "half_test", srcs = ["half_test.cc"], backends = [ - "cpu", + # TODO(b/72509305): Flaky (fails with SEGV) as of 2018-01-25 + # "cpu", "gpu", ], deps = [ @@ -1034,7 +1064,10 @@ xla_test( name = "select_and_scatter_test", timeout = "long", srcs = ["select_and_scatter_test.cc"], - tags = ["enable_for_xla_interpreter"], + tags = [ + "enable_for_xla_interpreter", + "optonly", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -1560,6 +1593,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1586,6 +1620,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 56fc21d019bb823f8f4631420a15fd607ef46a9a..7e9005001db34d403ea923eb9c152d114bf32803 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -1879,20 +1879,73 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { auto min_scalar = builder.ConstantR0(0.0f); auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); - auto arg_scalar = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto max_scalar = builder.ConstantR0(3.0f); auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. auto clamp = builder.Add( builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_scalar, max_vector), - builder.Clamp(min_scalar, arg_scalar, max_vector))); + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); - ComputeAndCompareR1(&builder, {8.0f, 4.5f, 2.0f, 6.5f, 15.0f}, {}, + ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { + ComputationBuilder builder(client_, TestName()); + auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0, -5}); + auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4, 10}); + auto max_vector = builder.ConstantR1({3, 0, 25, 5, 123, -1}); + auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + + ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { + ComputationBuilder builder(client_, TestName()); + auto min_scalar = builder.ConstantR0(0); + auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0}); + auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4}); + auto max_scalar = builder.ConstantR0(3); + auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + // Perform clamp with broadcasted scalar and vector. + auto clamp = builder.Add( + builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); + + ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { + ComputationBuilder builder(client_, TestName()); + auto min_vector = builder.ConstantR1({1, 2, 1, 2, 0, ~0u - 4}); + auto arg_vector = builder.ConstantR1({2, 10, 5, 1, 4, 10}); + auto max_vector = builder.ConstantR1({3, 5, 25, 5, 123, ~0u}); + auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + + ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { + ComputationBuilder builder(client_, TestName()); + auto min_scalar = builder.ConstantR0(0); + auto min_vector = builder.ConstantR1({1, 0, 1, 2, 0}); + auto arg_vector = builder.ConstantR1({2, 10, 0, 1, 4}); + auto max_scalar = builder.ConstantR0(3); + auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + // Perform clamp with broadcasted scalar and vector. + auto clamp = builder.Add( + builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); + + ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { ComputationBuilder builder(client_, TestName()); @@ -1995,47 +2048,117 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that // the input tensor is large enough to exercise the vectorized tanh - // implementation. - ComputationBuilder builder(client_, TestName()); - auto input_literal = Literal::CreateR2( - {{1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80}, - {-0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25}, - {0.41, 0.65, -1.08, 0.32, -1.45, -0.77, -1.09, 0.91}, - {-1.03, -0.30, -1.11, -1.17, 1.50, -0.85, 0.04, 1.02}, - {0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81}, - {0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26}, - {-1.29, 1.35, 0.08, -1.24, -0.92, 0.49, 1.17, -0.45}, - {-1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05}}); - auto input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + // implementation on XLA CPU. + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateR1( + {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, + -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, + -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85, + 0.04, 1.02, 0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81, + 0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26, -1.29, 1.35, + 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, + -0.79, 1.41, 1.21, 1.05}); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + client_->TransferToServer(*input_literal)); auto input = builder.Parameter(0, input_literal->shape(), "input"); builder.Tanh(input); - ComputeAndCompareR2( + ComputeAndCompareR1( &builder, - {{0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684, - -0.71985596, -0.45764771, 0.66664988}, - {-0.58278900, 0.16050975, -0.06770509, 0.36843640, -0.38476998, - 0.04018109, 0.87562293, 0.84788644}, - {0.38603750, 0.57294142, -0.79140943, 0.31032649, -0.89590985, - -0.64770776, -0.79625875, 0.72234446}, - {-0.77389336, -0.28871772, -0.80428445, -0.82541436, 0.90456349, - -0.68856895, 0.03877772, 0.76877952}, - {0.32561871, -0.54546672, 0.39072621, 0.07273290, -0.01924866, - 0.88924897, -0.55283129, 0.67183107}, - {0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743, - -0.41581789, 0.72739530, 0.85025692}, - {-0.85931867, 0.87357593, 0.07782833, -0.84597743, -0.72748238, - 0.45396307, 0.82449573, -0.42462519}, - {-0.86363792, -0.89368379, -0.12621804, -0.86445558, -0.65565848, - 0.88789743, 0.83566397, 0.78287679}}, + {0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684, + -0.71985596, -0.45764771, 0.66664988, -0.58278900, 0.16050975, + -0.06770509, 0.36843640, -0.38476998, 0.04018109, 0.87562293, + 0.84788644, 0.38603750, 0.57294142, -0.79140943, 0.31032649, + -0.89590985, -0.64770776, -0.79625875, 0.72234446, -0.77389336, + -0.28871772, -0.80428445, -0.82541436, 0.90456349, -0.68856895, + 0.03877772, 0.76877952, 0.32561871, -0.54546672, 0.39072621, + 0.07273290, -0.01924866, 0.88924897, -0.55283129, 0.67183107, + 0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743, + -0.41581789, 0.72739530, 0.85025692, -0.85931867, 0.87357593, + 0.07782833, -0.84597743, -0.72748238, 0.45396307, 0.82449573, + -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558, + -0.65565848, 0.88789743, 0.83566397, 0.78287679}, {input_data.get()}, // The error spec is unusually high here to account for the fact that we // use a rational interpolant to approximate tanh. ErrorSpec(0.004, 0.004)); } +XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { + // The input tensor is large enough to exercise the vectorized exp + // implementation on XLA CPU. + ComputationBuilder builder(client_, TestName()); + + // Just to help make sense of the scales here -- exp(89) saturates float32 and + // exp(-10) is smaller than our error spec. + std::unique_ptr input_literal = Literal::CreateR1( + {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, + -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, + -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, + -16.3, -15.2, -14.1, -13.0, -11.9, -10.8, -9.7, -8.6, -7.5, + -6.4, -5.3, -4.2, -3.1, -2.0, -0.9, 0.2, 1.3, 2.4, + 3.5, 4.6, 5.7, 6.8, 7.9, 9.0, 10.1, 11.2, 12.3, + 13.4, 14.5, 15.6, 16.7, 17.8, 18.9, 20.0, 21.1, 22.2, + 23.3, 24.4, 25.5, 26.6, 27.7, 28.8, 29.9, 31.0, 32.1, + 68.4, 69.5, 70.6, 71.7, 72.8, 73.9, 75.0, 76.1, 77.2, + 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, + 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + + auto input = builder.Parameter(0, input_literal->shape(), "input"); + builder.Exp(input); + + std::vector expected_result; + int64 input_size = input_literal->shape().dimensions(0); + expected_result.reserve(input_size); + for (int64 i = 0; i < input_size; i++) { + expected_result.push_back(std::exp(input_literal->Get({i}))); + } + + ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { + // The input tensor is large enough to exercise the vectorized exp + // implementation on XLA CPU. + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr input_literal = Literal::CreateR1( + {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, + -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, + 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, + 1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07, + 1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09, + 1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12, + 1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15, + 1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17, + 2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20, + 1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22, + 1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25, + 1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28, + 1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30, + 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, + 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + + auto input = builder.Parameter(0, input_literal->shape(), "input"); + builder.Log(input); + + std::vector expected_result; + int64 input_size = input_literal->shape().dimensions(0); + expected_result.reserve(input_size); + for (int64 i = 0; i < input_size; i++) { + expected_result.push_back(std::log(input_literal->Get({i}))); + } + + ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, + error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // a ------ (add) --------- (add) // / / diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 627a9c3e7d9f6eb8d360228362ea5adf12c6c798..3f6fd7c65d3360a622dbf754833009fb20410535 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -62,6 +62,10 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { auto ax = builder.Mul(alpha, x); auto axpy = builder.Add(ax, y); + TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); + + EXPECT_EQ("() -> f32[10]", ShapeUtil::HumanString(shape)); + std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 659660d91e519b428d28ced8591d05b4e4d45f53..f594cc10ac6496f710d03f0b0b134e6dd3b6d38f 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -104,7 +104,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), - ContainsRegex("expects parameter 0")); + ContainsRegex( + "Argument does not match shape of computation parameter 0")); // Shape mismatch in parameter 1 (rank) status = client_->Execute(computation, {f32_data.get(), f32_data.get()}, @@ -112,7 +113,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), - ContainsRegex("expects parameter 1")); + ContainsRegex( + "Argument does not match shape of computation parameter 1")); // Shape mismatch in parameter 1 (element type) status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}, @@ -120,7 +122,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), - ContainsRegex("expects parameter 1")); + ContainsRegex( + "Argument does not match shape of computation parameter 1")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index e472408dcf7ed5fec74e886fd0092ce47ee2e7eb..022641394f113ef28e7c53058385d77572822213 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -21,9 +21,11 @@ StatusOr> CodegenTestBase::CompileToExecutable( std::unique_ptr hlo_module) { TF_ASSIGN_OR_RETURN(hlo_module, backend().compiler()->RunHloPasses( std::move(hlo_module), - backend().default_stream_executor())); + backend().default_stream_executor(), + /*device_allocator=*/nullptr)); return backend().compiler()->RunBackend(std::move(hlo_module), - backend().default_stream_executor()); + backend().default_stream_executor(), + /*device_allocator=*/nullptr); } StatusOr> diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index 0016b6cc614469d7ac9b40b740d163a7a4f32abf..bc821674820fb128823786d7149037fc59b22ab6 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -355,8 +355,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { } // Test true and false computations that return a tuple of arrays. -// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend. -XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) { +XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { ComputationBuilder builder(client_, TestName()); auto pred = builder.ConstantR0(true); auto operands = builder.Tuple({builder.ConstantR1({12.2f, 15.8f}), @@ -373,9 +372,7 @@ XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) { // Test true and false computations that return a tuple of a predicate, a // scalar, and an array. -// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend. -XLA_TEST_F(ConditionalOpTest, - DISABLED_ON_GPU(ReturnTupleofPredicateScalarArray)) { +XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { ComputationBuilder true_builder(client_, TestName() + ".true"); { true_builder.Parameter(0, empty_tuple_, "tuple"); @@ -413,8 +410,7 @@ XLA_TEST_F(ConditionalOpTest, } // Test true and false computations that return a nested tuple. -// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend. -XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnNestedTuple)) { +XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { ComputationBuilder true_builder(client_, TestName() + ".true"); { true_builder.Parameter(0, empty_tuple_, "tuple"); @@ -532,6 +528,32 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } +XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { + ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); + auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); + auto pred_cond = inner_builder.GetTupleElement(param0, 0); + auto true_operand = inner_builder.GetTupleElement(param0, 1); + auto false_operand = inner_builder.GetTupleElement(param0, 2); + inner_builder.Conditional(pred_cond, true_operand, + CreateR0CeilComputation(), false_operand, + CreateR0FloorComputation()); + } + auto inner_builder_result = inner_builder.Build(); + EXPECT_IS_OK(inner_builder_result.status()); + + ComputationBuilder builder(client_, TestName()); + auto pred2 = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(1.1f); + auto operand2 = builder.ConstantR0(12.2f); + auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); + builder.Call(inner_builder_result.ConsumeValueOrDie(), {tuple_operand}); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + // Test a mismatch in the shape of the true operand and true computation. XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 0ceb9aff378ae8aa8098be9360310b1d78d31ab2..1385b437fc47fe5289c401581fab8b5278872382 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -53,157 +53,200 @@ class ConvolutionTest : public ClientLibraryTestBase { #endif }; -XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) { - const int kInputActivationSizeY = 3; - const int kInputActivationSizeX = 3; - const int kInputActivationSizeZ = 256; - const int kKernelSizeX = 2; - const int kKernelSizeY = 2; - const int kOutputActivationSizeZ = 256; - const int kMiniBatchSize = 4; - auto alhs = - MakeUnique>(kMiniBatchSize, kInputActivationSizeZ, - kInputActivationSizeY, kInputActivationSizeX); - alhs->FillWithMultiples(1.0f); - ASSERT_EQ(3, alhs->width()); - ASSERT_EQ(3, alhs->height()); - - auto arhs = - MakeUnique>(kOutputActivationSizeZ, kInputActivationSizeZ, - kKernelSizeY, kKernelSizeX); - Array2D rhs_raster({ - {1.0f, 0.0f}, // row 0 - {0.0f, 0.0f}, // row 1 - }); - arhs->FillWithYX(rhs_raster); - ASSERT_EQ(2, arhs->width()); - ASSERT_EQ(2, arhs->height()); +// TODO(b/72509305): Enable half data type tests for CPU +#if (XLA_TEST_BACKEND_GPU) +using TestTypes = ::testing::Types; +#else +using TestTypes = ::testing::Types; +#endif - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR4FromArray4D(*alhs); - auto rhs = builder.ConstantR4FromArray4D(*arhs); - auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); +template +Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice dimensions); - ComputeAndCompare(&builder, conv, {}, error_spec_); +template <> +Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice dimensions) { + return ShapeUtil::MakeShape(F32, dimensions); } -TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); +template <> +Shape MakeShapeWrapper( + tensorflow::gtl::ArraySlice dimensions) { + return ShapeUtil::MakeShape(F16, dimensions); +} - Array4D input_data(1, 1, 1, 2); - input_data.FillWithYX(Array2D({ - {1, 2}, - })); - Array4D filter_data(1, 1, 1, 2); - filter_data.FillWithYX(Array2D({ - {5, 6}, - })); +template +class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { + public: + void RunTest() { + const int kInputActivationSizeY = 3; + const int kInputActivationSizeX = 3; + const int kInputActivationSizeZ = 256; + const int kKernelSizeX = 2; + const int kKernelSizeY = 2; + const int kOutputActivationSizeZ = 256; + const int kMiniBatchSize = 4; + auto alhs = + MakeUnique>(kMiniBatchSize, kInputActivationSizeZ, + kInputActivationSizeY, kInputActivationSizeX); + alhs->FillWithMultiples(static_cast(1.0f)); + ASSERT_EQ(3, alhs->width()); + ASSERT_EQ(3, alhs->height()); + + auto arhs = + MakeUnique>(kOutputActivationSizeZ, kInputActivationSizeZ, + kKernelSizeY, kKernelSizeX); + Array2D rhs_raster({ + {1.0f, 0.0f}, // row 0 + {0.0f, 0.0f}, // row 1 + }); + arhs->FillWithYX(rhs_raster); + ASSERT_EQ(2, arhs->width()); + ASSERT_EQ(2, arhs->height()); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR4FromArray4D(*alhs); + auto rhs = builder.ConstantR4FromArray4D(*arhs); + auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + + ComputeAndCompare(&builder, conv, {}, error_spec_); + } +}; - ComputeAndCompare(&builder, conv, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, - error_spec_); +TYPED_TEST_CASE(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, TestTypes); +XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, Types) { + this->RunTest(); } +template +class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + Shape input_shape = MakeShapeWrapper({1, 1, 1, 2}); + Shape filter_shape = MakeShapeWrapper({1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 1, 2); + input_data.FillWithYX(Array2D({ + {1.0f, 2.0f}, + })); + Array4D filter_data(1, 1, 1, 2); + filter_data.FillWithYX(Array2D({ + {5.0f, 6.0f}, + })); + + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve_1x1x1x2_1x1x1x2_Valid, TestTypes); +TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid, Types) { this->RunTest(); } + // Tests valid padding for 2D convolution in raster space. -TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); +template +class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); + Shape filter_shape = MakeShapeWrapper({1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 4, 4); + input_data.FillWithYX(Array2D({ + {1.0f, 2.0f, 3.0f, 4.0f}, + {5.0f, 6.0f, 7.0f, 8.0f}, + {9.0f, 10.0f, 11.0f, 12.0f}, + {13.0f, 14.0f, 15.0f, 16.0f}, + })); + Array4D filter_data(1, 1, 2, 2); + filter_data.FillWithYX(Array2D({ + {5.0f, 6.0f}, + {7.0f, 8.0f}, + })); + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, + error_spec_); + } +}; - Array4D input_data(1, 1, 4, 4); - // clang-format off - input_data.FillWithYX(Array2D({ - {1, 2, 3, 4 }, - {5, 6, 7, 8 }, - {9, 10, 11, 12}, - {13, 14, 15, 16}, - })); - // clang-format on - Array4D filter_data(1, 1, 2, 2); - // clang-format off - filter_data.FillWithYX(Array2D({ - {5, 6}, - {7, 8}, - })); - // clang-format on - ComputeAndCompare(&builder, conv, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, - error_spec_); -} +TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Valid, TestTypes); +TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid, Types) { this->RunTest(); } // Tests same padding for 2D convolution in raster space. -TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); - - Array4D input_data(1, 1, 4, 4); - // clang-format off - input_data.FillWithYX(Array2D({ - {1, 2, 3, 4 }, - {5, 6, 7, 8 }, - {9, 10, 11, 12}, - {13, 14, 15, 16}, - })); - // clang-format on - Array4D filter_data(1, 1, 2, 2); - // clang-format off - filter_data.FillWithYX(Array2D({ - {5, 6}, - {7, 8}, - })); - // clang-format on - ComputeAndCompare(&builder, conv, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, - error_spec_); -} +template +class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); + Shape filter_shape = MakeShapeWrapper({1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + + Array4D input_data(1, 1, 4, 4); + input_data.FillWithYX(Array2D({ + {1.0f, 2.0f, 3.0f, 4.0f}, + {5.0f, 6.0f, 7.0f, 8.0f}, + {9.0f, 10.0f, 11.0f, 12.0f}, + {13.0f, 14.0f, 15.0f, 16.0f}, + })); + Array4D filter_data(1, 1, 2, 2); + filter_data.FillWithYX(Array2D({ + {5.0f, 6.0f}, + {7.0f, 8.0f}, + })); + + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Same, TestTypes); +TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same, Types) { this->RunTest(); } // Tests same padding for 2D convolution in raster space with an odd sized // kernel. -TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); - - Array4D input_data(1, 1, 4, 4); - // clang-format off - input_data.FillWithYX(Array2D({ - {1, 2, 3, 4 }, - {5, 6, 7, 8 }, - {9, 10, 11, 12}, - {13, 14, 15, 16}, - })); - // clang-format on - Array4D filter_data(1, 1, 3, 3); - // clang-format off - filter_data.FillWithYX(Array2D({ - { 5, 6, 7}, - { 8, 9, 10}, - {11, 12, 13}, - })); - // clang-format on - ComputeAndCompare(&builder, conv, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, - error_spec_); -} +template +class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); + Shape filter_shape = MakeShapeWrapper({1, 1, 3, 3}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + + Array4D input_data(1, 1, 4, 4); + input_data.FillWithYX(Array2D({{1.0f, 2.0f, 3.0f, 4.0f}, + {5.0f, 6.0f, 7.0f, 8.0f}, + {9.0f, 10.0f, 11.0f, 12.0f}, + {13.0f, 14.0f, 15.0f, 16.0f}})); + Array4D filter_data(1, 1, 3, 3); + filter_data.FillWithYX(Array2D( + {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); + // clang-format on + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes); +TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { ComputationBuilder builder(client_, TestName()); @@ -232,36 +275,44 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { error_spec_); } -XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithRHSDilation) { - ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( - input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, - /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2}, - /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); +template +class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + { + Shape input_shape = MakeShapeWrapper({1, 2, 5}); + Shape filter_shape = MakeShapeWrapper({1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + // Convolution dimensions are bf0_oi0->bo0. + builder.ConvGeneralDilated( + input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, + /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2}, + /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); + } + + Array3D input( + {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}}); + Array3D filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}}); + + Array3D expected({{{570.0f, 670.0f, 770.0f}}}); + + auto input_literal = + client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR3(&builder, expected, + {input_literal.get(), filter_literal.get()}, + error_spec_); } +}; // namespace - Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); - Array3D filter({{{10, 20}, {30, 40}}}); - - Array3D expected({{{570, 670, 770}}}); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); -} +TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes); +TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { ComputationBuilder builder(client_, TestName()); @@ -325,36 +376,45 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { error_spec_); } -XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithPadding) { - ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( - input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}}, - /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1}, - /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); +template +class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + { + Shape input_shape = MakeShapeWrapper({1, 2, 5}); + Shape filter_shape = MakeShapeWrapper({1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + // Convolution dimensions are bf0_oi0->bo0. + builder.ConvGeneralDilated( + input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}}, + /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1}, + /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); + } + + Array3D input( + {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}}); + Array3D filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}}); + + Array3D expected( + {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); + + auto input_literal = + client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR3(&builder, expected, + {input_literal.get(), filter_literal.get()}, + error_spec_); } +}; - Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); - Array3D filter({{{10, 20}, {30, 40}}}); - - Array3D expected({{{0, 260, 510, 610, 710, 810, 350, 0}}}); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); -} +TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes); +TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { ComputationBuilder builder(client_, TestName()); @@ -389,12 +449,12 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { } std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); - std::iota(input_elems.begin(), input_elems.end(), 1.0f); + iota(input_elems.begin(), input_elems.end(), 1.0f); auto input_r1 = Literal::CreateR1(input_elems); auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); - std::iota(filter_elems.begin(), filter_elems.end(), 1.0f); + iota(filter_elems.begin(), filter_elems.end(), 1.0f); auto filter_r1 = Literal::CreateR1(filter_elems); auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); @@ -412,56 +472,73 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { error_spec_); } -XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { - ComputationBuilder builder(client_, TestName()); - std::vector input_dims = {1, 3, 3, 5}; - std::vector filter_dims = {3, 3, 5, 3}; - Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); - Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims); - { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - - // Tensorflow dimension numbers for 2D convolution. - ConvolutionDimensionNumbers dnums; - dnums.set_input_batch_dimension(0); - dnums.set_output_batch_dimension(0); - dnums.add_input_spatial_dimensions(1); - dnums.add_output_spatial_dimensions(1); - dnums.add_input_spatial_dimensions(2); - dnums.add_output_spatial_dimensions(2); - dnums.set_input_feature_dimension(3); - dnums.set_output_feature_dimension(3); - dnums.add_kernel_spatial_dimensions(0); - dnums.add_kernel_spatial_dimensions(1); - dnums.set_kernel_input_feature_dimension(2); - dnums.set_kernel_output_feature_dimension(3); +// std::iota doesn't work when init_value has a type Eigen::half in some build +// servers. The error message is missing the operator ++. +template +void iota_int_init_value(std::vector& values, int init_value) { + std::for_each(values.begin(), values.end(), + [&](T& value) { value = static_cast(init_value++); }); +} - builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, - dnums); +template +class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + std::vector input_dims = {1, 3, 3, 5}; + std::vector filter_dims = {3, 3, 5, 3}; + Shape input_shape = MakeShapeWrapper(input_dims); + Shape filter_shape = MakeShapeWrapper(filter_dims); + { + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, + dnums); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = Literal::CreateR1(input_elems); + auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = Literal::CreateR1( + {static_cast(92115), static_cast(93150), static_cast(94185)}); + auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, *expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); } +}; - std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); - std::iota(input_elems.begin(), input_elems.end(), 1.0f); - auto input_r1 = Literal::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); - - std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); - std::iota(filter_elems.begin(), filter_elems.end(), 1.0f); - auto filter_r1 = Literal::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - - auto expected_r1 = Literal::CreateR1({92115, 93150, 94185}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); - - auto input_literal = client_->TransferToServer(*input_r4).ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); - - ComputeAndCompareLiteral(&builder, *expected_r4, - {input_literal.get(), filter_literal.get()}, - error_spec_); -} +TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x5_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x5_Valid, Types) { this->RunTest(); } // Test fixture to run convolution tests with and without convolution // canonicalization enabled. @@ -519,67 +596,117 @@ struct Convolve1DTestParam { int64 num_windows; }; -class Convolve1D1WindowTest +class Convolve1D1WindowTestBase : public ConvolutionTest, - public ::testing::WithParamInterface {}; - -XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) { - ComputationBuilder builder(client_, TestName()); - int64 input_feature = GetParam().input_feature; - int64 output_feature = GetParam().output_feature; - int64 batch = GetParam().batch; - int64 num_windows = GetParam().num_windows; - int64 window_size = GetParam().window_size; - std::vector input_dims = {batch, window_size + num_windows - 1, - input_feature}; - std::vector filter_dims = {window_size, input_feature, output_feature}; - Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); - Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims); - { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - - // Tensorflow dimension numbers for 1D convolution. - ConvolutionDimensionNumbers dnums; - dnums.set_input_batch_dimension(0); - dnums.set_output_batch_dimension(0); - dnums.add_input_spatial_dimensions(1); - dnums.add_output_spatial_dimensions(1); - dnums.set_input_feature_dimension(2); - dnums.set_output_feature_dimension(2); - dnums.add_kernel_spatial_dimensions(0); - dnums.set_kernel_input_feature_dimension(1); - dnums.set_kernel_output_feature_dimension(2); - - builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, - dnums); + public ::testing::WithParamInterface { + protected: + template + void TestImpl() { + ComputationBuilder builder(client_, TestName()); + int64 input_feature = GetParam().input_feature; + int64 output_feature = GetParam().output_feature; + int64 batch = GetParam().batch; + int64 num_windows = GetParam().num_windows; + int64 window_size = GetParam().window_size; + std::vector input_dims = {batch, window_size + num_windows - 1, + input_feature}; + std::vector filter_dims = {window_size, input_feature, + output_feature}; + Shape input_shape = MakeShapeWrapper(input_dims); + Shape filter_shape = MakeShapeWrapper(filter_dims); + { + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 1D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.set_input_feature_dimension(2); + dnums.set_output_feature_dimension(2); + dnums.add_kernel_spatial_dimensions(0); + dnums.set_kernel_input_feature_dimension(1); + dnums.set_kernel_output_feature_dimension(2); + + builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, + dnums); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1.0f)); + auto input_r1 = Literal::CreateR1(input_elems); + auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(1.0f)); + + auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector expect_elems(batch * output_feature * num_windows, + static_cast(window_size * input_feature)); + auto expected_r1 = Literal::CreateR1(expect_elems); + auto expected_r3 = + expected_r1->Reshape({batch, num_windows, output_feature}) + .ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(*input_r3).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, *expected_r3, + {input_literal.get(), filter_literal.get()}, + error_spec_); } +}; - std::vector input_elems(ShapeUtil::ElementsIn(input_shape), 1.0); - auto input_r1 = Literal::CreateR1(input_elems); - auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); +class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {}; - std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), 1.0); +XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl(); } - auto filter_r1 = Literal::CreateR1(filter_elems); - auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); +INSTANTIATE_TEST_CASE_P( + Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat, + ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2}, + Convolve1DTestParam{160, 1, 1, 5, 1}, + Convolve1DTestParam{24, 1, 1, 20, 1}, + Convolve1DTestParam{30, 1, 1, 20, 1}, + Convolve1DTestParam{23, 1, 1, 20, 20}, + Convolve1DTestParam{25, 1, 1, 20, 1}, + Convolve1DTestParam{24, 1, 1, 10, 5}, + Convolve1DTestParam{160, 1, 1, 10, 1}, + Convolve1DTestParam{255, 1, 1, 3, 1}, + Convolve1DTestParam{130, 1, 1, 1, 3}, + Convolve1DTestParam{64, 1, 1, 1, 1}, + Convolve1DTestParam{128, 1, 1, 1, 1}, + Convolve1DTestParam{139, 1, 1, 128, 1}, + Convolve1DTestParam{1, 10, 10, 1, 10}, + Convolve1DTestParam{1, 10, 130, 1, 2}, + Convolve1DTestParam{1, 10, 130, 1, 1}, + Convolve1DTestParam{1, 64, 64, 1, 10}, + Convolve1DTestParam{1, 65, 65, 1, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{128, 128, 128, 128, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{2, 2, 2, 2, 1}, + Convolve1DTestParam{161, 1, 1, 10, 1}, + Convolve1DTestParam{900, 1, 1, 10, 1}, + Convolve1DTestParam{640, 3, 3, 128, 1}) - std::vector expect_elems(batch * output_feature * num_windows, - window_size * input_feature); - auto expected_r1 = Literal::CreateR1(expect_elems); - auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature}) - .ConsumeValueOrDie(); +); - auto input_literal = client_->TransferToServer(*input_r3).ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r3, - {input_literal.get(), filter_literal.get()}, - error_spec_); +#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU) +class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {}; + +// TODO(b/72509305): Enable half data type tests for CPU. +XLA_TEST_P(Convolve1D1WindowTestHalf, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(Convolve1D1Window))) { + TestImpl(); } INSTANTIATE_TEST_CASE_P( - Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTest, + Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf, ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2}, Convolve1DTestParam{160, 1, 1, 5, 1}, Convolve1DTestParam{24, 1, 1, 20, 1}, @@ -592,7 +719,11 @@ INSTANTIATE_TEST_CASE_P( Convolve1DTestParam{130, 1, 1, 1, 3}, Convolve1DTestParam{64, 1, 1, 1, 1}, Convolve1DTestParam{128, 1, 1, 1, 1}, + // TODO(b/72566306): the following three tests fail on CPU + // backend due to result miscompare. Convolve1DTestParam{139, 1, 1, 128, 1}, + Convolve1DTestParam{640, 3, 3, 128, 1}, + Convolve1DTestParam{900, 1, 1, 10, 1}, Convolve1DTestParam{1, 10, 10, 1, 10}, Convolve1DTestParam{1, 10, 130, 1, 2}, Convolve1DTestParam{1, 10, 130, 1, 1}, @@ -602,11 +733,10 @@ INSTANTIATE_TEST_CASE_P( Convolve1DTestParam{128, 128, 128, 128, 1}, Convolve1DTestParam{1, 128, 128, 1, 1}, Convolve1DTestParam{2, 2, 2, 2, 1}, - Convolve1DTestParam{161, 1, 1, 10, 1}, - Convolve1DTestParam{900, 1, 1, 10, 1}, - Convolve1DTestParam{640, 3, 3, 128, 1}) + Convolve1DTestParam{161, 1, 1, 10, 1}) ); +#endif TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index cc683701e6305510d202721fe645310f1009081c..6b0c04c2c083bbfce267dd92d24ef15c06186d26 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -520,9 +520,39 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) { ComputeAndCompareR4( &builder, - /*expected=*/{{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}}, - {{{42900, 79200}, {429, 792}}, - {{250800, 299200}, {2508, 2992}}}}, + /*expected=*/ + {{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}}, + {{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}}, + {x_data.get(), y_data.get()}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, GeneralMatMul) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + + auto out = builder.DotGeneral(x, y, dnums); + + auto x_data = client_ + ->TransferToServer(*Literal::CreateR3( + {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}})) + .ConsumeValueOrDie(); + + auto y_data = client_ + ->TransferToServer(*Literal::CreateR3( + {{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}})) + .ConsumeValueOrDie(); + + ComputeAndCompareR3( + &builder, + /*expected=*/ + {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}, {x_data.get(), y_data.get()}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6fe7737de7af349dca2931b52d62dbc03b14e0b3 --- /dev/null +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/casts.h" + +namespace xla { +namespace { +class ExhaustiveF32ElementwiseOpTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface> { + protected: + ErrorSpec error_spec_{0.0001, 0.0001, /*relaxed_nans=*/true}; + + template + void ExhaustivelyTestF32Op(EnqueueOpTy enqueue_op, + float (*evaluate_op)(float), + std::pair known_incorrect_range) { + int64 begin, end; + std::tie(begin, end) = GetParam(); + int64 input_size = end - begin; + LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; + + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr input_literal = + Literal::CreateFromDimensions(F32, {input_size}); + for (int64 i = begin; i < end; i++) { + if (i >= known_incorrect_range.first && + i < known_incorrect_range.second) { + // If the operation is known to be buggy on a specific input clamp that + // input to 0 under the assumption that the op is at least correct on 0. + input_literal->Set({i - begin}, 0.0f); + } else { + input_literal->Set({i - begin}, tensorflow::bit_cast(i)); + } + } + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + + auto input = builder.Parameter(0, input_literal->shape(), "input"); + enqueue_op(&builder, input); + + std::vector expected_result; + expected_result.reserve(input_size); + for (int64 i = 0; i < input_size; i++) { + expected_result.push_back(evaluate_op(input_literal->Get({i}))); + } + + ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, + error_spec_); + } +}; + +XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { +#ifdef XLA_TEST_BACKEND_CPU + // TODO(b/73141998): The vectorized Log implementation gives results outside + // our error spec in this range (these numbers are bitwise representations of + // floats expressed as a zero extended int64): + std::pair known_incorrect_range = {1, 8315654}; +#else + std::pair known_incorrect_range = {0, 0}; +#endif + + ExhaustivelyTestF32Op( + [](ComputationBuilder* builder, const ComputationDataHandle& input) { + builder->Log(input); + }, + std::log, known_incorrect_range); +} + +XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { +#ifdef XLA_TEST_BACKEND_CPU + // TODO(b/73142289): The vectorized Exp implementation gives results outside + // our error spec in this range (these numbers are bitwise representations of + // floats expressed as a zero extended int64): + std::pair known_incorrect_range = {1107296256 + 11583654, + 1107296256 + 11629080}; +#else + std::pair known_incorrect_range = {0, 0}; +#endif + + ExhaustivelyTestF32Op( + [](ComputationBuilder* builder, const ComputationDataHandle& input) { + builder->Exp(input); + }, + std::exp, known_incorrect_range); +} + +XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { + ExhaustivelyTestF32Op( + [](ComputationBuilder* builder, const ComputationDataHandle& input) { + builder->Tanh(input); + }, + std::tanh, /*known_incorrect_range=*/{0, 0}); +} + +std::vector> CreateExhaustiveParameters() { + // We break up the 2^32-element space into small'ish chunks to keep peak + // memory usage low. + std::vector> result; + const int64 step = 1 << 25; + for (int64 i = 0; i < (1l << 32); i += step) { + result.push_back({i, i + step}); + } + return result; +} + +INSTANTIATE_TEST_CASE_P(ExhaustiveF32ElementwiseOpTestInstance, + ExhaustiveF32ElementwiseOpTest, + ::testing::ValuesIn(CreateExhaustiveParameters())); +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 7c1a993b478a0e0878e85c0e4192da053e33619f..9f5806c5e16c30cf198027cffab5f78c315cb957 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -230,7 +230,7 @@ template const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = - HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); if (!module_or_status.ok()) { return ::testing::AssertionFailure() << "failed reading hlo module from file"; @@ -258,7 +258,7 @@ template const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = - HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); if (!module_or_status.ok()) { return ::testing::AssertionFailure() << "failed reading hlo module from file"; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index f8205de702fb3534dcd7dbdce6ee0cbfb11d6ee4..5aa71a9261dbd414d1499f15c9b83cd63b634b49 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -355,9 +355,9 @@ class NearComparator { // temporary files on failure. Returns true if literals match. bool ExpectNear(const Literal& expected, const Literal& actual) { VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, expected.ToString()); + XLA_VLOG_LINES(1, TruncateHugeLiteral(expected)); VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, actual.ToString()); + XLA_VLOG_LINES(1, TruncateHugeLiteral(actual)); // If the shapes mismatch, we simply fail the expectation instead of // printing out data, as it's a type error rather than a value error. @@ -376,7 +376,12 @@ class NearComparator { abs_expected_miscompare_sum_ = 0.0; max_rel_err_ = 0.0; max_abs_err_ = 0.0; + first_linear_index_ = -1; + last_linear_index_ = -1; + max_rel_linear_index_ = -1; + max_abs_linear_index_ = -1; miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED)); + miscompares_.PopulateWithValue(false); multi_index_.resize(expected.shape().dimensions_size(), 0); switch (expected.shape().element_type()) { @@ -404,21 +409,33 @@ class NearComparator { if (num_miscompares_ > 0) { if (!VLOG_IS_ON(1)) { LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) - << " " << expected.ToString(); + << " " << TruncateHugeLiteral(expected); LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) - << " " << actual.ToString(); + << " " << TruncateHugeLiteral(actual); + LOG(INFO) << "Dumping literals to temp files..."; + WriteLiteralToTempFile(expected, "expected"); + WriteLiteralToTempFile(actual, "actual"); + WriteLiteralToTempFile(miscompares_, "miscompares"); } EXPECT_TRUE(num_miscompares_ == 0) << "\nmax relative mismatch at index " - << LiteralTestUtil::MultiIndexAsString(max_rel_multi_index_) + << LiteralTestUtil::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex( + actual.shape(), max_rel_linear_index_)) << "\nmaximum relative error " << max_rel_err_ << "\nmax absolute mismatch at index " - << LiteralTestUtil::MultiIndexAsString(max_abs_multi_index_) + << LiteralTestUtil::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex( + actual.shape(), max_abs_linear_index_)) << "\nmaximum absolute error " << max_abs_err_ << "\nfirst mismatch at index " - << LiteralTestUtil::MultiIndexAsString(first_multi_index_) + << LiteralTestUtil::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex( + actual.shape(), first_linear_index_)) << "\nlast mismatch at index " - << LiteralTestUtil::MultiIndexAsString(last_multi_index_) + << LiteralTestUtil::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex( + actual.shape(), last_linear_index_)) << "\ntotal absolute error " << abs_diff_sum_ << "\ntotal absolute error of miscompares " << abs_diff_miscompare_sum_ << "\ntotal relative error " @@ -426,18 +443,18 @@ class NearComparator { << "\ntotal relative error of miscompares " << (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_) << "\nfailure count " << num_miscompares_; - - WriteLiteralToTempFile(expected, "expected"); - WriteLiteralToTempFile(actual, "actual"); - WriteLiteralToTempFile(miscompares_, "miscompares"); } return num_miscompares_ == 0; } private: template - bool NanMismatch(NativeT lhs, NativeT rhs) { - return std::isnan(lhs) != std::isnan(rhs); + bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { + if (relaxed_nans) { + return !std::isnan(expected) && std::isnan(actual); + } else { + return std::isnan(expected) != std::isnan(actual); + } } template @@ -457,57 +474,94 @@ class NearComparator { return true; } - float abs_diff = std::abs(actual - expected); - float rel_err = abs_diff / std::abs(expected); + const float abs_diff = std::abs(actual - expected); + const float rel_err = abs_diff / std::abs(expected); + const bool nan_mismatch = + NanMismatch(expected, actual, error_.relaxed_nans); + const bool mismatch = + (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); + return !mismatch; + } + + // Assumes that expected vs actual fail ExpectValuesNear. + template + void UpdateAndLogMiscompares(const NativeT expected, const NativeT actual, + const Shape& shape, const int64 linear_index) { + const float abs_diff = std::abs(actual - expected); + const float rel_err = abs_diff / std::abs(expected); abs_diff_sum_ += abs_diff; abs_expected_sum_ += std::abs(expected); - if (rel_err > max_rel_err_) { + if (rel_err > max_rel_err_ || std::isnan(rel_err)) { max_rel_err_ = rel_err; - max_rel_multi_index_ = multi_index_; + max_rel_linear_index_ = linear_index; } - if (abs_diff > max_abs_err_) { + if (abs_diff > max_abs_err_ || std::isnan(abs_diff)) { max_abs_err_ = abs_diff; - max_abs_multi_index_ = multi_index_; + max_abs_linear_index_ = linear_index; } - VLOG(10) << tensorflow::strings::Printf( - "index %s abs_diff %f rel_err %f", - LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff, - rel_err); - bool nan_mismatch = NanMismatch(expected, actual); - bool mismatch = - (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); - if (mismatch) { - abs_diff_miscompare_sum_ += abs_diff; - abs_expected_miscompare_sum_ += std::abs(expected); - const int64 kMaxFailures = 2; - if (num_miscompares_ < kMaxFailures) { - ::testing::Message msg; - msg << "mismatch at index " - << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff " - << abs_diff << " rel err " << rel_err << " failure #" - << num_miscompares_; - ExpectNear(expected, actual, msg); - } else if (num_miscompares_ == kMaxFailures) { - LOG(ERROR) - << "reached max 'loud' failure count; silently proceeding..."; - } - if (num_miscompares_ == 0) { - first_multi_index_ = multi_index_; - } - num_miscompares_++; - last_multi_index_ = multi_index_; + if (VLOG_IS_ON(10)) { + VLOG(10) << tensorflow::strings::Printf( + "index %s abs_diff %f rel_err %f", + LiteralTestUtil::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex(shape, + linear_index)) + .c_str(), + abs_diff, rel_err); } - return !mismatch; + abs_diff_miscompare_sum_ += abs_diff; + abs_expected_miscompare_sum_ += std::abs(expected); + const int64 kMaxFailures = 2; + if (num_miscompares_ < kMaxFailures) { + const auto multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index); + ::testing::Message msg; + msg << "mismatch at index " + << LiteralTestUtil::MultiIndexAsString(multi_index) << " abs diff " + << abs_diff << " rel err " << rel_err << " failure #" + << num_miscompares_; + ExpectNear(expected, actual, msg); + } else if (num_miscompares_ == kMaxFailures) { + LOG(ERROR) << "reached max 'loud' failure count; silently proceeding..."; + } + if (num_miscompares_ == 0) { + first_linear_index_ = linear_index; + } + num_miscompares_++; + last_linear_index_ = linear_index; + miscompares_.data()[linear_index] = true; } // Recursive function which compares the two given literals elementwise. template void ExpectLiteralsNear(const Literal& expected, const Literal& actual, int64 dimension) { + // Fast path optimization for the case were layouts match. + if (LayoutUtil::Equal(actual.shape().layout(), expected.shape().layout())) { + tensorflow::gtl::ArraySlice expected_data = + expected.data(); + tensorflow::gtl::ArraySlice actual_data = + actual.data(); + const int64 len = expected_data.size(); + for (int64 i = 0; i < len; ++i) { + const bool near = ExpectValuesNear(expected_data[i], actual_data[i]); + if (!near) { + UpdateAndLogMiscompares(expected_data[i], actual_data[i], + actual.shape(), i); + } + } + return; + } + if (dimension == expected.shape().dimensions_size()) { bool near = ExpectValuesNear(expected.Get(multi_index_), actual.Get(multi_index_)); - miscompares_.Set(multi_index_, !near); + if (!near) { + UpdateAndLogMiscompares( + expected.Get(multi_index_), + actual.Get(multi_index_), actual.shape(), + IndexUtil::MultidimensionalIndexToLinearIndex(actual.shape(), + multi_index_)); + } } else { for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { multi_index_[dimension] = i; @@ -528,6 +582,32 @@ class NearComparator { LOG(ERROR) << "wrote to " << name << " file: " << filename; } + // Gets the total element count. For tuples, this is not the count of tuple + // elements, but the sum of elements of each tuple element. + int64 RecursiveElementCount(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); + int64 total = 0; + for (int64 i = 0; i < tuple_elements; ++i) { + total += + RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); + } + return total; + } else { + return ShapeUtil::ElementsIn(shape); + } + } + + // Calling ToString on a literal with over 100 million elements takes around + // 3 minutes. The utility of printing a literal with >1000 elements is + // questionable, especially when writing the Literal proto to disk is orders + // of magnitude faster. + string TruncateHugeLiteral(const Literal& literal) { + return RecursiveElementCount(literal.shape()) < 1000 + ? literal.ToString() + : "[TRUNCATED, Literal with more than 1000 values]"; + } + ErrorSpec error_; // Number of element miscomparisons encountered so far. @@ -548,16 +628,18 @@ class NearComparator { double abs_expected_miscompare_sum_; float max_rel_err_; float max_abs_err_; - std::vector first_multi_index_; - std::vector last_multi_index_; - std::vector max_rel_multi_index_; - std::vector max_abs_multi_index_; + int64 first_linear_index_; + int64 last_linear_index_; + int64 max_rel_linear_index_; + int64 max_abs_linear_index_; }; template <> -bool NearComparator::NanMismatch(complex64 lhs, complex64 rhs) { - return std::isnan(lhs.real()) != std::isnan(rhs.real()) || - std::isnan(lhs.imag()) != std::isnan(rhs.imag()); +bool NearComparator::NanMismatch(complex64 expected, + complex64 actual, + bool relaxed_nans) { + return NanMismatch(expected.real(), actual.real(), relaxed_nans) || + NanMismatch(expected.imag(), actual.imag(), relaxed_nans); } template <> @@ -584,6 +666,23 @@ bool NearComparator::ExpectValuesNear(half expected, half actual) { static_cast(std::move(actual))); } +template <> +void NearComparator::UpdateAndLogMiscompares( + const bfloat16 expected, const bfloat16 actual, const Shape& shape, + const int64 linear_index) { + UpdateAndLogMiscompares(static_cast(expected), + static_cast(actual), shape, linear_index); +} + +template <> +void NearComparator::UpdateAndLogMiscompares(half expected, half actual, + const Shape& shape, + const int64 linear_index) { + UpdateAndLogMiscompares(static_cast(std::move(expected)), + static_cast(std::move(actual)), shape, + linear_index); +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 9b0724262d51ec7964a918bb8eb8716308662b96..7b757a4bd7e7592583b7596b4305ddb7e6c52d75 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -40,10 +40,16 @@ namespace xla { // Structure describing permissible absolute and relative error bounds. struct ErrorSpec { - explicit ErrorSpec(float aabs, float arel = 0) : abs(aabs), rel(arel) {} + explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) + : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} float abs; // Absolute error bound. float rel; // Relative error bound. + + // If relaxed_nans is true then any result is valid if we are expecting NaNs. + // In effect, this allows the tested operation to produce incorrect results + // for inputs outside its mathematical domain. + bool relaxed_nans; }; // Utility class for making expectations/assertions related to XLA literals. diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index e477784557a3b9340cff644a3695485389d8cc22..3a421f8458268a14dcdd84889bcae4990c095ea4 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -97,5 +97,29 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { } } +TEST(LiteralTestUtilTest, NearComparatorR1) { + auto a = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); +} + +TEST(LiteralTestUtilTest, NearComparatorR1Nan) { + auto a = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + auto b = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); +} + +TEST(LiteralTestUtil, NearComparatorDifferentLengths) { + auto a = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); + EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index b5b95967ff9162301a092f3a57996e0f3f78658f..7e92439c494b677f718a63c71c20828d65bebef4 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -74,7 +74,8 @@ class LLVMCompilerTest : public ::testing::Test { ASSERT_TRUE(compiler ->RunBackend(std::move(hlo_module), - backend_->default_stream_executor()) + backend_->default_stream_executor(), + /*device_allocator=*/nullptr) .ok()); // Test that hooks were called. @@ -98,7 +99,8 @@ class LLVMCompilerTest : public ::testing::Test { executors.push_back({backend_->default_stream_executor()}); executors.push_back({backend_->default_stream_executor()}); - EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors))); + EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors), + /*device_allocator=*/nullptr)); } private: diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 6e6cb7ff1e2ac74dc54f14d8811c9a5d3662bbd2..0a603f4954badd12adf3144320789a5edd0d9c6c 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -176,5 +178,38 @@ XLA_TEST_F(MultiOutputFusionTest, 2DFusionSize129) { RunTest2D(true, 129); } XLA_TEST_F(MultiOutputFusionTest, DiffentTypesNoFusion) { RunTest1D(false, 8); } XLA_TEST_F(MultiOutputFusionTest, DiffentTypesFusion) { RunTest1D(true, 8); } +XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { + const char* testcase = R"( + HloModule m + + fused_computation { + x.param_0 = (((s32[]), f32[]), (f32[], s32[])) parameter(0) + gte.3 = ((s32[]), f32[]) get-tuple-element(x.param_0), index=0 + gte.2 = (s32[]) get-tuple-element(gte.3), index=0 + gte.4 = s32[] get-tuple-element(gte.2), index=0 + copy = s32[] copy(gte.4) + ROOT tuple = (s32[]) tuple(copy) + } + + ENTRY thing.v3 { + x = (((s32[]), f32[]), (f32[], s32[])) parameter(0) + ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation + } + )"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::MakeTupleOwned( + Literal::MakeTupleOwned( + Literal::MakeTupleOwned(Literal::CreateR0(42)), + Literal::CreateR0(1.0)), + Literal::MakeTupleOwned(Literal::CreateR0(3.0), + Literal::CreateR0(4))); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index 3fd83a4c3b104831f03366339fb7b8b5d816a3f7..8cef8dd34dc7b16b1e58ded67d6b6a4ba79f20db 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -33,6 +33,14 @@ limitations under the License. namespace xla { namespace { +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +// Tests both F32 and BF16. +static std::array use_bfloat16_params{false, true}; +#else +// Only tests F32. +static std::array use_bfloat16_params{false}; +#endif + class PadTest : public ClientLibraryTestBase { protected: PadTest() { @@ -61,8 +69,22 @@ class PadTest : public ClientLibraryTestBase { PaddingConfig r4_padding_on_dim0_dim1_; }; +class PadTestFloat : public PadTest, + public ::testing::WithParamInterface { + protected: + PadTestFloat() { set_use_bfloat16(GetParam()); } + + ErrorSpec DefaultErrorSpec() const { + if (use_bfloat16()) { + return ErrorSpec(1e-3, 1e-3); + } else { + return ErrorSpec(1e-5, 1e-5); + } + } +}; + // Tests a Pad() with a zero-element input and output. -XLA_TEST_F(PadTest, Pad1DS0ToS0Array) { +XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { ComputationBuilder b(client_, TestName()); // Set up the padding configuration {low: 0, high: 0, interior: 0}. PaddingConfig padding_config; @@ -71,12 +93,13 @@ XLA_TEST_F(PadTest, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - b.Pad(b.ConstantR1({}), b.ConstantR0(0.1), padding_config); - ComputeAndCompareR1(&b, {}, {}, ErrorSpec(0.0001)); + b.Pad(AddParam(*Literal::CreateR1({}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); + ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } // Tests a Pad() with a zero-element input but a non-zero-element output. -XLA_TEST_F(PadTest, Pad1DS0ToS5Array) { +XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { ComputationBuilder b(client_, TestName()); // Set up the padding configuration {low: 3, high: 0, interior: 1}. PaddingConfig padding_config; @@ -85,12 +108,13 @@ XLA_TEST_F(PadTest, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - b.Pad(b.ConstantR1({}), b.ConstantR0(0.1), padding_config); + b.Pad(AddParam(*Literal::CreateR1({}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, - ErrorSpec(0.0001)); + DefaultErrorSpec()); } -XLA_TEST_F(PadTest, Pad1DS3Array) { +XLA_TEST_P(PadTestFloat, Pad1DS3Array) { ComputationBuilder b(client_, TestName()); // Set up the padding configuration {low: 3, high: 0, interior: 1}. PaddingConfig padding_config; @@ -99,21 +123,21 @@ XLA_TEST_F(PadTest, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - b.Pad(b.ConstantR1({1, 2, 3}), b.ConstantR0(0.1), - padding_config); + b.Pad(AddParam(*Literal::CreateR1({1, 2, 3}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); - ComputeAndCompareR1(&b, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, Pad4D_2x0x3x2_FloatArray) { +XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { ComputationBuilder b(client_, TestName()); - b.Pad(b.ConstantR4FromArray4D(Array4D(2, 0, 3, 2)), - b.ConstantR0(1.5), r4_padding_on_dim0_dim1_); + b.Pad(AddParam(Array4D(2, 0, 3, 2), &b), + AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, - ErrorSpec(0.0001)); + DefaultErrorSpec()); } -TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) { +TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { ComputationBuilder b(client_, TestName()); auto input = MakeUnique>(1, 1, 3, 2); Array2D input_xy({ @@ -123,7 +147,7 @@ TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - b.Pad(b.ConstantR4FromArray4D(*input), b.ConstantR0(1.5), + b.Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); @@ -134,15 +158,15 @@ TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) { (*expected)(1, 0, 1, 1) = 4.0f; (*expected)(1, 0, 2, 0) = 5.0f; (*expected)(1, 0, 2, 1) = 6.0f; - ComputeAndCompareR4(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR4(&b, *expected, {}, DefaultErrorSpec()); } -TEST_F(PadTest, Pad4DFloatArrayWithInteriorPadding) { +TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { ComputationBuilder b(client_, TestName()); const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - b.Pad(b.ConstantR4FromArray4D(input), b.ConstantR0(pad_value), + b.Pad(AddParam(input, &b), AddParam(*Literal::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(8, 5, 1, 1); @@ -156,7 +180,7 @@ TEST_F(PadTest, Pad4DFloatArrayWithInteriorPadding) { ComputeAndCompareR4(&b, *expected, {}, ErrorSpec(0.0001)); } -TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { +TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { ComputationBuilder b(client_, TestName()); PaddingConfig padding_config; @@ -184,7 +208,8 @@ TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { auto input = Literal::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); + b.Pad(AddParam(*input, &b), + AddParam(*Literal::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -197,7 +222,7 @@ TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { ComputeAndCompareR4(&b, expected_array, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { +XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { ComputationBuilder b(client_, TestName()); PaddingConfig padding_config; @@ -229,7 +254,8 @@ XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { auto input = Literal::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); + b.Pad(AddParam(*input, &b), + AddParam(*Literal::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -249,7 +275,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { }); input->FillWithYX(input_xy); - b.Pad(b.ConstantR4FromArray4D(*input), b.ConstantR0(35), + b.Pad(AddParam(*input, &b), b.ConstantR0(35), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); @@ -277,8 +303,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { auto ones = MakeUnique>(2, 3, 3, 2); zeros->Fill(0); ones->Fill(1); - b.Select(padded, b.ConstantR4FromArray4D(*ones), - b.ConstantR4FromArray4D(*zeros)); + b.Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(0); @@ -291,10 +316,12 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { ComputeAndCompareR4(&b, *expected, {}); } -XLA_TEST_F(PadTest, Large2DPad) { +XLA_TEST_P(PadTestFloat, Large2DPad) { ComputationBuilder b(client_, TestName()); - auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {4, 4}), "input"); + auto ones = MakeUnique>(4, 4); + ones->Fill(1.0f); + auto input = AddParam(*ones, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -302,25 +329,22 @@ XLA_TEST_F(PadTest, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - auto padded = b.Pad(input, b.ConstantR0(0.0f), padding_config); - - auto ones = MakeUnique>(4, 4); - ones->Fill(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*ones); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); - ComputeAndCompareR2(&b, *expected, {input_data.get()}); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, AllTypes2DPad) { +XLA_TEST_P(PadTestFloat, AllTypes2DPad) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 35; constexpr int64 in_cols = 35; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(0.0f); + auto input = AddParam(*operand, &b); + PaddingConfig padding_config = MakeNoPaddingConfig(2); padding_config.mutable_dimensions(0)->set_edge_padding_low(7); padding_config.mutable_dimensions(0)->set_edge_padding_high(5); @@ -328,20 +352,14 @@ XLA_TEST_F(PadTest, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - auto padded = b.Pad(input, b.ConstantR0(3.14f), padding_config); - - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(0.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec{0.0001}); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, High2DPad) { +XLA_TEST_P(PadTestFloat, High2DPad) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 129; @@ -349,8 +367,9 @@ XLA_TEST_F(PadTest, High2DPad) { constexpr int64 low_padding = 0; int64 high_padding[2] = {5, 7}; constexpr int64 interior_padding = 0; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(1.0f); + auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low(low_padding); @@ -359,20 +378,15 @@ XLA_TEST_F(PadTest, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - auto padded = b.Pad(input, b.ConstantR0(2.718f), padding_config); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), + padding_config); - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, NegativePadding2D) { +XLA_TEST_P(PadTestFloat, NegativePadding2D) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 129; @@ -380,8 +394,9 @@ XLA_TEST_F(PadTest, NegativePadding2D) { int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {-3, 4}; constexpr int64 interior_padding = 0; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(1.0f); + auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -391,20 +406,15 @@ XLA_TEST_F(PadTest, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - auto padded = b.Pad(input, b.ConstantR0(2.718f), padding_config); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), + padding_config); - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { +XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 8; @@ -412,8 +422,9 @@ XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { int64 low_padding[2] = {4, -1}; int64 high_padding[2] = {-2, -4}; int64 interior_padding[2] = {1, 2}; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(1.0f); + auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -423,44 +434,40 @@ XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - auto padded = b.Pad(input, b.ConstantR0(2.718f), padding_config); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), + padding_config); - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } // Regression test for b/31827337. -XLA_TEST_F(PadTest, ReducePad) { +XLA_TEST_P(PadTestFloat, ReducePad) { ComputationBuilder b(client_, TestName()); - auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "input"); + auto ones = MakeUnique>(2, 2, 2, 2); + ones->Fill(1.0); + auto input = AddParam(*ones, &b); - Computation add_f32 = CreateScalarAddComputation(F32, &b); - auto reduce = b.Reduce(input, b.ConstantR0(0.0), add_f32, {0}); + Computation add = CreateScalarAddComputation(FloatType(), &b); + auto reduce = + b.Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - auto pad = b.Pad(reduce, b.ConstantR0(0.0), padding_config); - - auto ones = MakeUnique>(2, 2, 2, 2); - ones->Fill(1.0); - auto input_literal = Literal::CreateR4FromArray4D(*ones); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + auto padded = b.Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), + padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, {{2.0, 2.0}, {2.0, 2.0}}, {{0.0, 0.0}, {0.0, 0.0}}}); - ComputeAndCompareR3(&b, expected, {input_data.get()}); + ComputeAndCompareR3(&b, expected, {}, DefaultErrorSpec()); } +INSTANTIATE_TEST_CASE_P(PadTestFloatInstantiation, PadTestFloat, + ::testing::ValuesIn(use_bfloat16_params)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 4756ba096896806ece8fe35d18c4eaef041b8830..dc7ce3253cee255a7949326fa5b49fc8917432b8 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -249,7 +249,9 @@ INSTANTIATE_TEST_CASE_P(ReducePrecisionAccuracyTest, // ReducePrecisionInsertion passes. class ReducePrecisionInsertionTest : public ClientLibraryTestBase {}; -XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionBeforeFusion) { +// The interpreter has no fusion pass, so skip this test. +XLA_TEST_F(ReducePrecisionInsertionTest, + DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); @@ -276,7 +278,9 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionBeforeFusion) { ComputeAndCompareR1(&builder, {0.0f}, {a_data.get()}); } -XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedAfterFusion) { +// The interpreter has no fusion pass, so skip this test. +XLA_TEST_F(ReducePrecisionInsertionTest, + DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); @@ -300,7 +304,9 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedAfterFusion) { ComputeAndCompareR1(&builder, {-1.00001f}, {a_data.get()}); } -XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionAddedAfterFusion) { +// The interpreter has no fusion pass, so skip this test. +XLA_TEST_F(ReducePrecisionInsertionTest, + DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); @@ -322,7 +328,9 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionAddedAfterFusion) { ComputeAndCompareR1(&builder, {-1.0f}, {a_data.get()}); } -XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedFusionContains) { +// The interpreter has no fusion pass, so skip this test. +XLA_TEST_F(ReducePrecisionInsertionTest, + DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); @@ -345,7 +353,9 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedFusionContains) { ComputeAndCompareR1(&builder, {-1.00001f}, {a_data.get()}); } -XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionAddedFusionContains) { +// The interpreter has no fusion pass, so skip this test. +XLA_TEST_F(ReducePrecisionInsertionTest, + DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index a766fa2db0e193c52171490981855843ab3ee158..50d7b5074d201d2292cf90224ef4cd37efdbb8d3 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -494,6 +494,26 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { ErrorSpec(0.01, 1e-4)); } +// Test that algebraic simplifier does not incorrectly fold a transpose into a +// reduction operation. +XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50}); + ComputationDataHandle input = builder.Parameter(0, input_shape, "input"); + ComputationDataHandle zero = builder.ConstantR0(0.0); + ComputationDataHandle transpose = + builder.Transpose(input, /*permutation=*/{1, 0, 2}); + ComputationDataHandle reduce = + builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + MakeFakeLiteral(input_shape)); + + ComputeAndCompare(&builder, reduce, {std::move(*input_data)}, + ErrorSpec(0.01, 1e-4)); +} + XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { const int64 rows = 111, cols = 50; diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 73b37e201afa13546179e2ce7a76d3f7967de524..b11b64e40a582150d6adf29e915cd70b4bcb982b 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -272,7 +272,7 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { builder_.ReduceWindow( input, - CreateConstantFromLiteral(*Literal::CreateR0(3.0f), &builder_), + CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_), reduce_fn, /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); @@ -282,7 +282,7 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { }; auto expected = - ReferenceUtil::ReduceWindow4DGeneric(input_array, 3.0f, reduce_func, + ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func, /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); @@ -800,6 +800,14 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { /*pad_high=*/{1, 1, 0, 0}, /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2}, + /*window_bounds=*/{1, 1, 4, 1}, + /*strides=*/{1, 1, 4, 1}, + /*pad_low=*/{0, 0, 1, 0}, + /*pad_high=*/{0, 0, 2, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kMax}, }; INSTANTIATE_TEST_CASE_P( @@ -1016,37 +1024,39 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ::testing::tuple> { protected: R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } -}; - -TEST_P(R2ReduceWindowTest, Add) { - ComputationBuilder b(client_, TestName()); - const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); - const float kInitValue = 0.0f; - Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); - std::unique_ptr input_literal = - Literal::CreateR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + void DoIt() { + ComputationBuilder b(client_, TestName()); + const auto& param = ::testing::get<0>(GetParam()); + CHECK(param.reducer == kAdd); - ComputationDataHandle parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); - auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindow(/*operand=*/parameter, - /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + const float kInitValue = 0.0f; + Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); + std::unique_ptr input_literal = + Literal::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); - auto expected = ReferenceUtil::ReduceWindow2DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); + + auto expected = ReferenceUtil::ReduceWindow2DAdd( + /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/param.padding); + + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); + } +}; - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), - {input_arg.get()}, DefaultErrorSpec()); -} +TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); } INSTANTIATE_TEST_CASE_P( R2ReduceWindowTestInstantiation, R2ReduceWindowTest, @@ -1054,6 +1064,26 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(use_bfloat16_params)), R2ReduceWindowTestDataToString); +class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {}; + +// TODO(b/72234705): Fix the test cases failed on CPU and GPU. +XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) { + DoIt(); +} + +const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = { + {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128}, + /*strides=*/{1, 1}, /*layout=*/{1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, +}; + +INSTANTIATE_TEST_CASE_P( + R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test, + ::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test), + ::testing::ValuesIn(use_bfloat16_params)), + R2ReduceWindowTestDataToString); + struct R1ReduceWindowTestData { int64 base_bounds[1]; int64 window_bounds[1]; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index debf2d2d317fe64ca1ef86cb1f2978e76af1b55d..4da6ee91607941b395b00befc98a10e7c17746ed 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -737,7 +737,61 @@ XLA_TEST_F(ScalarComputationsTest, PowScalar) { ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarHigh) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(-1), // The lower bound. + builder.ConstantR0(5), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 3, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(-1), // The lower bound. + builder.ConstantR0(2), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 2, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(-1), // The lower bound. + builder.ConstantR0(-5), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, -1, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(1), // The lower bound. + builder.ConstantR0(5), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 3, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(1), // The lower bound. + builder.ConstantR0(2), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 2, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(1), // The lower bound. + builder.ConstantR0(0), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 1, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(5.0f), // The operand to be clamped. @@ -746,7 +800,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHigh) { ComputeAndCompareR0(&builder, 3.0, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddle) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(2.5f), // The operand to be clamped. @@ -755,7 +809,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddle) { ComputeAndCompareR0(&builder, 2.5, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarLow) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(-5.0f), // The operand to be clamped. @@ -852,5 +906,12 @@ XLA_TEST_F(ScalarComputationsTest, SqrtF320) { ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); } +XLA_TEST_F(ScalarComputationsTest, RoundScalar) { + ComputationBuilder builder(client_, TestName()); + builder.Round(builder.ConstantR0(1.4f)); + + ComputeAndCompareR0(&builder, 1.0f, {}, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index cc4eaf62f50d1fa622c705fab810fe1e1b0fbf08..e2d406f66d94f8ec76faa5b7d2d2e84dcaf6db57 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -161,4 +161,31 @@ string PrependDisabledIfIndicated(const string& test_case_name, #define XLA_TEST_P(test_case_name, test_name) \ XLA_TEST_P_IMPL_(test_case_name, test_name) + +// This is identical to the TEST_F macro from "gtest", but it potentially +// disables the test based on an external manifest file, DISABLED_MANIFEST. +#define XLA_TYPED_TEST(CaseName, TestName) \ + template \ + class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \ + : public CaseName { \ + private: \ + typedef CaseName TestFixture; \ + typedef gtest_TypeParam_ TypeParam; \ + virtual void TestBody(); \ + }; \ + bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \ + ::testing::internal::TypeParameterizedTest< \ + CaseName, \ + ::testing::internal::TemplateSel, \ + GTEST_TYPE_PARAMS_(CaseName)>:: \ + Register( \ + "", ::testing::internal::CodeLocation(__FILE__, __LINE__), \ + #CaseName, \ + ::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \ + 0); \ + template \ + void GTEST_TEST_CLASS_NAME_(CaseName, \ + TestName)::TestBody() + #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 0e90a323583de7336556c203a4b46fc14b53454d..0bc7df2a65b44a76f877b6513e6bf93b99fbc1a3 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -24,51 +24,127 @@ namespace xla { namespace { template -void PopulateWithRandomFloatingPointData(Literal* literal) { +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - std::minstd_rand0 engine; - // Create uniform numbers between 1 and 1.125 ot avoid creating denormal + // Create uniform numbers between 1 and 1.125 to avoid creating denormal // numbers. std::uniform_real_distribution generator(1.0f, 1.125f); + const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice indices) { - // Generate a random uniforma number from -0.0625 and 0.0625 and bias it - // with a position dependent number with mean 0.037109375. These number + // Generate a random uniform number from -0.0625 and 0.0625 and bias it + // with a position dependent number with mean 0.037109375. These number // should allow for long chains of accumulation without being too close - // to zero or to large to accumulate all numbers accurately. - return (generator(engine) - 1.0625) + - static_cast(Product(indices) % 113 - 47) / - static_cast(256.0f); + // to zero or too large to accumulate all numbers accurately. Only do + // this for large literals where the number of elements is much greater + // than 47 otherwise only negative values are produced. + // + // The value is positionally biased using a product of the indices. Add + // one to each index value to avoid collapsing to zero if any of the + // indices are zero. + int64 index_product = 1; + for (int64 i : indices) { + index_product *= (1 + i); + } + const int64 negative_bias = should_index_bias ? 47 : 0; + FloatT index_bias = + static_cast(index_product % 113 - negative_bias) / + static_cast(256.0f); + return (generator(*engine) - 1.0625) + index_bias; })); } // The standard library does not have a case for bfloat16, unsurprisingly, so we // handle that one specially. template <> -void PopulateWithRandomFloatingPointData(Literal* literal) { +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), BF16); - std::minstd_rand0 engine; std::uniform_real_distribution generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice /*indices*/) { - return static_cast(generator(engine)); + return static_cast(generator(*engine)); })); } template -void PopulateWithRandomIntegralData(Literal* literal) { +void PopulateWithRandomIntegralData(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - std::minstd_rand0 engine; std::uniform_int_distribution generator( std::numeric_limits::lowest(), std::numeric_limits::max()); TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); + return generator(*engine); })); } +// Similar to MakeFakeLiteral but takes a random number generator engine to +// enable reusing the engine across randomly generated literals. +StatusOr> MakeFakeLiteralInternal( + const Shape& shape, std::minstd_rand0* engine) { + if (ShapeUtil::IsTuple(shape)) { + std::vector> elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr element, + MakeFakeLiteralInternal(element_shape, engine)); + elements.push_back(std::move(element)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } + std::unique_ptr literal = Literal::CreateFromShape(shape); + switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData(literal.get(), engine); + break; + case F32: + PopulateWithRandomFloatingPointData(literal.get(), engine); + break; + case F64: + PopulateWithRandomFloatingPointData(literal.get(), engine); + break; + case S8: + PopulateWithRandomIntegralData(literal.get(), engine); + break; + case U8: + PopulateWithRandomIntegralData(literal.get(), engine); + break; + case S16: + PopulateWithRandomIntegralData(literal.get(), engine); + break; + case U16: + PopulateWithRandomIntegralData(literal.get(), engine); + break; + case S32: + PopulateWithRandomIntegralData(literal.get(), engine); + break; + case U32: + PopulateWithRandomIntegralData(literal.get(), engine); + break; + case S64: + PopulateWithRandomIntegralData(literal.get(), engine); + break; + case U64: + PopulateWithRandomIntegralData(literal.get(), engine); + break; + case PRED: { + std::uniform_int_distribution generator(0, 1); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(*engine); + })); + break; + } + default: + return Unimplemented("Unsupported type for fake literal generation: %s", + ShapeUtil::HumanString(shape).c_str()); + } + return std::move(literal); +} + // Matches binary addition computations. bool LooksLikeSum(const HloComputation& computation) { const HloInstruction* const root = computation.root_instruction(); @@ -95,15 +171,15 @@ bool NeedsZeroInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. std::unique_ptr MakeRandomNonwrappingSliceIndex( - const Shape& input_shape, const Shape& slice_shape) { + const Shape& input_shape, const Shape& slice_shape, + std::minstd_rand0* engine) { const int64 rank = ShapeUtil::Rank(input_shape); std::vector start_indices(rank); - std::minstd_rand0 engine; for (int i = 0; i < rank; ++i) { const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - ShapeUtil::GetDimension(slice_shape, i); std::uniform_int_distribution generator(0, upper_bound); - start_indices[i] = generator(engine); + start_indices[i] = generator(*engine); } return Literal::CreateR1(start_indices); } @@ -150,7 +226,7 @@ std::vector FindConstrainedUses( // zero in the case of init_values for reductions). StatusOr> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice constrained_uses, - const HloInstruction& param) { + const HloInstruction& param, std::minstd_rand0* engine) { HloInstruction* needs_index = nullptr; HloInstruction* needs_zero = nullptr; for (HloInstruction* use : constrained_uses) { @@ -185,93 +261,39 @@ StatusOr> CreateLiteralForConstrainedUses( } if (needs_index != nullptr) { return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), - needs_index->shape()); + needs_index->shape(), engine); } else if (needs_zero != nullptr) { return Literal::CreateFromShape(param.shape()); } else { - return MakeFakeLiteral(param.shape()); + return MakeFakeLiteralInternal(param.shape(), engine); } } // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. StatusOr> MakeConstrainedArgument( - const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + const HloDataflowAnalysis& dataflow, const HloInstruction& param, + std::minstd_rand0* engine) { const auto constrained_uses = FindConstrainedUses(dataflow, param); - return CreateLiteralForConstrainedUses(constrained_uses, param); + return CreateLiteralForConstrainedUses(constrained_uses, param, engine); } } // namespace StatusOr> MakeFakeLiteral(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { - std::vector> elements; - for (const Shape& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr element, - MakeFakeLiteral(element_shape)); - elements.push_back(std::move(element)); - } - return Literal::MakeTupleOwned(std::move(elements)); - } - std::unique_ptr literal = Literal::CreateFromShape(shape); - switch (shape.element_type()) { - case BF16: - PopulateWithRandomFloatingPointData(literal.get()); - break; - case F32: - PopulateWithRandomFloatingPointData(literal.get()); - break; - case F64: - PopulateWithRandomFloatingPointData(literal.get()); - break; - case S8: - PopulateWithRandomIntegralData(literal.get()); - break; - case U8: - PopulateWithRandomIntegralData(literal.get()); - break; - case S16: - PopulateWithRandomIntegralData(literal.get()); - break; - case U16: - PopulateWithRandomIntegralData(literal.get()); - break; - case S32: - PopulateWithRandomIntegralData(literal.get()); - break; - case U32: - PopulateWithRandomIntegralData(literal.get()); - break; - case S64: - PopulateWithRandomIntegralData(literal.get()); - break; - case U64: - PopulateWithRandomIntegralData(literal.get()); - break; - case PRED: { - std::uniform_int_distribution generator(0, 1); - std::minstd_rand0 engine; - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - default: - return Unimplemented("Unsupported type for fake literal generation: %s", - ShapeUtil::HumanString(shape).c_str()); - } - return std::move(literal); + std::minstd_rand0 engine; + return MakeFakeLiteralInternal(shape, &engine); } StatusOr>> MakeFakeArguments( HloModule* const module) { - TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); + std::minstd_rand0 engine; std::vector> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN(arguments[i], - MakeConstrainedArgument(*dataflow, *params[i])); + TF_ASSIGN_OR_RETURN( + arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); } return std::move(arguments); } diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index a8bca70d85ddf168bc441231d6f43bead019b10a..2029312f94a14bc81706368b9ecfc2727fd9fe4c 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -194,8 +194,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } -// TODO(b/68395210): GPU does not tolerate ambiguous top-level buffers. -XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) { +XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { ComputationBuilder b(client_, TestName()); ComputationDataHandle v1, v2; diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 1d2f436194a921c8d1b23732e2b4be11b59ac043..9ad2a1985331b80625dd0687ea052300bc99e440 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -32,6 +34,7 @@ limitations under the License. namespace xla { namespace { namespace se = ::perftools::gputools; +namespace gtl = ::tensorflow::gtl; class HloProfileTest : public ClientLibraryTestBase {}; @@ -43,39 +46,74 @@ struct ParsedProfileOutputLine { string trops; string bytes_per_sec; string bytes_per_cycle; - string name; + string opcode; }; -StatusOr ParseProfileOutputLine(const string& line, - bool expect_flops, - bool expect_trops) { +::testing::AssertionResult HasFlops( + const ParsedProfileOutputLine& parsed_line) { + if (RE2::FullMatch(parsed_line.flops, "[0-9.TGMk]+FLOP/s")) { + return ::testing::AssertionSuccess() + << "'flops' field present in " << parsed_line.opcode << ": '" + << parsed_line.flops << "'"; + } + + return ::testing::AssertionFailure() + << "'flops' field absent in " << parsed_line.opcode << ": '" + << parsed_line.flops << "'"; +} + +::testing::AssertionResult HasTrops( + const ParsedProfileOutputLine& parsed_line) { + if (RE2::FullMatch(parsed_line.trops, "[0-9.TGMk]+TROP/s")) { + return ::testing::AssertionSuccess() + << "'trops' field present in " << parsed_line.opcode << ": '" + << parsed_line.trops << "'"; + } + + return ::testing::AssertionFailure() + << "'trops' field absent in " << parsed_line.opcode << ": '" + << parsed_line.trops << "'"; +} + +Status ParseOneProfileOutputLine( + const string& line, bool expect_hlo, + gtl::FlatMap* parsed_results) { string separator = "[^:]*:: +"; string match_percentage = "\\d+\\.\\d\\d%"; string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; string match_usecs = "([0-9.]+) usec"; - string match_flops = expect_flops ? "([0-9.TGMk]+)FLOP/s" : "()"; - string match_trops = expect_trops ? "([0-9.TGMk]+)TROP/s" : "()"; + string match_flops = "([^ ]+)"; + string match_trops = "([^ ]+)"; string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; + + // The underlined part is what we're trying to match with match_opcode: + // + // %dot33 = f32[256,256]{1,0} dot(...) + // ^^^ + + string match_opcode = + expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])"; string regexp_pattern = tensorflow::strings::StrCat( " +", match_cycles, separator, match_usecs, separator, match_flops, separator, match_trops, separator, match_bytes_per_sec, separator, - match_bytes_per_cycle, separator, "(.*)"); + match_bytes_per_cycle, separator, match_opcode); - RE2 pattern(regexp_pattern); ParsedProfileOutputLine parsed_line; bool matched = RE2::FullMatch( - line, pattern, &parsed_line.cycles, &parsed_line.cycles_percentage, + line, regexp_pattern, &parsed_line.cycles, &parsed_line.cycles_percentage, &parsed_line.usec, &parsed_line.flops, &parsed_line.trops, &parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle, - &parsed_line.name); + &parsed_line.opcode); if (!matched) { return tensorflow::errors::InvalidArgument( "Input did not match regexp. Input: ", line, ", Regexp: ", regexp_pattern); } - return parsed_line; + InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); + + return Status::OK(); } // Returns void so that we can ASSERT. @@ -148,7 +186,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { ClientLibrary::GetOrCreateLocalClient(platform)); ComputationBuilder builder(client, TestName()); - auto result = builder.Tanh(builder.Dot( + auto result = builder.Tanh(builder.Add( builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); @@ -161,31 +199,43 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { std::vector profile_output_lines = tensorflow::str_util::Split(profile_output, '\n'); - TF_ASSERT_OK_AND_ASSIGN( - ParsedProfileOutputLine total_profile, - ParseProfileOutputLine(profile_output_lines[1], /*expect_flops=*/true, - /*expect_trops=*/true)); + gtl::FlatMap parsed_profile_lines; - TF_ASSERT_OK_AND_ASSIGN( - ParsedProfileOutputLine dot_profile, - ParseProfileOutputLine(profile_output_lines[2], /*expect_flops=*/true, - /*expect_trops=*/false)); + TF_ASSERT_OK(ParseOneProfileOutputLine( + profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines)); - TF_ASSERT_OK_AND_ASSIGN( - ParsedProfileOutputLine tanh_profile, - ParseProfileOutputLine(profile_output_lines[3], /*expect_flops=*/false, - /*expect_trops=*/true)); + TF_ASSERT_OK(ParseOneProfileOutputLine( + profile_output_lines[2], /*expect_hlo=*/true, &parsed_profile_lines)); + + TF_ASSERT_OK(ParseOneProfileOutputLine( + profile_output_lines[3], /*expect_hlo=*/true, &parsed_profile_lines)); + + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile, + MaybeFind(parsed_profile_lines, "[total]")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile, + MaybeFind(parsed_profile_lines, "add")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine tanh_profile, + MaybeFind(parsed_profile_lines, "tanh")); EXPECT_GT(total_profile.cycles, 0); EXPECT_EQ(total_profile.cycles_percentage, "100.00%"); + EXPECT_TRUE(HasFlops(total_profile)); + EXPECT_TRUE(HasTrops(total_profile)); + EXPECT_GT(total_profile.cycles, dot_profile.cycles); EXPECT_NE(dot_profile.cycles_percentage, "0.00%"); EXPECT_NE(dot_profile.cycles_percentage, "100.00%"); + EXPECT_TRUE(HasFlops(dot_profile)); + EXPECT_FALSE(HasTrops(dot_profile)); + EXPECT_GT(total_profile.cycles, tanh_profile.cycles); EXPECT_NE(tanh_profile.cycles_percentage, "0.00%"); EXPECT_NE(tanh_profile.cycles_percentage, "100.00%"); + + EXPECT_FALSE(HasFlops(tanh_profile)); + EXPECT_TRUE(HasTrops(tanh_profile)); } // TODO(b/71364943): This test exposes a bug in the parallel CPU backend. @@ -220,7 +270,7 @@ XLA_TEST_F(HloProfileTest, auto matrix = builder.GetTupleElement(state, 1); auto next_iteration = builder.Add(builder.GetTupleElement(state, 0), builder.ConstantR0(1)); - builder.Tuple({next_iteration, builder.Dot(matrix, matrix)}); + builder.Tuple({next_iteration, builder.Add(matrix, matrix)}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } @@ -249,20 +299,23 @@ XLA_TEST_F(HloProfileTest, ASSERT_NE(while_body_profile_start, profile_output_lines.end()); - TF_ASSERT_OK_AND_ASSIGN( - ParsedProfileOutputLine total_while_body_profile, - ParseProfileOutputLine(*std::next(while_body_profile_start, 1), - /*expect_flops=*/false, - /*expect_trops=*/false)); + gtl::FlatMap parsed_profile_lines; - TF_ASSERT_OK_AND_ASSIGN( - ParsedProfileOutputLine dot_profile, - ParseProfileOutputLine(*std::next(while_body_profile_start, 2), - /*expect_flops=*/false, - /*expect_trops=*/false)); + TF_ASSERT_OK( + ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1), + /*expect_hlo=*/false, &parsed_profile_lines)); + + TF_ASSERT_OK( + ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2), + /*expect_hlo=*/true, &parsed_profile_lines)); + + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile, + MaybeFind(parsed_profile_lines, "[total]")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile, + MaybeFind(parsed_profile_lines, "add")); EXPECT_GT(total_while_body_profile.cycles, 0); - EXPECT_EQ(total_while_body_profile.name, "[total]"); + EXPECT_EQ(total_while_body_profile.opcode, "[total]"); EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%"); EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 5ede37b8737bd4fa6235464ddeb6382af17c8a80..b82f1c81c84b487c1661af5267b9123da97bb107 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -85,10 +85,12 @@ void RealMain(tensorflow::gtl::ArraySlice args) { for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } + ExecutableBuildOptions build_options; + build_options.set_device_ordinal(0); + build_options.set_result_layout(program_shape->result()); StatusOr> executable = local_service->CompileExecutable(computation.handle(), layouts, - &program_shape->result(), - /*device_ordinal=*/0); + build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 24417a0cb8212e59cc0af53bd5bb21afcf3e134b..05c0fdf97d27c09eb2bbb0f265b5b2a5982ca7b1 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -60,10 +60,13 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } + + ExecutableBuildOptions build_options; + build_options.set_device_ordinal(0); + build_options.set_result_layout(program_shape->result()); StatusOr> executable = local_service->CompileExecutable(computation.handle(), layouts, - &program_shape->result(), - /*device_ordinal=*/0); + build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc index 4e02e17db65c0a4220672733be8319e1a0cc4f0f..8460ae3e4991ee091af72d2553a8491f627c722e 100644 --- a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc +++ b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc @@ -19,7 +19,7 @@ limitations under the License. // // Reads one serilized Hlo module, convert it into JSON format and dump into // some output directory. some_binaray_proto is obtained by serializing Hlo -// module to disk using --xla_dump_hlo_proto_to debug optoin. +// module to disk using --xla_dump_optimized_hlo_proto_to debug option. #include #include diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 42e7f91f26f3454b247d95d328c3422c44131c43..cd2b843ad36013ae83818ecbc184fb823093f037 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -220,10 +220,13 @@ class HloParser { bool AddComputation(const string& name, HloComputation* computation, LocTy name_loc); - // The map from the instruction name to the instruction. This does not own the - // instructions. - std::unordered_map instruction_pool_; - std::unordered_map computation_pool_; + // The map from the instruction/computation name to the + // instruction/computation itself and it's location. This does not own the + // pointers. + std::unordered_map> + instruction_pool_; + std::unordered_map> + computation_pool_; HloLexer lexer_; std::unique_ptr module_; @@ -340,15 +343,16 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - HloInstruction* root = - tensorflow::gtl::FindPtrOrNull(instruction_pool_, root_name); + std::pair* root_node = + tensorflow::gtl::FindOrNull(instruction_pool_, root_name); // This means some instruction was marked as ROOT but we didn't find it in the // pool, which should not happen. - if (!root_name.empty() && root == nullptr) { + if (!root_name.empty() && root_node == nullptr) { LOG(FATAL) << "instruction " << root_name << " was marked as ROOT but the parser has not seen it before"; } + HloInstruction* root = root_node == nullptr ? nullptr : root_node->first; // Now root can be either an existing instruction or a nullptr. If it's a // nullptr, the implementation of Builder will set the last instruction as // root instruction. @@ -990,6 +994,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands, *custom_call_target)); break; } + case HloOpcode::kHostCompute: { + optional channel_name; + optional cost_estimate_ns; + attrs["channel_name"] = {/*required=*/true, AttrTy::kString, + &channel_name}; + attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, + &cost_estimate_ns}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateHostCompute( + shape, operands, *channel_name, *cost_estimate_ns)); + break; + } case HloOpcode::kDot: { optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { @@ -1031,6 +1049,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); break; } + case HloOpcode::kGather: + // TODO(b/72710576): HLO parsing is not implemented for Gather. + return TokenError("HLO parsing is not implemented for Gather"); case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); @@ -1229,13 +1250,13 @@ bool HloParser::ParseInstructionNames( if (!ParseName(&name)) { return Error(loc, "expects a instruction name"); } - HloInstruction* instr = - tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + std::pair* instr = + tensorflow::gtl::FindOrNull(instruction_pool_, name); if (!instr) { return TokenError( Printf("instruction '%s' is not defined", name.c_str())); } - instructions->push_back(instr); + instructions->push_back(instr->first); } while (EatIfPresent(TokKind::kComma)); return ParseToken(TokKind::kRbrace, @@ -1705,12 +1726,12 @@ bool HloParser::ParseOperands(std::vector* operands) { if (!ParseName(&name)) { return false; } - HloInstruction* instruction = - tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + std::pair* instruction = + tensorflow::gtl::FindOrNull(instruction_pool_, name); if (!instruction) { return Error(loc, StrCat("instruction does not exist: ", name)); } - operands->push_back(instruction); + operands->push_back(instruction->first); } while (EatIfPresent(TokKind::kComma)); } return ParseToken(TokKind::kRparen, "expects ')' at the end of operands"); @@ -1957,10 +1978,12 @@ bool HloParser::ParseComputationName(HloComputation** value) { if (!ParseName(&name)) { return Error(loc, "expects computation name"); } - *value = tensorflow::gtl::FindPtrOrNull(computation_pool_, name); - if (*value == nullptr) { + std::pair* computation = + tensorflow::gtl::FindOrNull(computation_pool_, name); + if (computation == nullptr) { return Error(loc, StrCat("computation does not exist: ", name)); } + *value = computation->first; return true; } @@ -2173,7 +2196,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // // {[2:3:4], [5:6:7], [8:9]} // -// The the parsed result will be: +// The parsed result will be: // // {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}} // @@ -2576,18 +2599,22 @@ bool HloParser::EatIfPresent(TokKind kind) { bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, LocTy name_loc) { - auto result = instruction_pool_.insert({name, instruction}); + auto result = instruction_pool_.insert({name, {instruction, name_loc}}); if (!result.second) { - return Error(name_loc, StrCat("instruction already exists: ", name)); + Error(name_loc, StrCat("instruction already exists: ", name)); + return Error(/*loc=*/result.first->second.second, + "instruction previously defined here"); } return true; } bool HloParser::AddComputation(const string& name, HloComputation* computation, LocTy name_loc) { - auto result = computation_pool_.insert({name, computation}); + auto result = computation_pool_.insert({name, {computation, name_loc}}); if (!result.second) { - return Error(name_loc, StrCat("computation already exists: ", name)); + Error(name_loc, StrCat("computation already exists: ", name)); + return Error(/*loc=*/result.first->second.second, + "computation previously defined here"); } return true; } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index dd76d8d0fee7cdfa22829fe92ff889e44157216e..b8c6b59204f897c7dc07b846370b5b776a19a808 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -1275,6 +1275,35 @@ ENTRY consts { "one computation should have only one ROOT"); } +TEST_F(HloParserTest, InstructionExists) { + const string original = R"(HloModule comp_exists +c1 { + instr = f32[1]{0} constant({12345}) +} +c2 { + instr = f32[1]{0} constant({67890}) +})"; + + ExpectHasSubstr(Parse(original).status().error_message(), + R"(was parsing 3:3: error: instruction previously defined here + instr = f32[1]{0} constant({12345}) + ^)"); +} + +TEST_F(HloParserTest, ComputationExists) { + const string original = R"(HloModule comp_exists +comp { + const1 = f32[1]{0} constant({12345}) +} +comp { + const2 = f32[1]{0} constant({67890}) +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + R"(was parsing 2:1: error: computation previously defined here +comp { +^)"); +} + } // namespace } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index fe5d29a6b655a89d559eb1214c2b8dd54d34094c..1f0c626bbb2d64ef4e67c9ec51485ae96ae73d04 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -30,9 +30,7 @@ limitations under the License. #include "tensorflow/core/platform/stacktrace.h" namespace xla { -namespace { -// Logs the provided status message with a backtrace. Status WithLogBacktrace(const Status& status) { CHECK(!status.ok()); VLOG(1) << status.ToString(); @@ -40,8 +38,6 @@ Status WithLogBacktrace(const Status& status) { return status; } -} // namespace - ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled) : enabled(enabled), label(label) { if (enabled) { @@ -74,13 +70,18 @@ Status AppendStatus(Status prior, tensorflow::StringPiece context) { // Implementation note: we can't common these out (without using macros) because // they all need to va_start/va_end their varargs in their frame. -Status InvalidArgument(const char* format, ...) { +Status InvalidArgumentV(const char* format, va_list args) { string message; + tensorflow::strings::Appendv(&message, format, args); + return WithLogBacktrace(tensorflow::errors::InvalidArgument(message)); +} + +Status InvalidArgument(const char* format, ...) { va_list args; va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); + Status result = InvalidArgumentV(format, args); va_end(args); - return WithLogBacktrace(tensorflow::errors::InvalidArgument(message)); + return result; } Status Unimplemented(const char* format, ...) { @@ -338,7 +339,7 @@ std::vector> CommonFactors( string SanitizeFileName(string file_name) { for (char& c : file_name) { - if (c == '/' || c == '\\' || c == '[' || c == ']') { + if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') { c = '_'; } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 1d7dd344493f91d84714c72783c95a49ad72ad1c..46ec7af54290f40dfac1e4627801eab4dabb8aa5 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -40,6 +40,13 @@ limitations under the License. namespace xla { +// Logs the provided status message with a backtrace. +// +// For use by Status-factories, logs a backtrace at the point where the status +// is created, such that we can use --vmodule=util=1 to see all status +// creation backtraces. +Status WithLogBacktrace(const Status& status); + // Ranks greater than 8 are very rare, so use InlinedVector to store // the bounds and indices. And for the rare cases of ranks greater than 8, // the InlinedVector will just behave like an std::vector<> and allocate the @@ -207,6 +214,27 @@ Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); +// Passed-varargs variant of the InvalidArgument factory above. +Status InvalidArgumentV(const char* format, va_list args); + +template +Status UnimplementedStrCat(Args&&... concat) { + return Unimplemented( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + +template +Status InternalErrorStrCat(Args&&... concat) { + return InternalError( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + +template +Status ResourceExhaustedStrCat(Args&&... concat) { + return ResourceExhausted( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + // Splits the lines of the original, replaces leading whitespace with the prefix // given by "indentation", and returns the string joined by newlines again. As a // side effect, any additional trailing whitespace is removed. @@ -332,7 +360,7 @@ T CeilOfRatio(T dividend, T divisor) { } // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio -// then multiplying by the divisor. For example: RoundUpToMultiple(13, 8) => 16 +// then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16 template T RoundUpToNearest(T value, T divisor) { return CeilOfRatio(value, divisor) * divisor; @@ -340,7 +368,7 @@ T RoundUpToNearest(T value, T divisor) { // Rounds the value down to a multiple of the divisor by first calling // FloorOfRatio then multiplying by the divisor. For example: -// RoundUpToMultiple(13, 8) => 8 +// RoundDownToNearest(13, 8) => 8 template T RoundDownToNearest(T value, T divisor) { return FloorOfRatio(value, divisor) * divisor; @@ -420,11 +448,38 @@ OutputIterator c_copy_if(InputContainer input_container, output_iterator, predicate); } +template +OutputIterator c_copy(InputContainer input_container, + OutputIterator output_iterator) { + return std::copy(std::begin(input_container), std::end(input_container), + output_iterator); +} + +template +void c_sort(InputContainer& input_container) { + std::sort(std::begin(input_container), std::end(input_container)); +} + template void c_sort(InputContainer& input_container, Comparator comparator) { - std::sort(input_container.begin(), input_container.end(), comparator); + std::sort(std::begin(input_container), std::end(input_container), comparator); } +template +bool c_binary_search(Sequence& sequence, T&& value) { + return std::binary_search(std::begin(sequence), std::end(sequence), + std::forward(value)); +} + +template +bool c_is_sorted(const C& c) { + return std::is_sorted(std::begin(c), std::end(c)); +} + +template +auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { + return std::adjacent_find(std::begin(c), std::end(c)); +} } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 55f42ed3a454baa3f8b6adf60a78582488733e9b..93284b80f9e1f82c4b18dc7388754d5c01a7740c 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -32,6 +32,8 @@ Window MakeWindow(tensorflow::gtl::ArraySlice sizes) { auto* dimension = window.add_dimensions(); dimension->set_size(size); dimension->set_stride(1); + dimension->set_base_dilation(1); + dimension->set_window_dilation(1); } return window; } diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index e1ed08c8480fa73e9c5ff914bb9f5e38f1ce96e9..56162ab44e2e0e3e4478fe631888f243332dc1d8 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -82,8 +82,9 @@ message DebugOptions { // Dump all HLO modules as text into the provided directory path. string xla_generate_hlo_text_to = 7; - // Dump compilation artifacts in binary proto into this directory. - string xla_dump_hlo_proto_to = 8; + // Dump Hlo after all hlo passes are executed as proto binary into this + // directory. + string xla_dump_optimized_hlo_proto_to = 8; // Instrument the computation to collect per-HLO cycle counts. bool xla_hlo_profile = 9; @@ -179,9 +180,13 @@ message DebugOptions { // ops. bool xla_gpu_use_cudnn_batchnorm = 94; - // Dump compilation artifacts, before hlo passes are executed, in binary proto - // into this directory. - string xla_dump_prepass_hlo_proto_to = 95; + // Dump HLO before any hlo passes are executed as proto binary into this + // directory. + string xla_dump_unoptimized_hlo_proto_to = 95; + + // Dump HLO after each pass as an HloProto in binary file format into this + // directory. + string xla_dump_per_pass_hlo_proto_to = 96; // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 3aea0217539b89b5d60ecfaf2605eee4b69af728..28620c3b86349281573eaf57d2838bee1488d838 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -393,6 +393,33 @@ message Window { repeated WindowDimension dimensions = 1; } +// Describes the dimension numbers for a gather operation. +// +// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for +// more details. +message GatherDimensionNumbers { + // "Window indices" is a term for a set of indices that index into the + // interior of a dynamic-slice from the input tensor, the starting indices for + // which were computed from output_gather_dims (see the operation semantic for + // how this is defined) and the gather_indices tensor. + // + // The window indices for a specific output index Out is computed as: + // + // i = 0 + // for (k : [0, input_tensor_shape.rank)) + // window_indices[k] = + // if k in elided_window_dims + // then 0 + // else Out[output_window_dims[i++]] + repeated int64 output_window_dims = 1; + repeated int64 elided_window_dims = 2; + + // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It + // transforms the gather index looked up from the gather_indices tensor into + // the starting index in the input space. + repeated int64 gather_dims_to_operand_dims = 3; +} + // Operation requests that are all collected as a tagged union with a oneof // field in OpRequest. @@ -519,6 +546,20 @@ message CustomCallRequest { Shape shape = 4; } +message HostComputeRequest { + // Operand to the HostCompute. Supports tuple. + repeated ComputationDataHandle operands = 1; + + // Name used to identify HostSend/Recv channels. + string channel_name = 2; + + // Cost estimate in nanoseconds. + int64 cost_estimate_ns = 3; + + // The shape of any data returned by host. + Shape shape = 4; +} + message DotDimensionNumbers { // The dimension numbers that represent the 'lhs' contracting dimensions. repeated int64 lhs_contracting_dimensions = 1; @@ -880,6 +921,13 @@ message RecvRequest { ChannelHandle channel_handle = 2; } +message GatherRequest { + ComputationDataHandle input = 1; + ComputationDataHandle gather_indices = 2; + GatherDimensionNumbers dimension_numbers = 3; + repeated int64 window_bounds = 4; +} + message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -957,7 +1005,9 @@ message OpRequest { FftRequest fft_request = 41; ConvertRequest bitcast_convert_request = 42; ConditionalRequest conditional_request = 44; - // Next: 45 + HostComputeRequest host_compute_request = 45; + GatherRequest gather_request = 46; + // Next: 47 } } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index f1e54432faa3c59ada0d89c472bcdcc28f6d0970..bab37e8906e5c648acdc1556da7e5f4601776ff5 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -7,6 +7,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") py_library( name = "contrib_py", @@ -24,6 +25,7 @@ py_library( "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/boosted_trees:init_py", "//tensorflow/contrib/cloud:cloud_py", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_ops_py", "//tensorflow/contrib/compiler:compiler_py", @@ -36,6 +38,7 @@ py_library( "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/estimator:estimator_py", "//tensorflow/contrib/factorization:factorization_py", + "//tensorflow/contrib/feature_column:feature_column_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/fused_conv:fused_conv_py", @@ -48,6 +51,7 @@ py_library( "//tensorflow/contrib/image:single_image_random_dot_stereograms_py", "//tensorflow/contrib/input_pipeline:input_pipeline_py", "//tensorflow/contrib/integrate:integrate_py", + "//tensorflow/contrib/kafka", "//tensorflow/contrib/keras", "//tensorflow/contrib/kernel_methods", "//tensorflow/contrib/kfac", @@ -68,7 +72,6 @@ py_library( "//tensorflow/contrib/metrics:metrics_py", "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", - "//tensorflow/contrib/ndlstm", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", @@ -76,6 +79,7 @@ py_library( "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", + "//tensorflow/contrib/py2tf", "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", @@ -104,7 +108,9 @@ py_library( "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", - ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]), + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ + "//tensorflow/contrib/tensorrt:init_py", + ]), ) cc_library( @@ -114,6 +120,7 @@ cc_library( "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", "//tensorflow/contrib/coder:all_kernels", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels", + "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", @@ -136,9 +143,11 @@ cc_library( "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", "//tensorflow/contrib/coder:all_ops", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_ops_op_lib", + "//tensorflow/contrib/data:dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", + "//tensorflow/contrib/kafka:kafka_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 8f6a3cb1ca4544cae6f42fd1727d509af9fc0233..4f6f539027b040de7554d09fe9118ff97aa006f8 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -33,6 +33,7 @@ from tensorflow.contrib import deprecated from tensorflow.contrib import distributions from tensorflow.contrib import estimator from tensorflow.contrib import factorization +from tensorflow.contrib import feature_column from tensorflow.contrib import framework from tensorflow.contrib import gan from tensorflow.contrib import graph_editor @@ -83,7 +84,6 @@ from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager from tensorflow.contrib.lite.python import lite -from tensorflow.contrib.ndlstm import python as ndlstm from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 28f60b34996945d573facc665c01d0bc10cf5cd1..6658f0d9c13f6db17b25354cde2593d57f104f17 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -48,7 +48,7 @@ def _flatten_tensors(tensors): if shape.ndims is None: raise ValueError("At least one of the tensors in 'tensors' must have " "statically known rank.") - if len(shape) > 1: + if len(shape) != 1: reshaped = [] for t in tensors: with ops.colocate_with(t): @@ -289,7 +289,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, chunks_by_dev) if pad_len > 0: output_tensors = _strip_padding(output_tensors, pad_len) - if len(shape) > 1: + if len(shape) != 1: output_tensors = _reshape_tensors(output_tensors, shape) return output_tensors @@ -466,7 +466,7 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): if un_op: reduced_shards = [un_op(t) for t in reduced_shards] output_tensors = _build_recursive_hd_scatter(reduced_shards, devices) - if len(shape) > 1: + if len(shape) != 1: output_tensors = _reshape_tensors(output_tensors, shape) return output_tensors @@ -578,7 +578,7 @@ def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): reduced_shards = _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op) output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices) - if len(shape) > 1: + if len(shape) != 1: output_tensors = _reshape_tensors(output_tensors, shape) return output_tensors @@ -752,13 +752,13 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): dst_tensors.append(array_ops.identity(broadcast_src)) down_values[w] = dst_tensors output_tensors = [v for sublist in down_values for v in sublist] - if len(shape) > 1: + if len(shape) != 1: output_tensors = _reshape_tensors(output_tensors, shape) return output_tensors def _reduce_non_singleton(input_tensors, red_f, un_op): - """If input_tenors has more than one element apply red_f, else apply un_op.""" + """If input_tensors has more than one element apply red_f, else apply un_op.""" if len(input_tensors) > 1: return red_f(input_tensors) else: @@ -831,7 +831,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): for w in range(0, num_workers): output_tensors += _build_shuffle_scatter( [level_2_output[w]], per_worker_devices[w]) - if len(shape) > 1: + if len(shape) != 1: output_tensors = _reshape_tensors(output_tensors, shape) return output_tensors diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py index 0802b2736909c2a6f075ea2eac6d4dd3ab2918d8..47bab0a3670a90644972b2c961954a3036b8ecba 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py @@ -119,7 +119,7 @@ class AllReduceTest(test_util.TensorFlowTestCase): def _buildInitialVars(self, shape, dev_list): values = [] num_devices = len(dev_list) - dim = np.prod(shape) + dim = np.prod(shape) if shape else 1 for d in range(0, num_devices): with ops.device(dev_list[d]): npt = np.zeros(shape).astype(np.float32) @@ -164,6 +164,7 @@ class AllReduceTest(test_util.TensorFlowTestCase): (num_workers, num_gpus, shape, subdiv, elapsed)) def testRingAllReduce(self): + self._testRingAllReduce(1, 2, [], 1) self._testRingAllReduce(1, 2, [8], 1) self._testRingAllReduce(1, 2, [4, 4], 1) self._testRingAllReduce(6, 1, [8], 1) @@ -192,6 +193,7 @@ class AllReduceTest(test_util.TensorFlowTestCase): "elapsed=%f" % (num_workers, num_gpus, shape, elapsed)) def testShuffleAllReduce(self): + self._testShuffleAllReduce(1, 2, [], 1) self._testShuffleAllReduce(1, 2, [8], 1) self._testShuffleAllReduce(1, 2, [4, 4], 1) self._testShuffleAllReduce(1, 8, [32], 1) diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md index b8d73bf24ce60e0b3850d4f39ac9e6d6c2194a02..db37bcf73d144eb81c32a461a276d10be7e2d193 100644 --- a/tensorflow/contrib/android/README.md +++ b/tensorflow/contrib/android/README.md @@ -81,6 +81,11 @@ For documentation on building a self-contained AAR file with cmake, see [tensorflow/contrib/android/cmake](cmake). +### Makefile + +For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md) + + ## AssetManagerFileSystem This directory also contains a TensorFlow filesystem supporting the Android diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index dc5b9fb88742d78d0f40207b589e29451a6358dd..abddadac5bcace9b1f992b69bdcc69c24b29cd13 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -194,6 +194,11 @@ public class TensorFlowInferenceInterface { * @param outputNames A list of output nodes which should be filled by the inference pass. */ public void run(String[] outputNames, boolean enableStats) { + run(outputNames, enableStats, new String[] {}); + } + + /** An overloaded version of runInference that allows supplying targetNodeNames as well */ + public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) { // Release any Tensors from the previous run calls. closeFetches(); @@ -204,6 +209,11 @@ public class TensorFlowInferenceInterface { runner.fetch(tid.name, tid.outputIndex); } + // Add targets. + for (String t : targetNodeNames) { + runner.addTarget(t); + } + // Run the session. try { if (enableStats) { diff --git a/tensorflow/contrib/android/jni/run_stats_jni.cc b/tensorflow/contrib/android/jni/run_stats_jni.cc index 119fa9cd2c378d2ba2383ea8b0e09e1b6083d84e..707853b59befc2625145ad96952fbf9f66d62b43 100644 --- a/tensorflow/contrib/android/jni/run_stats_jni.cc +++ b/tensorflow/contrib/android/jni/run_stats_jni.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/stat_summarizer.h" -using tensorflow::StatSummarizer; using tensorflow::RunMetadata; +using tensorflow::StatSummarizer; namespace { StatSummarizer* requireHandle(JNIEnv* env, jlong handle) { diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 4e0b3f9af989c414ad88c510c1bfd180dbadd5ea..921d6917a4e478c3e60771fdc3ae99febc33d2e3 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -53,10 +53,13 @@ def _UnbatchGrad(op, grad): # pylint: disable=invalid-name ] -def batch_function(num_batch_threads, max_batch_size, batch_timeout_micros, +def batch_function(num_batch_threads, + max_batch_size, + batch_timeout_micros, allowed_batch_sizes=None, grad_timeout_micros=60 * 1000 * 1000, - unbatch_timeout_micros=60 * 1000 * 1000): + unbatch_timeout_micros=60 * 1000 * 1000, + max_enqueued_batches=10): """Batches the computation done by the decorated function. So, for example, in the following code @@ -94,6 +97,7 @@ def batch_function(num_batch_threads, max_batch_size, batch_timeout_micros, documentation of the unbatch op for more details. Defaults to 60s. unbatch_timeout_micros: The timeout to use for unbatching. See the documentation of the unbatch op for more details. Defaults to 60s. + max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. Returns: The decorated function will return the unbatched computation output Tensors. @@ -111,6 +115,7 @@ def batch_function(num_batch_threads, max_batch_size, batch_timeout_micros, num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, batch_timeout_micros=batch_timeout_micros, + max_enqueued_batches=max_enqueued_batches, allowed_batch_sizes=allowed_batch_sizes, grad_timeout_micros=grad_timeout_micros, shared_name=name) diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 11c3c037c4e8b4ba41eae60d28d6aac49f1488f2..74712aeb67c3f0a31def78f25a0298f9c02c9590 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -137,6 +137,26 @@ cuda_py_test( ], ) +cuda_py_test( + name = "mcmc_diagnostics_test", + size = "small", + srcs = ["python/kernel_tests/mcmc_diagnostics_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/python:spectral_ops_test_util", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", + ], +) + cuda_py_test( name = "monte_carlo_test", size = "small", @@ -175,6 +195,7 @@ cuda_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = ["no_mac"], # b/73192243 ) cuda_py_test( @@ -217,6 +238,24 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:random_seed", ], + tags = ["notsan"], +) + +cuda_py_test( + name = "variable_utils_test", + size = "small", + srcs = ["python/kernel_tests/variable_utils_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], ) cuda_py_test( diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 95b9452b1ada60c44672f37800ced2133d2bd8b2..528c4fbacd06c7b0defa0e32bd24a98b2bc07b64 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -26,9 +26,11 @@ from tensorflow.contrib.bayesflow.python.ops import custom_grad from tensorflow.contrib.bayesflow.python.ops import halton_sequence from tensorflow.contrib.bayesflow.python.ops import hmc from tensorflow.contrib.bayesflow.python.ops import layers +from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo from tensorflow.contrib.bayesflow.python.ops import optimizers +from tensorflow.contrib.bayesflow.python.ops import variable_utils # pylint: enable=unused-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented @@ -42,10 +44,12 @@ _allowed_symbols = [ 'hmc', 'layers', 'metropolis_hastings', + 'mcmc_diagnostics', 'monte_carlo', 'optimizers', 'special_math', 'stochastic_variables', + 'variable_utils', 'variational_inference', ] diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index cbc66b6dc13db62c25952de6b6c13b2fdfe27f12..5bd834e56245ab4d874544cfd014fe59ae521ea8 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -18,30 +18,40 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + import numpy as np -from scipy import special from scipy import stats from tensorflow.contrib.bayesflow.python.ops import hmc +from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change +from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import gen_linalg_ops +from tensorflow.python.ops import gradients_impl as gradients_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import gamma as gamma_lib +from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.platform import tf_logging as logging_ops + + +def _reduce_variance(x, axis=None, keepdims=False): + sample_mean = math_ops.reduce_mean(x, axis, keepdims=True) + return math_ops.reduce_mean( + math_ops.squared_difference(x, sample_mean), axis, keepdims) -# TODO(b/66964210): Test float16. class HMCTest(test.TestCase): def setUp(self): self._shape_param = 5. self._rate_param = 10. - self._expected_x = (special.digamma(self._shape_param) - - np.log(self._rate_param)) - self._expected_exp_x = self._shape_param / self._rate_param random_seed.set_random_seed(10003) np.random.seed(10003) @@ -63,63 +73,46 @@ class HMCTest(test.TestCase): self._rate_param * math_ops.exp(x), event_dims) - def _log_gamma_log_prob_grad(self, x, event_dims=()): - """Computes log-pdf and gradient of a log-gamma random variable. - - Args: - x: Value of the random variable. - event_dims: Dimensions not to treat as independent. Default is (), - i.e., all dimensions are independent. - - Returns: - log_prob: The log-pdf up to a normalizing constant. - grad: The gradient of the log-pdf with respect to x. - """ - return (math_ops.reduce_sum(self._shape_param * x - - self._rate_param * math_ops.exp(x), - event_dims), - self._shape_param - self._rate_param * math_ops.exp(x)) - - def _n_event_dims(self, x_shape, event_dims): - return np.prod([int(x_shape[i]) for i in event_dims]) - - def _integrator_conserves_energy(self, x, event_dims, sess, + def _integrator_conserves_energy(self, x, independent_chain_ndims, sess, feed_dict=None): - def potential_and_grad(x): - log_prob, grad = self._log_gamma_log_prob_grad(x, event_dims) - return -log_prob, -grad - - step_size = array_ops.placeholder(np.float32, [], name='step_size') - hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') + step_size = array_ops.placeholder(np.float32, [], name="step_size") + hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps") if feed_dict is None: feed_dict = {} feed_dict[hmc_lf_steps] = 1000 - m = random_ops.random_normal(array_ops.shape(x)) - potential_0, grad_0 = potential_and_grad(x) - old_energy = potential_0 + 0.5 * math_ops.reduce_sum(m * m, - event_dims) - - _, new_m, potential_1, _ = ( - hmc.leapfrog_integrator(step_size, hmc_lf_steps, x, - m, potential_and_grad, grad_0)) + event_dims = math_ops.range(independent_chain_ndims, + array_ops.rank(x)) - new_energy = potential_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, + m = random_ops.random_normal(array_ops.shape(x)) + log_prob_0 = self._log_gamma_log_prob(x, event_dims) + grad_0 = gradients_ops.gradients(log_prob_0, x) + old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims) + + new_m, _, log_prob_1, _ = _leapfrog_integrator( + current_momentums=[m], + target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims), + current_state_parts=[x], + step_sizes=[step_size], + num_leapfrog_steps=hmc_lf_steps, + current_target_log_prob=log_prob_0, + current_grads_target_log_prob=grad_0) + new_m = new_m[0] + + new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, event_dims) x_shape = sess.run(x, feed_dict).shape - n_event_dims = self._n_event_dims(x_shape, event_dims) - feed_dict[step_size] = 0.1 / n_event_dims - old_energy_val, new_energy_val = sess.run([old_energy, new_energy], - feed_dict) - logging.vlog(1, 'average energy change: {}'.format( - abs(old_energy_val - new_energy_val).mean())) - - self.assertAllEqual(np.ones_like(new_energy_val, dtype=np.bool), - abs(old_energy_val - new_energy_val) < 1.) - - def _integrator_conserves_energy_wrapper(self, event_dims): + event_size = np.prod(x_shape[independent_chain_ndims:]) + feed_dict[step_size] = 0.1 / event_size + old_energy_, new_energy_ = sess.run([old_energy, new_energy], + feed_dict) + logging_ops.vlog(1, "average energy relative change: {}".format( + (1. - new_energy_ / old_energy_).mean())) + self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02) + + def _integrator_conserves_energy_wrapper(self, independent_chain_ndims): """Tests the long-term energy conservation of the leapfrog integrator. The leapfrog integrator is symplectic, so for sufficiently small step @@ -127,135 +120,310 @@ class HMCTest(test.TestCase): the energy of the system blowing up or collapsing. Args: - event_dims: A tuple of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. + independent_chain_ndims: Python `int` scalar representing the number of + dims associated with independent chains. """ - with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name='x_ph') - - feed_dict = {x_ph: np.zeros([50, 10, 2])} - self._integrator_conserves_energy(x_ph, event_dims, sess, feed_dict) + with self.test_session(graph=ops.Graph()) as sess: + x_ph = array_ops.placeholder(np.float32, name="x_ph") + feed_dict = {x_ph: np.random.rand(50, 10, 2)} + self._integrator_conserves_energy(x_ph, independent_chain_ndims, + sess, feed_dict) def testIntegratorEnergyConservationNullShape(self): - self._integrator_conserves_energy_wrapper([]) + self._integrator_conserves_energy_wrapper(0) def testIntegratorEnergyConservation1(self): - self._integrator_conserves_energy_wrapper([1]) + self._integrator_conserves_energy_wrapper(1) def testIntegratorEnergyConservation2(self): - self._integrator_conserves_energy_wrapper([2]) - - def testIntegratorEnergyConservation12(self): - self._integrator_conserves_energy_wrapper([1, 2]) - - def testIntegratorEnergyConservation012(self): - self._integrator_conserves_energy_wrapper([0, 1, 2]) - - def _chain_gets_correct_expectations(self, x, event_dims, sess, - feed_dict=None): + self._integrator_conserves_energy_wrapper(2) + + def testIntegratorEnergyConservation3(self): + self._integrator_conserves_energy_wrapper(3) + + def testSampleChainSeedReproducibleWorksCorrectly(self): + with self.test_session(graph=ops.Graph()) as sess: + num_results = 10 + independent_chain_ndims = 1 + + def log_gamma_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, + array_ops.rank(x)) + return self._log_gamma_log_prob(x, event_dims) + + kwargs = dict( + target_log_prob_fn=log_gamma_log_prob, + current_state=np.random.rand(4, 3, 2), + step_size=0.1, + num_leapfrog_steps=2, + num_burnin_steps=150, + seed=52, + ) + + samples0, kernel_results0 = hmc.sample_chain( + **dict(list(kwargs.items()) + list(dict( + num_results=2 * num_results, + num_steps_between_results=0).items()))) + + samples1, kernel_results1 = hmc.sample_chain( + **dict(list(kwargs.items()) + list(dict( + num_results=num_results, + num_steps_between_results=1).items()))) + + [ + samples0_, + samples1_, + target_log_prob0_, + target_log_prob1_, + ] = sess.run([ + samples0, + samples1, + kernel_results0.current_target_log_prob, + kernel_results1.current_target_log_prob, + ]) + self.assertAllClose(samples0_[::2], samples1_, + atol=1e-5, rtol=1e-5) + self.assertAllClose(target_log_prob0_[::2], target_log_prob1_, + atol=1e-5, rtol=1e-5) + + def _chain_gets_correct_expectations(self, x, independent_chain_ndims, + sess, feed_dict=None): + counter = collections.Counter() def log_gamma_log_prob(x): + counter["target_calls"] += 1 + event_dims = math_ops.range(independent_chain_ndims, + array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) - step_size = array_ops.placeholder(np.float32, [], name='step_size') - hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') - hmc_n_steps = array_ops.placeholder(np.int32, [], name='hmc_n_steps') + num_results = array_ops.placeholder( + np.int32, [], name="num_results") + step_size = array_ops.placeholder( + np.float32, [], name="step_size") + num_leapfrog_steps = array_ops.placeholder( + np.int32, [], name="num_leapfrog_steps") if feed_dict is None: feed_dict = {} - feed_dict.update({step_size: 0.1, - hmc_lf_steps: 2, - hmc_n_steps: 300}) - - sample_chain, acceptance_prob_chain = hmc.chain([hmc_n_steps], - step_size, - hmc_lf_steps, - x, log_gamma_log_prob, - event_dims) - - acceptance_probs, samples = sess.run([acceptance_prob_chain, sample_chain], - feed_dict) - samples = samples[feed_dict[hmc_n_steps] // 2:] - expected_x_est = samples.mean() - expected_exp_x_est = np.exp(samples).mean() - - logging.vlog(1, 'True E[x, exp(x)]: {}\t{}'.format( - self._expected_x, self._expected_exp_x)) - logging.vlog(1, 'Estimated E[x, exp(x)]: {}\t{}'.format( - expected_x_est, expected_exp_x_est)) - self.assertNear(expected_x_est, self._expected_x, 2e-2) - self.assertNear(expected_exp_x_est, self._expected_exp_x, 2e-2) - self.assertTrue((acceptance_probs > 0.5).all()) - self.assertTrue((acceptance_probs <= 1.0).all()) - - def _chain_gets_correct_expectations_wrapper(self, event_dims): - with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name='x_ph') - - feed_dict = {x_ph: np.zeros([50, 10, 2])} - self._chain_gets_correct_expectations(x_ph, event_dims, sess, - feed_dict) + feed_dict.update({num_results: 150, + step_size: 0.05, + num_leapfrog_steps: 2}) + + samples, kernel_results = hmc.sample_chain( + num_results=num_results, + target_log_prob_fn=log_gamma_log_prob, + current_state=x, + step_size=step_size, + num_leapfrog_steps=num_leapfrog_steps, + num_burnin_steps=150, + seed=42) + + self.assertAllEqual(dict(target_calls=2), counter) + + expected_x = (math_ops.digamma(self._shape_param) + - np.log(self._rate_param)) + + expected_exp_x = self._shape_param / self._rate_param + + acceptance_probs_, samples_, expected_x_ = sess.run( + [kernel_results.acceptance_probs, samples, expected_x], + feed_dict) + + actual_x = samples_.mean() + actual_exp_x = np.exp(samples_).mean() + + logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( + expected_x_, expected_exp_x)) + logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( + actual_x, actual_exp_x)) + self.assertNear(actual_x, expected_x_, 2e-2) + self.assertNear(actual_exp_x, expected_exp_x, 2e-2) + self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), + acceptance_probs_ > 0.5) + self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), + acceptance_probs_ <= 1.) + + def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): + with self.test_session(graph=ops.Graph()) as sess: + x_ph = array_ops.placeholder(np.float32, name="x_ph") + feed_dict = {x_ph: np.random.rand(50, 10, 2)} + self._chain_gets_correct_expectations(x_ph, independent_chain_ndims, + sess, feed_dict) def testHMCChainExpectationsNullShape(self): - self._chain_gets_correct_expectations_wrapper([]) + self._chain_gets_correct_expectations_wrapper(0) def testHMCChainExpectations1(self): - self._chain_gets_correct_expectations_wrapper([1]) + self._chain_gets_correct_expectations_wrapper(1) def testHMCChainExpectations2(self): - self._chain_gets_correct_expectations_wrapper([2]) + self._chain_gets_correct_expectations_wrapper(2) - def testHMCChainExpectations12(self): - self._chain_gets_correct_expectations_wrapper([1, 2]) - - def _kernel_leaves_target_invariant(self, initial_draws, event_dims, + def testKernelResultsUsingTruncatedDistribution(self): + def log_prob(x): + return array_ops.where( + x >= 0., + -x - x**2, # Non-constant gradient. + array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) + # This log_prob has the property that it is likely to attract + # the HMC flow toward, and below, zero...but for x <=0, + # log_prob(x) = -inf, which should result in rejection, as well + # as a non-finite log_prob. Thus, this distribution gives us an opportunity + # to test out the kernel results ability to correctly capture rejections due + # to finite AND non-finite reasons. + # Why use a non-constant gradient? This ensures the leapfrog integrator + # will not be exact. + + num_results = 1000 + # Large step size, will give rejections due to integration error in addition + # to rejection due to going into a region of log_prob = -inf. + step_size = 0.1 + num_leapfrog_steps = 5 + num_chains = 2 + + with self.test_session(graph=ops.Graph()) as sess: + + # Start multiple independent chains. + initial_state = ops.convert_to_tensor([0.1] * num_chains) + + states, kernel_results = hmc.sample_chain( + num_results=num_results, + target_log_prob_fn=log_prob, + current_state=initial_state, + step_size=step_size, + num_leapfrog_steps=num_leapfrog_steps, + seed=42) + + states_, kernel_results_ = sess.run([states, kernel_results]) + pstates_ = kernel_results_.proposed_state + + neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob) + + # First: Test that the mathematical properties of the above log prob + # function in conjunction with HMC show up as expected in kernel_results_. + + # We better have log_prob = -inf some of the time. + self.assertLess(0, neg_inf_mask.sum()) + # We better have some rejections due to something other than -inf. + self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) + # We better have been accepted a decent amount, even near the end of the + # chain, or else this HMC run just got stuck at some point. + self.assertLess( + 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) + # We better not have any NaNs in proposed state or log_prob. + # We may have some NaN in grads, which involve multiplication/addition due + # to gradient rules. This is the known "NaN grad issue with tf.where." + self.assertAllEqual(np.zeros_like(states_), + np.isnan(kernel_results_.proposed_target_log_prob)) + self.assertAllEqual(np.zeros_like(states_), + np.isnan(states_)) + # We better not have any +inf in states, grads, or log_prob. + self.assertAllEqual(np.zeros_like(states_), + np.isposinf(kernel_results_.proposed_target_log_prob)) + self.assertAllEqual( + np.zeros_like(states_), + np.isposinf(kernel_results_.proposed_grads_target_log_prob[0])) + self.assertAllEqual(np.zeros_like(states_), + np.isposinf(states_)) + + # Second: Test that kernel_results is congruent with itself and + # acceptance/rejection of states. + + # Proposed state is negative iff proposed target log prob is -inf. + np.testing.assert_array_less(pstates_[neg_inf_mask], 0.) + np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) + + # Acceptance probs are zero whenever proposed state is negative. + self.assertAllEqual( + np.zeros_like(pstates_[neg_inf_mask]), + kernel_results_.acceptance_probs[neg_inf_mask]) + + # The move is accepted ==> state = proposed state. + self.assertAllEqual( + states_[kernel_results_.is_accepted], + pstates_[kernel_results_.is_accepted], + ) + # The move was rejected <==> state[t] == state[t - 1]. + for t in range(1, num_results): + for i in range(num_chains): + if kernel_results_.is_accepted[t, i]: + self.assertNotEqual(states_[t, i], states_[t - 1, i]) + else: + self.assertEqual(states_[t, i], states_[t - 1, i]) + + def _kernel_leaves_target_invariant(self, initial_draws, + independent_chain_ndims, sess, feed_dict=None): def log_gamma_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) def fake_log_prob(x): """Cooled version of the target distribution.""" return 1.1 * log_gamma_log_prob(x) - step_size = array_ops.placeholder(np.float32, [], name='step_size') + step_size = array_ops.placeholder(np.float32, [], name="step_size") if feed_dict is None: feed_dict = {} feed_dict[step_size] = 0.4 - sample, acceptance_probs, _, _ = hmc.kernel(step_size, 5, initial_draws, - log_gamma_log_prob, event_dims) - bad_sample, bad_acceptance_probs, _, _ = hmc.kernel( - step_size, 5, initial_draws, fake_log_prob, event_dims) - (acceptance_probs_val, bad_acceptance_probs_val, initial_draws_val, - updated_draws_val, fake_draws_val) = sess.run([acceptance_probs, - bad_acceptance_probs, - initial_draws, sample, - bad_sample], feed_dict) + sample, kernel_results = hmc.kernel( + target_log_prob_fn=log_gamma_log_prob, + current_state=initial_draws, + step_size=step_size, + num_leapfrog_steps=5, + seed=43) + + bad_sample, bad_kernel_results = hmc.kernel( + target_log_prob_fn=fake_log_prob, + current_state=initial_draws, + step_size=step_size, + num_leapfrog_steps=5, + seed=44) + + [ + acceptance_probs_, + bad_acceptance_probs_, + initial_draws_, + updated_draws_, + fake_draws_, + ] = sess.run([ + kernel_results.acceptance_probs, + bad_kernel_results.acceptance_probs, + initial_draws, + sample, + bad_sample, + ], feed_dict) + # Confirm step size is small enough that we usually accept. - self.assertGreater(acceptance_probs_val.mean(), 0.5) - self.assertGreater(bad_acceptance_probs_val.mean(), 0.5) + self.assertGreater(acceptance_probs_.mean(), 0.5) + self.assertGreater(bad_acceptance_probs_.mean(), 0.5) + # Confirm step size is large enough that we sometimes reject. - self.assertLess(acceptance_probs_val.mean(), 0.99) - self.assertLess(bad_acceptance_probs_val.mean(), 0.99) - _, ks_p_value_true = stats.ks_2samp(initial_draws_val.flatten(), - updated_draws_val.flatten()) - _, ks_p_value_fake = stats.ks_2samp(initial_draws_val.flatten(), - fake_draws_val.flatten()) - logging.vlog(1, 'acceptance rate for true target: {}'.format( - acceptance_probs_val.mean())) - logging.vlog(1, 'acceptance rate for fake target: {}'.format( - bad_acceptance_probs_val.mean())) - logging.vlog(1, 'K-S p-value for true target: {}'.format(ks_p_value_true)) - logging.vlog(1, 'K-S p-value for fake target: {}'.format(ks_p_value_fake)) + self.assertLess(acceptance_probs_.mean(), 0.99) + self.assertLess(bad_acceptance_probs_.mean(), 0.99) + + _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), + updated_draws_.flatten()) + _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(), + fake_draws_.flatten()) + + logging_ops.vlog(1, "acceptance rate for true target: {}".format( + acceptance_probs_.mean())) + logging_ops.vlog(1, "acceptance rate for fake target: {}".format( + bad_acceptance_probs_.mean())) + logging_ops.vlog(1, "K-S p-value for true target: {}".format( + ks_p_value_true)) + logging_ops.vlog(1, "K-S p-value for fake target: {}".format( + ks_p_value_fake)) # Make sure that the MCMC update hasn't changed the empirical CDF much. self.assertGreater(ks_p_value_true, 1e-3) # Confirm that targeting the wrong distribution does # significantly change the empirical CDF. self.assertLess(ks_p_value_fake, 1e-6) - def _kernel_leaves_target_invariant_wrapper(self, event_dims): + def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims): """Tests that the kernel leaves the target distribution invariant. Draws some independent samples from the target distribution, @@ -267,86 +435,160 @@ class HMCTest(test.TestCase): does change the target distribution. (And that we can detect that.) Args: - event_dims: A tuple of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. + independent_chain_ndims: Python `int` scalar representing the number of + dims associated with independent chains. """ - with self.test_session() as sess: + with self.test_session(graph=ops.Graph()) as sess: initial_draws = np.log(np.random.gamma(self._shape_param, size=[50000, 2, 2])) initial_draws -= np.log(self._rate_param) - x_ph = array_ops.placeholder(np.float32, name='x_ph') + x_ph = array_ops.placeholder(np.float32, name="x_ph") feed_dict = {x_ph: initial_draws} - self._kernel_leaves_target_invariant(x_ph, event_dims, sess, - feed_dict) - - def testKernelLeavesTargetInvariantNullShape(self): - self._kernel_leaves_target_invariant_wrapper([]) + self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims, + sess, feed_dict) def testKernelLeavesTargetInvariant1(self): - self._kernel_leaves_target_invariant_wrapper([1]) + self._kernel_leaves_target_invariant_wrapper(1) def testKernelLeavesTargetInvariant2(self): - self._kernel_leaves_target_invariant_wrapper([2]) + self._kernel_leaves_target_invariant_wrapper(2) + + def testKernelLeavesTargetInvariant3(self): + self._kernel_leaves_target_invariant_wrapper(3) - def testKernelLeavesTargetInvariant12(self): - self._kernel_leaves_target_invariant_wrapper([1, 2]) + def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims, + sess, feed_dict=None): + counter = collections.Counter() - def _ais_gets_correct_log_normalizer(self, init, event_dims, sess, - feed_dict=None): def proposal_log_prob(x): - return math_ops.reduce_sum(-0.5 * x * x - 0.5 * np.log(2*np.pi), - event_dims) + counter["proposal_calls"] += 1 + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) + return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), + axis=event_dims) def target_log_prob(x): + counter["target_calls"] += 1 + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) if feed_dict is None: feed_dict = {} - w, _, _ = hmc.ais_chain(200, 0.5, 2, init, target_log_prob, - proposal_log_prob, event_dims) + num_steps = 200 + + _, ais_weights, _ = hmc.sample_annealed_importance_chain( + proposal_log_prob_fn=proposal_log_prob, + num_steps=num_steps, + target_log_prob_fn=target_log_prob, + step_size=0.5, + current_state=init, + num_leapfrog_steps=2, + seed=45) + + # We have three calls because the calculation of `ais_weights` entails + # another call to the `convex_combined_log_prob_fn`. We could refactor + # things to avoid this, if needed (eg, b/72994218). + self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter) + + event_shape = array_ops.shape(init)[independent_chain_ndims:] + event_size = math_ops.reduce_prod(event_shape) + + log_true_normalizer = ( + -self._shape_param * math_ops.log(self._rate_param) + + math_ops.lgamma(self._shape_param)) + log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype) + + log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights) + - np.log(num_steps)) + + ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer) + ais_weights_size = array_ops.size(ais_weights) + standard_error = math_ops.sqrt( + _reduce_variance(ratio_estimate_true) + / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype)) + + [ + ratio_estimate_true_, + log_true_normalizer_, + log_estimated_normalizer_, + standard_error_, + ais_weights_size_, + event_size_, + ] = sess.run([ + ratio_estimate_true, + log_true_normalizer, + log_estimated_normalizer, + standard_error, + ais_weights_size, + event_size, + ], feed_dict) + + logging_ops.vlog(1, " log_true_normalizer: {}\n" + " log_estimated_normalizer: {}\n" + " ais_weights_size: {}\n" + " event_size: {}\n".format( + log_true_normalizer_, + log_estimated_normalizer_, + ais_weights_size_, + event_size_)) + self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_) + + def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims): + """Tests that AIS yields reasonable estimates of normalizers.""" + with self.test_session(graph=ops.Graph()) as sess: + x_ph = array_ops.placeholder(np.float32, name="x_ph") + initial_draws = np.random.normal(size=[30, 2, 1]) + self._ais_gets_correct_log_normalizer( + x_ph, + independent_chain_ndims, + sess, + feed_dict={x_ph: initial_draws}) - w_val = sess.run(w, feed_dict) - init_shape = sess.run(init, feed_dict).shape - normalizer_multiplier = np.prod([init_shape[i] for i in event_dims]) + def testAIS1(self): + self._ais_gets_correct_log_normalizer_wrapper(1) - true_normalizer = -self._shape_param * np.log(self._rate_param) - true_normalizer += special.gammaln(self._shape_param) - true_normalizer *= normalizer_multiplier + def testAIS2(self): + self._ais_gets_correct_log_normalizer_wrapper(2) - n_weights = np.prod(w_val.shape) - normalized_w = np.exp(w_val - true_normalizer) - standard_error = np.std(normalized_w) / np.sqrt(n_weights) - logging.vlog(1, 'True normalizer {}, estimated {}, n_weights {}'.format( - true_normalizer, np.log(normalized_w.mean()) + true_normalizer, - n_weights)) - self.assertNear(normalized_w.mean(), 1.0, 4.0 * standard_error) + def testAIS3(self): + self._ais_gets_correct_log_normalizer_wrapper(3) - def _ais_gets_correct_log_normalizer_wrapper(self, event_dims): - """Tests that AIS yields reasonable estimates of normalizers.""" - with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name='x_ph') + def testSampleAIChainSeedReproducibleWorksCorrectly(self): + with self.test_session(graph=ops.Graph()) as sess: + independent_chain_ndims = 1 + x = np.random.rand(4, 3, 2) - initial_draws = np.random.normal(size=[30, 2, 1]) - feed_dict = {x_ph: initial_draws} + def proposal_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) + return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), + axis=event_dims) - self._ais_gets_correct_log_normalizer(x_ph, event_dims, sess, - feed_dict) + def target_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) + return self._log_gamma_log_prob(x, event_dims) - def testAISNullShape(self): - self._ais_gets_correct_log_normalizer_wrapper([]) + ais_kwargs = dict( + proposal_log_prob_fn=proposal_log_prob, + num_steps=200, + target_log_prob_fn=target_log_prob, + step_size=0.5, + current_state=x, + num_leapfrog_steps=2, + seed=53) - def testAIS1(self): - self._ais_gets_correct_log_normalizer_wrapper([1]) + _, ais_weights0, _ = hmc.sample_annealed_importance_chain( + **ais_kwargs) - def testAIS2(self): - self._ais_gets_correct_log_normalizer_wrapper([2]) + _, ais_weights1, _ = hmc.sample_annealed_importance_chain( + **ais_kwargs) + + [ais_weights0_, ais_weights1_] = sess.run([ + ais_weights0, ais_weights1]) - def testAIS12(self): - self._ais_gets_correct_log_normalizer_wrapper([1, 2]) + self.assertAllClose(ais_weights0_, ais_weights1_, + atol=1e-5, rtol=1e-5) def testNanRejection(self): """Tests that an update that yields NaN potentials gets rejected. @@ -359,86 +601,263 @@ class HMCTest(test.TestCase): """ def _unbounded_exponential_log_prob(x): """An exponential distribution with log-likelihood NaN for x < 0.""" - per_element_potentials = array_ops.where(x < 0, - np.nan * array_ops.ones_like(x), - -x) + per_element_potentials = array_ops.where( + x < 0., + array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)), + -x) return math_ops.reduce_sum(per_element_potentials) - with self.test_session() as sess: + with self.test_session(graph=ops.Graph()) as sess: initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, acceptance_probs, _, _ = hmc.kernel( - 2., 5, initial_x, _unbounded_exponential_log_prob, [0]) - initial_x_val, updated_x_val, acceptance_probs_val = sess.run( - [initial_x, updated_x, acceptance_probs]) - - logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) - logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) - logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) - - self.assertAllEqual(initial_x_val, updated_x_val) - self.assertEqual(acceptance_probs_val, 0.) + updated_x, kernel_results = hmc.kernel( + target_log_prob_fn=_unbounded_exponential_log_prob, + current_state=initial_x, + step_size=2., + num_leapfrog_steps=5, + seed=46) + initial_x_, updated_x_, acceptance_probs_ = sess.run( + [initial_x, updated_x, kernel_results.acceptance_probs]) + + logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) + logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) + logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) + + self.assertAllEqual(initial_x_, updated_x_) + self.assertEqual(acceptance_probs_, 0.) def testNanFromGradsDontPropagate(self): """Test that update with NaN gradients does not cause NaN in results.""" def _nan_log_prob_with_nan_gradient(x): return np.nan * math_ops.reduce_sum(x) - with self.test_session() as sess: + with self.test_session(graph=ops.Graph()) as sess: initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel( - 2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0]) - initial_x_val, updated_x_val, acceptance_probs_val = sess.run( - [initial_x, updated_x, acceptance_probs]) - - logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) - logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) - logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) - - self.assertAllEqual(initial_x_val, updated_x_val) - self.assertEqual(acceptance_probs_val, 0.) + updated_x, kernel_results = hmc.kernel( + target_log_prob_fn=_nan_log_prob_with_nan_gradient, + current_state=initial_x, + step_size=2., + num_leapfrog_steps=5, + seed=47) + initial_x_, updated_x_, acceptance_probs_ = sess.run( + [initial_x, updated_x, kernel_results.acceptance_probs]) + + logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) + logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) + logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) + + self.assertAllEqual(initial_x_, updated_x_) + self.assertEqual(acceptance_probs_, 0.) self.assertAllFinite( - gradients_impl.gradients(updated_x, initial_x)[0].eval()) - self.assertTrue( - gradients_impl.gradients(new_grad, initial_x)[0] is None) + gradients_ops.gradients(updated_x, initial_x)[0].eval()) + self.assertAllEqual([True], [g is None for g in gradients_ops.gradients( + kernel_results.proposed_grads_target_log_prob, initial_x)]) + self.assertAllEqual([False], [g is None for g in gradients_ops.gradients( + kernel_results.proposed_grads_target_log_prob, + kernel_results.proposed_state)]) # Gradients of the acceptance probs and new log prob are not finite. - _ = new_log_prob # Prevent unused arg error. # self.assertAllFinite( - # gradients_impl.gradients(acceptance_probs, initial_x)[0].eval()) + # gradients_ops.gradients(acceptance_probs, initial_x)[0].eval()) # self.assertAllFinite( - # gradients_impl.gradients(new_log_prob, initial_x)[0].eval()) + # gradients_ops.gradients(new_log_prob, initial_x)[0].eval()) + + def _testChainWorksDtype(self, dtype): + with self.test_session(graph=ops.Graph()) as sess: + states, kernel_results = hmc.sample_chain( + num_results=10, + target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), + current_state=np.zeros(5).astype(dtype), + step_size=0.01, + num_leapfrog_steps=10, + seed=48) + states_, acceptance_probs_ = sess.run( + [states, kernel_results.acceptance_probs]) + self.assertEqual(dtype, states_.dtype) + self.assertEqual(dtype, acceptance_probs_.dtype) def testChainWorksIn64Bit(self): - def log_prob(x): - return - math_ops.reduce_sum(x * x, axis=-1) - states, acceptance_probs = hmc.chain( - n_iterations=10, - step_size=np.float64(0.01), - n_leapfrog_steps=10, - initial_x=np.zeros(5).astype(np.float64), - target_log_prob_fn=log_prob, - event_dims=[-1]) - with self.test_session() as sess: - states_, acceptance_probs_ = sess.run([states, acceptance_probs]) - self.assertEqual(np.float64, states_.dtype) - self.assertEqual(np.float64, acceptance_probs_.dtype) + self._testChainWorksDtype(np.float64) def testChainWorksIn16Bit(self): - def log_prob(x): - return - math_ops.reduce_sum(x * x, axis=-1) - states, acceptance_probs = hmc.chain( - n_iterations=10, - step_size=np.float16(0.01), - n_leapfrog_steps=10, - initial_x=np.zeros(5).astype(np.float16), - target_log_prob_fn=log_prob, - event_dims=[-1]) - with self.test_session() as sess: - states_, acceptance_probs_ = sess.run([states, acceptance_probs]) - self.assertEqual(np.float16, states_.dtype) - self.assertEqual(np.float16, acceptance_probs_.dtype) - - -if __name__ == '__main__': + self._testChainWorksDtype(np.float16) + + def testChainWorksCorrelatedMultivariate(self): + dtype = np.float32 + true_mean = dtype([0, 0]) + true_cov = dtype([[1, 0.5], + [0.5, 1]]) + num_results = 2000 + counter = collections.Counter() + with self.test_session(graph=ops.Graph()) as sess: + def target_log_prob(x, y): + counter["target_calls"] += 1 + # Corresponds to unnormalized MVN. + # z = matmul(inv(chol(true_cov)), [x, y] - true_mean) + z = array_ops.stack([x, y], axis=-1) - true_mean + z = array_ops.squeeze( + gen_linalg_ops.matrix_triangular_solve( + np.linalg.cholesky(true_cov), + z[..., array_ops.newaxis]), + axis=-1) + return -0.5 * math_ops.reduce_sum(z**2., axis=-1) + states, _ = hmc.sample_chain( + num_results=num_results, + target_log_prob_fn=target_log_prob, + current_state=[dtype(-2), dtype(2)], + step_size=[0.5, 0.5], + num_leapfrog_steps=2, + num_burnin_steps=200, + num_steps_between_results=1, + seed=54) + self.assertAllEqual(dict(target_calls=2), counter) + states = array_ops.stack(states, axis=-1) + self.assertEqual(num_results, states.shape[0].value) + sample_mean = math_ops.reduce_mean(states, axis=0) + x = states - sample_mean + sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results) + [sample_mean_, sample_cov_] = sess.run([ + sample_mean, sample_cov]) + self.assertAllClose(true_mean, sample_mean_, + atol=0.05, rtol=0.) + self.assertAllClose(true_cov, sample_cov_, + atol=0., rtol=0.1) + + +class _EnergyComputationTest(object): + + def testHandlesNanFromPotential(self): + with self.test_session(graph=ops.Graph()) as sess: + x = [1, np.inf, -np.inf, np.nan] + target_log_prob, proposed_target_log_prob = [ + self.dtype(x.flatten()) for x in np.meshgrid(x, x)] + num_chains = len(target_log_prob) + dummy_momentums = [-1, 1] + momentums = [self.dtype([dummy_momentums] * num_chains)] + proposed_momentums = [self.dtype([dummy_momentums] * num_chains)] + + target_log_prob = ops.convert_to_tensor(target_log_prob) + momentums = [ops.convert_to_tensor(momentums[0])] + proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) + proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] + + energy = _compute_energy_change( + target_log_prob, + momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims=1) + grads = gradients_ops.gradients(energy, momentums) + + [actual_energy, grads_] = sess.run([energy, grads]) + + # Ensure energy is `inf` (note: that's positive inf) in weird cases and + # finite otherwise. + expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) + self.assertAllEqual(expected_energy, actual_energy) + + # Ensure gradient is finite. + self.assertAllEqual(np.ones_like(grads_).astype(np.bool), + np.isfinite(grads_)) + + def testHandlesNanFromKinetic(self): + with self.test_session(graph=ops.Graph()) as sess: + x = [1, np.inf, -np.inf, np.nan] + momentums, proposed_momentums = [ + [np.reshape(self.dtype(x), [-1, 1])] + for x in np.meshgrid(x, x)] + num_chains = len(momentums[0]) + target_log_prob = np.ones(num_chains, self.dtype) + proposed_target_log_prob = np.ones(num_chains, self.dtype) + + target_log_prob = ops.convert_to_tensor(target_log_prob) + momentums = [ops.convert_to_tensor(momentums[0])] + proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) + proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] + + energy = _compute_energy_change( + target_log_prob, + momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims=1) + grads = gradients_ops.gradients(energy, momentums) + + [actual_energy, grads_] = sess.run([energy, grads]) + + # Ensure energy is `inf` (note: that's positive inf) in weird cases and + # finite otherwise. + expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) + self.assertAllEqual(expected_energy, actual_energy) + + # Ensure gradient is finite. + g = grads_[0].reshape([len(x), len(x)])[:, 0] + self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) + + # The remaining gradients are nan because the momentum was itself nan or + # inf. + g = grads_[0].reshape([len(x), len(x)])[:, 1:] + self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g)) + + +class EnergyComputationTest16(test.TestCase, _EnergyComputationTest): + dtype = np.float16 + + +class EnergyComputationTest32(test.TestCase, _EnergyComputationTest): + dtype = np.float32 + + +class EnergyComputationTest64(test.TestCase, _EnergyComputationTest): + dtype = np.float64 + + +class _HMCHandlesLists(object): + + def testStateParts(self): + with self.test_session(graph=ops.Graph()) as sess: + dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) + dist_y = independent_lib.Independent( + gamma_lib.Gamma(concentration=self.dtype([1, 2]), + rate=self.dtype([0.5, 0.75])), + reinterpreted_batch_ndims=1) + def target_log_prob(x, y): + return dist_x.log_prob(x) + dist_y.log_prob(y) + x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] + samples, _ = hmc.sample_chain( + num_results=int(2e3), + target_log_prob_fn=target_log_prob, + current_state=x0, + step_size=0.85, + num_leapfrog_steps=3, + num_burnin_steps=int(250), + seed=49) + actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] + actual_vars = [_reduce_variance(s, axis=0) for s in samples] + expected_means = [dist_x.mean(), dist_y.mean()] + expected_vars = [dist_x.variance(), dist_y.variance()] + [ + actual_means_, + actual_vars_, + expected_means_, + expected_vars_, + ] = sess.run([ + actual_means, + actual_vars, + expected_means, + expected_vars, + ]) + self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) + self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25) + + +class HMCHandlesLists32(_HMCHandlesLists, test.TestCase): + dtype = np.float32 + + +class HMCHandlesLists64(_HMCHandlesLists, test.TestCase): + dtype = np.float64 + + +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py new file mode 100644 index 0000000000000000000000000000000000000000..52e36e135d95c1ec919c710f35d59073c2134d05 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py @@ -0,0 +1,445 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MCMC diagnostic utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics_impl as mcmc_diagnostics +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import spectral_ops_test_util +from tensorflow.python.platform import test + +rng = np.random.RandomState(42) + + +class _EffectiveSampleSizeTest(object): + + @property + def use_static_shape(self): + raise NotImplementedError( + "Subclass failed to implement `use_static_shape`.") + + def _check_versus_expected_effective_sample_size(self, + x_, + expected_ess, + sess, + atol=1e-2, + rtol=1e-2, + filter_threshold=None, + filter_beyond_lag=None): + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + ess = mcmc_diagnostics.effective_sample_size( + x, + filter_threshold=filter_threshold, + filter_beyond_lag=filter_beyond_lag) + if self.use_static_shape: + self.assertAllEqual(x.shape[1:], ess.shape) + + ess_ = sess.run(ess) + + self.assertAllClose( + np.ones_like(ess_) * expected_ess, ess_, atol=atol, rtol=rtol) + + def testIidRank1NormalHasFullEssMaxLags10(self): + # With a length 5000 iid normal sequence, and filter_beyond_lag = 10, we + # should have a good estimate of ESS, and it should be close to the full + # sequence length of 5000. + # The choice of filter_beyond_lag = 10 is a short cutoff, reasonable only + # since we know the correlation length should be zero right away. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000).astype(np.float32), + expected_ess=5000, + sess=sess, + filter_beyond_lag=10, + filter_threshold=None, + rtol=0.3) + + def testIidRank2NormalHasFullEssMaxLags10(self): + # See similar test for Rank1Normal for reasoning. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000, 2).astype(np.float32), + expected_ess=5000, + sess=sess, + filter_beyond_lag=10, + filter_threshold=None, + rtol=0.3) + + def testIidRank1NormalHasFullEssMaxLagThresholdZero(self): + # With a length 5000 iid normal sequence, and filter_threshold = 0, + # we should have a super-duper estimate of ESS, and it should be very close + # to the full sequence length of 5000. + # The choice of filter_beyond_lag = 0 means we cutoff as soon as the + # auto-corris below zero. This should happen very quickly, due to the fact + # that the theoretical auto-corr is [1, 0, 0,...] + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000).astype(np.float32), + expected_ess=5000, + sess=sess, + filter_beyond_lag=None, + filter_threshold=0., + rtol=0.1) + + def testIidRank2NormalHasFullEssMaxLagThresholdZero(self): + # See similar test for Rank1Normal for reasoning. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000, 2).astype(np.float32), + expected_ess=5000, + sess=sess, + filter_beyond_lag=None, + filter_threshold=0., + rtol=0.1) + + def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLags50(self): + # Create x_, such that + # x_[i] = iid_x_[0], i = 0,...,9 + # x_[i] = iid_x_[1], i = 10,..., 19, + # and so on. + iid_x_ = rng.randn(5000, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=x_, + expected_ess=50000 // 10, + sess=sess, + filter_beyond_lag=50, + filter_threshold=None, + rtol=0.2) + + def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLagsThresholdZero( + self): + # Create x_, such that + # x_[i] = iid_x_[0], i = 0,...,9 + # x_[i] = iid_x_[1], i = 10,..., 19, + # and so on. + iid_x_ = rng.randn(5000, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=x_, + expected_ess=50000 // 10, + sess=sess, + filter_beyond_lag=None, + filter_threshold=0., + rtol=0.1) + + def testListArgs(self): + # x_ has correlation length 10 ==> ESS = N / 10 + # y_ has correlation length 1 ==> ESS = N + iid_x_ = rng.randn(5000, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) + y_ = rng.randn(50000).astype(np.float32) + states = [x_, x_, y_, y_] + filter_threshold = [0., None, 0., None] + filter_beyond_lag = [None, 5, None, 5] + + # See other tests for reasoning on tolerance. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + ess = mcmc_diagnostics.effective_sample_size( + states, + filter_threshold=filter_threshold, + filter_beyond_lag=filter_beyond_lag) + ess_ = sess.run(ess) + self.assertAllEqual(4, len(ess_)) + + self.assertAllClose(50000 // 10, ess_[0], rtol=0.3) + self.assertAllClose(50000 // 10, ess_[1], rtol=0.3) + self.assertAllClose(50000, ess_[2], rtol=0.1) + self.assertAllClose(50000, ess_[3], rtol=0.1) + + def testMaxLagsThresholdLessThanNeg1SameAsNone(self): + # Setting both means we filter out items R_k from the auto-correlation + # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. + + # x_ has correlation length 10. + iid_x_ = rng.randn(500, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + + ess_none_none = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=None, filter_beyond_lag=None) + ess_none_200 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=None, filter_beyond_lag=200) + ess_neg2_200 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=-2., filter_beyond_lag=200) + ess_neg2_none = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=-2., filter_beyond_lag=None) + ess_none_none_, ess_none_200_, ess_neg2_200_, ess_neg2_none_ = sess.run( + [ess_none_none, ess_none_200, ess_neg2_200, ess_neg2_none]) + + # filter_threshold=-2 <==> filter_threshold=None. + self.assertAllClose(ess_none_none_, ess_neg2_none_) + self.assertAllClose(ess_none_200_, ess_neg2_200_) + + def testMaxLagsArgsAddInAnOrManner(self): + # Setting both means we filter out items R_k from the auto-correlation + # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. + + # x_ has correlation length 10. + iid_x_ = rng.randn(500, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + + ess_1_9 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=1., filter_beyond_lag=9) + ess_1_none = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=1., filter_beyond_lag=None) + ess_none_9 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=1., filter_beyond_lag=9) + ess_1_9_, ess_1_none_, ess_none_9_ = sess.run( + [ess_1_9, ess_1_none, ess_none_9]) + + # Since R_k = 1 for k < 10, and R_k < 1 for k >= 10, + # filter_threshold = 1 <==> filter_beyond_lag = 9. + self.assertAllClose(ess_1_9_, ess_1_none_) + self.assertAllClose(ess_1_9_, ess_none_9_) + + +class EffectiveSampleSizeStaticTest(test.TestCase, _EffectiveSampleSizeTest): + + @property + def use_static_shape(self): + return True + + +class EffectiveSampleSizeDynamicTest(test.TestCase, _EffectiveSampleSizeTest): + + @property + def use_static_shape(self): + return False + + +class _PotentialScaleReductionTest(object): + + @property + def use_static_shape(self): + raise NotImplementedError( + "Subclass failed to impliment `use_static_shape`.") + + def testListOfStatesWhereFirstPassesSecondFails(self): + """Simple test showing API with two states. Read first!.""" + n_samples = 1000 + + # state_0 is two scalar chains taken from iid Normal(0, 1). Will pass. + state_0 = rng.randn(n_samples, 2) + + # state_1 is three 4-variate chains taken from Normal(0, 1) that have been + # shifted. Since every chain is shifted, they are not the same, and the + # test should fail. + offset = np.array([1., -1., 2.]).reshape(3, 1) + state_1 = rng.randn(n_samples, 3, 4) + offset + + rhat = mcmc_diagnostics.potential_scale_reduction( + chains_states=[state_0, state_1], independent_chain_ndims=1) + + self.assertIsInstance(rhat, list) + with self.test_session() as sess: + rhat_0_, rhat_1_ = sess.run(rhat) + + # r_hat_0 should be close to 1, meaning test is passed. + self.assertAllEqual((), rhat_0_.shape) + self.assertAllClose(1., rhat_0_, rtol=0.02) + + # r_hat_1 should be greater than 1.2, meaning test has failed. + self.assertAllEqual((4,), rhat_1_.shape) + self.assertAllEqual(np.ones_like(rhat_1_).astype(bool), rhat_1_ > 1.2) + + def check_results(self, state_, independent_chain_shape, should_pass): + sample_ndims = 1 + independent_chain_ndims = len(independent_chain_shape) + with self.test_session(): + state = array_ops.placeholder_with_default( + input=state_, shape=state_.shape if self.use_static_shape else None) + + rhat = mcmc_diagnostics.potential_scale_reduction( + state, independent_chain_ndims=independent_chain_ndims) + + if self.use_static_shape: + self.assertAllEqual( + state_.shape[sample_ndims + independent_chain_ndims:], rhat.shape) + + rhat_ = rhat.eval() + if should_pass: + self.assertAllClose(np.ones_like(rhat_), rhat_, atol=0, rtol=0.02) + else: + self.assertAllEqual(np.ones_like(rhat_).astype(bool), rhat_ > 1.2) + + def iid_normal_chains_should_pass_wrapper(self, + sample_shape, + independent_chain_shape, + other_shape, + dtype=np.float32): + """Check results with iid normal chains.""" + + state_shape = sample_shape + independent_chain_shape + other_shape + state_ = rng.randn(*state_shape).astype(dtype) + + # The "other" dimensions do not have to be identical, just independent, so + # force them to not be identical. + if other_shape: + state_ *= rng.rand(*other_shape).astype(dtype) + + self.check_results(state_, independent_chain_shape, should_pass=True) + + def testPassingIIDNdimsAreIndependentOneOtherZero(self): + self.iid_normal_chains_should_pass_wrapper( + sample_shape=[10000], independent_chain_shape=[4], other_shape=[]) + + def testPassingIIDNdimsAreIndependentOneOtherOne(self): + self.iid_normal_chains_should_pass_wrapper( + sample_shape=[10000], independent_chain_shape=[3], other_shape=[7]) + + def testPassingIIDNdimsAreIndependentOneOtherTwo(self): + self.iid_normal_chains_should_pass_wrapper( + sample_shape=[10000], independent_chain_shape=[2], other_shape=[5, 7]) + + def testPassingIIDNdimsAreIndependentTwoOtherTwo64Bit(self): + self.iid_normal_chains_should_pass_wrapper( + sample_shape=[10000], + independent_chain_shape=[2, 3], + other_shape=[5, 7], + dtype=np.float64) + + def offset_normal_chains_should_fail_wrapper( + self, sample_shape, independent_chain_shape, other_shape): + """Check results with normal chains that are offset from each other.""" + + state_shape = sample_shape + independent_chain_shape + other_shape + state_ = rng.randn(*state_shape) + + # Add a significant offset to the different (formerly iid) chains. + offset = np.linspace( + 0, 2, num=np.prod(independent_chain_shape)).reshape([1] * len( + sample_shape) + independent_chain_shape + [1] * len(other_shape)) + state_ += offset + + self.check_results(state_, independent_chain_shape, should_pass=False) + + def testFailingOffsetNdimsAreSampleOneIndependentOneOtherOne(self): + self.offset_normal_chains_should_fail_wrapper( + sample_shape=[10000], independent_chain_shape=[2], other_shape=[5]) + + +class PotentialScaleReductionStaticTest(test.TestCase, + _PotentialScaleReductionTest): + + @property + def use_static_shape(self): + return True + + def testIndependentNdimsLessThanOneRaises(self): + with self.assertRaisesRegexp(ValueError, "independent_chain_ndims"): + mcmc_diagnostics.potential_scale_reduction( + rng.rand(2, 3, 4), independent_chain_ndims=0) + + +class PotentialScaleReductionDynamicTest(test.TestCase, + _PotentialScaleReductionTest): + + @property + def use_static_shape(self): + return False + + +class _ReduceVarianceTest(object): + + @property + def use_static_shape(self): + raise NotImplementedError( + "Subclass failed to impliment `use_static_shape`.") + + def check_versus_numpy(self, x_, axis, biased, keepdims): + with self.test_session(): + x_ = np.asarray(x_) + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + var = mcmc_diagnostics._reduce_variance( + x, axis=axis, biased=biased, keepdims=keepdims) + np_var = np.var(x_, axis=axis, ddof=0 if biased else 1, keepdims=keepdims) + + if self.use_static_shape: + self.assertAllEqual(np_var.shape, var.shape) + + var_ = var.eval() + # We will mask below, which changes shape, so check shape explicitly here. + self.assertAllEqual(np_var.shape, var_.shape) + + # We get NaN when we divide by zero due to the size being the same as ddof + nan_mask = np.isnan(np_var) + if nan_mask.any(): + self.assertTrue(np.isnan(var_[nan_mask]).all()) + self.assertAllClose(np_var[~nan_mask], var_[~nan_mask], atol=0, rtol=0.02) + + def testScalarBiasedTrue(self): + self.check_versus_numpy(x_=-1.234, axis=None, biased=True, keepdims=False) + + def testScalarBiasedFalse(self): + # This should result in NaN. + self.check_versus_numpy(x_=-1.234, axis=None, biased=False, keepdims=False) + + def testShape2x3x4AxisNoneBiasedFalseKeepdimsFalse(self): + self.check_versus_numpy( + x_=rng.randn(2, 3, 4), axis=None, biased=True, keepdims=False) + + def testShape2x3x4Axis1BiasedFalseKeepdimsTrue(self): + self.check_versus_numpy( + x_=rng.randn(2, 3, 4), axis=1, biased=True, keepdims=True) + + def testShape2x3x4x5Axis13BiasedFalseKeepdimsTrue(self): + self.check_versus_numpy( + x_=rng.randn(2, 3, 4, 5), axis=1, biased=True, keepdims=True) + + def testShape2x3x4x5Axis13BiasedFalseKeepdimsFalse(self): + self.check_versus_numpy( + x_=rng.randn(2, 3, 4, 5), axis=1, biased=False, keepdims=False) + + +class ReduceVarianceTestStaticShape(test.TestCase, _ReduceVarianceTest): + + @property + def use_static_shape(self): + return True + + +class ReduceVarianceTestDynamicShape(test.TestCase, _ReduceVarianceTest): + + @property + def use_static_shape(self): + return False + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f978cf86417dc5ff5412a3eee584330a266e0964 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py @@ -0,0 +1,135 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utility functions related to managing `tf.Variable`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import variable_utils + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope as varscope_ops +from tensorflow.python.ops import variables as variables_ops +from tensorflow.python.platform import test + + +def test_fn(x): + x = ops.convert_to_tensor(x, name="x") + dtype = x.dtype.as_numpy_dtype + s = x.shape.as_list() + z = varscope_ops.get_variable( + name="z", + dtype=dtype, + initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)) + y = varscope_ops.get_variable( + name="y", + dtype=dtype, + initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)**2) + return x + y + z + + +class _WrapCallableTest(object): + + def testDefaultArgsWorkCorrectly(self): + with self.test_session(): + x = constant_op.constant(self.dtype([0.1, 0.2])) + wrapped_fn, vars_args = variable_utils.externalize_variables_as_args( + test_fn, [x]) + + varscope_ops.get_variable_scope().reuse_variables() + + result = wrapped_fn(self.dtype(2), [3, 4, 5], 0.5) + + y_actual = varscope_ops.get_variable("y", dtype=self.dtype) + z_actual = varscope_ops.get_variable("z", dtype=self.dtype) + + variables_ops.global_variables_initializer().run() + result_ = result.eval() + + self.assertEqual(self.dtype, result_.dtype) + self.assertAllEqual([5.5, 6.5, 7.5], result_) + self.assertAllEqual([y_actual, z_actual], vars_args) + + def testNonDefaultArgsWorkCorrectly(self): + with self.test_session(): + x = constant_op.constant(self.dtype([0.1, 0.2])) + + _ = test_fn(self.dtype([0., 0.])) # Needed to create vars. + varscope_ops.get_variable_scope().reuse_variables() + + y_actual = varscope_ops.get_variable("y", dtype=self.dtype) + + wrapped_fn, vars_args = variable_utils.externalize_variables_as_args( + test_fn, [x], possible_ancestor_vars=[y_actual]) + + result = wrapped_fn(self.dtype([2, 3]), 0.5) # x, y + + variables_ops.global_variables_initializer().run() + result_ = result.eval() + + self.assertEqual(self.dtype, result_.dtype) + self.assertAllEqual([2.5, 4.5], result_) + self.assertAllEqual([y_actual], vars_args) + + def testWarnings(self): + with self.test_session(): + x = constant_op.constant(self.dtype([0.1, 0.2])) + wrapped_fn, _ = variable_utils.externalize_variables_as_args( + test_fn, [x], possible_ancestor_vars=[]) + varscope_ops.get_variable_scope().reuse_variables() + with warnings.catch_warnings(record=True) as w: + wrapped_fn(self.dtype(2)) + w = sorted(w, key=lambda w: str(w.message)) + self.assertEqual(2, len(w)) + self.assertRegexpMatches( + str(w[0].message), + r"Variable .* 'y:0' .* not found in bypass dict.") + self.assertRegexpMatches( + str(w[1].message), + r"Variable .* 'z:0' .* not found in bypass dict.") + + def testExceptions(self): + with self.test_session(): + x = constant_op.constant(self.dtype([0.1, 0.2])) + wrapped_fn, _ = variable_utils.externalize_variables_as_args( + test_fn, + [x], + possible_ancestor_vars=[], + assert_variable_override=True) + varscope_ops.get_variable_scope().reuse_variables() + with self.assertRaisesRegexp(ValueError, r"not found"): + wrapped_fn(self.dtype(2)) + + +class WrapCallableTest16(test.TestCase, _WrapCallableTest): + dtype = np.float16 + + +class WrapCallableTest32(test.TestCase, _WrapCallableTest): + dtype = np.float32 + + +class WrapCallableTest64(test.TestCase, _WrapCallableTest): + dtype = np.float64 + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py index fdc12e3b21466a2c552124d6c6a339a0c25f9f46..d44fe6529a7ff0da0c6747e193fdb98a272a8da3 100644 --- a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py @@ -31,8 +31,7 @@ __all__ = [ ] -def custom_gradient(fx, gx, x, axis=(), - fx_gx_manually_stopped=False, +def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False, name=None): """Enables specifying a custom gradient. @@ -43,7 +42,8 @@ def custom_gradient(fx, gx, x, axis=(), h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x)) ``` - is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] = stop_gradient(g(x)).` + is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] = + stop_gradient(g(x)).` In addition to scalar-domain/scalar-range functions, this function also supports tensor-domain/scalar-range functions. However, in the latter case it diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py index 977d42fc16bb91777a76c45ac24f3c5dc587f5fe..7fd5652c5c3e085b23c05baef6e3a42b7a42e08f 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. -""" +"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.""" from __future__ import absolute_import from __future__ import division @@ -24,11 +23,9 @@ from tensorflow.contrib.bayesflow.python.ops.hmc_impl import * # pylint: disabl from tensorflow.python.util import all_util _allowed_symbols = [ - 'chain', - 'kernel', - 'leapfrog_integrator', - 'leapfrog_step', - 'ais_chain' + "sample_chain", + "sample_annealed_importance_chain", + "kernel", ] all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index 5685a942e98800a39ec718adc67bcfd43aeafd52..f724910c59315867a42a56fab3deb36f5d3adb7a 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -14,17 +14,16 @@ # ============================================================================== """Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. -@@chain -@@update -@@leapfrog_integrator -@@leapfrog_step -@@ais_chain +@@sample_chain +@@sample_annealed_importance_chain +@@kernel """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import numpy as np from tensorflow.python.framework import dtypes @@ -32,168 +31,326 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import gradients_impl as gradients_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.ops.distributions import util as distributions_util __all__ = [ - 'chain', - 'kernel', - 'leapfrog_integrator', - 'leapfrog_step', - 'ais_chain' + "sample_chain", + "sample_annealed_importance_chain", + "kernel", ] -def _make_potential_and_grad(target_log_prob_fn): - def potential_and_grad(x): - log_prob_result = -target_log_prob_fn(x) - grad_result = gradients_impl.gradients(math_ops.reduce_sum(log_prob_result), - x)[0] - return log_prob_result, grad_result - return potential_and_grad - - -def chain(n_iterations, step_size, n_leapfrog_steps, initial_x, - target_log_prob_fn, event_dims=(), name=None): +KernelResults = collections.namedtuple( + "KernelResults", + [ + "acceptance_probs", + "current_grads_target_log_prob", # "Current result" means "accepted". + "current_target_log_prob", # "Current result" means "accepted". + "energy_change", + "is_accepted", + "proposed_grads_target_log_prob", + "proposed_state", + "proposed_target_log_prob", + "random_positive", + ]) + + +def _make_dummy_kernel_results( + dummy_state, + dummy_target_log_prob, + dummy_grads_target_log_prob): + return KernelResults( + acceptance_probs=dummy_target_log_prob, + current_grads_target_log_prob=dummy_grads_target_log_prob, + current_target_log_prob=dummy_target_log_prob, + energy_change=dummy_target_log_prob, + is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool), + proposed_grads_target_log_prob=dummy_grads_target_log_prob, + proposed_state=dummy_state, + proposed_target_log_prob=dummy_target_log_prob, + random_positive=dummy_target_log_prob, + ) + + +def sample_chain( + num_results, + target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + num_burnin_steps=0, + num_steps_between_results=0, + seed=None, + current_target_log_prob=None, + current_grads_target_log_prob=None, + name=None): """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains. - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) - algorithm that takes a series of gradient-informed steps to produce - a Metropolis proposal. This function samples from an HMC Markov - chain whose initial state is `initial_x` and whose stationary - distribution has log-density `target_log_prob_fn()`. - - This function can update multiple chains in parallel. It assumes - that all dimensions of `initial_x` not specified in `event_dims` are - independent, and should therefore be updated independently. The - output of `target_log_prob_fn()` should sum log-probabilities across - all event dimensions. Slices along dimensions not in `event_dims` - may have different target distributions; this is up to + Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm + that takes a series of gradient-informed steps to produce a Metropolis + proposal. This function samples from an HMC Markov chain at `current_state` + and whose stationary distribution has log-unnormalized-density `target_log_prob_fn()`. - This function basically just wraps `hmc.kernel()` in a tf.scan() loop. + This function samples from multiple chains in parallel. It assumes that the + the leftmost dimensions of (each) `current_state` (part) index an independent + chain. The function `target_log_prob_fn()` sums log-probabilities across + event dimensions (i.e., current state (part) rightmost dimensions). Each + element of the output of `target_log_prob_fn()` represents the (possibly + unnormalized) log-probability of the joint distribution over (all) the current + state (parts). - Args: - n_iterations: Integer number of Markov chain updates to run. - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `initial_x`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. - When possible, it's often helpful to match per-variable step - sizes to the standard deviations of the target distribution in - each variable. - n_leapfrog_steps: Integer number of steps to run the leapfrog - integrator for. Total progress per HMC step is roughly - proportional to step_size * n_leapfrog_steps. - initial_x: Tensor of initial state(s) of the Markov chain(s). - target_log_prob_fn: Python callable which takes an argument like `initial_x` - and returns its (possibly unnormalized) log-density under the target - distribution. - event_dims: List of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. - name: Python `str` name prefixed to Ops created by this function. + The `current_state` can be represented as a single `Tensor` or a `list` of + `Tensors` which collectively represent the current state. When specifying a + `list`, one must also specify a list of `step_size`s. - Returns: - acceptance_probs: Tensor with the acceptance probabilities for each - iteration. Has shape matching `target_log_prob_fn(initial_x)`. - chain_states: Tensor with the state of the Markov chain at each iteration. - Has shape `[n_iterations, initial_x.shape[0],...,initial_x.shape[-1]`. + Note: `target_log_prob_fn` is called exactly twice. + + Only one out of every `num_steps_between_samples + 1` steps is included in the + returned results. This "thinning" comes at a cost of reduced statistical + power, while reducing memory requirements and autocorrelation. For more + discussion see [1]. + + [1]: "Statistically efficient thinning of a Markov chain sampler." + Art B. Owen. April 2017. + http://statweb.stanford.edu/~owen/reports/bestthinning.pdf #### Examples: - ```python - # Sampling from a standard normal (note `log_joint()` is unnormalized): - def log_joint(x): - return tf.reduce_sum(-0.5 * tf.square(x)) - chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint, - event_dims=[0]) - # Discard first half of chain as warmup/burn-in - warmed_up = chain[500:] - mean_est = tf.reduce_mean(warmed_up, 0) - var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) - ``` + ##### Sample from a diagonal-variance Gaussian. ```python - # Sampling from a diagonal-variance Gaussian: - variances = tf.linspace(1., 3., 10) - def log_joint(x): - return tf.reduce_sum(-0.5 / variances * tf.square(x)) - chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint, - event_dims=[0]) - # Discard first half of chain as warmup/burn-in - warmed_up = chain[500:] - mean_est = tf.reduce_mean(warmed_up, 0) - var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) + tfd = tf.contrib.distributions + + def make_likelihood(true_variances): + return tfd.MultivariateNormalDiag( + scale_diag=tf.sqrt(true_variances)) + + dims = 10 + dtype = np.float32 + true_variances = tf.linspace(dtype(1), dtype(3), dims) + likelihood = make_likelihood(true_variances) + + states, kernel_results = hmc.sample_chain( + num_results=1000, + target_log_prob_fn=likelihood.log_prob, + current_state=tf.zeros(dims), + step_size=0.5, + num_leapfrog_steps=2, + num_burnin_steps=500) + + # Compute sample stats. + sample_mean = tf.reduce_mean(states, axis=0) + sample_var = tf.reduce_mean( + tf.squared_difference(states, sample_mean), + axis=0) ``` - ```python - # Sampling from factor-analysis posteriors with known factors W: - # mu[i, j] ~ Normal(0, 1) - # x[i] ~ Normal(matmul(mu[i], W), I) - def log_joint(mu, x, W): - prior = -0.5 * tf.reduce_sum(tf.square(mu), 1) - x_mean = tf.matmul(mu, W) - likelihood = -0.5 * tf.reduce_sum(tf.square(x - x_mean), 1) - return prior + likelihood - chain, acceptance_probs = hmc.chain(1000, 0.1, 2, - tf.zeros([x.shape[0], W.shape[0]]), - lambda mu: log_joint(mu, x, W), - event_dims=[1]) - # Discard first half of chain as warmup/burn-in - warmed_up = chain[500:] - mean_est = tf.reduce_mean(warmed_up, 0) - var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) + ##### Sampling from factor-analysis posteriors with known factors. + + I.e., + + ```none + for i=1..n: + w[i] ~ Normal(0, eye(d)) # prior + x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood ``` + where `F` denotes factors. + ```python - # Sampling from the posterior of a Bayesian regression model.: - - # Run 100 chains in parallel, each with a different initialization. - initial_beta = tf.random_normal([100, x.shape[1]]) - chain, acceptance_probs = hmc.chain(1000, 0.1, 10, initial_beta, - log_joint_partial, event_dims=[1]) - # Discard first halves of chains as warmup/burn-in - warmed_up = chain[500:] - # Averaging across samples within a chain and across chains - mean_est = tf.reduce_mean(warmed_up, [0, 1]) - var_est = tf.reduce_mean(tf.square(warmed_up), [0, 1]) - tf.square(mean_est) + tfd = tf.contrib.distributions + + def make_prior(dims, dtype): + return tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)) + + def make_likelihood(weights, factors): + return tfd.MultivariateNormalDiag( + loc=tf.tensordot(weights, factors, axes=[[0], [-1]])) + + # Setup data. + num_weights = 10 + num_factors = 4 + num_chains = 100 + dtype = np.float32 + + prior = make_prior(num_weights, dtype) + weights = prior.sample(num_chains) + factors = np.random.randn(num_factors, num_weights).astype(dtype) + x = make_likelihood(weights, factors).sample(num_chains) + + def target_log_prob(w): + # Target joint is: `f(w) = p(w, x | factors)`. + return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x) + + # Get `num_results` samples from `num_chains` independent chains. + chains_states, kernels_results = hmc.sample_chain( + num_results=1000, + target_log_prob_fn=target_log_prob, + current_state=tf.zeros([num_chains, dims], dtype), + step_size=0.1, + num_leapfrog_steps=2, + num_burnin_steps=500) + + # Compute sample stats. + sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) + sample_var = tf.reduce_mean( + tf.squared_difference(chains_states, sample_mean), + axis=[0, 1]) ``` + + Args: + num_results: Integer number of Markov chain draws. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + step_size: `Tensor` or Python `list` of `Tensor`s representing the step size + for the leapfrog integrator. Must broadcast with the shape of + `current_state`. Larger step sizes lead to faster progress, but too-large + step sizes make rejection exponentially more likely. When possible, it's + often helpful to match per-variable step sizes to the standard deviations + of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + num_burnin_steps: Integer number of chain steps to take before starting to + collect results. + Default value: 0 (i.e., no burn-in). + num_steps_between_results: Integer number of chain steps between collecting + a result. Only one out of every `num_steps_between_samples + 1` steps is + included in the returned results. This "thinning" comes at a cost of + reduced statistical power, while reducing memory requirements and + autocorrelation. For more discussion see [1]. + Default value: 0 (i.e., no subsampling). + seed: Python integer to seed the random number generator. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn` at the `current_state`. The only reason to specify + this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + current_grads_target_log_prob: (Optional) Python list of `Tensor`s + representing gradient of `target_log_prob` at the `current_state` and wrt + the `current_state`. Must have same shape as `current_state`. The only + reason to specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_sample_chain"). + + Returns: + accepted_states: Tensor or Python list of `Tensor`s representing the + state(s) of the Markov chain(s) at each result step. Has same shape as + input `current_state` but with a prepended `num_results`-size dimension. + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. """ - with ops.name_scope(name, 'hmc_chain', [n_iterations, step_size, - n_leapfrog_steps, initial_x]): - initial_x = ops.convert_to_tensor(initial_x, name='initial_x') - non_event_shape = array_ops.shape(target_log_prob_fn(initial_x)) - - def body(a, _): - updated_x, acceptance_probs, log_prob, grad = kernel( - step_size, n_leapfrog_steps, a[0], target_log_prob_fn, event_dims, - a[2], a[3]) - return updated_x, acceptance_probs, log_prob, grad - - potential_and_grad = _make_potential_and_grad(target_log_prob_fn) - potential, grad = potential_and_grad(initial_x) - return functional_ops.scan( - body, array_ops.zeros(n_iterations, dtype=initial_x.dtype), - (initial_x, - array_ops.zeros(non_event_shape, dtype=initial_x.dtype), - -potential, -grad))[:2] - - -def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, - target_log_prob_fn, proposal_log_prob_fn, event_dims=(), - name=None): + with ops.name_scope( + name, "hmc_sample_chain", + [num_results, current_state, step_size, num_leapfrog_steps, + num_burnin_steps, num_steps_between_results, seed, + current_target_log_prob, current_grads_target_log_prob]): + with ops.name_scope("initialize"): + [ + current_state, + step_size, + current_target_log_prob, + current_grads_target_log_prob, + ] = _prepare_args( + target_log_prob_fn, + current_state, + step_size, + current_target_log_prob, + current_grads_target_log_prob) + num_results = ops.convert_to_tensor( + num_results, + dtype=dtypes.int32, + name="num_results") + num_leapfrog_steps = ops.convert_to_tensor( + num_leapfrog_steps, + dtype=dtypes.int32, + name="num_leapfrog_steps") + num_burnin_steps = ops.convert_to_tensor( + num_burnin_steps, + dtype=dtypes.int32, + name="num_burnin_steps") + num_steps_between_results = ops.convert_to_tensor( + num_steps_between_results, + dtype=dtypes.int32, + name="num_steps_between_results") + + def _run_chain(num_steps, current_state, kernel_results): + """Runs the chain(s) for `num_steps`.""" + def _loop_body(iter_, current_state, kernel_results): + return [iter_ + 1] + list(kernel( + target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + seed, + kernel_results.current_target_log_prob, + kernel_results.current_grads_target_log_prob)) + while_loop_kwargs = dict( + cond=lambda iter_, *args: iter_ < num_steps, + body=_loop_body, + loop_vars=[ + np.int32(0), + current_state, + kernel_results, + ], + ) + if seed is not None: + while_loop_kwargs["parallel_iterations"] = 1 + return control_flow_ops.while_loop( + **while_loop_kwargs)[1:] # Lop-off "iter_". + + def _scan_body(args_list, iter_): + """Closure which implements `tf.scan` body.""" + current_state, kernel_results = args_list + return _run_chain( + 1 + array_ops.where(math_ops.equal(iter_, 0), + num_burnin_steps, + num_steps_between_results), + current_state, + kernel_results) + + scan_kwargs = dict( + fn=_scan_body, + elems=math_ops.range(num_results), # iter_: used to choose burnin. + initializer=[ + current_state, + _make_dummy_kernel_results( + current_state, + current_target_log_prob, + current_grads_target_log_prob), + ]) + if seed is not None: + scan_kwargs["parallel_iterations"] = 1 + return functional_ops.scan(**scan_kwargs) + + +def sample_annealed_importance_chain( + proposal_log_prob_fn, + num_steps, + target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + seed=None, + name=None): """Runs annealed importance sampling (AIS) to estimate normalizing constants. - This routine uses Hamiltonian Monte Carlo to sample from a series of + This function uses Hamiltonian Monte Carlo to sample from a series of distributions that slowly interpolates between an initial "proposal" - distribution + distribution: `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` - and the target distribution + and the target distribution: `exp(target_log_prob_fn(x) - target_log_normalizer)`, @@ -202,113 +359,203 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, normalizing constants of the initial distribution and the target distribution: - E[exp(w)] = exp(target_log_normalizer - proposal_log_normalizer). + `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. - Args: - n_iterations: Integer number of Markov chain updates to run. More - iterations means more expense, but smoother annealing between q - and p, which in turn means exponentially lower variance for the - normalizing constant estimator. - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `initial_x`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. - When possible, it's often helpful to match per-variable step - sizes to the standard deviations of the target distribution in - each variable. - n_leapfrog_steps: Integer number of steps to run the leapfrog - integrator for. Total progress per HMC step is roughly - proportional to step_size * n_leapfrog_steps. - initial_x: Tensor of initial state(s) of the Markov chain(s). Must - be a sample from q, or results will be incorrect. - target_log_prob_fn: Python callable which takes an argument like `initial_x` - and returns its (possibly unnormalized) log-density under the target - distribution. - proposal_log_prob_fn: Python callable that returns the log density of the - initial distribution. - event_dims: List of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - ais_weights: Tensor with the estimated weight(s). Has shape matching - `target_log_prob_fn(initial_x)`. - chain_states: Tensor with the state(s) of the Markov chain(s) the final - iteration. Has shape matching `initial_x`. - acceptance_probs: Tensor with the acceptance probabilities for the final - iteration. Has shape matching `target_log_prob_fn(initial_x)`. + Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three + times (although this may be reduced to two times, in the future). #### Examples: + ##### Estimate the normalizing constant of a log-gamma distribution. + ```python - # Estimating the normalizing constant of a log-gamma distribution: - def proposal_log_prob(x): - # Standard normal log-probability. This is properly normalized. - return tf.reduce_sum(-0.5 * tf.square(x) - 0.5 * np.log(2 * np.pi), 1) - def target_log_prob(x): - # Unnormalized log-gamma(2, 3) distribution. - # True normalizer is (lgamma(2) - 2 * log(3)) * x.shape[1] - return tf.reduce_sum(2. * x - 3. * tf.exp(x), 1) + tfd = tf.contrib.distributions + # Run 100 AIS chains in parallel - initial_x = tf.random_normal([100, 20]) - w, _, _ = hmc.ais_chain(1000, 0.2, 2, initial_x, target_log_prob, - proposal_log_prob, event_dims=[1]) - log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100) + num_chains = 100 + dims = 20 + dtype = np.float32 + + proposal = tfd.MultivatiateNormalDiag( + loc=tf.zeros([dims], dtype=dtype)) + + target = tfd.TransformedDistribution( + distribution=tfd.Gamma(concentration=dtype(2), + rate=dtype(3)), + bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()), + event_shape=[dims]) + + chains_state, ais_weights, kernels_results = ( + hmc.sample_annealed_importance_chain( + proposal_log_prob_fn=proposal.log_prob, + num_steps=1000, + target_log_prob_fn=target.log_prob, + step_size=0.2, + current_state=proposal.sample(num_chains), + num_leapfrog_steps=2)) + + log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) + - np.log(num_chains)) + log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) ``` + ##### Estimate marginal likelihood of a Bayesian regression model. + ```python - # Estimating the marginal likelihood of a Bayesian regression model: - base_measure = -0.5 * np.log(2 * np.pi) - def proposal_log_prob(x): - # Standard normal log-probability. This is properly normalized. - return tf.reduce_sum(-0.5 * tf.square(x) + base_measure, 1) - def regression_log_joint(beta, x, y): - # This function returns a vector whose ith element is log p(beta[i], y | x). - # Each row of beta corresponds to the state of an independent Markov chain. - log_prior = tf.reduce_sum(-0.5 * tf.square(beta) + base_measure, 1) - means = tf.matmul(beta, x, transpose_b=True) - log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means) + - base_measure, 1) - return log_prior + log_likelihood - def log_joint_partial(beta): - return regression_log_joint(beta, x, y) + tfd = tf.contrib.distributions + + def make_prior(dims, dtype): + return tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)) + + def make_likelihood(weights, x): + return tfd.MultivariateNormalDiag( + loc=tf.tensordot(weights, x, axes=[[0], [-1]])) + # Run 100 AIS chains in parallel - initial_beta = tf.random_normal([100, x.shape[1]]) - w, beta_samples, _ = hmc.ais_chain(1000, 0.1, 2, initial_beta, - log_joint_partial, proposal_log_prob, - event_dims=[1]) - log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100) + num_chains = 100 + dims = 10 + dtype = np.float32 + + # Make training data. + x = np.random.randn(num_chains, dims).astype(dtype) + true_weights = np.random.randn(dims).astype(dtype) + y = np.dot(x, true_weights) + np.random.randn(num_chains) + + # Setup model. + prior = make_prior(dims, dtype) + def target_log_prob_fn(weights): + return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) + + proposal = tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)) + + weight_samples, ais_weights, kernel_results = ( + hmc.sample_annealed_importance_chain( + num_steps=1000, + proposal_log_prob_fn=proposal.log_prob, + target_log_prob_fn=target_log_prob_fn + current_state=tf.zeros([num_chains, dims], dtype), + step_size=0.1, + num_leapfrog_steps=2)) + log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) + - np.log(num_chains)) ``` + + Args: + proposal_log_prob_fn: Python callable that returns the log density of the + initial distribution. + num_steps: Integer number of Markov chain updates to run. More + iterations means more expense, but smoother annealing between q + and p, which in turn means exponentially lower variance for the + normalizing constant estimator. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + step_size: `Tensor` or Python `list` of `Tensor`s representing the step size + for the leapfrog integrator. Must broadcast with the shape of + `current_state`. Larger step sizes lead to faster progress, but too-large + step sizes make rejection exponentially more likely. When possible, it's + often helpful to match per-variable step sizes to the standard deviations + of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + seed: Python integer to seed the random number generator. + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_sample_annealed_importance_chain"). + + Returns: + accepted_state: `Tensor` or Python list of `Tensor`s representing the + state(s) of the Markov chain(s) at the final iteration. Has same shape as + input `current_state`. + ais_weights: Tensor with the estimated weight(s). Has shape matching + `target_log_prob_fn(current_state)`. + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. """ - with ops.name_scope(name, 'hmc_ais_chain', - [n_iterations, step_size, n_leapfrog_steps, initial_x]): - non_event_shape = array_ops.shape(target_log_prob_fn(initial_x)) - - beta_series = math_ops.linspace(0., 1., n_iterations+1)[1:] - def _body(a, beta): # pylint: disable=missing-docstring - def log_prob_beta(x): - return ((1 - beta) * proposal_log_prob_fn(x) + - beta * target_log_prob_fn(x)) - last_x = a[0] - w = a[2] - w += (1. / n_iterations) * (target_log_prob_fn(last_x) - - proposal_log_prob_fn(last_x)) - # TODO(b/66917083): There's an opportunity for gradient reuse here. - updated_x, acceptance_probs, _, _ = kernel(step_size, n_leapfrog_steps, - last_x, log_prob_beta, - event_dims) - return updated_x, acceptance_probs, w - - x, acceptance_probs, w = functional_ops.scan( - _body, beta_series, - (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype), - array_ops.zeros(non_event_shape, dtype=initial_x.dtype))) - return w[-1], x[-1], acceptance_probs[-1] - - -def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), - x_log_prob=None, x_grad=None, name=None): + def make_convex_combined_log_prob_fn(iter_): + def _fn(*args): + p = proposal_log_prob_fn(*args) + t = target_log_prob_fn(*args) + dtype = p.dtype.base_dtype + beta = (math_ops.cast(iter_ + 1, dtype) + / math_ops.cast(num_steps, dtype)) + return (1. - beta) * p + beta * t + return _fn + + with ops.name_scope( + name, "hmc_sample_annealed_importance_chain", + [num_steps, current_state, step_size, num_leapfrog_steps, seed]): + with ops.name_scope("initialize"): + [ + current_state, + step_size, + current_log_prob, + current_grads_log_prob, + ] = _prepare_args( + make_convex_combined_log_prob_fn(iter_=0), + current_state, + step_size, + description="convex_combined_log_prob") + num_steps = ops.convert_to_tensor( + num_steps, + dtype=dtypes.int32, + name="num_steps") + num_leapfrog_steps = ops.convert_to_tensor( + num_leapfrog_steps, + dtype=dtypes.int32, + name="num_leapfrog_steps") + def _loop_body(iter_, ais_weights, current_state, kernel_results): + """Closure which implements `tf.while_loop` body.""" + current_state_parts = (list(current_state) + if _is_list_like(current_state) + else [current_state]) + # TODO(b/72994218): Consider refactoring things to avoid this unecessary + # call. + ais_weights += ((target_log_prob_fn(*current_state_parts) + - proposal_log_prob_fn(*current_state_parts)) + / math_ops.cast(num_steps, ais_weights.dtype)) + return [iter_ + 1, ais_weights] + list(kernel( + make_convex_combined_log_prob_fn(iter_), + current_state, + step_size, + num_leapfrog_steps, + seed, + kernel_results.current_target_log_prob, + kernel_results.current_grads_target_log_prob)) + + while_loop_kwargs = dict( + cond=lambda iter_, *args: iter_ < num_steps, + body=_loop_body, + loop_vars=[ + np.int32(0), # iter_ + array_ops.zeros_like(current_log_prob), # ais_weights + current_state, + _make_dummy_kernel_results(current_state, + current_log_prob, + current_grads_log_prob), + ]) + if seed is not None: + while_loop_kwargs["parallel_iterations"] = 1 + + [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop( + **while_loop_kwargs)[1:] # Lop-off "iter_". + + return [current_state, ais_weights, kernel_results] + + +def kernel(target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + seed=None, + current_target_log_prob=None, + current_grads_target_log_prob=None, + name=None): """Runs one iteration of Hamiltonian Monte Carlo. Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) @@ -316,334 +563,623 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), a Metropolis proposal. This function applies one step of HMC to randomly update the variable `x`. - This function can update multiple chains in parallel. It assumes - that all dimensions of `x` not specified in `event_dims` are - independent, and should therefore be updated independently. The - output of `target_log_prob_fn()` should sum log-probabilities across - all event dimensions. Slices along dimensions not in `event_dims` - may have different target distributions; for example, if - `event_dims == (1,)`, then `x[0, :]` could have a different target - distribution from x[1, :]. This is up to `target_log_prob_fn()`. - - Args: - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `x`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. - When possible, it's often helpful to match per-variable step - sizes to the standard deviations of the target distribution in - each variable. - n_leapfrog_steps: Integer number of steps to run the leapfrog - integrator for. Total progress per HMC step is roughly - proportional to step_size * n_leapfrog_steps. - x: Tensor containing the value(s) of the random variable(s) to update. - target_log_prob_fn: Python callable which takes an argument like `initial_x` - and returns its (possibly unnormalized) log-density under the target - distribution. - event_dims: List of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. - x_log_prob (optional): Tensor containing the cached output of a previous - call to `target_log_prob_fn()` evaluated at `x` (such as that provided by - a previous call to `kernel()`). Providing `x_log_prob` and - `x_grad` saves one gradient computation per call to `kernel()`. - x_grad (optional): Tensor containing the cached gradient of - `target_log_prob_fn()` evaluated at `x` (such as that provided by - a previous call to `kernel()`). Providing `x_log_prob` and - `x_grad` saves one gradient computation per call to `kernel()`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - updated_x: The updated variable(s) x. Has shape matching `initial_x`. - acceptance_probs: Tensor with the acceptance probabilities for the final - iteration. This is useful for diagnosing step size problems etc. Has - shape matching `target_log_prob_fn(initial_x)`. - new_log_prob: The value of `target_log_prob_fn()` evaluated at `updated_x`. - new_grad: The value of the gradient of `target_log_prob_fn()` evaluated at - `updated_x`. + This function can update multiple chains in parallel. It assumes that all + leftmost dimensions of `current_state` index independent chain states (and are + therefore updated independently). The output of `target_log_prob_fn()` should + sum log-probabilities across all event dimensions. Slices along the rightmost + dimensions may have different target distributions; for example, + `current_state[0, :]` could have a different target distribution from + `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of + independent chains is `tf.size(target_log_prob_fn(*current_state))`.) #### Examples: + ##### Simple chain with warm-up. + ```python + tfd = tf.contrib.distributions + # Tuning acceptance rates: + dtype = np.float32 target_accept_rate = 0.631 - def target_log_prob(x): - # Standard normal - return tf.reduce_sum(-0.5 * tf.square(x)) - initial_x = tf.zeros([10]) - initial_log_prob = target_log_prob(initial_x) - initial_grad = tf.gradients(initial_log_prob, initial_x)[0] - # Algorithm state - x = tf.Variable(initial_x, name='x') - step_size = tf.Variable(1., name='step_size') - last_log_prob = tf.Variable(initial_log_prob, name='last_log_prob') - last_grad = tf.Variable(initial_grad, name='last_grad') - # Compute updates - new_x, acceptance_prob, log_prob, grad = hmc.kernel(step_size, 3, x, - target_log_prob, - event_dims=[0], - x_log_prob=last_log_prob) - x_update = tf.assign(x, new_x) - log_prob_update = tf.assign(last_log_prob, log_prob) - grad_update = tf.assign(last_grad, grad) - step_size_update = tf.assign(step_size, - tf.where(acceptance_prob > target_accept_rate, - step_size * 1.01, step_size / 1.01)) - adaptive_updates = [x_update, log_prob_update, grad_update, step_size_update] - sampling_updates = [x_update, log_prob_update, grad_update] - - sess = tf.Session() - sess.run(tf.global_variables_initializer()) + num_warmup_iter = 500 + num_chain_iter = 500 + + x = tf.get_variable(name="x", initializer=dtype(1)) + step_size = tf.get_variable(name="step_size", initializer=dtype(1)) + + target = tfd.Normal(loc=dtype(0), scale=dtype(1)) + + new_x, other_results = hmc.kernel( + target_log_prob_fn=target.log_prob, + current_state=x, + step_size=step_size, + num_leapfrog_steps=3)[:4] + + x_update = x.assign(new_x) + + step_size_update = step_size.assign_add( + step_size * tf.where( + other_results.acceptance_probs > target_accept_rate, + 0.01, -0.01)) + + warmup = tf.group([x_update, step_size_update]) + + tf.global_variables_initializer().run() + + sess.graph.finalize() # No more graph building. + # Warm up the sampler and adapt the step size - for i in xrange(500): - sess.run(adaptive_updates) - # Collect samples without adapting step size - samples = np.zeros([500, 10]) - for i in xrange(500): - x_val, _ = sess.run([new_x, sampling_updates]) - samples[i] = x_val - ``` + for _ in xrange(num_warmup_iter): + sess.run(warmup) - ```python - # Empirical-Bayes estimation of a hyperparameter by MCMC-EM: - - # Problem setup - N = 150 - D = 10 - x = np.random.randn(N, D).astype(np.float32) - true_sigma = 0.5 - true_beta = true_sigma * np.random.randn(D).astype(np.float32) - y = x.dot(true_beta) + np.random.randn(N).astype(np.float32) - - def log_prior(beta, log_sigma): - return tf.reduce_sum(-0.5 / tf.exp(2 * log_sigma) * tf.square(beta) - - log_sigma) - def regression_log_joint(beta, log_sigma, x, y): - # This function returns log p(beta | log_sigma) + log p(y | x, beta). - means = tf.matmul(tf.expand_dims(beta, 0), x, transpose_b=True) - means = tf.squeeze(means) - log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means)) - return log_prior(beta, log_sigma) + log_likelihood - def log_joint_partial(beta): - return regression_log_joint(beta, log_sigma, x, y) - # Our estimate of log(sigma) - log_sigma = tf.Variable(0., name='log_sigma') - # The state of the Markov chain - beta = tf.Variable(tf.random_normal([x.shape[1]]), name='beta') - new_beta, _, _, _ = hmc.kernel(0.1, 5, beta, log_joint_partial, - event_dims=[0]) - beta_update = tf.assign(beta, new_beta) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - with tf.control_dependencies([beta_update]): - log_sigma_update = optimizer.minimize(-log_prior(beta, log_sigma), - var_list=[log_sigma]) - - sess = tf.Session() - sess.run(tf.global_variables_initializer()) - log_sigma_history = np.zeros(1000) - for i in xrange(1000): - log_sigma_val, _ = sess.run([log_sigma, log_sigma_update]) - log_sigma_history[i] = log_sigma_val - # Should converge to something close to true_sigma - plt.plot(np.exp(log_sigma_history)) + # Collect samples without adapting step size + samples = np.zeros([num_chain_iter]) + for i in xrange(num_chain_iter): + _, x_, target_log_prob_, grad_ = sess.run([ + x_update, + x, + other_results.target_log_prob, + other_results.grads_target_log_prob]) + samples[i] = x_ + + print(samples.mean(), samples.std()) ``` - """ - with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]): - potential_and_grad = _make_potential_and_grad(target_log_prob_fn) - x = ops.convert_to_tensor(x, name='x') - - x_shape = array_ops.shape(x) - m = random_ops.random_normal(x_shape, dtype=x.dtype) - - kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims) - - if (x_log_prob is not None) and (x_grad is not None): - log_potential_0, grad_0 = -x_log_prob, -x_grad # pylint: disable=invalid-unary-operand-type - else: - if x_log_prob is not None: - logging.warn('x_log_prob was provided, but x_grad was not,' - ' so x_log_prob was not used.') - if x_grad is not None: - logging.warn('x_grad was provided, but x_log_prob was not,' - ' so x_grad was not used.') - log_potential_0, grad_0 = potential_and_grad(x) - - new_x, new_m, log_potential_1, grad_1 = leapfrog_integrator( - step_size, n_leapfrog_steps, x, m, potential_and_grad, grad_0) - - kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims) - - energy_change = log_potential_1 - log_potential_0 + kinetic_1 - kinetic_0 - # Treat NaN as infinite energy (and therefore guaranteed rejection). - energy_change = array_ops.where( - math_ops.is_nan(energy_change), - array_ops.fill(array_ops.shape(energy_change), - energy_change.dtype.as_numpy_dtype(np.inf)), - energy_change) - acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.)) - accepted = ( - random_ops.random_uniform( - array_ops.shape(acceptance_probs), dtype=x.dtype) - < acceptance_probs) - new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0) - - # TODO(b/65738010): This should work, but it doesn't for now. - # reduced_shape = math_ops.reduced_shape(x_shape, event_dims) - reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims, - keep_dims=True)) - accepted = array_ops.reshape(accepted, reduced_shape) - accepted = math_ops.logical_or( - accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool)) - new_x = array_ops.where(accepted, new_x, x) - new_grad = -array_ops.where(accepted, grad_1, grad_0) - - # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect - # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This - # should be fixed. - return new_x, acceptance_probs, new_log_prob, new_grad - - -def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum, - potential_and_grad, initial_grad, name=None): - """Applies `n_steps` steps of the leapfrog integrator. - - This just wraps `leapfrog_step()` in a `tf.while_loop()`, reusing - gradient computations where possible. - Args: - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `initial_position`. Larger step sizes lead to faster progress, but - too-large step sizes lead to larger discretization error and - worse energy conservation. - n_steps: Number of steps to run the leapfrog integrator. - initial_position: Tensor containing the value(s) of the position variable(s) - to update. - initial_momentum: Tensor containing the value(s) of the momentum variable(s) - to update. - potential_and_grad: Python callable that takes a position tensor like - `initial_position` and returns the potential energy and its gradient at - that position. - initial_grad: Tensor with the value of the gradient of the potential energy - at `initial_position`. - name: Python `str` name prefixed to Ops created by this function. + ##### Sample from more complicated posterior. - Returns: - updated_position: Updated value of the position. - updated_momentum: Updated value of the momentum. - new_potential: Potential energy of the new position. Has shape matching - `potential_and_grad(initial_position)`. - new_grad: Gradient from potential_and_grad() evaluated at the new position. - Has shape matching `initial_position`. + I.e., - Example: Simple quadratic potential. + ```none + W ~ MVN(loc=0, scale=sigma * eye(dims)) + for i=1...num_samples: + X[i] ~ MVN(loc=0, scale=eye(dims)) + eps[i] ~ Normal(loc=0, scale=1) + Y[i] = X[i].T * W + eps[i] + ``` ```python - def potential_and_grad(position): - return tf.reduce_sum(0.5 * tf.square(position)), position - position = tf.placeholder(np.float32) - momentum = tf.placeholder(np.float32) - potential, grad = potential_and_grad(position) - new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_integrator( - 0.1, 3, position, momentum, potential_and_grad, grad) - - sess = tf.Session() - position_val = np.random.randn(10) - momentum_val = np.random.randn(10) - potential_val, grad_val = sess.run([potential, grad], - {position: position_val}) - positions = np.zeros([100, 10]) - for i in xrange(100): - position_val, momentum_val, potential_val, grad_val = sess.run( - [new_position, new_momentum, new_potential, new_grad], - {position: position_val, momentum: momentum_val}) - positions[i] = position_val - # Should trace out sinusoidal dynamics. - plt.plot(positions[:, 0]) - ``` - """ - def leapfrog_wrapper(step_size, x, m, grad, l): - x, m, _, grad = leapfrog_step(step_size, x, m, potential_and_grad, grad) - return step_size, x, m, grad, l + 1 + tfd = tf.contrib.distributions + + def make_training_data(num_samples, dims, sigma): + dt = np.asarray(sigma).dtype + zeros = tf.zeros(dims, dtype=dt) + x = tfd.MultivariateNormalDiag( + loc=zeros).sample(num_samples, seed=1) + w = tfd.MultivariateNormalDiag( + loc=zeros, + scale_identity_multiplier=sigma).sample(seed=2) + noise = tfd.Normal( + loc=dt(0), + scale=dt(1)).sample(num_samples, seed=3) + y = tf.tensordot(x, w, axes=[[1], [0]]) + noise + return y, x, w + + def make_prior(sigma, dims): + # p(w | sigma) + return tfd.MultivariateNormalDiag( + loc=tf.zeros([dims], dtype=sigma.dtype), + scale_identity_multiplier=sigma) + + def make_likelihood(x, w): + # p(y | x, w) + return tfd.MultivariateNormalDiag( + loc=tf.tensordot(x, w, axes=[[1], [0]])) + + # Setup assumptions. + dtype = np.float32 + num_samples = 150 + dims = 10 + num_iters = int(5e3) + + true_sigma = dtype(0.5) + y, x, true_weights = make_training_data(num_samples, dims, true_sigma) + + # Estimate of `log(true_sigma)`. + log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0)) + sigma = tf.exp(log_sigma) + + # State of the Markov chain. + weights = tf.get_variable( + name="weights", + initializer=np.random.randn(dims).astype(dtype)) + + prior = make_prior(sigma, dims) + + def joint_log_prob_fn(w): + # f(w) = log p(w, y | x) + return prior.log_prob(w) + make_likelihood(x, w).log_prob(y) + + weights_update = weights.assign( + hmc.kernel(target_log_prob_fn=joint_log_prob, + current_state=weights, + step_size=0.1, + num_leapfrog_steps=5)[0]) + + with tf.control_dependencies([weights_update]): + loss = -prior.log_prob(weights) + + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) + log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma]) + + sess.graph.finalize() # No more graph building. - def counter_fn(a, b, c, d, counter): # pylint: disable=unused-argument - return counter < n_steps + tf.global_variables_initializer().run() - with ops.name_scope(name, 'leapfrog_integrator', - [step_size, n_steps, initial_position, initial_momentum, - initial_grad]): - _, new_x, new_m, new_grad, _ = control_flow_ops.while_loop( - counter_fn, leapfrog_wrapper, [step_size, initial_position, - initial_momentum, initial_grad, - array_ops.constant(0)], back_prop=False) - # We're counting on the runtime to eliminate this redundant computation. - new_potential, new_grad = potential_and_grad(new_x) - return new_x, new_m, new_potential, new_grad + sigma_history = np.zeros(num_iters, dtype) + weights_history = np.zeros([num_iters, dims], dtype) + for i in xrange(num_iters): + _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights]) + weights_history[i, :] = weights_ + sigma_history[i] = sigma_ -def leapfrog_step(step_size, position, momentum, potential_and_grad, grad, - name=None): - """Applies one step of the leapfrog integrator. + true_weights_ = sess.run(true_weights) - Assumes a simple quadratic kinetic energy function: 0.5 * ||momentum||^2. + # Should converge to something close to true_sigma. + plt.plot(sigma_history); + plt.ylabel("sigma"); + plt.xlabel("iteration"); + ``` Args: - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `position`. Larger step sizes lead to faster progress, but - too-large step sizes lead to larger discretization error and - worse energy conservation. - position: Tensor containing the value(s) of the position variable(s) - to update. - momentum: Tensor containing the value(s) of the momentum variable(s) - to update. - potential_and_grad: Python callable that takes a position tensor like - `position` and returns the potential energy and its gradient at that - position. - grad: Tensor with the value of the gradient of the potential energy - at `position`. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + step_size: `Tensor` or Python `list` of `Tensor`s representing the step size + for the leapfrog integrator. Must broadcast with the shape of + `current_state`. Larger step sizes lead to faster progress, but too-large + step sizes make rejection exponentially more likely. When possible, it's + often helpful to match per-variable step sizes to the standard deviations + of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + seed: Python integer to seed the random number generator. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn` at the `current_state`. The only reason to + specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + current_grads_target_log_prob: (Optional) Python list of `Tensor`s + representing gradient of `current_target_log_prob` at the `current_state` + and wrt the `current_state`. Must have same shape as `current_state`. The + only reason to specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_kernel"). Returns: - updated_position: Updated value of the position. - updated_momentum: Updated value of the momentum. - new_potential: Potential energy of the new position. Has shape matching - `potential_and_grad(position)`. - new_grad: Gradient from potential_and_grad() evaluated at the new position. - Has shape matching `position`. + accepted_state: Tensor or Python list of `Tensor`s representing the state(s) + of the Markov chain(s) at each result step. Has same shape as + `current_state`. + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. + + Raises: + ValueError: if there isn't one `step_size` or a list with same length as + `current_state`. + """ + with ops.name_scope( + name, "hmc_kernel", + [current_state, step_size, num_leapfrog_steps, seed, + current_target_log_prob, current_grads_target_log_prob]): + with ops.name_scope("initialize"): + [current_state_parts, step_sizes, current_target_log_prob, + current_grads_target_log_prob] = _prepare_args( + target_log_prob_fn, current_state, step_size, + current_target_log_prob, current_grads_target_log_prob, + maybe_expand=True) + independent_chain_ndims = distributions_util.prefer_static_rank( + current_target_log_prob) + current_momentums = [] + for s in current_state_parts: + current_momentums.append(random_ops.random_normal( + shape=array_ops.shape(s), + dtype=s.dtype.base_dtype, + seed=seed)) + seed = distributions_util.gen_new_seed( + seed, salt="hmc_kernel_momentums") + + num_leapfrog_steps = ops.convert_to_tensor( + num_leapfrog_steps, + dtype=dtypes.int32, + name="num_leapfrog_steps") + [ + proposed_momentums, + proposed_state_parts, + proposed_target_log_prob, + proposed_grads_target_log_prob, + ] = _leapfrog_integrator(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + num_leapfrog_steps, + current_target_log_prob, + current_grads_target_log_prob) + + energy_change = _compute_energy_change(current_target_log_prob, + current_momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims) + + # u < exp(min(-energy, 0)), where u~Uniform[0,1) + # ==> -log(u) >= max(e, 0) + # ==> -log(u) >= e + # (Perhaps surprisingly, we don't have a better way to obtain a random + # uniform from positive reals, i.e., `tf.random_uniform(minval=0, + # maxval=np.inf)` won't work.) + random_uniform = random_ops.random_uniform( + shape=array_ops.shape(energy_change), + dtype=energy_change.dtype, + seed=seed) + random_positive = -math_ops.log(random_uniform) + is_accepted = random_positive >= energy_change + + accepted_target_log_prob = array_ops.where(is_accepted, + proposed_target_log_prob, + current_target_log_prob) + + accepted_state_parts = [_choose(is_accepted, + proposed_state_part, + current_state_part, + independent_chain_ndims) + for current_state_part, proposed_state_part + in zip(current_state_parts, proposed_state_parts)] + + accepted_grads_target_log_prob = [ + _choose(is_accepted, + proposed_grad, + grad, + independent_chain_ndims) + for proposed_grad, grad + in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)] + + maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] + return [ + maybe_flatten(accepted_state_parts), + KernelResults( + acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)), + current_grads_target_log_prob=accepted_grads_target_log_prob, + current_target_log_prob=accepted_target_log_prob, + energy_change=energy_change, + is_accepted=is_accepted, + proposed_grads_target_log_prob=proposed_grads_target_log_prob, + proposed_state=maybe_flatten(proposed_state_parts), + proposed_target_log_prob=proposed_target_log_prob, + random_positive=random_positive, + ), + ] + + +def _leapfrog_integrator(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + num_leapfrog_steps, + current_target_log_prob=None, + current_grads_target_log_prob=None, + name=None): + """Applies `num_leapfrog_steps` of the leapfrog integrator. + + Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. + + #### Examples: - Example: Simple quadratic potential. + ##### Simple quadratic potential. ```python - def potential_and_grad(position): - # Simple quadratic potential - return tf.reduce_sum(0.5 * tf.square(position)), position + tfd = tf.contrib.distributions + + dims = 10 + num_iter = int(1e3) + dtype = np.float32 + position = tf.placeholder(np.float32) momentum = tf.placeholder(np.float32) - potential, grad = potential_and_grad(position) - new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_step( - 0.1, position, momentum, potential_and_grad, grad) - - sess = tf.Session() - position_val = np.random.randn(10) - momentum_val = np.random.randn(10) - potential_val, grad_val = sess.run([potential, grad], - {position: position_val}) - positions = np.zeros([100, 10]) - for i in xrange(100): - position_val, momentum_val, potential_val, grad_val = sess.run( - [new_position, new_momentum, new_potential, new_grad], - {position: position_val, momentum: momentum_val}) - positions[i] = position_val - # Should trace out sinusoidal dynamics. - plt.plot(positions[:, 0]) + + [ + new_momentums, + new_positions, + ] = hmc._leapfrog_integrator( + current_momentums=[momentum], + target_log_prob_fn=tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)).log_prob, + current_state_parts=[position], + step_sizes=0.1, + num_leapfrog_steps=3)[:2] + + sess.graph.finalize() # No more graph building. + + momentum_ = np.random.randn(dims).astype(dtype) + position_ = np.random.randn(dims).astype(dtype) + + positions = np.zeros([num_iter, dims], dtype) + for i in xrange(num_iter): + position_, momentum_ = sess.run( + [new_momentums[0], new_position[0]], + feed_dict={position: position_, momentum: momentum_}) + positions[i] = position_ + + plt.plot(positions[:, 0]); # Sinusoidal. ``` + + Args: + current_momentums: Tensor containing the value(s) of the momentum + variable(s) to update. + target_log_prob_fn: Python callable which takes an argument like + `*current_state_parts` and returns its (possibly unnormalized) log-density + under the target distribution. + current_state_parts: Python `list` of `Tensor`s representing the current + state(s) of the Markov chain(s). The first `independent_chain_ndims` of + the `Tensor`(s) index different chains. + step_sizes: Python `list` of `Tensor`s representing the step size for the + leapfrog integrator. Must broadcast with the shape of + `current_state_parts`. Larger step sizes lead to faster progress, but + too-large step sizes make rejection exponentially more likely. When + possible, it's often helpful to match per-variable step sizes to the + standard deviations of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn(*current_state_parts)`. The only reason to specify + this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + current_grads_target_log_prob: (Optional) Python list of `Tensor`s + representing gradient of `target_log_prob_fn(*current_state_parts`) wrt + `current_state_parts`. Must have same shape as `current_state_parts`. The + only reason to specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_leapfrog_integrator"). + + Returns: + proposed_momentums: Updated value of the momentum. + proposed_state_parts: Tensor or Python list of `Tensor`s representing the + state(s) of the Markov chain(s) at each result step. Has same shape as + input `current_state_parts`. + proposed_target_log_prob: `Tensor` representing the value of + `target_log_prob_fn` at `accepted_state`. + proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt + `accepted_state`. + + Raises: + ValueError: if `len(momentums) != len(state_parts)`. + ValueError: if `len(state_parts) != len(step_sizes)`. + ValueError: if `len(state_parts) != len(grads_target_log_prob)`. + TypeError: if `not target_log_prob.dtype.is_floating`. """ - with ops.name_scope(name, 'leapfrog_step', [step_size, position, momentum, - grad]): - momentum -= 0.5 * step_size * grad - position += step_size * momentum - potential, grad = potential_and_grad(position) - momentum -= 0.5 * step_size * grad - - return position, momentum, potential, grad + def _loop_body(step, + current_momentums, + current_state_parts, + ignore_current_target_log_prob, # pylint: disable=unused-argument + current_grads_target_log_prob): + return [step + 1] + list(_leapfrog_step(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + current_grads_target_log_prob)) + + with ops.name_scope( + name, "hmc_leapfrog_integrator", + [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps, + current_target_log_prob, current_grads_target_log_prob]): + if len(current_momentums) != len(current_state_parts): + raise ValueError("`momentums` must be in one-to-one correspondence " + "with `state_parts`") + num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps, + name="num_leapfrog_steps") + current_target_log_prob, current_grads_target_log_prob = ( + _maybe_call_fn_and_grads( + target_log_prob_fn, + current_state_parts, + current_target_log_prob, + current_grads_target_log_prob)) + return control_flow_ops.while_loop( + cond=lambda iter_, *args: iter_ < num_leapfrog_steps, + body=_loop_body, + loop_vars=[ + np.int32(0), # iter_ + current_momentums, + current_state_parts, + current_target_log_prob, + current_grads_target_log_prob, + ], + back_prop=False)[1:] # Lop-off "iter_". + + +def _leapfrog_step(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + current_grads_target_log_prob, + name=None): + """Applies one step of the leapfrog integrator.""" + with ops.name_scope( + name, "_leapfrog_step", + [current_momentums, current_state_parts, step_sizes, + current_grads_target_log_prob]): + proposed_momentums = [m + 0.5 * ss * g for m, ss, g + in zip(current_momentums, + step_sizes, + current_grads_target_log_prob)] + proposed_state_parts = [x + ss * m for x, ss, m + in zip(current_state_parts, + step_sizes, + proposed_momentums)] + proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) + if not proposed_target_log_prob.dtype.is_floating: + raise TypeError("`target_log_prob_fn` must produce a `Tensor` " + "with `float` `dtype`.") + proposed_grads_target_log_prob = gradients_ops.gradients( + proposed_target_log_prob, proposed_state_parts) + if any(g is None for g in proposed_grads_target_log_prob): + raise ValueError( + "Encountered `None` gradient. Does your target `target_log_prob_fn` " + "access all `tf.Variable`s via `tf.get_variable`?\n" + " current_state_parts: {}\n" + " proposed_state_parts: {}\n" + " proposed_grads_target_log_prob: {}".format( + current_state_parts, + proposed_state_parts, + proposed_grads_target_log_prob)) + proposed_momentums = [m + 0.5 * ss * g for m, ss, g + in zip(proposed_momentums, + step_sizes, + proposed_grads_target_log_prob)] + return [ + proposed_momentums, + proposed_state_parts, + proposed_target_log_prob, + proposed_grads_target_log_prob, + ] + + +def _compute_energy_change(current_target_log_prob, + current_momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims, + name=None): + """Helper to `kernel` which computes the energy change.""" + with ops.name_scope( + name, "compute_energy_change", + ([current_target_log_prob, proposed_target_log_prob, + independent_chain_ndims] + + current_momentums + proposed_momentums)): + # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy + # since they're a mouthful and lets us inline more. + lk0, lk1 = [], [] + for current_momentum, proposed_momentum in zip(current_momentums, + proposed_momentums): + axis = math_ops.range(independent_chain_ndims, + array_ops.rank(current_momentum)) + lk0.append(_log_sum_sq(current_momentum, axis)) + lk1.append(_log_sum_sq(proposed_momentum, axis)) + + lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1), + axis=-1) + lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1), + axis=-1) + lp0 = -current_target_log_prob # log_potential + lp1 = -proposed_target_log_prob # proposed_log_potential + x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)], + axis=-1) + + # The sum is NaN if any element is NaN or we see both +Inf and -Inf. + # Thus we will replace such rows with infinite energy change which implies + # rejection. Recall that float-comparisons with NaN are always False. + is_sum_determinate = ( + math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) & + math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1)) + is_sum_determinate = array_ops.tile( + is_sum_determinate[..., array_ops.newaxis], + multiples=array_ops.concat([ + array_ops.ones(array_ops.rank(is_sum_determinate), + dtype=dtypes.int32), + [4], + ], axis=0)) + x = array_ops.where(is_sum_determinate, + x, + array_ops.fill(array_ops.shape(x), + value=x.dtype.as_numpy_dtype(np.inf))) + + return math_ops.reduce_sum(x, axis=-1) + + +def _choose(is_accepted, + accepted, + rejected, + independent_chain_ndims, + name=None): + """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where.""" + def _expand_is_accepted_like(x): + with ops.name_scope("_choose"): + expand_shape = array_ops.concat([ + array_ops.shape(is_accepted), + array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)], + dtype=dtypes.int32), + ], axis=0) + multiples = array_ops.concat([ + array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32), + array_ops.shape(x)[independent_chain_ndims:], + ], axis=0) + m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape), + multiples) + m.set_shape(x.shape) + return m + with ops.name_scope(name, "_choose", values=[ + is_accepted, accepted, rejected, independent_chain_ndims]): + return array_ops.where(_expand_is_accepted_like(accepted), + accepted, + rejected) + + +def _maybe_call_fn_and_grads(fn, + fn_arg_list, + fn_result=None, + grads_fn_result=None, + description="target_log_prob"): + """Helper which computes `fn_result` and `grads` if needed.""" + fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list) + else [fn_arg_list]) + if fn_result is None: + fn_result = fn(*fn_arg_list) + if not fn_result.dtype.is_floating: + raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format( + description)) + if grads_fn_result is None: + grads_fn_result = gradients_ops.gradients( + fn_result, fn_arg_list) + if len(fn_arg_list) != len(grads_fn_result): + raise ValueError("`{}` must be in one-to-one correspondence with " + "`grads_{}`".format(*[description]*2)) + if any(g is None for g in grads_fn_result): + raise ValueError("Encountered `None` gradient.") + return fn_result, grads_fn_result + + +def _prepare_args(target_log_prob_fn, state, step_size, + target_log_prob=None, grads_target_log_prob=None, + maybe_expand=False, description="target_log_prob"): + """Helper which processes input args to meet list-like assumptions.""" + state_parts = list(state) if _is_list_like(state) else [state] + state_parts = [ops.convert_to_tensor(s, name="state") + for s in state_parts] + target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads( + target_log_prob_fn, + state_parts, + target_log_prob, + grads_target_log_prob, + description) + step_sizes = list(step_size) if _is_list_like(step_size) else [step_size] + step_sizes = [ + ops.convert_to_tensor( + s, name="step_size", dtype=target_log_prob.dtype) + for s in step_sizes] + if len(step_sizes) == 1: + step_sizes *= len(state_parts) + if len(state_parts) != len(step_sizes): + raise ValueError("There should be exactly one `step_size` or it should " + "have same length as `current_state`.") + maybe_flatten = lambda x: x if maybe_expand or _is_list_like(state) else x[0] + return [ + maybe_flatten(state_parts), + maybe_flatten(step_sizes), + target_log_prob, + grads_target_log_prob, + ] + + +def _is_list_like(x): + """Helper which returns `True` if input is `list`-like.""" + return isinstance(x, (tuple, list)) + + +def _log_sum_sq(x, axis=None): + """Computes log(sum(x**2)).""" + return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis) diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a645eafc249d1c39e0d4a238ae7ec8755c78d8 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for Markov Chain Monte Carlo (MCMC) sampling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.mcmc_diagnostics_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "effective_sample_size", + "potential_scale_reduction", +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..0424b6952bc89ce7fe5b00b0135c9a5fe1faa8cf --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py @@ -0,0 +1,400 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for Markov Chain Monte Carlo (MCMC) sampling. + +@@effective_sample_size +@@potential_scale_reduction +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import sample_stats +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + +__all__ = [ + "effective_sample_size", + "potential_scale_reduction", +] + + +def effective_sample_size(states, + filter_threshold=0., + filter_beyond_lag=None, + name=None): + """Estimate a lower bound on effective sample size for each independent chain. + + Roughly speaking, "effective sample size" (ESS) is the size of an iid sample + with the same variance as `state`. + + More precisely, given a stationary sequence of possibly correlated random + variables `X_1, X_2,...,X_N`, each identically distributed ESS is the number + such that + + ```Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.``` + + If the sequence is uncorrelated, `ESS = N`. In general, one should expect + `ESS <= N`, with more highly correlated sequences having smaller `ESS`. + + #### Example of using ESS to estimate standard error. + + ``` + tfd = tf.contrib.distributions + tfb = tf.contrib.bayesflow + + target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) + + # Get 1000 states from one chain. + states = tfb.hmc.sample_chain( + num_results=1000, + target_log_prob_fn=target.log_prob, + current_state=tf.constant([0., 0.]), + step_size=0.05, + num_leapfrog_steps=20, + num_burnin_steps=200) + states.shape + ==> (1000, 2) + + ess = effective_sample_size(states) + ==> Shape (2,) Tensor + + mean, variance = tf.nn.moments(states, axis=0) + standard_error = tf.sqrt(variance / ess) + ``` + + Some math shows that, with `R_k` the auto-correlation sequence, + `R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}`, we have + + ```ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]``` + + This function estimates the above by first estimating the auto-correlation. + Since `R_k` must be estimated using only `N - k` samples, it becomes + progressively noisier for larger `k`. For this reason, the summation over + `R_k` should be truncated at some number `filter_beyond_lag < N`. Since many + MCMC methods generate chains where `R_k > 0`, a reasonable critera is to + truncate at the first index where the estimated auto-correlation becomes + negative. + + The arguments `filter_beyond_lag`, `filter_threshold` are filters intended to + remove noisy tail terms from `R_k`. They combine in an "OR" manner meaning + terms are removed if they were to be filtered under the `filter_beyond_lag` OR + `filter_threshold` criteria. + + Args: + states: `Tensor` or list of `Tensor` objects. Dimension zero should index + identically distributed states. + filter_threshold: `Tensor` or list of `Tensor` objects. + Must broadcast with `state`. The auto-correlation sequence is truncated + after the first appearance of a term less than `filter_threshold`. + Setting to `None` means we use no threshold filter. Since `|R_k| <= 1`, + setting to any number less than `-1` has the same effect. + filter_beyond_lag: `Tensor` or list of `Tensor` objects. Must be + `int`-like and scalar valued. The auto-correlation sequence is truncated + to this length. Setting to `None` means we do not filter based on number + of lags. + name: `String` name to prepend to created ops. + + Returns: + ess: `Tensor` or list of `Tensor` objects. The effective sample size of + each component of `states`. Shape will be `states.shape[1:]`. + + Raises: + ValueError: If `states` and `filter_threshold` or `states` and + `filter_beyond_lag` are both lists with different lengths. + """ + states_was_list = _is_list_like(states) + + # Convert all args to lists. + if not states_was_list: + states = [states] + + filter_beyond_lag = _broadcast_maybelist_arg(states, filter_beyond_lag, + "filter_beyond_lag") + filter_threshold = _broadcast_maybelist_arg(states, filter_threshold, + "filter_threshold") + + # Process items, one at a time. + with ops.name_scope(name, "effective_sample_size"): + ess_list = [ + _effective_sample_size_single_state(s, ml, mlt) + for (s, ml, mlt) in zip(states, filter_beyond_lag, filter_threshold) + ] + + if states_was_list: + return ess_list + return ess_list[0] + + +def _effective_sample_size_single_state(states, filter_beyond_lag, + filter_threshold): + """ESS computation for one single Tensor argument.""" + + with ops.name_scope( + "effective_sample_size_single_state", + values=[states, filter_beyond_lag, filter_threshold]): + + states = ops.convert_to_tensor(states, name="states") + dt = states.dtype + + # filter_beyond_lag == None ==> auto_corr is the full sequence. + auto_corr = sample_stats.auto_correlation( + states, axis=0, max_lags=filter_beyond_lag) + if filter_threshold is not None: + filter_threshold = ops.convert_to_tensor( + filter_threshold, dtype=dt, name="filter_threshold") + # Get a binary mask to zero out values of auto_corr below the threshold. + # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, + # mask[i, ...] = 0, otherwise. + # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] + # Building step by step, + # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. + # Step 1: mask = [False, False, True, False] + mask = auto_corr < filter_threshold + # Step 2: mask = [0, 0, 1, 1] + mask = math_ops.cast(mask, dtype=dt) + # Step 3: mask = [0, 0, 1, 2] + mask = math_ops.cumsum(mask, axis=0) + # Step 4: mask = [1, 1, 0, 0] + mask = math_ops.maximum(1. - mask, 0.) + auto_corr *= mask + + # With R[k] := auto_corr[k, ...], + # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]} + # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1) + # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]} + # where M is the filter_beyond_lag truncation point chosen above. + + # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total + # ndims the same as auto_corr + n = _axis_size(states, axis=0) + k = math_ops.range(0., _axis_size(auto_corr, axis=0)) + nk_factor = (n - k) / n + if auto_corr.shape.ndims is not None: + new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1) + else: + new_shape = array_ops.concat( + ([-1], + array_ops.ones([array_ops.rank(auto_corr) - 1], dtype=dtypes.int32)), + axis=0) + nk_factor = array_ops.reshape(nk_factor, new_shape) + + return n / (-1 + 2 * math_ops.reduce_sum(nk_factor * auto_corr, axis=0)) + + +def potential_scale_reduction(chains_states, + independent_chain_ndims=1, + name=None): + """Gelman and Rubin's potential scale reduction factor for chain convergence. + + Given `N > 1` states from each of `C > 1` independent chains, the potential + scale reduction factor, commonly referred to as R-hat, measures convergence of + the chains (to the same target) by testing for equality of means. + Specifically, R-hat measures the degree to which variance (of the means) + between chains exceeds what one would expect if the chains were identically + distributed. See [1], [2]. + + Some guidelines: + + * The initial state of the chains should be drawn from a distribution + overdispersed with respect to the target. + * If all chains converge to the target, then as `N --> infinity`, R-hat --> 1. + Before that, R-hat > 1 (except in pathological cases, e.g. if the chain + paths were identical). + * The above holds for any number of chains `C > 1`. Increasing `C` does + improves effectiveness of the diagnostic. + * Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of + course this is problem depedendent. See [2]. + * R-hat only measures non-convergence of the mean. If higher moments, or other + statistics are desired, a different diagnostic should be used. See [2]. + + #### Examples + + Diagnosing convergence by monitoring 10 chains that each attempt to + sample from a 2-variate normal. + + ```python + tfd = tf.contrib.distributions + tfb = tf.contrib.bayesflow + + target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) + + # Get 10 (2x) overdispersed initial states. + initial_state = target.sample(10) * 2. + ==> (10, 2) + + # Get 1000 samples from the 10 independent chains. + chains_states, _ = tfb.hmc.sample_chain( + num_results=1000, + target_log_prob_fn=target.log_prob, + current_state=initial_state, + step_size=0.05, + num_leapfrog_steps=20, + num_burnin_steps=200) + chains_states.shape + ==> (1000, 10, 2) + + rhat = tfb.mcmc_diagnostics.potential_scale_reduction( + chains_states, independent_chain_ndims=1) + + # The second dimension needed a longer burn-in. + rhat.eval() + ==> [1.05, 1.3] + ``` + + To see why R-hat is reasonable, let `X` be a random variable drawn uniformly + from the combined states (combined over all chains). Then, in the limit + `N, C --> infinity`, with `E`, `Var` denoting expectation and variance, + + ```R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].``` + + Using the law of total variance, the numerator is the variance of the combined + states, and the denominator is the total variance minus the variance of the + the individual chain means. If the chains are all drawing from the same + distribution, they will have the same mean, and thus the ratio should be one. + + [1] "Inference from Iterative Simulation Using Multiple Sequences" + Andrew Gelman and Donald B. Rubin + Statist. Sci. Volume 7, Number 4 (1992), 457-472. + [2] "General Methods for Monitoring Convergence of Iterative Simulations" + Stephen P. Brooks and Andrew Gelman + Journal of Computational and Graphical Statistics, 1998. Vol 7, No. 4. + + Args: + chains_states: `Tensor` or Python `list` of `Tensor`s representing the + state(s) of a Markov Chain at each result step. The `ith` state is + assumed to have shape `[Ni, Ci1, Ci2,...,CiD] + A`. + Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain. + Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent + chains to be tested for convergence to the same target. + The remaining dimensions, `A`, can have any shape (even empty). + independent_chain_ndims: Integer type `Tensor` with value `>= 1` giving the + number of giving the number of dimensions, from `dim = 1` to `dim = D`, + holding independent chain results to be tested for convergence. + name: `String` name to prepend to created ops. Default: + `potential_scale_reduction`. + + Returns: + `Tensor` or Python `list` of `Tensor`s representing the R-hat statistic for + the state(s). Same `dtype` as `state`, and shape equal to + `state.shape[1 + independent_chain_ndims:]`. + + Raises: + ValueError: If `independent_chain_ndims < 1`. + """ + chains_states_was_list = _is_list_like(chains_states) + if not chains_states_was_list: + chains_states = [chains_states] + + # tensor_util.constant_value returns None iff a constant value (as a numpy + # array) is not efficiently computable. Therefore, we try constant_value then + # check for None. + icn_const_ = tensor_util.constant_value( + ops.convert_to_tensor(independent_chain_ndims)) + if icn_const_ is not None: + independent_chain_ndims = icn_const_ + if icn_const_ < 1: + raise ValueError( + "Argument `independent_chain_ndims` must be `>= 1`, found: {}".format( + independent_chain_ndims)) + + with ops.name_scope(name, "potential_scale_reduction"): + rhat_list = [ + _potential_scale_reduction_single_state(s, independent_chain_ndims) + for s in chains_states + ] + + if chains_states_was_list: + return rhat_list + return rhat_list[0] + + +def _potential_scale_reduction_single_state(state, independent_chain_ndims): + """potential_scale_reduction for one single state `Tensor`.""" + with ops.name_scope( + "potential_scale_reduction_single_state", + values=[state, independent_chain_ndims]): + # We assume exactly one leading dimension indexes e.g. correlated samples + # from each Markov chain. + state = ops.convert_to_tensor(state, name="state") + sample_ndims = 1 + + sample_axis = math_ops.range(0, sample_ndims) + chain_axis = math_ops.range(sample_ndims, + sample_ndims + independent_chain_ndims) + sample_and_chain_axis = math_ops.range( + 0, sample_ndims + independent_chain_ndims) + + n = _axis_size(state, sample_axis) + m = _axis_size(state, chain_axis) + + # In the language of [2], + # B / n is the between chain variance, the variance of the chain means. + # W is the within sequence variance, the mean of the chain variances. + b_div_n = _reduce_variance( + math_ops.reduce_mean(state, sample_axis, keepdims=True), + sample_and_chain_axis, + biased=False) + w = math_ops.reduce_mean( + _reduce_variance(state, sample_axis, keepdims=True, biased=True), + sample_and_chain_axis) + + # sigma^2_+ is an estimate of the true variance, which would be unbiased if + # each chain was drawn from the target. c.f. "law of total variance." + sigma_2_plus = w + b_div_n + + return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n) + + +# TODO(b/72873233) Move some variant of this to sample_stats. +def _reduce_variance(x, axis=None, biased=True, keepdims=False): + with ops.name_scope("reduce_variance"): + x = ops.convert_to_tensor(x, name="x") + mean = math_ops.reduce_mean(x, axis=axis, keepdims=True) + biased_var = math_ops.reduce_mean( + math_ops.squared_difference(x, mean), axis=axis, keepdims=keepdims) + if biased: + return biased_var + n = _axis_size(x, axis) + return (n / (n - 1.)) * biased_var + + +def _axis_size(x, axis=None): + """Get number of elements of `x` in `axis`, as type `x.dtype`.""" + if axis is None: + return math_ops.cast(array_ops.size(x), x.dtype) + return math_ops.cast( + math_ops.reduce_prod(array_ops.gather(array_ops.shape(x), axis)), x.dtype) + + +def _is_list_like(x): + """Helper which returns `True` if input is `list`-like.""" + return isinstance(x, (tuple, list)) + + +def _broadcast_maybelist_arg(states, secondary_arg, name): + """Broadcast a listable secondary_arg to that of states.""" + if _is_list_like(secondary_arg): + if len(secondary_arg) != len(states): + raise ValueError("Argument `%s` was a list of different length ({}) than " + "`states` ({})".format(name, len(states))) + else: + secondary_arg = [secondary_arg] * len(states) + + return secondary_arg diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eadf6f4d5fa1c776e2c71c66c4b64b8f5ac98359 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/variable_utils.py @@ -0,0 +1,29 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions related to managing `tf.Variable`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +from tensorflow.contrib.bayesflow.python.ops.variable_utils_impl import * # pylint: disable=wildcard-import,unused-wildcard-import,g-importing-member +from tensorflow.python.util import all_util + +_allowed_symbols = [ + "externalize_variables_as_args", +] + +all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..ca3d75b5bfee093449026c7d1d62e3bdeff6b096 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py @@ -0,0 +1,157 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions related to managing `tf.Variable`s. + +@@externalize_variables_as_args +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +from tensorflow.python.framework import ops +from tensorflow.python.ops import gradients_impl as gradients_ops +from tensorflow.python.ops import variable_scope as varscope_ops +from tensorflow.python.ops import variables as variables_ops + +__all__ = [ + "externalize_variables_as_args", +] + + +# Cause all warnings to always be triggered. +# Not having this means subsequent calls wont trigger the warning. +warnings.simplefilter("always") + + +def externalize_variables_as_args(fn, + fn_args=(), + ancestor_variables=None, + possible_ancestor_vars=None, + assert_variable_override=False, + name=None): + """"Converts variables within a callable into explicit args. + + Makes a new callable from `fn` which has arguments `list(fn_args) + + list(ancestor_variables)`. If `ancestor_variables` is not specified, it is + inferred by checking which of `possible_ancestor_vars` actually influences the + return value of `fn` (concretely, gradient of `fn(*fn_args)` is not `None`). + By default `possible_ancestor_vars` is `tf.trainable_variables() + + tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)`. + + #### Examples: + + ```python + num_samples = 2 + num_dims = 1 + dtype = np.float32 + + def foo(x): + x = tf.convert_to_tensor(x, dtype=dtype, name="x") + s = x.shape.as_list() + y = tf.get_variable( + name="y", + dtype=dtype, + initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)) + return x + y + + x = tf.constant(dtype([0.1, 0.2])) + + wrapped_foo, discovered_ancestor_variables = ( + externalize_variables_as_args(foo, [x])) + + new_x = dtype([[1.], [2.]]) + new_y = dtype([[3.], [4.]]) + new_result = wrapped_foo(new_x, new_y) + # ==> [[4.], [6.]] + + discovered_ancestor_variables == [tf.get_variable("y", dtype)] + # ==> [True] + ``` + + Args: + fn: Python callable which returns a `Tensor` and accepts `*fn_args`. + fn_args: Python list of args to `fn`. Represents dummy arguments passed to + `fn` to trace its execution; actual values are unimportant. These args are + only used to construct the output of `fn` and to resolve the ancestor + `tf.Variable`s. + Default value: `()` (i.e., `fn` takes no args). + ancestor_variables: Python list of `tf.Variable`s. When `None` the list is + expanded to non-`None` gradients of `fn(*fn_args)`. By directly providing + the `ancestor_variables` the internal call to `fn` is avoided. + Default value: `None` (i.e., `tf.Variable` dependencies are discovered). + possible_ancestor_vars: Python list of possible `tf.Variable`s which might + be a dependency of computing `fn(*fn_args)`. + Default value: `None` (i.e., expanded as described above). + assert_variable_override: Python `bool` indicating that not finding a + `tf.Variable` in the override list is an exception. + Default value: `False` (i.e., missing a `Variable` triggers a `warning`). + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "externalize_variables_as_args"). + + Returns: + wrapped_fn: Python callable taking arguments like + `*(list(fn_args) + discovered_ancestor_variables)`. + discovered_ancestor_variables: Python list of `tf.Variable`s known to be a + dependency of `fn(*fn_args)`. + + Raises: + ValueError: if `assert_variable_override` is `True` and `Variable` is + requested but not overridden. + """ + def _make_bypassing_custom_getter_fn(new_var_dict): + """Return dict value rather than what would otherwise be dict key.""" + def _custom_getter(getter, *args, **kwargs): + v = getter(*args, **kwargs) + new_v = new_var_dict.get(v, None) + if new_v is None: + msg = "Variable \"{}\" not found in bypass dict.".format(v) + if assert_variable_override: + raise ValueError(msg) + warnings.warn(msg) + return v + return new_v + return _custom_getter + + with ops.name_scope(name, "externalize_variables_as_args"): + if ancestor_variables is not None and not ancestor_variables: + return fn, () + if ancestor_variables is None: + y = fn(*fn_args) # Side-effect: adds trainable vars. + if possible_ancestor_vars is None: + possible_ancestor_vars = ( + variables_ops.trainable_variables() + + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + # TODO(b/72873296): Add a dedicated op for identifying ancestors. + ancestors = [v for g, v + in zip(gradients_ops.gradients(y, possible_ancestor_vars), + possible_ancestor_vars) + if g is not None] + ancestor_variables = sorted(ancestors, key=lambda v: v.name) + n = len(fn_args) + def _fn(*args): + with ops.name_scope("wrapped_fn"): + vars_dict = dict( + (k, ops.convert_to_tensor( + v, dtype=k.dtype.base_dtype, name=k.op.name)) + for k, v in zip(ancestor_variables, args[n:])) + with varscope_ops.variable_scope( + varscope_ops.get_variable_scope(), + reuse=True, + custom_getter=_make_bypassing_custom_getter_fn(vars_dict)): + return fn(*args[:n]) + return _fn, ancestor_variables diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index 4b5d5ba0de6c3995ee2da7a44ab0ba099cbf1b35..754b7bc3270d647fc381033b769eadd7b791771e 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -48,8 +48,9 @@ class CreateTreeEnsembleVariableOp : public OpKernel { if (!result->InitFromSerialized(tree_ensemble_config_t->scalar()(), stamp_token)) { result->Unref(); - OP_REQUIRES(context, false, errors::InvalidArgument( - "Unable to parse tree ensemble config.")); + OP_REQUIRES( + context, false, + errors::InvalidArgument("Unable to parse tree ensemble config.")); } // Only create one, if one does not exist already. Report status for all diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index f8086b0c2bb93eae6af0336bbe33fc23f8fcde22..b3fe38614e05801b223f0c96f7a70ce7e432a70b 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -47,8 +47,8 @@ namespace boosted_trees { using boosted_trees::learner::LearnerConfig; using boosted_trees::learner::LearningRateConfig; using boosted_trees::learner::LearningRateDropoutDrivenConfig; -using boosted_trees::models::MultipleAdditiveTrees; using boosted_trees::models::DecisionTreeEnsembleResource; +using boosted_trees::models::MultipleAdditiveTrees; using boosted_trees::utils::DropoutUtils; using boosted_trees::utils::TensorUtils; diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 88f30064076d1b9410665e06ca27e20d14c6dde0..0f4c2298f56be48bb32f52d5d44cff8afe284f1e 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -36,13 +36,12 @@ namespace tensorflow { using ::boosted_trees::QuantileConfig; -using boosted_trees::utils::TensorUtils; using boosted_trees::QuantileStreamResource; +using boosted_trees::utils::TensorUtils; namespace { const char* const kExampleWeightsName = "example_weights"; const char* const kMaxElementsName = "max_elements"; -const char* const kHandleName = "handle"; const char* const kNextStampTokenName = "next_stamp_token"; const char* const kStampTokenName = "stamp_token"; const char* const kAreBucketsReadyName = "are_buckets_ready"; @@ -52,7 +51,6 @@ const char* const kNumSparseFeaturesName = "num_sparse_features"; const char* const kSparseBucketsName = "sparse_buckets"; const char* const kSparseValuesName = "sparse_values"; const char* const kSparseIndicesName = "sparse_indices"; -const char* const kSparseStreamsStateName = "sparse_streams_state"; const char* const kSparseSummariesName = "sparse_summaries"; const char* const kSparseConfigName = "sparse_config"; const char* const kSparseOutputTensorName = "sparse_quantiles"; @@ -60,7 +58,6 @@ const char* const kSparseOutputTensorName = "sparse_quantiles"; const char* const kDenseBucketsName = "dense_buckets"; const char* const kDenseConfigName = "dense_config"; const char* const kDenseOutputTensorName = "dense_quantiles"; -const char* const kDenseStreamsStateName = "dense_streams_state"; const char* const kDenseSummariesName = "dense_summaries"; const char* const kDenseValuesName = "dense_values"; const char* const kNumDenseFeaturesName = "num_dense_features"; @@ -387,7 +384,7 @@ class MakeQuantileSummariesOp : public OpKernel { protobuf::Arena arena; ::boosted_trees::QuantileSummaryState* summary_proto = protobuf::Arena::CreateMessage< - ::boosted_trees::QuantileSummaryState>(&arena); + ::boosted_trees::QuantileSummaryState>(&arena); const auto& summary = stream.GetFinalSummary(); CopySummaryToProto(summary, summary_proto); // Output to tensor. diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 18b4abd654ea3541d646a43ac901aca1a678446f..44a8ffaf4b2f5a9c11b3abc46ce55a18c80ad318 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -34,10 +34,10 @@ namespace tensorflow { +using boosted_trees::learner::LearnerConfig_MultiClassStrategy; using boosted_trees::learner::SplitInfo; using boosted_trees::learner::stochastic::GradientStats; using boosted_trees::learner::stochastic::NodeStats; -using boosted_trees::learner::LearnerConfig_MultiClassStrategy; namespace { const int32 DUMMY_FEATURE_DIMENSION = -1; @@ -47,9 +47,8 @@ class BaseBuildSplitOp : public OpKernel { public: explicit BaseBuildSplitOp(OpKernelConstruction* const context) : OpKernel(context) { - OP_REQUIRES_OK( - context, - context->GetAttr("feature_column_group_id", &feature_column_group_id_)); + OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id", + &feature_column_group_id_)); OP_REQUIRES_OK(context, context->GetAttr("l1_regularization", &l1_regularization_)); OP_REQUIRES_OK(context, diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index a9a229c8ae0c26bba5f0a684dad7e546298577bb..90a0655201f8cb8df6fc6417cb51216dec91b4d7 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -134,10 +134,9 @@ void SerializeScalarAccumulatorToOutput( OpKernelContext* context) { int64 num_slots = accumulator_resource.values().size(); Tensor* partition_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_partition_ids", TensorShape({num_slots}), - &partition_ids_t)); + OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", + TensorShape({num_slots}), + &partition_ids_t)); auto partition_ids = partition_ids_t->vec(); // Feature ids tensor has ids of feature columns and their dimensions. @@ -149,15 +148,14 @@ void SerializeScalarAccumulatorToOutput( Tensor* gradients_t = nullptr; OP_REQUIRES_OK( - context, - context->allocate_output("output_gradients", TensorShape({num_slots}), - &gradients_t)); + context, context->allocate_output( + "output_gradients", TensorShape({num_slots}), &gradients_t)); auto gradients = gradients_t->vec(); Tensor* hessians_t = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output( - "output_hessians", TensorShape({num_slots}), &hessians_t)); + OP_REQUIRES_OK( + context, context->allocate_output("output_hessians", + TensorShape({num_slots}), &hessians_t)); auto hessians = hessians_t->vec(); int i = 0; @@ -177,10 +175,9 @@ void SerializeTensorAccumulatorToOutput( OpKernelContext* context) { int64 num_slots = accumulator_resource.values().size(); Tensor* partition_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_partition_ids", TensorShape({num_slots}), - &partition_ids_t)); + OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", + TensorShape({num_slots}), + &partition_ids_t)); auto partition_ids = partition_ids_t->vec(); Tensor* feature_ids_t = nullptr; @@ -202,9 +199,8 @@ void SerializeTensorAccumulatorToOutput( int64 num_hessian_elements = hessian_shape.num_elements(); hessian_shape.InsertDim(0, num_slots); Tensor* hessians_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_hessians", hessian_shape, &hessians_t)); + OP_REQUIRES_OK(context, context->allocate_output("output_hessians", + hessian_shape, &hessians_t)); auto hessians = hessians_t->flat_outer_dims(); int i = 0; diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc index f867e77d3ef0609774628b2a9c36ca52bcf2a957..8bca132acfde9397942b198db9a8d4c0e4d74897 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc @@ -17,8 +17,8 @@ #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" -using tensorflow::test::AsTensor; using std::vector; +using tensorflow::test::AsTensor; namespace tensorflow { namespace boosted_trees { diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index 1c4181f1b13b01f85833157e554c3b821f96ff90..8ad97fedc923ac50bcaad86e0ba2c2e46df6821b 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -15,9 +15,9 @@ #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ +#include #include #include -#include #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h" #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h" diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc index cbe26ba918d384ad903fb854ca3e88e84d16a923..705b65e9db9f1aed9af1be153240d57e163c2d5b 100644 --- a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc @@ -22,9 +22,9 @@ namespace tensorflow { namespace boosted_trees { namespace testutil { +using boosted_trees::trees::DenseFloatBinarySplit; using tensorflow::boosted_trees::trees::DecisionTreeConfig; using tensorflow::boosted_trees::trees::TreeNode; -using boosted_trees::trees::DenseFloatBinarySplit; namespace { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc index 9de3e32b097a151b3bd6f5c30df2db0938b65e9c..609519e8b1153a27d987c5f9ca9bfcc9ee6717d6 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc @@ -25,8 +25,8 @@ namespace boosted_trees { namespace utils { namespace { -using test::AsTensor; using errors::InvalidArgument; +using test::AsTensor; class BatchFeaturesTest : public ::testing::Test {}; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc index 38f0151255bbf4fcd87f1d0d76fd111649ee4a12..db34db998a7442c69f2ab468f4557d991429f4ee 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc @@ -23,10 +23,10 @@ #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/platform/logging.h" +using tensorflow::Status; using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; using tensorflow::random::PhiloxRandom; using tensorflow::random::SimplePhilox; -using tensorflow::Status; namespace tensorflow { namespace boosted_trees { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc index ce7632e58987f5890beaded5dd305724f950e1e8..02f972c8e00e8229426ac53d8f20765484787b6e 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc @@ -26,9 +26,9 @@ #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" +using std::unordered_set; using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; using tensorflow::boosted_trees::trees::DecisionTreeEnsembleConfig; -using std::unordered_set; namespace tensorflow { namespace boosted_trees { diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index bb57dcf8ae7475486bcc0fc82460cbbce9a18b68..ae99d53a2cf805d70d60746cd44f73f7fd9dc6e2 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -19,8 +19,8 @@ namespace tensorflow { namespace boosted_trees { -using shape_inference::InferenceContext; using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_RESOURCE_HANDLE_OP(QuantileStreamResource); diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 0d27ddaf3a1d540efee268c2bcca217077ff5871..5d0ebbf73ce1272b51a475f67984db3a181b7130 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -18,9 +18,9 @@ namespace tensorflow { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -using shape_inference::DimensionHandle; REGISTER_OP("BuildDenseInequalitySplits") .Attr("feature_column_group_id: int") diff --git a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc index 0354f7853cbedf22d0a299273b4dbd225b3121ab..179505eef01f79bb149137400468b84285fe478a 100644 --- a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc @@ -19,9 +19,9 @@ namespace tensorflow { namespace boosted_trees { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -using shape_inference::DimensionHandle; REGISTER_RESOURCE_HANDLE_OP(StatsAccumulatorScalarResource); diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index eefa7ef0dccf5e88099974302dd26eebe21b1bd2..81f58de28cbe98bb996c6665114eeb0030ee52f9 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -183,11 +183,10 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): self.assertEqual(num_quantiles + 1, len(buckets)) self.assertAllEqual([2030, 2040, 2050, 2060], buckets) - def _testStreamingQuantileBucketsHelper(self, inputs): + def _testStreamingQuantileBucketsHelper( + self, inputs, num_quantiles=3, expected_buckets=None): """Helper to test quantile buckets on different inputs.""" - # Use 3 quantiles, 4 boundaries for simplicity. - num_quantiles = 3 # set generate_quantiles to True since the test will generate fewer # boundaries otherwise. with self.test_session() as sess: @@ -213,7 +212,10 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): buckets, are_ready_flush = (sess.run( [buckets, are_ready_flush])) self.assertEqual(True, are_ready_flush) + # By default, use 3 quantiles, 4 boundaries for simplicity. self.assertEqual(num_quantiles + 1, len(buckets)) + if expected_buckets: + self.assertAllEqual(buckets, expected_buckets) def testStreamingQuantileBucketsRepeatedSingleValue(self): inputs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @@ -231,6 +233,28 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): inputs = [5] self._testStreamingQuantileBucketsHelper(inputs) + def testStreamingQuantileBucketsEqualDistributionInSequence(self): + # Input pattern is of the form [1, 1, 1, 2, 2, 2, 3, 3, 3, ...] + ones = 100 * [1] + inputs = [] + for i in range(1, 101): + inputs += [i * k for k in ones] + # Expect 100 equally spaced buckets. + expected_buckets = range(1, 101) + self._testStreamingQuantileBucketsHelper( + inputs, num_quantiles=99, expected_buckets=expected_buckets) + + def testStreamingQuantileBucketsEqualDistributionInterleaved(self): + # Input pattern is of the form [1, 2, 3, 1, 2, 3, 1, 2, 3, ...] + sequence = range(1, 101) + inputs = [] + for _ in range(1, 101): + inputs += sequence + # Expect 100 equally spaced buckets. + expected_buckets = range(1, 101) + self._testStreamingQuantileBucketsHelper( + inputs, num_quantiles=99, expected_buckets=expected_buckets) + def testStreamingQuantileBuckets(self): """Sets up the quantile summary op test as follows. diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index b281a4c6d1cab9bfa1dc4018c8f49a16f21f2a36..7a5f329b7ab3216972180ccbb4c85f2537175422 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -81,32 +81,32 @@ def _scheduled_stamp_resource_op_runner(batch, stamp): if not batch: return arg_keys = set(batch[0].args.keys()) - grouped_args = collections.defaultdict(list) + grouped_args = collections.OrderedDict() resource_handles = [] # Check that the set of arguments is the same across all the scheduled ops. for op in batch: if set(op.args.keys()) != arg_keys: raise ValueError("Mismatching arguments: %s, %s.", op.args, arg_keys) for key in arg_keys: - grouped_args[key].append(op.args[key]) + grouped_args.setdefault(key, []).append(op.args[key]) resource_handles.append(op.resource_handle) # Move all the inputs to the op device in one RPC. - grouped_args = { - k: _move_tensors(v, resource_handles[0].device) - for k, v in grouped_args.items() - } + grouped_args = collections.OrderedDict( + (k, _move_tensors(v, resource_handles[0].device)) + for k, v in sorted(grouped_args.items())) with ops.device(resource_handles[0].device): return batch[0].op(resource_handles, stamp, **grouped_args) def run_handler_scheduled_ops(per_handler_ops, stamp, worker_device): """Given a dictionary of ops for each handler, runs them in batch.""" - batched_ops = collections.defaultdict(list) + batched_ops = collections.OrderedDict() # Group the ops by their batching_key. Ops that share the same batching key # can be executed together. - for handler in sorted(per_handler_ops.keys()): + for handler in per_handler_ops.keys(): for op in per_handler_ops[handler]: - batched_ops[(op.batching_key(), op.batch_runner_fn())].append(op) + key = (op.batching_key(), op.batch_runner_fn()) + batched_ops.setdefault(key, []).append(op) op_results = {} for batch in batched_ops.values(): # Run each of the batched ops using its runner. diff --git a/tensorflow/contrib/boosted_trees/python/training/__init__.py b/tensorflow/contrib/boosted_trees/python/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b569ac5fdb60e0907c322ad73aca65645e548d94 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/python/training/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""training module under boosted_trees.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/__init__.py b/tensorflow/contrib/boosted_trees/python/training/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1750117cd7c311515b4bca6882d55f496daac0e --- /dev/null +++ b/tensorflow/contrib/boosted_trees/python/training/functions/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""functions module under boosted_trees.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index b95956dae2a62b28643cd31815c5f5650eca337b..f0b66dcbbe1c5167b9993e66b30b1dc8a839c380 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy from tensorflow.contrib import learn @@ -163,7 +164,7 @@ def extract_features(features, feature_columns): scope = "gbdt" with variable_scope.variable_scope(scope): feature_columns = list(feature_columns) - transformed_features = {} + transformed_features = collections.OrderedDict() for fc in feature_columns: # pylint: disable=protected-access if isinstance(fc, feature_column_lib._EmbeddingColumn): @@ -681,13 +682,13 @@ class GradientBoostedDecisionTreeModel(object): control_flow_ops.no_op)) # Update handler stats. - handler_reads = {} + handler_reads = collections.OrderedDict() for handler in handlers: handler_reads[handler] = handler.scheduled_reads() handler_results = batch_ops_utils.run_handler_scheduled_ops( handler_reads, ensemble_stamp, worker_device) - per_handler_updates = {} + per_handler_updates = collections.OrderedDict() # Two values per handler. First one is if the handler is active for the # current layer. The second one is if the handler is going to be active # for the next layer. diff --git a/tensorflow/contrib/ndlstm/__init__.py b/tensorflow/contrib/boosted_trees/python/utils/__init__.py similarity index 87% rename from tensorflow/contrib/ndlstm/__init__.py rename to tensorflow/contrib/boosted_trees/python/utils/__init__.py index 52e83069cb0c68b510da46149248369dce376647..6ceb150c26552584d631948f5eef2fedfa690894 100644 --- a/tensorflow/contrib/ndlstm/__init__.py +++ b/tensorflow/contrib/boosted_trees/python/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +"""utils module under boosted_trees.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 1e8b3ac08a74a94a0e5729e42ace91398a7b5c94..ab7ac2aba605db22a8ed370049b27d55cf1d413a 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -78,7 +78,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): # Calculate softmax probabilities for each class. unnormalized_probs = math_ops.exp(logits) - normalizers = math_ops.reduce_sum(unnormalized_probs, 1, keep_dims=True) + normalizers = math_ops.reduce_sum(unnormalized_probs, 1, keepdims=True) softmax_predictions = math_ops.divide(unnormalized_probs, math_ops.add(normalizers, eps)) @@ -120,7 +120,7 @@ def per_example_squared_loss(labels, weights, predictions): update_op: An update operation to update the loss's internal state. """ unweighted_loss = math_ops.reduce_sum( - math_ops.square(predictions - labels), 1, keep_dims=True) + math_ops.square(predictions - labels), 1, keepdims=True) return unweighted_loss * weights, control_flow_ops.no_op() diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc index deb324634b6edc17c9725996115d80c5bd11cbde..1bfd27305d569668a0bd67d876e59eec082296b3 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" namespace tensorflow { - namespace { constexpr size_t kBufferSize = 1024 * 1024; // In bytes. @@ -40,33 +39,6 @@ Status ParseJson(StringPiece json, Json::Value* result) { return Status::OK(); } -string ColumnTypeToString(BigQueryTableAccessor::ColumnType enum_type) { - switch (enum_type) { - case BigQueryTableAccessor::ColumnType::kRecord: - return "RECORD"; - case BigQueryTableAccessor::ColumnType::kString: - return "STRING"; - case BigQueryTableAccessor::ColumnType::kBytes: - return "BYTES"; - case BigQueryTableAccessor::ColumnType::kInteger: - return "INTEGER"; - case BigQueryTableAccessor::ColumnType::kFloat: - return "FLOAT"; - case BigQueryTableAccessor::ColumnType::kBoolean: - return "BOOLEAN"; - case BigQueryTableAccessor::ColumnType::kTimestamp: - return "TIMESTAMP"; - case BigQueryTableAccessor::ColumnType::kDate: - return "DATE"; - case BigQueryTableAccessor::ColumnType::kTime: - return "TIME"; - case BigQueryTableAccessor::ColumnType::kDatetime: - return "DATETIME"; - case BigQueryTableAccessor::ColumnType::kNone: - return "NONE"; - } -} - Status ParseColumnType(const string& type, BigQueryTableAccessor::ColumnType* enum_type) { if (type == "RECORD") { diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h index 59f23332983e2328286d3b1b8b8c8fa228be991e..fea6b15640ded74432f35112bc5d5d68e641c9dc 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h @@ -399,6 +399,6 @@ const string kTestEmptyRow = R"({ }]}]})"; } // namespace -} // namepsace tensorflow +} // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 15abd2be0385eb776ff4f76484133efb6e34f076..80e18a43a71cc9d6c9e2ccf5836e50c6427a30f6 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -34,6 +34,7 @@ py_library( ":cluster_resolver_py", ":gce_cluster_resolver_py", ":tpu_cluster_resolver_py", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py index d17501e87e79158b1602ac6ddecc091bd86f2c2d..b4d8cd4a7cf42e910e7506dbeec8656a2cef62eb 100644 --- a/tensorflow/contrib/cluster_resolver/__init__.py +++ b/tensorflow/contrib/cluster_resolver/__init__.py @@ -26,3 +26,15 @@ from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver # pylint: enable=wildcard-import,unused-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'ClusterResolver', + 'SimpleClusterResolver', + 'UnionClusterResolver', + 'GceClusterResolver', + 'TPUClusterResolver', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 2e75ac226ea74e879edda5e03dff3d53c8a76569..a6a6e642e4e4c721b94821a70d55d6fe931347d6 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -143,7 +143,8 @@ class TPUClusterResolver(ClusterResolver): request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() - instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list.append(instance_url) + if 'health' in response and response['health'] == 'HEALTHY': + instance_url = '%s:%s' % (response['ipAddress'], response['port']) + worker_list.append(instance_url) return ClusterSpec({self._job_name: worker_list}) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 0c4730613af4ad9ca87deb6200ab4bb93d3f6a53..4fd34629cf74f90869c77b8cb098d3c585a49404 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -105,7 +105,8 @@ class TPUClusterResolverTest(test.TestCase): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } @@ -126,7 +127,8 @@ class TPUClusterResolverTest(test.TestCase): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } @@ -147,11 +149,13 @@ class TPUClusterResolverTest(test.TestCase): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' }, 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { 'ipAddress': '10.4.5.6', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } @@ -169,15 +173,54 @@ class TPUClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testHealthyTpuNodeRetrieval(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'health': 'HEALTHY' + }, + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { + 'ipAddress': '10.4.5.6', + 'port': '8470', + }, + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-3': { + 'ipAddress': '10.7.8.9', + 'port': '8470', + 'health': 'UNHEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=['test-tpu-2', 'test-tpu-1', 'test-tpu-3'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'tpu_worker' + tasks { + key: 0 + value: '10.1.2.3:8470' + } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testGetMasterMultipleEntries(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' }, 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { 'ipAddress': '10.4.5.6', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 817e96f5da0e7512a9fd99cc9a4b4c6025d7dd68..23b31ae1dcc83d8a7152354ac147de9ada320429 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -52,6 +52,7 @@ if (NOT WIN32) # for targets that link ${CMAKE_THREAD_LIBS_INIT}. find_package (Threads) + # Options for linking CUDA/CUDNN libraries option(tensorflow_PATH_STATIC_LIB "Additional library search path for libcudnn_static.a, libnccl_static.a, libculibos.a" /usr/local/cuda/lib64/) option(tensorflow_CUDNN_INCLUDE "cudnn.h header install path" /usr/include/) if (NOT tensorflow_CUDNN_INCLUDE) @@ -73,6 +74,14 @@ if (NOT WIN32) # option's default value is OFF. Fill it with real default values set(tensorflow_CUDA_LIBRARY_PATH /usr/local/cuda/lib64) endif (NOT tensorflow_CUDA_LIBRARY_PATH) + + # Options for linking other libraries + option(systemlib_ZLIB "Use the system installed library as shared objects instead of downloading ZLIB and statically linking to it: ZLIB" OFF) + + option(systemlib_ALL "Turn on every possible systemlib_* options" OFF) + if (systemlib_ALL) + set (systmelib_ZLIB ON) + endif (systemlib_ALL) endif() if (WIN32) @@ -134,6 +143,9 @@ if(WIN32) set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /D_ITERATOR_DEBUG_LEVEL=0") set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /D_ITERATOR_DEBUG_LEVEL=0") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /D_ITERATOR_DEBUG_LEVEL=0") + + # Try to avoid flaky failures due to failed generation of generate.stamp files. + set(CMAKE_SUPPRESS_REGENERATION ON) endif() if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") @@ -185,8 +197,10 @@ if (tensorflow_BUILD_CC_TESTS) include(googletest) endif() +add_definitions(${ADD_CFLAGS}) +link_directories(${ADD_LINK_DIRECTORY}) + set(tensorflow_EXTERNAL_LIBRARIES - ${zlib_STATIC_LIBRARIES} ${gif_STATIC_LIBRARIES} ${png_STATIC_LIBRARIES} ${jpeg_STATIC_LIBRARIES} @@ -200,6 +214,15 @@ set(tensorflow_EXTERNAL_LIBRARIES ${re2_STATIC_LIBRARIES} ${sqlite_STATIC_LIBRARIES} ) + +if (systemlib_ZLIB) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${ZLIB_LIBRARIES}) +else (systemlib_ZLIB) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${zlib_STATIC_LIBRARIES}) +endif (systemlib_ZLIB) + set(tensorflow_EXTERNAL_DEPENDENCIES zlib_copy_headers_to_destination gif_copy_headers_to_destination @@ -283,7 +306,21 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED) + # later command will make use of the value in tensorflow_CUDA_VERSION + find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED EXACT) + + # Test compatibility of compiler on CUDA + try_compile(CUDA_TEST_COMPILE_C + ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda + ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.c + CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) + try_compile(CUDA_TEST_COMPILE_CXX + ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda + ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.cc + CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) + if(NOT (CUDA_TEST_COMPILE_C AND CUDA_TEST_COMPILE_CXX)) + message(FATAL_ERROR "Selected compiler (or version) is not supported for CUDA") + endif() # by default we assume compute cabability 3.5 and 5.2. If you change this change it in # CUDA_NVCC_FLAGS and cuda_config.h below @@ -304,7 +341,8 @@ if (tensorflow_ENABLE_GPU) if(NOT CUDNN_HOME) set(CUDNN_HOME ${CUDA_TOOLKIT_TARGET_DIR}) endif(NOT CUDNN_HOME) - include_directories(${CUDNN_HOME}) + set(CUDNN_INCLUDE "${CUDNN_HOME}/include") + set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib) else (WIN32) @@ -332,10 +370,10 @@ if (tensorflow_ENABLE_GPU) message("culibos-static: ${culibos_STATIC_LIBRARY}") endif (NOT culibos_STATIC_LIBRARY) - include_directories(${CUDNN_INCLUDE}) set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) endif (WIN32) + include_directories(${CUDNN_INCLUDE}) # Remove "." from CUDA version variable. string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) @@ -351,29 +389,22 @@ if (tensorflow_ENABLE_GPU) "#endif // CUDA_CUDA_CONFIG_H_\n" ) - if (WIN32) - # tf assumes in various places header files to be in cuda/include. On windows the cuda sdk - # installs them under cuda/version/include and to avoid that we need to change tf we copy a - # few files to cuda/include - FILE(COPY - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDNN_HOME}/include/cudnn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h - DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include - ) - else(WIN32) - # Linux has slightly differnt install paths than Windows - FILE(COPY - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDNN_INCLUDE}/cudnn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h - DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include - ) - endif(WIN32) + # tf assumes in various places header files to be in cuda/include. On windows the cuda sdk + # installs them under cuda/version/include and to avoid that we need to change tf we copy a + # few files to cuda/include + FILE(COPY + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h + ${CUDNN_INCLUDE}/cudnn.h + DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include + ) include_directories(${tensorflow_source_dir}/third_party/gpus) # add cuda libraries to tensorflow_EXTERNAL_LIBRARIES diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake index 5ad477fdff68feab4adf0c0072c68c8e55390ab8..3c4bb01e24fd121c9d0fc3594cc25de37af0e8a1 100644 --- a/tensorflow/contrib/cmake/external/boringssl.cmake +++ b/tensorflow/contrib/cmake/external/boringssl.cmake @@ -37,6 +37,7 @@ ExternalProject_Add(boringssl GIT_TAG ${boringssl_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" # BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${boringssl_STATIC_LIBRARIES} INSTALL_COMMAND "" CMAKE_CACHE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} diff --git a/tensorflow/contrib/cmake/external/farmhash.cmake b/tensorflow/contrib/cmake/external/farmhash.cmake index 0cd0c1030c73d5218411f281d2b077af217e8275..d51569bc213f2bd354571a00910714e787120951 100644 --- a/tensorflow/contrib/cmake/external/farmhash.cmake +++ b/tensorflow/contrib/cmake/external/farmhash.cmake @@ -33,6 +33,7 @@ if(WIN32) URL_HASH ${farmhash_HASH} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${farmhash_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/farmhash/CMakeLists.txt ${farmhash_BUILD} INSTALL_DIR ${farmhash_INSTALL} CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/fft2d.cmake b/tensorflow/contrib/cmake/external/fft2d.cmake index d3af2a46761c0f7f0b5db134af8400fc93f2f095..a7bc50d5bcd4384d5c943d681fd7cd6fa1ffa796 100644 --- a/tensorflow/contrib/cmake/external/fft2d.cmake +++ b/tensorflow/contrib/cmake/external/fft2d.cmake @@ -29,6 +29,7 @@ if(WIN32) URL_HASH ${fft2d_HASH} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${fft2d_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/fft2d/CMakeLists.txt ${fft2d_BUILD}/src/fft2d/CMakeLists.txt INSTALL_DIR ${fft2d_INSTALL} CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/gif.cmake b/tensorflow/contrib/cmake/external/gif.cmake index 3d53c51fffcec1602a3b5553cdf3b225e3b0ae46..e1f8d13f8ea47b83e4a1840afac7398ef226eb45 100644 --- a/tensorflow/contrib/cmake/external/gif.cmake +++ b/tensorflow/contrib/cmake/external/gif.cmake @@ -33,6 +33,7 @@ if(WIN32) PREFIX gif URL ${gif_URL} URL_HASH ${gif_HASH} + BUILD_BYPRODUCTS ${gif_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_SOURCE_DIR}/patches/gif/CMakeLists.txt ${gif_BUILD} INSTALL_DIR ${gif_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" diff --git a/tensorflow/contrib/cmake/external/googletest.cmake b/tensorflow/contrib/cmake/external/googletest.cmake index d09bb02890f25a0312e62c876c1729e57a059e82..7cc5ae6390934773635cf7a4dff77a3cbfb41ba1 100644 --- a/tensorflow/contrib/cmake/external/googletest.cmake +++ b/tensorflow/contrib/cmake/external/googletest.cmake @@ -20,8 +20,13 @@ set(googletest_BUILD ${CMAKE_CURRENT_BINARY_DIR}/googletest/) set(googletest_TAG ec44c6c1675c25b9827aacd08c02433cccde7780) if(WIN32) - set(googletest_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/$(Configuration)/gtest.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(googletest_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/$(Configuration)/gtest.lib) + else() + set(googletest_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/gtest.lib) + endif() else() set(googletest_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/${CMAKE_BUILD_TYPE}/gtest.a) @@ -33,6 +38,7 @@ ExternalProject_Add(googletest GIT_TAG ${googletest_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${googletest_STATIC_LIBRARIES} #PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_SOURCE_DIR}/patches/grpc/CMakeLists.txt ${GRPC_BUILD} INSTALL_COMMAND "" CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 28adb4fe84423bb5a21c78dac4e757505ce87d1d..a9f43a3ecba4830533efcc13f8c4c1c61fe1ef78 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -20,10 +20,17 @@ set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) set(GRPC_TAG 730b778632e79cc3c96ad237f282d687ee325ce7) if(WIN32) - set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/gpr.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(grpc_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/gpr.lib) + else() + set(grpc_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/grpc++_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/grpc_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib) + endif() else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a @@ -40,6 +47,7 @@ ExternalProject_Add(grpc GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES} BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" diff --git a/tensorflow/contrib/cmake/external/highwayhash.cmake b/tensorflow/contrib/cmake/external/highwayhash.cmake index 2c23bef8a331de356c93dbf9d0e91d8bb13bd6c8..a6e8a38d8c2ee3deb5453c264e0c5eb23248301f 100644 --- a/tensorflow/contrib/cmake/external/highwayhash.cmake +++ b/tensorflow/contrib/cmake/external/highwayhash.cmake @@ -42,6 +42,7 @@ ExternalProject_Add(highwayhash GIT_TAG ${highwayhash_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${highwayhash_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/highwayhash/CMakeLists.txt ${highwayhash_BUILD} INSTALL_DIR ${highwayhash_INSTALL} CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/jemalloc.cmake b/tensorflow/contrib/cmake/external/jemalloc.cmake index 198ba13e64e4b6df57c4325a0104b1a6745d173a..afadcc007d66414be3306e91e7186a00b6e587ce 100644 --- a/tensorflow/contrib/cmake/external/jemalloc.cmake +++ b/tensorflow/contrib/cmake/external/jemalloc.cmake @@ -24,8 +24,11 @@ if (WIN32) ${jemalloc_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include/msvc_compat ) - set(jemalloc_ADDITIONAL_CMAKE_OPTIONS -A x64) - set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.lib) + else() + set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/jemalloc.lib) + endif() else() set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.a) endif() @@ -36,12 +39,12 @@ ExternalProject_Add(jemalloc URL_HASH ${jemalloc_HASH} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 - CONFIGURE_COMMAND ${CMAKE_COMMAND} + BUILD_BYPRODUCTS ${jemalloc_STATIC_LIBRARIES} + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target jemalloc + INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step." + CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -Dwith-jemalloc-prefix:STRING=jemalloc_ -Dwithout-export:BOOL=ON - ${jemalloc_ADDITIONAL_CMAKE_OPTIONS} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target jemalloc - INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step." ) diff --git a/tensorflow/contrib/cmake/external/jpeg.cmake b/tensorflow/contrib/cmake/external/jpeg.cmake index d9a165e856c588880ebdf996666d70c9e7f53da8..c1c5842aa4454f1c95ec284392194a89d47ee8d5 100644 --- a/tensorflow/contrib/cmake/external/jpeg.cmake +++ b/tensorflow/contrib/cmake/external/jpeg.cmake @@ -46,6 +46,7 @@ if (WIN32) PREFIX jpeg URL ${jpeg_URL} URL_HASH ${jpeg_HASH} + BUILD_BYPRODUCTS ${jpeg_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/jpeg/CMakeLists.txt ${jpeg_BUILD} INSTALL_DIR ${jpeg_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" diff --git a/tensorflow/contrib/cmake/external/jsoncpp.cmake b/tensorflow/contrib/cmake/external/jsoncpp.cmake index 861201f97edbce2d9d70a833ce5a8cad46f2470a..84c52e3652ff935c287d32c0c80fd407e1213f29 100644 --- a/tensorflow/contrib/cmake/external/jsoncpp.cmake +++ b/tensorflow/contrib/cmake/external/jsoncpp.cmake @@ -23,7 +23,11 @@ set(jsoncpp_LIBRARIES ${jsoncpp_BUILD}/obj/so/libjsoncpp.so) set(jsoncpp_INCLUDES ${jsoncpp_BUILD}) if(WIN32) - set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/$(Configuration)/jsoncpp.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/$(Configuration)/jsoncpp.lib) + else() + set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/jsoncpp.lib) + endif() else() set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/libjsoncpp.a) endif() @@ -40,6 +44,7 @@ ExternalProject_Add(jsoncpp GIT_TAG ${jsoncpp_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${jsoncpp_STATIC_LIBRARIES} INSTALL_COMMAND "" CMAKE_CACHE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} diff --git a/tensorflow/contrib/cmake/external/lmdb.cmake b/tensorflow/contrib/cmake/external/lmdb.cmake index 41b314e2857577581eb27eb6c6480b757d0b436c..ed5ab788acc5625b9c8020fce15f027d98433096 100644 --- a/tensorflow/contrib/cmake/external/lmdb.cmake +++ b/tensorflow/contrib/cmake/external/lmdb.cmake @@ -20,10 +20,17 @@ set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb) set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install) +if(WIN32) + set(lmdb_STATIC_LIBRARIES ${lmdb_INSTALL}/lib/lmdb.lib) +else() + set(lmdb_STATIC_LIBRARIES ${lmdb_INSTALL}/lib/liblmdb.a) +endif() + ExternalProject_Add(lmdb PREFIX lmdb URL ${lmdb_URL} URL_HASH ${lmdb_HASH} + BUILD_BYPRODUCTS ${lmdb_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/lmdb/CMakeLists.txt ${lmdb_BUILD} INSTALL_DIR ${lmdb_INSTALL} @@ -35,12 +42,6 @@ ExternalProject_Add(lmdb -DCMAKE_INSTALL_PREFIX:STRING=${lmdb_INSTALL} ) -if(WIN32) - set(lmdb_STATIC_LIBRARIES ${lmdb_INSTALL}/lib/lmdb.lib) -else() - set(lmdb_STATIC_LIBRARIES ${lmdb_INSTALL}/lib/liblmdb.a) -endif() - set(lmdb_HEADERS "${lmdb_INSTALL}/include/lmdb.h" "${lmdb_INSTALL}/include/midl.h" diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 05080060479b6240edb8ab9f65160b3dd182feb9..f3a37ff5088e3f9e54e38c0edb5777c27b26969f 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -42,6 +42,7 @@ ExternalProject_Add(nsync GIT_TAG ${nsync_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${nsync_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/nsync/CMakeLists.txt ${nsync_BUILD} INSTALL_DIR ${nsync_INSTALL} CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index b277be5690387b06876ca89eb88becbf885486a4..6cd66a65990e7a2b963b52b310061b551752cd4d 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -21,9 +21,19 @@ set(png_BUILD ${CMAKE_BINARY_DIR}/png/src/png) set(png_INSTALL ${CMAKE_BINARY_DIR}/png/install) if(WIN32) - set(png_STATIC_LIBRARIES - debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib - optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(png_STATIC_LIBRARIES + debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib + optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + else() + if(CMAKE_BUILD_TYPE EQUAL Debug) + set(png_STATIC_LIBRARIES + ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib) + else() + set(png_STATIC_LIBRARIES + ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + endif() + endif() else() set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng12.a) endif() @@ -38,6 +48,7 @@ ExternalProject_Add(png DEPENDS zlib URL ${png_URL} URL_HASH ${png_HASH} + BUILD_BYPRODUCTS ${png_STATIC_LIBRARIES} INSTALL_DIR ${png_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index aedb793d2aef4bf6950cd074cd065909667eaf75..aba8a5244e17d717293deec6d9b6e8e725ef010e 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,14 +16,37 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) +set(PROTOBUF_TAG 396336eb961b75f03b25824fe86cf6490fb75e3a) if(WIN32) - set(protobuf_STATIC_LIBRARIES - debug ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/libprotobufd.lib - optimized ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/libprotobuf.lib) - set(PROTOBUF_PROTOC_EXECUTABLE ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/protoc.exe) - set(PROTOBUF_ADDITIONAL_CMAKE_OPTIONS -Dprotobuf_MSVC_STATIC_RUNTIME:BOOL=OFF -A x64) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(protobuf_STATIC_LIBRARIES + debug ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/libprotobufd.lib + optimized ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/libprotobuf.lib) + set(PROTOBUF_PROTOC_EXECUTABLE ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/protoc.exe) + else() + if(CMAKE_BUILD_TYPE EQUAL Debug) + set(protobuf_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/libprotobufd.lib) + else() + set(protobuf_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/libprotobuf.lib) + endif() + set(PROTOBUF_PROTOC_EXECUTABLE ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/protoc.exe) + endif() + + # This section is to make sure CONFIGURE_COMMAND use the same generator settings + set(PROTOBUF_GENERATOR_PLATFORM) + if (CMAKE_GENERATOR_PLATFORM) + set(PROTOBUF_GENERATOR_PLATFORM -A ${CMAKE_GENERATOR_PLATFORM}) + endif() + set(PROTOBUF_GENERATOR_TOOLSET) + if (CMAKE_GENERATOR_TOOLSET) + set(PROTOBUF_GENERATOR_TOOLSET -T ${CMAKE_GENERATOR_TOOLSET}) + endif() + set(PROTOBUF_ADDITIONAL_CMAKE_OPTIONS -Dprotobuf_MSVC_STATIC_RUNTIME:BOOL=OFF + -G${CMAKE_GENERATOR} ${PROTOBUF_GENERATOR_PLATFORM} ${PROTOBUF_GENERATOR_TOOLSET}) + # End of section else() set(protobuf_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/libprotobuf.a) set(PROTOBUF_PROTOC_EXECUTABLE ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/protoc) @@ -36,10 +59,15 @@ ExternalProject_Add(protobuf GIT_TAG ${PROTOBUF_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${PROTOBUF_PROTOC_EXECUTABLE} ${protobuf_STATIC_LIBRARIES} SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf + # SOURCE_SUBDIR cmake/ # Requires CMake 3.7, this will allow removal of CONFIGURE_COMMAND + # CONFIGURE_COMMAND resets some settings made in CMAKE_CACHE_ARGS and the generator used CONFIGURE_COMMAND ${CMAKE_COMMAND} cmake/ - -Dprotobuf_BUILD_TESTS=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -Dprotobuf_BUILD_TESTS:BOOL=OFF -DZLIB_ROOT=${ZLIB_INSTALL} ${PROTOBUF_ADDITIONAL_CMAKE_OPTIONS} INSTALL_COMMAND "" @@ -47,5 +75,7 @@ ExternalProject_Add(protobuf -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -Dprotobuf_BUILD_TESTS:BOOL=OFF + -Dprotobuf_MSVC_STATIC_RUNTIME:BOOL=OFF -DZLIB_ROOT:STRING=${ZLIB_INSTALL} ) diff --git a/tensorflow/contrib/cmake/external/re2.cmake b/tensorflow/contrib/cmake/external/re2.cmake index 371d8447f93735e7af2a5a2b16f128a47b5a082a..c4bc0b1707bf9e86ea41234c8155fd6321c4c33b 100644 --- a/tensorflow/contrib/cmake/external/re2.cmake +++ b/tensorflow/contrib/cmake/external/re2.cmake @@ -21,7 +21,11 @@ set(re2_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/re2/install) set(re2_TAG e7efc48) if(WIN32) - set(re2_STATIC_LIBRARIES ${re2_BUILD}/$(Configuration)/re2.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(re2_STATIC_LIBRARIES ${re2_BUILD}/$(Configuration)/re2.lib) + else() + set(re2_STATIC_LIBRARIES ${re2_BUILD}/re2.lib) + endif() else() set(re2_STATIC_LIBRARIES ${re2_BUILD}/libre2.a) endif() @@ -36,6 +40,7 @@ ExternalProject_Add(re2 GIT_TAG ${re2_TAG} INSTALL_DIR ${re2_INSTALL} BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${re2_STATIC_LIBRARIES} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake index fd57734298affda13fa90f4cff560eeeb08e59ab..f54197643b06781dad35b40f526f28d301047299 100644 --- a/tensorflow/contrib/cmake/external/snappy.cmake +++ b/tensorflow/contrib/cmake/external/snappy.cmake @@ -20,7 +20,11 @@ set(snappy_BUILD ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy) set(snappy_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy) if(WIN32) - set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/$(Configuration)/snappy.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/$(Configuration)/snappy.lib) + else() + set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/snappy.lib) + endif() else() set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/libsnappy.a) endif() @@ -35,6 +39,7 @@ ExternalProject_Add(snappy GIT_TAG ${snappy_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${snappy_STATIC_LIBRARIES} INSTALL_COMMAND "" LOG_DOWNLOAD ON LOG_CONFIGURE ON diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 8297c60712c49ed6f47a9750691eee1325a5b55e..57c4ae76517e4d7247093edd5e5bd95a83258d87 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -36,6 +36,7 @@ if (WIN32) PREFIX sqlite URL ${sqlite_URL} URL_HASH ${sqlite_HASH} + BUILD_BYPRODUCTS ${sqlite_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/sqlite/CMakeLists.txt ${sqlite_BUILD} INSTALL_DIR ${sqlite_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake index 5bec14fb00a50f6e6e8c7d8b703bde681e9d02ae..116d42309394b92407cef79c9d3a975f494bc3ff 100644 --- a/tensorflow/contrib/cmake/external/zlib.cmake +++ b/tensorflow/contrib/cmake/external/zlib.cmake @@ -12,50 +12,75 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -include (ExternalProject) - -set(zlib_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/zlib_archive) -set(ZLIB_URL https://github.com/madler/zlib) -set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib) -set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install) -set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d) - -if(WIN32) - set(zlib_STATIC_LIBRARIES - debug ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib - optimized ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib) -else() - set(zlib_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/libz.a) -endif() - -set(ZLIB_HEADERS - "${ZLIB_INSTALL}/include/zconf.h" - "${ZLIB_INSTALL}/include/zlib.h" -) - -ExternalProject_Add(zlib - PREFIX zlib - GIT_REPOSITORY ${ZLIB_URL} - GIT_TAG ${ZLIB_TAG} - INSTALL_DIR ${ZLIB_INSTALL} - BUILD_IN_SOURCE 1 - DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" - CMAKE_CACHE_ARGS - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} - -DCMAKE_BUILD_TYPE:STRING=Release - -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL} -) - -# put zlib includes in the directory where they are expected -add_custom_target(zlib_create_destination_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${zlib_INCLUDE_DIR} - DEPENDS zlib) - -add_custom_target(zlib_copy_headers_to_destination - DEPENDS zlib_create_destination_dir) - -foreach(header_file ${ZLIB_HEADERS}) - add_custom_command(TARGET zlib_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${zlib_INCLUDE_DIR}) -endforeach() +if (systemlib_ZLIB) + find_package(PkgConfig) + pkg_search_module(ZLIB REQUIRED zlib) + set(zlib_INCLUDE_DIR ${ZLIB_INCLUDE_DIRS}) + set(ADD_LINK_DIRECTORY ${ADD_LINK_DIRECTORY} ${ZLIB_LIBRARY_DIRS}) + set(ADD_CFLAGS ${ADD_CFLAGS} ${ZLIB_CFLAGS_OTHER}) + + # To meet DEPENDS zlib from other projects. + # If we hit this line, zlib is already built and installed to the system. + add_custom_target(zlib) + add_custom_target(zlib_copy_headers_to_destination) + +else (systemlib_ZLIB) + include (ExternalProject) + + set(zlib_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/zlib_archive) + set(ZLIB_URL https://github.com/madler/zlib) + set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib) + set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install) + set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d) + + if(WIN32) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(zlib_STATIC_LIBRARIES + debug ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib + optimized ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib) + else() + if(CMAKE_BUILD_TYPE EQUAL Debug) + set(zlib_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib) + else() + set(zlib_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib) + endif() + endif() + else() + set(zlib_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/libz.a) + endif() + + set(ZLIB_HEADERS + "${ZLIB_INSTALL}/include/zconf.h" + "${ZLIB_INSTALL}/include/zlib.h" + ) + + ExternalProject_Add(zlib + PREFIX zlib + GIT_REPOSITORY ${ZLIB_URL} + GIT_TAG ${ZLIB_TAG} + INSTALL_DIR ${ZLIB_INSTALL} + BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${zlib_STATIC_LIBRARIES} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + CMAKE_CACHE_ARGS + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL} + ) + + # put zlib includes in the directory where they are expected + add_custom_target(zlib_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${zlib_INCLUDE_DIR} + DEPENDS zlib) + + add_custom_target(zlib_copy_headers_to_destination + DEPENDS zlib_create_destination_dir) + + foreach(header_file ${ZLIB_HEADERS}) + add_custom_command(TARGET zlib_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${zlib_INCLUDE_DIR}) + endforeach() +endif (systemlib_ZLIB) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 7db454bd83ec7fee463b8cd448f5a5ff4ba73258..bfe53c01b3b5fb9db8a5d8fa280d1d7f98974882 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -6,6 +6,7 @@ tensorflow/core/example tensorflow/core/framework tensorflow/core/lib tensorflow/core/lib/core +tensorflow/core/profiler tensorflow/core/protobuf tensorflow/core/util tensorflow/examples @@ -33,9 +34,11 @@ tensorflow/python/grappler tensorflow/python/keras tensorflow/python/keras/activations tensorflow/python/keras/applications +tensorflow/python/keras/applications/densenet tensorflow/python/keras/applications/inception_resnet_v2 tensorflow/python/keras/applications/inception_v3 tensorflow/python/keras/applications/mobilenet +tensorflow/python/keras/applications/nasnet tensorflow/python/keras/applications/resnet50 tensorflow/python/keras/applications/vgg16 tensorflow/python/keras/applications/vgg19 @@ -172,6 +175,9 @@ tensorflow/contrib/factorization/kernels tensorflow/contrib/factorization/ops tensorflow/contrib/factorization/python tensorflow/contrib/factorization/python/ops +tensorflow/contrib/feature_column +tensorflow/contrib/feature_column/python +tensorflow/contrib/feature_column/python/feature_column tensorflow/contrib/ffmpeg tensorflow/contrib/ffmpeg/default tensorflow/contrib/framework @@ -214,6 +220,8 @@ tensorflow/contrib/input_pipeline/python/ops tensorflow/contrib/integrate tensorflow/contrib/integrate/python tensorflow/contrib/integrate/python/ops +tensorflow/contrib/kafka/python +tensorflow/contrib/kafka/python/ops tensorflow/contrib/keras tensorflow/contrib/keras/api tensorflow/contrib/keras/api/keras @@ -291,7 +299,9 @@ tensorflow/contrib/linear_optimizer/kernels/g3doc tensorflow/contrib/linear_optimizer/python tensorflow/contrib/linear_optimizer/python/ops # TODO(drpngx): Fix failing imports +# tensorflow/contrib/lite # tensorflow/contrib/lite/python +# tensorflow/contrib/lite/toco # tensorflow/contrib/lite/toco/python tensorflow/contrib/lookup tensorflow/contrib/losses @@ -321,8 +331,6 @@ tensorflow/contrib/nccl/kernels tensorflow/contrib/nccl/ops tensorflow/contrib/nccl/python tensorflow/contrib/nccl/python/ops -tensorflow/contrib/ndlstm -tensorflow/contrib/ndlstm/python tensorflow/contrib/nearest_neighbor/kernels tensorflow/contrib/nearest_neighbor/ops tensorflow/contrib/nearest_neighbor/python @@ -354,6 +362,7 @@ tensorflow/contrib/reduce_slice_ops/kernels tensorflow/contrib/reduce_slice_ops/ops tensorflow/contrib/reduce_slice_ops/python tensorflow/contrib/reduce_slice_ops/python/ops +tensorflow/contrib/remote_fused_graph tensorflow/contrib/remote_fused_graph/pylib tensorflow/contrib/remote_fused_graph/pylib/python tensorflow/contrib/remote_fused_graph/pylib/python/ops @@ -403,6 +412,10 @@ tensorflow/contrib/summary tensorflow/contrib/tensorboard tensorflow/contrib/tensorboard/plugins tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensorboard/plugins/trace +# TODO(sami): Add cmake implementations. +# tensorflow/contrib/tensorrt/python +# tensorflow/contrib/tensorrt/python/ops tensorflow/contrib/tensor_forest tensorflow/contrib/tensor_forest/client tensorflow/contrib/tensor_forest/hybrid @@ -413,6 +426,7 @@ tensorflow/contrib/tensor_forest/hybrid/python/layers tensorflow/contrib/tensor_forest/hybrid/python/models tensorflow/contrib/tensor_forest/hybrid/python/ops tensorflow/contrib/tensor_forest/kernels +tensorflow/contrib/tensor_forest/proto tensorflow/contrib/tensor_forest/python tensorflow/contrib/tensor_forest/python/ops tensorflow/contrib/testing @@ -433,6 +447,7 @@ tensorflow/contrib/timeseries/python/timeseries/state_space_models tensorflow/contrib/tpu tensorflow/contrib/tpu/ops tensorflow/contrib/tpu/profiler +tensorflow/contrib/tpu/proto tensorflow/contrib/tpu/python tensorflow/contrib/tpu/python/ops tensorflow/contrib/tpu/python/profiler diff --git a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c new file mode 100644 index 0000000000000000000000000000000000000000..9e355da33a7258119b6086216f5487d7ea94716c --- /dev/null +++ b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is a program to test if compiler is compatible with CUDA. +#define __CUDACC__ +#include "crt/host_config.h" + +int main(void) { + return 0; +} diff --git a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..beb574061bea8d04af8386223749677ae36a5d9b --- /dev/null +++ b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================*/ + +// This is a program to test if compiler is compatible with CUDA. +#define __CUDACC__ +#include "crt/host_config.h" + +int main(void) { + return 0; +} diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index f3cf3e70441de67ef79bc9cedf85549315170c29..f73da0b8ab18af1eca4c2bd577604595f8b8ec6d 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -149,7 +149,11 @@ add_library(tf_cc OBJECT ${tf_cc_srcs}) add_dependencies(tf_cc tf_cc_framework tf_cc_ops) if (WIN32) - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.lib") + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.lib") + else() + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib") + endif() else (WIN32) set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so") endif (WIN32) diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index e4213ea2a47da2a7381cccd0504235ad62018d4e..96ac60d095dbc84470ff1be92f4bf52bb420fc52 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -50,6 +50,12 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.h" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.cc" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.h" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 129c208ecd6b574ed63c2fe378e1a6ebb92de558..a1c320347fe60f87806736befc677541a93e7e93 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -292,6 +292,12 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.h" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.cc" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.h" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc" "${tensorflow_source_dir}/tensorflow/core/util/*.h" diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 6927bf03f08b68a1f13f6a0978af629af45575e8..f219d5eb577afa9edaadca09aef9869c81d2bd87 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -69,8 +69,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 138993db35252d3f1ab6326dff463bdc10cabdb1..59e094812aaf4da2549d96314fc550e5635f9de8 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -30,6 +30,7 @@ set(tf_op_lib_names "list_ops" "lookup_ops" "logging_ops" + "manip_ops" "math_ops" "nn_ops" "no_op" @@ -84,7 +85,7 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/te GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 8862390d2b62f72c11d60f2ae48a845d22363f06..b730ebd3baacafe8ae401e8987104f3062372954 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -307,7 +307,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) # containing the wrappers. add_custom_command( OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION} - COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} + COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} DEPENDS ${tf_python_op_lib_name}_gen_python ) @@ -335,6 +335,7 @@ GENERATE_PYTHON_OP_LIB("list_ops") GENERATE_PYTHON_OP_LIB("logging_ops") GENERATE_PYTHON_OP_LIB("lookup_ops") GENERATE_PYTHON_OP_LIB("nn_ops") +GENERATE_PYTHON_OP_LIB("manip_ops") GENERATE_PYTHON_OP_LIB("parsing_ops") GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" @@ -367,8 +368,8 @@ GENERATE_PYTHON_OP_LIB("contrib_coder_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops" @@ -540,7 +541,11 @@ if(WIN32) ${nsync_STATIC_LIBRARIES} ) - set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow.def") + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow.def") + else() + set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow.def") + endif() set_source_files_properties(${pywrap_tensorflow_deffile} PROPERTIES GENERATED TRUE) add_custom_command(TARGET pywrap_tensorflow_internal_static POST_BUILD @@ -548,6 +553,7 @@ if(WIN32) --input "${pywrap_tensorflow_internal_static_dependencies}" --output "${pywrap_tensorflow_deffile}" --target _pywrap_tensorflow_internal.pyd + BYPRODUCTS ${pywrap_tensorflow_deffile} # Required for Ninja ) endif(WIN32) @@ -701,11 +707,19 @@ add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/testing/python/framework/) if(WIN32) - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + add_custom_command(TARGET tf_python_build_pip_package POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) + else() + add_custom_command(TARGET tf_python_build_pip_package POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.dll + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) + endif() else() add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 571d2b0decb5e9afcec2314f9837546f0974e90d..6d36d5fc5c2854b2d7d2542a3cb12e033e193b88 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -46,7 +46,11 @@ if(WIN32) $ ) - set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/tensorflow.def") + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/tensorflow.def") + else() + set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/tensorflow.def") + endif() set_source_files_properties(${tensorflow_deffile} PROPERTIES GENERATED TRUE) add_custom_command(TARGET tensorflow_static POST_BUILD diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 2e79eadf7f566690a7742757ceb56e147ebd6ea0..1c4ebd7f0c1113bcd0857fb0858df2248499f920 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -156,6 +156,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/coder/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/feature_column/python/feature_column/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py" "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/*_test.py" @@ -275,8 +276,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" - # Broken tensorboard test due to cmake issues. - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py" # Deadlocks "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. @@ -310,6 +310,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/control_flow_util_test.py" # Flaky replicate_model_fn_test "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py" # b/71901810 + # Broken io_utils_test + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py" # b/72894325 ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/cmake/tf_tools.cmake b/tensorflow/contrib/cmake/tf_tools.cmake index cb58a2e7df85b2f214654eff5547c5788592f208..58c7df95c821b4d1aa2cc63c8aaf4039518b83ca 100644 --- a/tensorflow/contrib/cmake/tf_tools.cmake +++ b/tensorflow/contrib/cmake/tf_tools.cmake @@ -48,9 +48,6 @@ file(GLOB_RECURSE tf_tools_transform_graph_lib_exclude_srcs "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/compare_graphs.cc" "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/summarize_graph_main.cc" "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/transform_graph_main.cc" - "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/quantize_nodes.cc" - "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/quantize_weights.cc" - "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/round_weights.cc" ) list(REMOVE_ITEM tf_tools_transform_graph_lib_srcs ${tf_tools_transform_graph_lib_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index f67698eb99a38eae307b52e55de748a67b798cbd..53c2285699a6ca94e1e6b147080338b507f4d768 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -31,7 +31,7 @@ from __future__ import division from __future__ import print_function import argparse -import io +import codecs import os import re import subprocess @@ -103,7 +103,7 @@ def main(): for lib_path in args.input: proc = subprocess.Popen([DUMPBIN, "/nologo", "/linkermember:1", lib_path], stdout=subprocess.PIPE) - for line in io.TextIOWrapper(proc.stdout, encoding="utf-8"): + for line in codecs.getreader("utf-8")(proc.stdout): cols = line.split() if len(cols) < 2: continue @@ -131,7 +131,7 @@ def main(): # We compare on undname but use the decorated name from candidates. dupes = 0 proc = subprocess.Popen([UNDNAME, tmpfile.name], stdout=subprocess.PIPE) - for idx, line in enumerate(io.TextIOWrapper(proc.stdout, encoding="utf-8")): + for idx, line in enumerate(codecs.getreader("utf-8")(proc.stdout)): decorated = candidates[idx] if decorated in taken: # Symbol is already in output, done. diff --git a/tensorflow/contrib/coder/README.md b/tensorflow/contrib/coder/README.md index e1e867db5aa701eb73ee43a47cd3dcc2dc783a04..c6c379c458893551b765327c0c1cbfff7f24f9c3 100644 --- a/tensorflow/contrib/coder/README.md +++ b/tensorflow/contrib/coder/README.md @@ -30,7 +30,7 @@ following sense: around, - The number of CDF axes does not extend, i.e., `CDF.ndim == data.ndim + 1`. -In the previous example where data has shape (10, 10), the followings are +In the previous example where data has shape (10, 10), the following are acceptable CDF shapes: - (10, 10, 65) diff --git a/tensorflow/contrib/coder/kernels/range_coder.cc b/tensorflow/contrib/coder/kernels/range_coder.cc index f4f076b6c4e0c82cc297266bedc63034d5f5bf8b..21b35155ff317c6afbb1b86745f05385726505b6 100644 --- a/tensorflow/contrib/coder/kernels/range_coder.cc +++ b/tensorflow/contrib/coder/kernels/range_coder.cc @@ -276,7 +276,7 @@ void RangeEncoder::Finalize(string* sink) { } } else if (base_ != 0) { // If base == 0, then pick 0 from [base, base + size) and no zeros are - // explcitly written. + // explicitly written. // // Otherwise, pick (base + (2^16 - base[16:0])), i.e., round up base to the // next multiple of 2^16. As 2^16 < size, this value should be in the diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 2108e42bce4eba1eed158fe85888f1699a69ba7e..29a593f6bcfa05dcafcdb2f94087380ad720dba1 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -169,6 +170,7 @@ class JITTest(test.TestCase): self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) +@test_util.with_c_api class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): @@ -188,7 +190,7 @@ class CompilationEnabledInGradientTest(test.TestCase): for cg in c_grad_ops: self.assertTrue(cg.get_attr("_XlaCompile")) for ncg in nc_grad_ops: - with self.assertRaisesRegexp(ValueError, "No attr named"): + with self.assertRaisesRegexp(ValueError, "[Nn]o attr named"): ncg.get_attr("_XlaCompile") # d/dx (x ** 4) = 4 * (x ** 3) diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index bae66ffd4289308f2cbfc730ec50d057b13923fb..b806799202bff4f2f6dbf717fbeea74a04b8cd6e 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -35,10 +35,10 @@ from tensorflow.python.ops.variables import Variable from tensorflow.python.client.session import Session from tensorflow.python.framework import ops -__all__ = ["copy_op_to_graph", "copy_variable_to_graph", "get_copied_op"] +__all__ = ['copy_op_to_graph', 'copy_variable_to_graph', 'get_copied_op'] -def copy_variable_to_graph(org_instance, to_graph, scope=""): +def copy_variable_to_graph(org_instance, to_graph, scope=''): """Given a `Variable` instance from one `Graph`, initializes and returns a copy of it from another `Graph`, under the specified scope (default `""`). @@ -56,12 +56,11 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""): """ if not isinstance(org_instance, Variable): - raise TypeError(str(org_instance) + " is not a Variable") + raise TypeError(str(org_instance) + ' is not a Variable') #The name of the new variable - if scope != "": - new_name = (scope + '/' + - org_instance.name[:org_instance.name.index(':')]) + if scope != '': + new_name = (scope + '/' + org_instance.name[:org_instance.name.index(':')]) else: new_name = org_instance.name[:org_instance.name.index(':')] @@ -73,15 +72,15 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""): for name, collection in org_instance.graph._collections.items(): if org_instance in collection: if (name == ops.GraphKeys.GLOBAL_VARIABLES or - name == ops.GraphKeys.TRAINABLE_VARIABLES or - scope == ''): + name == ops.GraphKeys.TRAINABLE_VARIABLES or scope == ''): collections.append(name) else: collections.append(scope + '/' + name) #See if its trainable. - trainable = (org_instance in org_instance.graph.get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES)) + trainable = ( + org_instance in org_instance.graph.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES)) #Get the initial value with org_instance.graph.as_default(): temp_session = Session() @@ -89,17 +88,17 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""): #Initialize the new variable with to_graph.as_default(): - new_var = Variable(init_value, - trainable, - name=new_name, - collections=collections, - validate_shape=False) + new_var = Variable( + init_value, + trainable, + name=new_name, + collections=collections, + validate_shape=False) return new_var -def copy_op_to_graph(org_instance, to_graph, variables, - scope=""): +def copy_op_to_graph(org_instance, to_graph, variables, scope=''): """Returns a copy of an operation from another Graph under a specified scope. Given an `Operation` `org_instance` from one `Graph`, @@ -139,14 +138,12 @@ def copy_op_to_graph(org_instance, to_graph, variables, #If a variable by the new name already exists, return the #correspondng tensor that will act as an input if new_name in copied_variables: - return to_graph.get_tensor_by_name( - copied_variables[new_name].name) + return to_graph.get_tensor_by_name(copied_variables[new_name].name) #If an instance of the same name exists, return appropriately try: - already_present = to_graph.as_graph_element(new_name, - allow_tensor=True, - allow_operation=True) + already_present = to_graph.as_graph_element( + new_name, allow_tensor=True, allow_operation=True) return already_present except: pass @@ -184,20 +181,21 @@ def copy_op_to_graph(org_instance, to_graph, variables, #If it has an original_op parameter, copy it if op._original_op is not None: - new_original_op = copy_op_to_graph(op._original_op, to_graph, - variables, scope) + new_original_op = copy_op_to_graph(op._original_op, to_graph, variables, + scope) else: new_original_op = None #If it has control inputs, call this function recursively on each. - new_control_inputs = [copy_op_to_graph(x, to_graph, variables, - scope) - for x in op.control_inputs] + new_control_inputs = [ + copy_op_to_graph(x, to_graph, variables, scope) + for x in op.control_inputs + ] #If it has inputs, call this function recursively on each. - new_inputs = [copy_op_to_graph(x, to_graph, variables, - scope) - for x in op.inputs] + new_inputs = [ + copy_op_to_graph(x, to_graph, variables, scope) for x in op.inputs + ] #Make a new node_def based on that of the original. #An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it @@ -216,13 +214,8 @@ def copy_op_to_graph(org_instance, to_graph, variables, op_def = deepcopy(op._op_def) #Initialize a new Operation instance - new_op = ops.Operation(new_node_def, - to_graph, - new_inputs, - output_types, - new_control_inputs, - input_types, - new_original_op, + new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types, + new_control_inputs, input_types, new_original_op, op_def) #Use Graph's hidden methods to add the op to_graph._add_op(new_op) # pylint: disable=protected-access @@ -233,10 +226,10 @@ def copy_op_to_graph(org_instance, to_graph, variables, return new_op else: - raise TypeError("Could not copy instance: " + str(org_instance)) + raise TypeError('Could not copy instance: ' + str(org_instance)) -def get_copied_op(org_instance, graph, scope=""): +def get_copied_op(org_instance, graph, scope=''): """Given an `Operation` instance from some `Graph`, returns its namesake from `graph`, under the specified scope (default `""`). @@ -259,5 +252,5 @@ def get_copied_op(org_instance, graph, scope=""): else: new_name = org_instance.name - return graph.as_graph_element(new_name, allow_tensor=True, - allow_operation=True) + return graph.as_graph_element( + new_name, allow_tensor=True, allow_operation=True) diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py index 2798d31229d048561f8ebd9b63d3df94a44c45c7..05744bec4e05405c04b5ec442e72e4495737ab5b 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_test.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py @@ -17,9 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np from tensorflow.contrib.copy_graph.python.util import copy_elements -from tensorflow.contrib.framework.python.framework import tensor_util from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 62708636c6181ca63cddf2b2e7c84d3da740282a..a30bf06396117034f7cfe461d5e365b8a4a38a3f 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -105,8 +105,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, return utils.smart_cond( pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], 1), - fn1=_single_seq_fn, - fn2=_multi_seq_fn) + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) def crf_log_norm(inputs, sequence_lengths, transition_params): @@ -166,8 +166,8 @@ def crf_log_likelihood(inputs, sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix, if available. Returns: - log_likelihood: A scalar containing the log-likelihood of the given sequence - of tag indices. + log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of + each example, given the sequence of tag indices. transition_params: A [num_tags, num_tags] transition matrix. This is either provided by the caller or created in this function. """ @@ -182,7 +182,7 @@ def crf_log_likelihood(inputs, transition_params) log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) - # Normalize the scores to get the log-likelihood. + # Normalize the scores to get the log-likelihood per example. log_likelihood = sequence_scores - log_norm return log_likelihood, transition_params @@ -513,5 +513,5 @@ def crf_decode(potentials, transition_params, sequence_length): return utils.smart_cond( pred=math_ops.equal( potentials.shape[1].value or array_ops.shape(potentials)[1], 1), - fn1=_single_seq_fn, - fn2=_multi_seq_fn) + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc index 9e41e67857101534e8bfef8d5d0b8a45ed8f1f76..1a79bf066c3a27e040099729fb079ee963f59270 100644 --- a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc +++ b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc @@ -251,9 +251,8 @@ REGISTER_OP("CudnnRNNParamsToCanonical") TF_RETURN_IF_ERROR(c->GetAttr("num_params", &num_params)); // Set shape for weight matrices for (int i = 0; i < num_params; i++) { - c->set_output(i, - c->Matrix(InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim)); + c->set_output(i, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); } // Set shape for bias vectors for (int i = 0; i < num_params; i++) { @@ -300,6 +299,7 @@ upcoming training or inferences. num_params: number of parameter sets for all layers. Each layer may contain multiple parameter sets, with each set consisting of a weight matrix and a bias vector. -)doc", kCudnnRNNCommonAttrs)); +)doc", + kCudnnRNNCommonAttrs)); } // namespace tensorflow diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py index 4fc5ff1bd1887c4532e95fcf0e791d72b20471b0..933df6d71dd7c972efe63d54fa7344ecfc39b0a7 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py @@ -20,6 +20,7 @@ from __future__ import print_function import time +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import rnn as contrib_rnn from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.contrib.rnn.python.ops import lstm_ops diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 49d305cb0dd0387c34b7feb79ef631eac9e935cd..9897c31a98e0b335c18a84825fc518ed1fc310a2 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -51,7 +51,11 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import adagrad +from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent +from tensorflow.python.training import momentum +from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib @@ -316,6 +320,55 @@ class CudnnRNNTestBasic(TensorFlowTestCase): self.assertEqual(0, total_sum2_v) self.assertEqual(0, total_sum3_v) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testOptimizersSupport(self): + for opt in ("adagrad", "adam", "rmsprop", "momentum", "sgd"): + self._TestOptimizerSupportHelper(opt) + + def _GetOptimizer(self, opt): + if opt == "adagrad": + return adagrad.AdagradOptimizer(learning_rate=1e-2) + elif opt == "adam": + return adam.AdamOptimizer(learning_rate=1e-2) + elif opt == "rmsprop": + return rmsprop.RMSPropOptimizer(learning_rate=1e-2) + elif opt == "momentum": + return momentum.MomentumOptimizer(learning_rate=1e-2, momentum=0.9) + elif opt == "sgd": + return gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) + else: + raise ValueError("Unsupported optimizer: %s" % opt) + + def _TestOptimizerSupportHelper(self, opt): + num_layers = 4 + num_units = 2 + batch_size = 8 + direction = CUDNN_RNN_UNIDIRECTION + dir_count = 1 + + with ops.Graph().as_default() as g: + kernel_initializer = init_ops.constant_initializer(0.) + bias_initializer = init_ops.constant_initializer(0.) + inputs = random_ops.random_uniform([ + num_layers * dir_count, batch_size, num_units], dtype=dtypes.float32) + + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + outputs, _ = lstm(inputs) + loss = math_ops.reduce_sum(outputs) + optimizer = self._GetOptimizer(opt) + train_op = optimizer.minimize(loss) + + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(train_op) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") def testSaveableGraphDeviceAssignment(self): num_layers = 4 num_units = 2 diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 8ecc003348d70379ee48d050e63e93d0dd38efaa..0458199ff771bc45603106411550a39448e515b8 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -27,13 +27,13 @@ py_library( ) tf_custom_op_library( - name = "_prefetching_ops.so", - srcs = ["ops/prefetching_ops.cc"], - deps = ["//tensorflow/contrib/data/kernels:prefetching_kernels"], + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"], ) tf_gen_op_libs( - op_lib_names = ["prefetching_ops"], + op_lib_names = ["dataset_ops"], ) filegroup( diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index daeb6a610533404044d42033709d644deb481024..fcdccdd26ca1824bf13f1fd0cfd80b20ca8a10c3 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`tf.contrib.data` API for input pipelines. +"""Experimental API for building input pipelines. -This module contains the experimental (less stable) counterpart to the -`tf.data` API. See @{tf.data.Dataset} and @{tf.data.Iterator} for the -stable classes. +This module contains experimental `Dataset` sources and transformations that can +be used in conjunction with the @{tf.data.Dataset} API. Note that the +`tf.contrib.data` API is not subject to the same backwards compatibility +guarantees as `tf.data`, but we will provide deprecation advice in advance of +removing existing functionality. See the @{$datasets$Importing Data} Programmer's Guide for an overview. -@@Dataset @@Counter -@@Iterator -@@TFRecordDataset -@@FixedLengthRecordDataset -@@TextLineDataset @@batch_and_drop_remainder @@dense_to_sparse_batch @@ -58,23 +55,18 @@ from tensorflow.contrib.data.python.ops.batching import map_and_batch from tensorflow.contrib.data.python.ops.batching import padded_batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import unbatch from tensorflow.contrib.data.python.ops.counter import Counter -from tensorflow.contrib.data.python.ops.dataset_ops import Dataset -from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors +from tensorflow.contrib.data.python.ops.get_single_element import get_single_element from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator -from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset from tensorflow.contrib.data.python.ops.readers import read_batch_features from tensorflow.contrib.data.python.ops.readers import SqlDataset -from tensorflow.contrib.data.python.ops.readers import TextLineDataset -from tensorflow.contrib.data.python.ops.readers import TFRecordDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat -from tensorflow.python.data.ops.iterator_ops import Iterator # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 4cb53741ebf8cd0db41b382c878bd2ccd1dcf7f1..56471911c5c0d1c1825955c67997b5bbc0786463 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -17,6 +17,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "ignore_errors_dataset_op", + srcs = ["ignore_errors_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +cc_library( + name = "dataset_kernels", + deps = [ + ":ignore_errors_dataset_op", + ":prefetching_kernels", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/core/kernels/data/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc similarity index 98% rename from tensorflow/core/kernels/data/ignore_errors_dataset_op.cc rename to tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index 99df699d719b896df37515fc4147cd48db52a113..bb29df60e8f114aaa50f578c43e73874f72ab0a3 100644 --- a/tensorflow/core/kernels/data/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { diff --git a/tensorflow/contrib/data/ops/prefetching_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc similarity index 86% rename from tensorflow/contrib/data/ops/prefetching_ops.cc rename to tensorflow/contrib/data/ops/dataset_ops.cc index 23cb62b6f0dbfed15667dd00ae0039b33aa944d4..289ffa1d9c29092cdf434e86ed5553ff9644d43e 100644 --- a/tensorflow/contrib/data/ops/prefetching_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -17,6 +17,16 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("IgnoreErrorsDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the elements of `input_dataset` ignoring errors. +)doc"); + REGISTER_OP("FunctionBufferingResource") .Input("string_arg: string") .Input("target_device: string") diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 1cf0202fd88951ffcc611af39fa0915110c4d819..e51d57cc896dc32d8e11912cd89f34a04a858c78 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -52,24 +52,6 @@ py_test( ], ) -py_test( - name = "cache_dataset_op_test", - size = "small", - srcs = ["cache_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", - ], -) - py_test( name = "concatenate_dataset_op_test", size = "small", @@ -126,6 +108,7 @@ py_library( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", @@ -207,73 +190,18 @@ py_test( ) tf_py_test( - name = "iterator_ops_cluster_test", - size = "small", - srcs = ["iterator_ops_cluster_test.py"], - additional_deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:session", - "//tensorflow/python/data/ops:iterator_ops", - ], - grpc_enabled = True, - tags = [ - "no_windows", - "oss_serial", - ], -) - -tf_py_test( - name = "iterator_ops_test", + name = "get_single_element_test", size = "small", - srcs = ["iterator_ops_test.py"], + srcs = ["get_single_element_test.py"], additional_deps = [ "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:io_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:iterator_ops", - ], - grpc_enabled = True, -) - -py_test( - name = "list_files_dataset_op_test", - size = "small", - srcs = ["list_files_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:util", ], ) @@ -425,14 +353,17 @@ py_test( ) py_test( - name = "shard_dataset_op_test", + name = "serialization_integration_test", size = "small", - srcs = ["shard_dataset_op_test.py"], + srcs = ["serialization_integration_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -534,7 +465,7 @@ py_test( "no_oss", # b/68785503 ], deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_py", + "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 015f69c5673f185c53e61a5df2636333699ae203..71dc1c1172c9d515d4c85f85257c952135098329 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -23,283 +23,24 @@ import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -from tensorflow.python.util import compat class BatchDatasetTest(test.TestCase): - def testBatchDataset(self): - """Test an dataset that maps a TF function across its input elements.""" - # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> - # RepeatDataset(count) -> BatchDataset(batch_size). - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - - count = array_ops.placeholder(dtypes.int64, shape=[]) - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count).batch(batch_size).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([[None] + list(c.shape[1:]) for c in components], - [t.shape.as_list() for t in get_next]) - - with self.test_session() as sess: - # Batch of a finite input, where the batch_size divides the - # total number of elements. - sess.run(init_op, feed_dict={count: 28, batch_size: 14}) - num_batches = (28 * 7) // 14 - for i in range(num_batches): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(14): - self.assertAllEqual(component[(i * 14 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Batch of a finite input, where the batch_size does not - # divide the total number of elements. - sess.run(init_op, feed_dict={count: 14, batch_size: 8}) - - # We expect (num_batches - 1) full-sized batches. - num_batches = int(math.ceil((14 * 7) / 8)) - for i in range(num_batches - 1): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(8): - self.assertAllEqual(component[(i * 8 + j) % 7]**2, - result_component[j]) - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range((14 * 7) % 8): - self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Batch of an empty input should fail straight away. - sess.run(init_op, feed_dict={count: 0, batch_size: 8}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Empty batch should be an initialization time error. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) self.assertAllEqual(a.values, b.values) self.assertAllEqual(a.dense_shape, b.dense_shape) - def testBatchSparse(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=[[0]], values=(i * [1]), dense_shape=[1]) - - iterator = dataset_ops.Dataset.range(10).map(_sparse).batch( - 5).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(2): - actual = sess.run(get_next) - expected = sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], - values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], - dense_shape=[5, 1]) - self.assertTrue(sparse_tensor.is_sparse(actual)) - self.assertSparseValuesEqual(actual, expected) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testNestedBatchSparse(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=[[0]], values=(i * [1]), dense_shape=[1]) - - iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(5).batch( - 2).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - actual = sess.run(get_next) - expected = sparse_tensor.SparseTensorValue( - indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], - [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]], - values=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], - dense_shape=[2, 5, 1]) - self.assertTrue(sparse_tensor.is_sparse(actual)) - self.assertSparseValuesEqual(actual, expected) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testPaddedBatchDataset(self): - seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) - padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(seq_lens) - .map(lambda x: array_ops.fill([x], x)).padded_batch( - 4, padded_shapes=padded_shape).make_initializable_iterator()) - - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - # Test with random sequence lengths, and max padding. - random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run( - init_op, feed_dict={ - padded_shape: [-1], - seq_lens: random_seq_lens - }) - for i in range(8): - result = sess.run(get_next) - padded_len = np.max(result) - self.assertEqual((4, padded_len), result.shape) - for j in range(4): - seq_len = random_seq_lens[(i * 4) + j] - self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) - self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test with random sequence lengths, and constant padding. - sess.run( - init_op, feed_dict={ - padded_shape: [25], - seq_lens: random_seq_lens - }) - for i in range(8): - result = sess.run(get_next) - self.assertEqual((4, 25), result.shape) - for j in range(4): - seq_len = random_seq_lens[(i * 4) + j] - self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) - self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test correct handling of empty tensors. - sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]}) - result = sess.run(get_next) - self.assertAllEqual([[], [], [], []], result) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test error handling with constant sequence lengths, and - # too-short padding. - sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]}) - with self.assertRaises(errors.DataLossError): - result = sess.run(get_next) - - def testPaddedBatchDatasetNonDefaultPadding(self): - seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) - padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) - - def fill_tuple(x): - filled = array_ops.fill([x], x) - return (filled, string_ops.as_string(filled)) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) - .padded_batch( - 4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, "")).make_initializable_iterator()) - - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - # Test with random sequence lengths, and max padding. - random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run( - init_op, feed_dict={ - padded_shape: [-1], - seq_lens: random_seq_lens - }) - for i in range(8): - result = sess.run(get_next) - padded_len = np.max(result[0]) - self.assertEqual((4, padded_len), result[0].shape) - self.assertEqual((4, padded_len), result[1].shape) - for j in range(4): - seq_len = random_seq_lens[(i * 4) + j] - self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len) - self.assertAllEqual(result[0][j, seq_len:], - [-1] * (padded_len - seq_len)) - self.assertAllEqual(result[1][j, :seq_len], - [compat.as_bytes(str(seq_len))] * seq_len) - self.assertAllEqual(result[1][j, seq_len:], - [b""] * (padded_len - seq_len)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testPaddedBatchDatasetShapeSpecifications(self): - int_placeholder = array_ops.placeholder(dtypes.int32) - float_placeholder = array_ops.placeholder(dtypes.float32) - string_placeholder = array_ops.placeholder(dtypes.string) - input_dataset = dataset_ops.Dataset.from_tensors( - (int_placeholder, float_placeholder, string_placeholder)) - - # Test different ways of specifying the `padded_shapes` argument. - dynamic_padding_from_tensor_shapes = input_dataset.padded_batch( - 32, - padded_shapes=(tensor_shape.TensorShape([None]), - tensor_shape.TensorShape([None, None]), - tensor_shape.TensorShape([37]))) - dynamic_padding_from_lists = input_dataset.padded_batch( - 32, padded_shapes=([None], [None, None], [37])) - dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch( - 32, padded_shapes=([-1], [-1, -1], [37])) - dynamic_padding_from_tensors = input_dataset.padded_batch( - 32, - padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64), - constant_op.constant([-1, -1], dtype=dtypes.int64), - constant_op.constant([37], dtype=dtypes.int64))) - - for dataset in [ - dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists, - dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors - ]: - self.assertEqual([None, None], dataset.output_shapes[0].as_list()) - self.assertEqual([None, None, None], dataset.output_shapes[1].as_list()) - self.assertEqual([None, 37], dataset.output_shapes[2].as_list()) - - def testPaddedBatchSparseError(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i - - with self.assertRaises(TypeError): - _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) - def testDenseToSparseBatchDataset(self): components = np.random.randint(12, size=(100,)).astype(np.int32) iterator = ( @@ -744,6 +485,23 @@ class BatchDatasetSerializationTest( lambda: self._build_dataset_dense_to_sparse(diff_comp), num_outputs) + def _sparse(self, i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + def _build_dataset_sparse(self, batch_size=5): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) + + def testSparseCore(self): + self.run_core_tests(self._build_dataset_sparse, + lambda: self._build_dataset_sparse(2), 2) + + def _build_dataset_nested_sparse(self): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) + + def testNestedSparseCore(self): + self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 4d984bb4d76e52c4200ae471550dcf48668c5f89..f1b494e1a620992365ed75613b508e32f94b40a4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -41,8 +41,7 @@ class GroupByWindowTest(test.TestCase): dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x) .apply( grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), - 4)) - .make_initializable_iterator()) + 4)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -53,7 +52,8 @@ class GroupByWindowTest(test.TestCase): while True: result = sess.run(get_next) self.assertTrue( - all(x % 2 == 0 for x in result) or all(x % 2 == 1) + all(x % 2 == 0 + for x in result) or all(x % 2 == 1) for x in result) counts.append(result.shape[0]) @@ -116,8 +116,8 @@ class GroupByWindowTest(test.TestCase): iterator = ( dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply( - grouping.group_by_window(lambda x, _: x % 2, reduce_func, 32)) - .make_initializable_iterator()) + grouping.group_by_window(lambda x, _: x % 2, reduce_func, + 32)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -136,7 +136,8 @@ class GroupByWindowTest(test.TestCase): window.padded_batch( 4, padded_shapes=tensor_shape.TensorShape([None])), window.padded_batch( - 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),)) + 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])), + )) iterator = ( dataset_ops.Dataset.from_tensor_slices(components) @@ -200,9 +201,10 @@ class BucketTest(test.TestCase): # dynamically and does not rely on static shape information about # the arguments. return dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensors(bucket), window.padded_batch( - 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape([None]), - tensor_shape.TensorShape([3]))))) + (dataset_ops.Dataset.from_tensors(bucket), + window.padded_batch( + 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape( + [None]), tensor_shape.TensorShape([3]))))) def testSingleBucket(self): @@ -307,12 +309,13 @@ class BucketTest(test.TestCase): def _dynamic_pad_fn(bucket, window, _): return dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensors(bucket), window.padded_batch( - 32, { - "x": tensor_shape.TensorShape([]), - "y": tensor_shape.TensorShape([None]), - "z": tensor_shape.TensorShape([3]) - }))) + (dataset_ops.Dataset.from_tensors(bucket), + window.padded_batch( + 32, { + "x": tensor_shape.TensorShape([]), + "y": tensor_shape.TensorShape([None]), + "z": tensor_shape.TensorShape([3]) + }))) input_dataset = ( dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn) diff --git a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py deleted file mode 100644 index 9818020680afb9d0f0197d272ec5339c6358db36..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from os import path -import shutil -import tempfile - -import numpy as np - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class FilesystemCacheDatasetTest(test.TestCase): - - def setUp(self): - self.tmp_dir = tempfile.mkdtemp() - self.cache_prefix = path.join(self.tmp_dir, "cache") - - def tearDown(self): - if self.tmp_dir: - shutil.rmtree(self.tmp_dir, ignore_errors=True) - - def testCacheDatasetPassthrough(self): - components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), - np.array([9.0, 10.0, 11.0, 12.0])) - count_placeholder = array_ops.placeholder_with_default( - constant_op.constant(5, dtypes.int64), shape=[]) - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - - repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .repeat(count_placeholder)) - - cache_dataset = repeat_dataset.cache(filename_placeholder) - - self.assertEqual( - tuple([c.shape[1:] for c in components]), cache_dataset.output_shapes) - - # Create initialization ops for iterators without and with - # caching, respectively. - iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types, - cache_dataset.output_shapes) - init_fifo_op = iterator.make_initializer(repeat_dataset) - init_cache_op = iterator.make_initializer(cache_dataset) - - get_next = iterator.get_next() - - with self.test_session() as sess: - # First run without caching to collect the "ground truth". - sess.run(init_fifo_op) - elements = [] - for _ in range(20): - elements.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Assert that the cached dataset has the same elements as the - # "ground truth". - sess.run( - init_cache_op, feed_dict={filename_placeholder: self.cache_prefix}) - cached_elements = [] - for _ in range(20): - cached_elements.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertAllEqual(elements, cached_elements) - - # Re-initialize with an empty upstream (to throw errors.OutOfRangeError - # if we didn't use the cache). - sess.run( - init_cache_op, - feed_dict={ - count_placeholder: 0, - filename_placeholder: self.cache_prefix - }) - replayed_elements = [] - for _ in range(20): - replayed_elements.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(cached_elements, replayed_elements) - - # Re-initialize with an empty upstream and a missing cache file (should - # throw errors.OutOfRangeError immediately). - sess.run( - init_cache_op, - feed_dict={ - count_placeholder: 0, - filename_placeholder: self.cache_prefix + "nonsense" - }) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testConcurrentWriters(self): - components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), - np.array([9.0, 10.0, 11.0, 12.0])) - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - - cache_dataset1 = (dataset_ops.Dataset.from_tensor_slices(components) - .cache(filename_placeholder)) - cache_dataset2 = (dataset_ops.Dataset.from_tensor_slices(components) - .cache(filename_placeholder)) - - iterator1 = cache_dataset1.make_initializable_iterator() - iterator2 = cache_dataset2.make_initializable_iterator() - init_cache_op1 = iterator1.initializer - init_cache_op2 = iterator2.initializer - - get_next1 = iterator1.get_next() - get_next2 = iterator2.get_next() - - with self.test_session() as sess: - sess.run( - init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) - sess.run(get_next1) # this should succeed - - sess.run( - init_cache_op2, feed_dict={filename_placeholder: self.cache_prefix}) - with self.assertRaises(errors.AlreadyExistsError): - sess.run(get_next2) - - sess.run(get_next1) # this should continue to succeed - - def testConcurrentReaders(self): - components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), - np.array([9.0, 10.0, 11.0, 12.0])) - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - - cache_dataset1 = (dataset_ops.Dataset.from_tensor_slices(components) - .cache(filename_placeholder)) - cache_dataset2 = (dataset_ops.Dataset.from_tensor_slices(components) - .cache(filename_placeholder)) - - iterator1 = cache_dataset1.make_initializable_iterator() - iterator2 = cache_dataset2.make_initializable_iterator() - init_cache_op1 = iterator1.initializer - init_cache_op2 = iterator2.initializer - - get_next1 = iterator1.get_next() - get_next2 = iterator2.get_next() - - with self.test_session() as sess: - sess.run( - init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) - elements = [] - for _ in range(4): - elements.append(sess.run(get_next1)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next1) - - # Re-initialize - sess.run( - init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) - sess.run( - init_cache_op2, feed_dict={filename_placeholder: self.cache_prefix}) - - # Reading concurrently should succeed. - elements_itr1 = [] - elements_itr2 = [] - elements_itr2.append(sess.run(get_next2)) - elements_itr1.append(sess.run(get_next1)) - elements_itr2.append(sess.run(get_next2)) - elements_itr1.append(sess.run(get_next1)) - # Intentionally reversing the order - elements_itr1.append(sess.run(get_next1)) - elements_itr2.append(sess.run(get_next2)) - elements_itr1.append(sess.run(get_next1)) - elements_itr2.append(sess.run(get_next2)) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next2) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next1) - - self.assertAllEqual(elements, elements_itr1) - self.assertAllEqual(elements, elements_itr2) - - -class MemoryCacheDatasetTest(test.TestCase): - - def testCacheDatasetPassthrough(self): - repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64)) - dataset = dataset_ops.Dataset.range(3).flat_map( - lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count)) - - cached_dataset = dataset.cache().repeat(2) - uncached_dataset = dataset.repeat(2) - - # Needs to be initializable to capture the variable. - cached_iterator = cached_dataset.make_initializable_iterator() - cached_next = cached_iterator.get_next() - uncached_iterator = uncached_dataset.make_initializable_iterator() - uncached_next = uncached_iterator.get_next() - - with self.test_session() as sess: - - sess.run(repeat_count.initializer) - sess.run(cached_iterator.initializer) - sess.run(uncached_iterator.initializer) - - for i in range(3): - for _ in range(10): - self.assertEqual(sess.run(cached_next), i) - self.assertEqual(sess.run(uncached_next), i) - - sess.run(repeat_count.assign(0)) - - # The uncached iterator should now be empty. - with self.assertRaises(errors.OutOfRangeError): - sess.run(uncached_next) - - # The cached iterator replays from cache. - for i in range(3): - for _ in range(10): - self.assertEqual(sess.run(cached_next), i) - - # The cached iterator should now be empty. - with self.assertRaises(errors.OutOfRangeError): - sess.run(cached_next) - - def testEmptyCacheReading(self): - components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), - np.array([9.0, 10.0, 11.0, 12.0])) - count_placeholder = array_ops.placeholder_with_default( - constant_op.constant(5, dtypes.int64), shape=[]) - - repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .repeat(count_placeholder)) - - cache_dataset = repeat_dataset.cache() - - # Create initialization ops for iterators without and with - # caching, respectively. - iterator = cache_dataset.make_initializable_iterator() - init_cache_op = iterator.initializer - - get_next = iterator.get_next() - - with self.test_session() as sess: - # Initialize with an empty upstream and a missing cache file (should - # throw errors.OutOfRangeError immediately). - sess.run(init_cache_op, feed_dict={count_placeholder: 0}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testConcurrentReaders(self): - count_placeholder = array_ops.placeholder_with_default( - constant_op.constant(5, dtypes.int64), shape=[]) - dataset = dataset_ops.Dataset.range(count_placeholder).cache() - d1 = dataset.map(lambda x: x + 1) - d2 = dataset.map(lambda x: x + 6) - - i1 = d1.make_initializable_iterator() - i2 = d2.make_initializable_iterator() - - with self.test_session() as sess: - sess.run(i1.initializer) - - self.assertEqual(1, sess.run(i1.get_next())) - self.assertEqual(2, sess.run(i1.get_next())) - self.assertEqual(3, sess.run(i1.get_next())) - - sess.run(i2.initializer, feed_dict={count_placeholder: 3}) - - self.assertEqual(6, sess.run(i2.get_next())) - self.assertEqual(7, sess.run(i2.get_next())) - self.assertEqual(4, sess.run(i1.get_next())) # interleave execution - self.assertEqual([8, 5], sess.run([i2.get_next(), i1.get_next()])) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(i1.get_next()) - with self.assertRaises(errors.OutOfRangeError): - sess.run(i2.get_next()) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py index 063c71063601002af8168c4facf4057433061ab7..17f2980157ddd0350dafd1d745cbb9b64e65f7c5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py @@ -20,117 +20,10 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.framework import errors -from tensorflow.python.framework import tensor_shape +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -class ConcatenateDatasetTest(test.TestCase): - - def testConcatenateDataset(self): - input_components = ( - np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 15), - np.array([37.0, 38.0, 39.0, 40.0])) - to_concatenate_components = ( - np.tile(np.array([[1], [2], [3], [4], [5]]), 20), - np.tile(np.array([[12], [13], [14], [15], [16]]), 15), - np.array([37.0, 38.0, 39.0, 40.0, 41.0])) - - input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) - dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( - to_concatenate_components) - concatenated = input_dataset.concatenate(dataset_to_concatenate) - self.assertEqual(concatenated.output_shapes, (tensor_shape.TensorShape( - [20]), tensor_shape.TensorShape([15]), tensor_shape.TensorShape([]))) - - iterator = concatenated.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(9): - result = sess.run(get_next) - if i < 4: - for component, result_component in zip(input_components, result): - self.assertAllEqual(component[i], result_component) - else: - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testConcatenateDatasetDifferentShape(self): - input_components = ( - np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 4)) - to_concatenate_components = ( - np.tile(np.array([[1], [2], [3], [4], [5]]), 20), - np.tile(np.array([[12], [13], [14], [15], [16]]), 15)) - - input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) - dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( - to_concatenate_components) - concatenated = input_dataset.concatenate(dataset_to_concatenate) - self.assertEqual( - [ts.as_list() - for ts in nest.flatten(concatenated.output_shapes)], [[20], [None]]) - - iterator = concatenated.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(9): - result = sess.run(get_next) - if i < 4: - for component, result_component in zip(input_components, result): - self.assertAllEqual(component[i], result_component) - else: - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testConcatenateDatasetDifferentStructure(self): - input_components = ( - np.tile(np.array([[1], [2], [3], [4]]), 5), - np.tile(np.array([[12], [13], [14], [15]]), 4)) - to_concatenate_components = ( - np.tile(np.array([[1], [2], [3], [4], [5]]), 20), - np.tile(np.array([[12], [13], [14], [15], [16]]), 15), - np.array([37.0, 38.0, 39.0, 40.0, 41.0])) - - input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) - dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( - to_concatenate_components) - - with self.assertRaisesRegexp(ValueError, - "don't have the same number of elements"): - input_dataset.concatenate(dataset_to_concatenate) - - def testConcatenateDatasetDifferentType(self): - input_components = ( - np.tile(np.array([[1], [2], [3], [4]]), 5), - np.tile(np.array([[12], [13], [14], [15]]), 4)) - to_concatenate_components = ( - np.tile(np.array([[1.0], [2.0], [3.0], [4.0]]), 5), - np.tile(np.array([[12], [13], [14], [15]]), 15)) - - input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) - dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( - to_concatenate_components) - - with self.assertRaisesRegexp(TypeError, "have different types"): - input_dataset.concatenate(dataset_to_concatenate) - - class ConcatenateDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index a90ba30e60cef13156719bba24fb553c0acec391..a842502cc6fe3605dde0be5f50cf46e3e37d7ed4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,713 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test class DatasetConstructorTest(test.TestCase): - def testFromTensors(self): - """Test an dataset that represents a single tuple of tensors.""" - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - - iterator = (dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - sess.run(init_op) - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - - def testFromTensorsSparse(self): - """Test an dataset that represents a single tuple of tensors.""" - components = (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 1]]), - values=np.array([-1, 1]), - dense_shape=np.array([2, 2]))) - - iterator = ( - dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual( - [tensor_shape.TensorShape(c.dense_shape) for c in components], - [shape for shape in iterator.output_shapes]) - - with self.test_session() as sess: - sess.run(init_op) - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertSparseValuesEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromTensorsMixed(self): - """Test an dataset that represents a single tuple of tensors.""" - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0), - sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 1]]), - values=np.array([-1, 1]), - dense_shape=np.array([2, 2]))) - - iterator = ( - dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([ - tensor_shape.TensorShape(c.dense_shape) - if sparse_tensor.is_sparse(c) else c.shape for c in components - ], [shape for shape in iterator.output_shapes]) - - with self.test_session() as sess: - sess.run(init_op) - results = sess.run(get_next) - for component, result_component in zip(components, results): - if sparse_tensor.is_sparse(component): - self.assertSparseValuesEqual(component, result_component) - else: - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromTensorSlices(self): - """Test an dataset that represents the slices from a tuple of tensors.""" - components = ( - np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( - np.array([[12], [13], [14], [15]]), 22), - np.array([37.0, 38.0, 39.0, 40.0]) - ) - - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - sess.run(init_op) - for i in range(4): - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component[i], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromTensorSlicesSparse(self): - """Test an dataset that represents the slices from a tuple of tensors.""" - components = (sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 0], [2, 0]]), - values=np.array([0, 0, 0]), - dense_shape=np.array([3, 1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 1], [2, 2]]), - values=np.array([1, 2, 3]), - dense_shape=np.array([3, 3]))) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual( - [tensor_shape.TensorShape(c.dense_shape[1:]) for c in components], - [shape for shape in iterator.output_shapes]) - - with self.test_session() as sess: - sess.run(init_op) - expected = [ - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([1]), - dense_shape=np.array([3]))), - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[1]]), - values=np.array([2]), - dense_shape=np.array([3]))), - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[2]]), - values=np.array([3]), - dense_shape=np.array([3]))), - ] - for i in range(3): - results = sess.run(get_next) - for component, result_component in zip(expected[i], results): - self.assertSparseValuesEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromTensorSlicesMixed(self): - """Test an dataset that represents the slices from a tuple of tensors.""" - components = (np.tile(np.array([[1], [2], [3]]), 20), - np.tile(np.array([[12], [13], [14]]), 22), - np.array([37.0, 38.0, 39.0]), - sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 0], [2, 0]]), - values=np.array([0, 0, 0]), - dense_shape=np.array([3, 1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 1], [2, 2]]), - values=np.array([1, 2, 3]), - dense_shape=np.array([3, 3]))) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([ - tensor_shape.TensorShape(c.dense_shape[1:]) - if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components - ], [shape for shape in iterator.output_shapes]) - - with self.test_session() as sess: - sess.run(init_op) - expected = [ - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([1]), - dense_shape=np.array([3]))), - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[1]]), - values=np.array([2]), - dense_shape=np.array([3]))), - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[2]]), - values=np.array([3]), - dense_shape=np.array([3]))), - ] - for i in range(3): - results = sess.run(get_next) - for component, result_component in zip( - (zip(*components[:3])[i] + expected[i]), results): - if sparse_tensor.is_sparse(component): - self.assertSparseValuesEqual(component, result_component) - else: - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromTensorSlicesWithDict(self): - components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual(dtypes.int32, iterator.output_types["foo"]) - self.assertEqual(dtypes.float32, iterator.output_types["bar"]) - self.assertEqual((), iterator.output_shapes["foo"]) - self.assertEqual((1,), iterator.output_shapes["bar"]) - - with self.test_session() as sess: - sess.run(init_op) - for i in range(3): - results = sess.run(get_next) - self.assertEqual(components["foo"][i], results["foo"]) - self.assertEqual(components["bar"][i], results["bar"]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromSparseTensorSlices(self): - """Test a dataset based on slices of a `tf.SparseTensor`.""" - st = array_ops.sparse_placeholder(dtypes.float64) - iterator = (dataset_ops.Dataset.from_sparse_tensor_slices(st) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) - - with self.test_session() as sess: - slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] - - # Test with sparse tensor in the appropriate order. - indices = np.array( - [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))]) - values = np.array([val for s in slices for val in s]) - dense_shape = np.array([len(slices), max(len(s) for s in slices) + 1]) - sparse_feed = sparse_tensor.SparseTensorValue(indices, values, - dense_shape) - sess.run(init_op, feed_dict={st: sparse_feed}) - for i, s in enumerate(slices): - results = sess.run(get_next) - self.assertAllEqual(s, results.values) - expected_indices = np.array( - [[j] for j in range(len(slices[i]))]).reshape([-1, 1]) - self.assertAllEqual(expected_indices, results.indices) - self.assertAllEqual(dense_shape[1:], results.dense_shape) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test with sparse tensor in the reverse order, which is not - # currently supported. - reverse_order_indices = indices[::-1, :] - reverse_order_values = values[::-1] - sparse_feed = sparse_tensor.SparseTensorValue( - reverse_order_indices, reverse_order_values, dense_shape) - with self.assertRaises(errors.UnimplementedError): - sess.run(init_op, feed_dict={st: sparse_feed}) - - # Test with an empty sparse tensor. - empty_indices = np.empty((0, 4), dtype=np.int64) - empty_values = np.empty((0,), dtype=np.float64) - empty_dense_shape = [0, 4, 37, 9] - sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values, - empty_dense_shape) - sess.run(init_op, feed_dict={st: sparse_feed}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # pylint: disable=g-long-lambda,unnecessary-lambda - def testNestedStructure(self): - components = (np.array([1, 2, 3]), (np.array([4., 5.]), np.array([6., 7.])), - np.array([8, 9, 10])) - - dataset = dataset_ops.Dataset.from_tensors(components) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) - - dataset = dataset.shuffle(10, 10) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) - - dataset = dataset.repeat(-1) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) - - dataset = dataset.filter(lambda x, y, z: True) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) - - dataset = dataset.take(5) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) - - dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1]))) - self.assertEquals(((dtypes.int64, dtypes.int64), - (dtypes.float64, dtypes.float64)), dataset.output_types) - self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes) - - dataset = dataset.flat_map( - lambda x, y: dataset_ops.Dataset.from_tensors(((x[0], x[1]), - (y[0], y[1]))) - ) - self.assertEquals(((dtypes.int64, dtypes.int64), - (dtypes.float64, dtypes.float64)), dataset.output_types) - self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes) - - dataset = dataset.batch(32) - self.assertEquals(((dtypes.int64, dtypes.int64), - (dtypes.float64, dtypes.float64)), dataset.output_types) - self.assertEquals((([None, 3], [None, 3]), ([None, 2], [None, 2])), - nest.pack_sequence_as(dataset.output_shapes, [ - s.as_list() - for s in nest.flatten(dataset.output_shapes) - ])) - - iterator = dataset.make_one_shot_iterator() - (w, x), (y, z) = iterator.get_next() - self.assertEquals(dtypes.int64, w.dtype) - self.assertEquals(dtypes.int64, x.dtype) - self.assertEquals(dtypes.float64, y.dtype) - self.assertEquals(dtypes.float64, z.dtype) - self.assertEquals([None, 3], w.shape.as_list()) - self.assertEquals([None, 3], x.shape.as_list()) - self.assertEquals([None, 2], y.shape.as_list()) - self.assertEquals([None, 2], z.shape.as_list()) - - iterator = dataset.make_initializable_iterator() - (w, x), (y, z) = iterator.get_next() - self.assertEquals(dtypes.int64, w.dtype) - self.assertEquals(dtypes.int64, x.dtype) - self.assertEquals(dtypes.float64, y.dtype) - self.assertEquals(dtypes.float64, z.dtype) - self.assertEquals([None, 3], w.shape.as_list()) - self.assertEquals([None, 3], x.shape.as_list()) - self.assertEquals([None, 2], y.shape.as_list()) - self.assertEquals([None, 2], z.shape.as_list()) - - # Define a separate set of components with matching leading - # dimension for the from-slices constructor. - components_for_slices = (np.array([1, 2, 3]), (np.array( - [4., 5., 6.]), np.array([7., 8., 9.])), np.array([10, 11, 12])) - - dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([], ([], []), []), dataset.output_shapes) - - def testNestedDict(self): - components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]} - dataset = dataset_ops.Dataset.from_tensors(components) - self.assertEquals(dtypes.int32, dataset.output_types["a"]["aa"]) - self.assertEquals(dtypes.float32, dataset.output_types["a"]["ab"]) - self.assertEquals(dtypes.int32, dataset.output_types["b"]) - self.assertEquals([], dataset.output_shapes["a"]["aa"]) - self.assertEquals([2], dataset.output_shapes["a"]["ab"]) - self.assertEquals([3], dataset.output_shapes["b"]) - - def testNonSequenceNestedStructure(self): - components = np.array([1, 2, 3]) - - dataset = dataset_ops.Dataset.from_tensors(components) - self.assertEquals(dtypes.int64, dataset.output_types) - self.assertEquals([3], dataset.output_shapes) - - dataset = dataset.filter( - lambda x: math_ops.reduce_all(math_ops.equal(x, components))) - self.assertEquals(dtypes.int64, dataset.output_types) - self.assertEquals([3], dataset.output_shapes) - - dataset = dataset.map(lambda x: array_ops.stack([x, x])) - self.assertEquals(dtypes.int64, dataset.output_types) - self.assertEquals([2, 3], dataset.output_shapes) - - dataset = dataset.flat_map( - lambda x: dataset_ops.Dataset.from_tensor_slices(x)) - self.assertEquals(dtypes.int64, dataset.output_types) - self.assertEquals([3], dataset.output_shapes) - - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - self.assertEquals(dtypes.int64, get_next.dtype) - self.assertEquals([3], get_next.shape) - - def _testFromGenerator(self, generator, elem_sequence, num_repeats): - iterator = ( - dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) - .repeat(num_repeats) - .prefetch(5) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - for _ in range(2): # Run twice to test reinitialization. - sess.run(init_op) - for _ in range(num_repeats): - for elem in elem_sequence: - self.assertAllEqual(elem, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats): - iterator = ( - dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) - .repeat(num_repeats) - .prefetch(5) - .make_one_shot_iterator()) - get_next = iterator.get_next() - - with self.test_session() as sess: - for _ in range(num_repeats): - for elem in elem_sequence: - self.assertAllEqual(elem, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromGeneratorUsingFunction(self): - def generator(): - for i in range(1, 100): - yield [i] * i - elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) - self._testFromGeneratorOneShot(generator, elem_sequence, 1) - self._testFromGeneratorOneShot(generator, elem_sequence, 5) - - def testFromGeneratorUsingList(self): - generator = lambda: [[i] * i for i in range(1, 100)] - elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) - - def testFromGeneratorUsingNdarray(self): - generator = lambda: np.arange(100, dtype=np.int64) - elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) - - def testFromGeneratorUsingGeneratorExpression(self): - # NOTE(mrry): Generator *expressions* are not repeatable (or in - # general reusable), because they eagerly evaluate the `for` - # expression as `iter(range(1, 100))` and discard the means of - # reconstructing `range(1, 100)`. Wrapping the generator - # expression in a `lambda` makes it repeatable. - generator = lambda: ([i] * i for i in range(1, 100)) - elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) - - def testFromMultipleConcurrentGenerators(self): - num_inner_repeats = 5 - num_outer_repeats = 100 - - def generator(): - for i in range(1, 10): - yield ([i] * i, [i, i ** 2, i ** 3]) - input_list = list(generator()) - - # The interleave transformation is essentially a flat map that - # draws from multiple input datasets concurrently (in a cyclic - # fashion). By placing `Datsaet.from_generator()` inside an - # interleave, we test its behavior when multiple iterators are - # active at the same time; by additionally prefetching inside the - # interleave, we create the possibility of parallel (modulo GIL) - # invocations to several iterators created by the same dataset. - def interleave_fn(_): - return (dataset_ops.Dataset.from_generator( - generator, output_types=(dtypes.int64, dtypes.int64), - output_shapes=([None], [3])) - .repeat(num_inner_repeats).prefetch(5)) - - iterator = ( - dataset_ops.Dataset.range(num_outer_repeats) - .interleave(interleave_fn, cycle_length=10, - block_length=len(input_list)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(num_inner_repeats * num_outer_repeats): - for elem in input_list: - val0, val1 = sess.run(get_next) - self.assertAllEqual(elem[0], val0) - self.assertAllEqual(elem[1], val1) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromGeneratorsRunningInParallel(self): - num_parallel_iterators = 3 - - # Define shared state that multiple iterator instances will access to - # demonstrate their concurrent activity. - lock = threading.Lock() - condition = threading.Condition(lock) - next_ticket = [0] # GUARDED_BY(lock) - - def generator(): - # NOTE(mrry): We yield one element before the barrier, because - # the current implementation of `Dataset.interleave()` must - # fetch one element from each incoming dataset to start the - # prefetching. - yield 0 - - # Define a barrier that `num_parallel_iterators` iterators must enter - # before any can proceed. Demonstrates that multiple iterators may be - # active at the same time. - condition.acquire() - ticket = next_ticket[0] - next_ticket[0] += 1 - if ticket == num_parallel_iterators - 1: - # The last iterator to join the barrier notifies the others. - condition.notify_all() - else: - # Wait until the last iterator enters the barrier. - while next_ticket[0] < num_parallel_iterators: - condition.wait() - condition.release() - - yield 1 - - # As in `testFromMultipleConcurrentGenerators()`, we use a combination of - # `Dataset.interleave()` and `Dataset.prefetch()` to cause multiple - # iterators to be active concurrently. - def interleave_fn(_): - return dataset_ops.Dataset.from_generator( - generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2) - - iterator = ( - dataset_ops.Dataset.range(num_parallel_iterators) - .interleave( - interleave_fn, cycle_length=num_parallel_iterators, block_length=1) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for elem in [0, 1]: - for _ in range(num_parallel_iterators): - self.assertAllEqual(elem, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromGeneratorImplicitConversion(self): - def generator(): - yield [1] - yield [2] - yield [3] - - for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]: - iterator = (dataset_ops.Dataset.from_generator( - generator, output_types=dtype, output_shapes=[1]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual(dtype, get_next.dtype) - - with self.test_session() as sess: - sess.run(init_op) - for expected in [[1], [2], [3]]: - next_val = sess.run(get_next) - self.assertEqual(dtype.as_numpy_dtype, next_val.dtype) - self.assertAllEqual(expected, next_val) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromGeneratorTypeError(self): - def generator(): - yield np.array([1, 2, 3], dtype=np.int64) - yield np.array([4, 5, 6], dtype=np.int64) - yield "ERROR" - yield np.array([7, 8, 9], dtype=np.int64) - - iterator = (dataset_ops.Dataset.from_generator( - generator, output_types=dtypes.int64, output_shapes=[3]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - self.assertAllEqual([1, 2, 3], sess.run(get_next)) - self.assertAllEqual([4, 5, 6], sess.run(get_next)) - with self.assertRaisesOpError(r"invalid literal for long\(\)"): - sess.run(get_next) - self.assertAllEqual([7, 8, 9], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromGeneratorShapeError(self): - def generator(): - yield np.array([1, 2, 3], dtype=np.int64) - yield np.array([4, 5, 6], dtype=np.int64) - yield np.array([7, 8, 9, 10], dtype=np.int64) - yield np.array([11, 12, 13], dtype=np.int64) - - iterator = (dataset_ops.Dataset.from_generator( - generator, output_types=dtypes.int64, output_shapes=[3]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - self.assertAllEqual([1, 2, 3], sess.run(get_next)) - self.assertAllEqual([4, 5, 6], sess.run(get_next)) - with self.assertRaisesOpError(r"element of shape \(3,\) was expected"): - sess.run(get_next) - self.assertAllEqual([11, 12, 13], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSplitPipelineFailsWithPlacementError(self): - with session.Session( - target="", - config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: - - dataset = dataset_ops.Dataset.from_tensors(0) - - # Define a pipeline that attempts to use variables on two - # different devices. - # - # Initialize the variables before creating to iterator, to avoid the - # placement algorithm overriding the DT_RESOURCE colocation constraints. - with ops.device("/cpu:0"): - var_0 = resource_variable_ops.ResourceVariable(initial_value=0) - dataset = dataset.map(lambda x: x + var_0.read_value()) - sess.run(var_0.initializer) - - with ops.device("/cpu:1"): - var_1 = resource_variable_ops.ResourceVariable(initial_value=0) - dataset = dataset.map(lambda x: x + var_1.read_value()) - sess.run(var_1.initializer) - - iterator = dataset.make_initializable_iterator() - sess.run(iterator.initializer) - - with self.assertRaisesRegexp( - errors.FailedPreconditionError, - "Error while reading resource variable Variable"): - sess.run(iterator.get_next()) - def testRestructureDataset(self): components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder(dtypes.int32, shape=[None]), diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index 7cde6e05b244773966fd7c1bd4ca1e95abf7fd5e..dbc35097ddda9f0375060d43aeb43efa8107f929 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -24,9 +24,11 @@ import numpy as np from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -34,14 +36,29 @@ from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest +def remove_variants(get_next_op): + # TODO(b/72408568): Remove this once session.run can get + # variant tensors. + """Remove variants from a nest structure, so sess.run will execute.""" + + def _remove_variant(x): + if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: + return () + else: + return x + + return nest.map_structure(_remove_variant, get_next_op) + + class DatasetSerializationTestBase(test.TestCase): """Base class for testing serializable datasets.""" def tearDown(self): self._delete_ckpt() - # TODO(b/70988345): Support native `tf.SparseTensor` objects and get rid of - # `sparse_tensors` argument. + # TODO(b/72657739): Remove sparse_tensor argument, which is to test the + # (deprecated) saveable `SparseTensorSliceDataset`, once the API + # `from_sparse_tensor_slices()`and related tests are deleted. def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): """Runs the core tests. @@ -233,10 +250,10 @@ class DatasetSerializationTestBase(test.TestCase): saver = self._import_meta_graph() init_op, get_next_op = self._get_iterator_ops_from_collection( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: self._restore(saver, sess) - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) for _ in range(num_outputs): actual.append(sess.run(get_next_op)) if verify_exhausted: @@ -296,6 +313,7 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default() as g: _, get_next_op, saver = self._build_graph( ds_fn2, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: self._restore(saver, sess) for _ in range(num_outputs - break_point): @@ -356,6 +374,7 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default() as g: get_next_op, saver = self._build_empty_graph( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: self._restore(saver, sess) for _ in range(num_outputs - break_point): @@ -389,9 +408,9 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default() as g: init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) for _ in range(break_point): sess.run(get_next_op) with self.assertRaises(error): @@ -485,20 +504,20 @@ class DatasetSerializationTestBase(test.TestCase): else: init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) return init_op, get_next_op, saver for i in range(len(break_points) + 1): with ops.Graph().as_default() as g: init_op, get_next_op, saver = get_ops() + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: if ckpt_saved: if init_before_restore: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) self._restore(saver, sess) else: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) start = break_points[i - 1] if i > 0 else 0 end = break_points[i] if i < len(break_points) else num_outputs num_iters = end - start @@ -562,13 +581,16 @@ class DatasetSerializationTestBase(test.TestCase): get_next = sparse_tensor.SparseTensor(*iterator.get_next()) else: get_next = iterator.get_next() - self._add_iterator_ops_to_collection(init_op, get_next, sparse_tensors) + self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, + sparse_tensors) saver = saver_lib.Saver(allow_empty=True) return init_op, get_next, saver def _build_empty_graph(self, ds_fn, sparse_tensors=False): iterator = iterator_ops.Iterator.from_structure( - self._get_output_types(ds_fn), self._get_output_shapes(ds_fn)) + self._get_output_types(ds_fn), + output_shapes=self._get_output_shapes(ds_fn), + output_classes=self._get_output_classes(ds_fn)) saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) if sparse_tensors: @@ -581,12 +603,19 @@ class DatasetSerializationTestBase(test.TestCase): def _add_iterator_ops_to_collection(self, init_op, get_next, + ds_fn, sparse_tensors=False): ops.add_to_collection("iterator_ops", init_op) # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections # do not support tuples we flatten the tensors and restore the shape in # `_get_iterator_ops_from_collection`. - if sparse_tensors: + + # TODO(shivaniagrwal): `output_classes` is a nested structure of classes, + # this base class is specific to current test cases. Update when tests are + # added with `output_classes` as a nested structure with at least one of the + # component being `tf.SparseTensor`. + if (sparse_tensors or + self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): ops.add_to_collection("iterator_ops", get_next.indices) ops.add_to_collection("iterator_ops", get_next.values) ops.add_to_collection("iterator_ops", get_next.dense_shape) @@ -596,7 +625,8 @@ class DatasetSerializationTestBase(test.TestCase): def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): all_ops = ops.get_collection("iterator_ops") - if sparse_tensors: + if (sparse_tensors or + self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): init_op, indices, values, dense_shape = all_ops return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) else: @@ -611,6 +641,10 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default(): return ds_fn().output_shapes + def _get_output_classes(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_classes + def _ckpt_path(self): return os.path.join(self.get_temp_dir(), "iterator") @@ -621,8 +655,14 @@ class DatasetSerializationTestBase(test.TestCase): saver.save(sess, self._ckpt_path()) def _restore(self, saver, sess): + sess.run(lookup_ops.tables_initializer()) saver.restore(sess, self._latest_ckpt()) + def _initialize(self, init_op, sess): + sess.run(variables.global_variables_initializer()) + sess.run(lookup_ops.tables_initializer()) + sess.run(init_op) + def _import_meta_graph(self): meta_file_path = self._ckpt_path() + ".meta" return saver_lib.import_meta_graph(meta_file_path) diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index 5921be2ae89ba1bbbb8d6e3a509cf49c65949544..b572d6ed770fc0fe0f852359baf343c55966eddd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -20,144 +20,12 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class FilterDatasetTest(test.TestCase): - - def testFilterDataset(self): - components = ( - np.arange(7, dtype=np.int64), - np.array([[1, 2, 3]], dtype=np.int64) * np.arange( - 7, dtype=np.int64)[:, np.newaxis], - np.array(37.0, dtype=np.float64) * np.arange(7) - ) - count = array_ops.placeholder(dtypes.int64, shape=[]) - modulus = array_ops.placeholder(dtypes.int64) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count) - .filter(lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - # Test that we can dynamically feed a different modulus value for each - # iterator. - def do_test(count_val, modulus_val): - sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val}) - for _ in range(count_val): - for i in [x for x in range(7) if x**2 % modulus_val == 0]: - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - do_test(14, 2) - do_test(4, 18) - - # Test an empty dataset. - do_test(0, 1) - - def testFilterRange(self): - dataset = dataset_ops.Dataset.range(100).filter( - lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2)) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - self.assertEqual(0, sess.run(get_next)) - self.assertEqual(1, sess.run(get_next)) - self.assertEqual(3, sess.run(get_next)) - - def testFilterDict(self): - iterator = (dataset_ops.Dataset.range(10) - .map(lambda x: {"foo": x * 2, "bar": x ** 2}) - .filter(lambda d: math_ops.equal(d["bar"] % 2, 0)) - .map(lambda d: d["foo"] + d["bar"]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - if (i ** 2) % 2 == 0: - self.assertEqual(i * 2 + i ** 2, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testUseStepContainerInFilter(self): - input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) - - # Define a predicate that returns true for the first element of - # the sequence and not the second, and uses `tf.map_fn()`. - def _predicate(xs): - squared_xs = functional_ops.map_fn(lambda x: x * x, xs) - summed = math_ops.reduce_sum(squared_xs) - return math_ops.equal(summed, 1 + 4 + 9) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6]]) - .filter(_predicate) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - self.assertAllEqual(input_data[0], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - - def testSparse(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1])), - dense_shape=np.array([1, 1])), i - - def _filter_fn(_, i): - return math_ops.equal(i % 2, 0) - - iterator = ( - dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( - lambda x, i: x).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(5): - actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) - self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - class FilterDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): @@ -194,6 +62,10 @@ class FilterDatasetSerializationTest( return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( lambda x, i: x) + def testSparseCore(self): + num_outputs = 5 + self.run_core_tests(self._build_sparse_filter, None, num_outputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py index d4fbaa5cdcdd315aa0524134b48eb0515169722c..f3feecef32e587045be25056815315136a883ca7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py @@ -17,13 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import random - -import numpy as np - from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -34,124 +29,6 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -from tensorflow.python.training import server_lib - - -class FlatMapDatasetTest(test.TestCase): - - # pylint: disable=g-long-lambda - def testFlatMapDataset(self): - repeats = [1, 2, 3, 4, 5, 0, 1] - components = np.array(repeats, dtype=np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .flat_map(lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in repeats: - for _ in range(i): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testNestedFlatMapDataset(self): - repeats = [[1, 2], [3, 4], [5, 0], [1, 7]] - components = np.array(repeats, dtype=np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x) - .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y) - .repeat(y))).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for row in repeats: - for i in row: - for _ in range(i): - self.assertEqual(i, sess.run(get_next)) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSharedResourceNestedFlatMapDataset(self): - repeats = [[1, 2], [3, 4], [5, 0], [1, 7]] - components = np.array(repeats, dtype=np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x) - .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y) - .repeat(y))).make_initializable_iterator( - shared_name="shared_flat_map_iterator")) - init_op = iterator.initializer - get_next = iterator.get_next() - - # Create two concurrent sessions that share the same iterator - # resource on the same server, and verify that a random - # interleaving of `Session.run(get_next)` calls on the two - # sessions yields the expected result. - server = server_lib.Server.create_local_server() - with session.Session(server.target) as sess1: - with session.Session(server.target) as sess2: - for _ in range(3): - sess = random.choice([sess1, sess2]) - sess.run(init_op) - for row in repeats: - for i in row: - for _ in range(i): - sess = random.choice([sess1, sess2]) - self.assertEqual(i, sess.run(get_next)) - - with self.assertRaises(errors.OutOfRangeError): - sess = random.choice([sess1, sess2]) - sess.run(get_next) - - def testMapDict(self): - iterator = (dataset_ops.Dataset.range(10) - .map(lambda x: {"foo": x * 2, "bar": x ** 2}) - .flat_map(lambda d: dataset_ops.Dataset.from_tensors(d["foo"]) - .repeat(d["bar"])) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - for _ in range(i ** 2): - self.assertEqual(i * 2, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - # pylint: enable=g-long-lambda - - def testSparse(self): - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _flat_map_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - iterator = ( - dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - for j in range(2): - expected = [i, 0] if j % 2 == 0 else [0, -i] - self.assertAllEqual(expected, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) class FlatMapDatasetSerializationTest( @@ -225,6 +102,21 @@ class FlatMapDatasetSerializationTest( self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError) + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _flat_map_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_ds(): + return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) + + self.run_core_tests(_build_ds, None, 20) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py new file mode 100644 index 0000000000000000000000000000000000000000..32ea44f7c7ba329dc253bb9fbbcac0a1ed16aec7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import get_single_element +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class GetSingleElementTest(test.TestCase): + + def testGetSingleElement(self): + skip_value = array_ops.placeholder(dtypes.int64, shape=[]) + take_value = array_ops.placeholder_with_default( + constant_op.constant(1, dtype=dtypes.int64), shape=[]) + + dataset = (dataset_ops.Dataset.range(100) + .skip(skip_value) + .map(lambda x: x * x) + .take(take_value)) + + element = get_single_element.get_single_element(dataset) + + with self.test_session() as sess: + self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) + self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) + self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset was empty."): + sess.run(element, feed_dict={skip_value: 100}) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset had more than one element."): + sess.run(element, feed_dict={skip_value: 0, take_value: 2}) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index b1937c08f347734d0d6871bd30ed209ff520623a..256ad8d94dc1a7c2b26df3f1ebf8e8e321882c15 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -26,8 +26,8 @@ import numpy as np from six.moves import zip_longest from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor @@ -38,182 +38,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetTest(test.TestCase): - - def _interleave(self, lists, cycle_length, block_length): - # TODO(b/69678297): Consolidate python interleave implementations. - num_open = 0 - - # `all_iterators` acts as a queue of iterators over each element of `lists`. - all_iterators = [iter(l) for l in lists] - - # `open_iterators` are the iterators whose elements are currently being - # interleaved. - open_iterators = [] - for i in range(cycle_length): - if all_iterators: - open_iterators.append(all_iterators.pop(0)) - num_open += 1 - else: - open_iterators.append(None) - - while num_open or all_iterators: - for i in range(cycle_length): - if open_iterators[i] is None: - if all_iterators: - open_iterators[i] = all_iterators.pop(0) - num_open += 1 - else: - continue - for _ in range(block_length): - try: - yield next(open_iterators[i]) - except StopIteration: - open_iterators[i] = None - num_open -= 1 - break - - def testPythonImplementation(self): - input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], - [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] - - # Cycle length 1 acts like `Dataset.flat_map()`. - expected_elements = itertools.chain(*input_lists) - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 1, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1. - expected_elements = [4, 5, 4, 5, 4, 5, 4, - 5, 5, 6, 6, # NOTE(mrry): When we cycle back - # to a list and are already at - # the end of that list, we move - # on to the next element. - 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1 and block length > 1. - expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, - 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 3)): - self.assertEqual(expected, produced) - - # Cycle length > len(input_values). - expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, - 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 7, 2)): - self.assertEqual(expected, produced) - - def testInterleaveDataset(self): - input_values = array_ops.placeholder(dtypes.int64, shape=[None]) - cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) - block_length = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_count = 2 - - dataset = ( - dataset_ops.Dataset.from_tensor_slices(input_values) - .repeat(repeat_count) - .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - next_element = iterator.get_next() - - with self.test_session() as sess: - # Cycle length 1 acts like `Dataset.flat_map()`. - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 1, block_length: 3}) - - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): - self.assertEqual(expected_element, sess.run(next_element)) - - # Cycle length > 1. - # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, - # 6, 5, 6, 5, 6, 5, 6, 5] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 1}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > 1 and block length > 1. - # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, - # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > len(input_values) * repeat_count. - # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, - # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 7, block_length: 2}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Empty input. - sess.run(init_op, feed_dict={input_values: [], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Non-empty input leading to empty output. - sess.run(init_op, feed_dict={input_values: [0, 0, 0], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Mixture of non-empty and empty interleaved datasets. - sess.run(init_op, feed_dict={input_values: [4, 0, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testSparse(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _interleave_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - iterator = ( - dataset_ops.Dataset.range(10).map(_map_fn).interleave( - _interleave_fn, cycle_length=1).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - for j in range(2): - expected = [i, 0] if j % 2 == 0 else [0, -i] - self.assertAllEqual(expected, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - -class InterleaveDatasetSeriazationTest( +class InterleaveDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_iterator_graph(self, input_values, cycle_length, block_length): @@ -252,6 +77,22 @@ class InterleaveDatasetSeriazationTest( None, num_outputs) # pylint: enable=g-long-lambda + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1) + + self.run_core_tests(_build_dataset, None, 20) + class ParallelInterleaveDatasetTest(test.TestCase): diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py deleted file mode 100644 index 02379d064d4ab857ce9c7d13881a3ae37eea0980..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops that need test_util.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.platform import test - - -class IteratorClusterTest(test.TestCase): - - def testRemoteIteratorWithoutRemoteCallFail(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 - worker, _ = test_util.create_local_cluster( - 1, 1, worker_config=worker_config) - - with ops.device("/job:worker/replica:0/task:0/cpu:1"): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - with ops.device("/job:worker/replica:0/task:0/cpu:0"): - remote_it = iterator_ops.Iterator.from_string_handle( - iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes) - get_next_op = remote_it.get_next() - - with session.Session(worker[0].target) as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next_op) - - def _testRemoteIteratorHelper(self, device0, device1, target): - with ops.device(device1): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - @function.Defun(dtypes.string) - def _remote_fn(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, dataset_3.output_types, dataset_3.output_shapes) - return remote_iterator.get_next() - - with ops.device(device0): - target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - remote_op = functional_ops.remote_call( - args=[iterator_3_handle], - Tout=[dtypes.int32], - f=_remote_fn, - target=target_placeholder) - - with session.Session(target) as sess: - elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) - self.assertEqual(elem, [1]) - # Fails when target is cpu:0 where the resource is not located. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(remote_op, feed_dict={target_placeholder: device0}) - elem = sess.run(iterator_3.get_next()) - self.assertEqual(elem, [2]) - elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) - self.assertEqual(elem, [3]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(remote_op, feed_dict={target_placeholder: device1}) - - def testRemoteIteratorUsingRemoteCallOp(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 - worker, _ = test_util.create_local_cluster( - 1, 1, worker_config=worker_config) - - self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", - "/job:worker/replica:0/task:0/cpu:1", - worker[0].target) - - def testRemoteIteratorUsingRemoteCallOpCrossProcess(self): - workers, _ = test_util.create_local_cluster(2, 1) - - self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", - "/job:worker/replica:0/task:1/cpu:0", - workers[0].target) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py deleted file mode 100644 index bda9a2a4a37e9c3d35ff99041d1150ffc43f4c43..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import numpy as np - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import readers -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import script_ops -from tensorflow.python.platform import test -from tensorflow.python.training import server_lib - - -class IteratorTest(test.TestCase): - - def testAttemptingGradientsRaiseExceptions(self): - component = constant_op.constant([1]) - side = constant_op.constant(0) - add = lambda x: x + side - dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) - value = dataset.make_one_shot_iterator().get_next() - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, component) - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, side) - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, [component, side]) - - def testOneShotIterator(self): - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(14).make_one_shot_iterator()) - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOneShotIteratorCaptureByValue(self): - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - tensor_components = tuple([ops.convert_to_tensor(c) for c in components]) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = (dataset_ops.Dataset.from_tensor_slices(tensor_components) - .map(_map_fn).repeat(14).make_one_shot_iterator()) - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOneShotIteratorInsideContainer(self): - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - - def within_container(): - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(_map_fn).repeat(14).make_one_shot_iterator()) - return iterator.get_next() - - server = server_lib.Server.create_local_server() - - # Create two iterators within unique containers, and run them to - # make sure that the resources aren't shared. - # - # The test below would fail if cname were the same across both - # sessions. - for i in range(2): - with session.Session(server.target) as sess: - cname = "iteration%d" % i - with ops.container(cname): - get_next = within_container() - - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOneShotIteratorNonBlocking(self): - dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - # Create a session with a single thread to ensure that the - # one-shot iterator initializer does not deadlock. - config = config_pb2.ConfigProto(inter_op_parallelism_threads=1, - use_per_session_threads=True) - with session.Session(config=config) as sess: - self.assertAllEqual([1, 4, 9], sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Test with multiple threads invoking the one-shot iterator concurrently. - with session.Session(config=config) as sess: - results = [] - def consumer_thread(): - try: - results.append(sess.run(next_element)) - except errors.OutOfRangeError: - results.append(None) - - num_threads = 8 - threads = [ - self.checkedThread(consumer_thread) for _ in range(num_threads)] - for t in threads: - t.start() - for t in threads: - t.join() - - self.assertEqual(num_threads, len(results)) - self.assertEqual(num_threads - 1, - len([None for r in results if r is None])) - self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) - - def testOneShotIteratorInitializerFails(self): - # Define a dataset whose initialization will always fail. - dataset = dataset_ops.Dataset.from_tensors( - array_ops.check_numerics( - constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): - sess.run(next_element) - - # Test that subsequent attempts to use the iterator also fail. - with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): - sess.run(next_element) - - with self.test_session() as sess: - def consumer_thread(): - with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): - sess.run(next_element) - - num_threads = 8 - threads = [ - self.checkedThread(consumer_thread) for _ in range(num_threads)] - for t in threads: - t.start() - for t in threads: - t.join() - - def testSimpleSharedResource(self): - components = ( - np.array(1, dtype=np.int64), - np.array([1, 2, 3], dtype=np.int64), - np.array(37.0, dtype=np.float64) - ) - - server = server_lib.Server.create_local_server() - - # Create two non-overlapping sessions that share the same iterator - # resource on the same server, and verify that an action of the - # first session (initializing the iterator) is visible in the - # second session. - with ops.Graph().as_default(): - iterator = (dataset_ops.Dataset.from_tensors(components) - .map(lambda x, y, z: (x, y, z)).make_initializable_iterator( - shared_name="shared_iterator")) - init_op = iterator.initializer - get_next = iterator.get_next() - - with session.Session(server.target) as sess: - sess.run(init_op) - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Re-initialize the iterator in the first session. - sess.run(init_op) - - with ops.Graph().as_default(): - # Re-define the iterator manually, without defining any of the - # functions in this graph, to ensure that we are not - # accidentally redefining functions with the same names in the - # new graph. - iterator = iterator_ops.Iterator.from_structure( - shared_name="shared_iterator", - output_types=(dtypes.int64, dtypes.int64, dtypes.float64), - output_shapes=([], [3], [])) - get_next = iterator.get_next() - - with session.Session(server.target) as sess: - # Use the iterator without re-initializing in the second session. - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testNotInitializedError(self): - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - iterator = (dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.FailedPreconditionError, - "iterator has not been initialized"): - sess.run(get_next) - - def testReinitializableIterator(self): - dataset_3 = dataset_ops.Dataset.from_tensors( - constant_op.constant([1, 2, 3])) - dataset_4 = dataset_ops.Dataset.from_tensors( - constant_op.constant([4, 5, 6, 7])) - iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types, - [None]) - - dataset_3_init_op = iterator.make_initializer(dataset_3) - dataset_4_init_op = iterator.make_initializer(dataset_4) - get_next = iterator.get_next() - - self.assertEqual(dataset_3.output_types, iterator.output_types) - self.assertEqual(dataset_4.output_types, iterator.output_types) - self.assertEqual([None], iterator.output_shapes.as_list()) - - with self.test_session() as sess: - # The iterator is initially uninitialized. - with self.assertRaises(errors.FailedPreconditionError): - sess.run(get_next) - - # Initialize with one dataset. - sess.run(dataset_3_init_op) - self.assertAllEqual([1, 2, 3], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Initialize with a different dataset. - sess.run(dataset_4_init_op) - self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Reinitialize with the first dataset. - sess.run(dataset_3_init_op) - self.assertAllEqual([1, 2, 3], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testReinitializableIteratorStaticErrors(self): - # Non-matching structure for types and shapes. - with self.assertRaises(TypeError): - iterator = iterator_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64), [None]) - - # Test validation of dataset argument. - iterator = iterator_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64)) - - # Incompatible structure. - with self.assertRaises(ValueError): - iterator.make_initializer( - dataset_ops.Dataset.from_tensors(((constant_op.constant( - [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float64),)))) - - # Incompatible types. - with self.assertRaises(TypeError): - iterator.make_initializer( - dataset_ops.Dataset.from_tensors((constant_op.constant( - [1, 2, 3], dtype=dtypes.int32), constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float32)))) - - # Incompatible shapes. - iterator = iterator_ops.Iterator.from_structure( - (dtypes.int64, dtypes.float64), ([None], [])) - with self.assertRaises(TypeError): - iterator.make_initializer( - dataset_ops.Dataset.from_tensors((constant_op.constant( - [1, 2, 3], dtype=dtypes.int64), constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float64)))) - - def testIteratorStringHandle(self): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) - - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_4 = dataset_4.make_one_shot_iterator() - - handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - feedable_iterator = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) - next_element = feedable_iterator.get_next() - - self.assertEqual(dataset_3.output_types, feedable_iterator.output_types) - self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) - self.assertEqual([], feedable_iterator.output_shapes) - - with self.test_session() as sess: - iterator_3_handle = sess.run(iterator_3.string_handle()) - iterator_4_handle = sess.run(iterator_4.string_handle()) - - self.assertEqual( - 10, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 1, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 20, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 2, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 30, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 3, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 40, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle}) - - def testIteratorStringHandleError(self): - dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices([1, 2, - 3]).repeat()) - dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])) - - handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - - feedable_int_scalar = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dtypes.int32, []) - feedable_int_vector = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dtypes.int32, [None]) - feedable_int_any = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dtypes.int32) - - with self.test_session() as sess: - handle_int_scalar = sess.run( - dataset_int_scalar.make_one_shot_iterator().string_handle()) - handle_float_vector = sess.run( - dataset_float_vector.make_one_shot_iterator().string_handle()) - - self.assertEqual(1, - sess.run( - feedable_int_scalar.get_next(), - feed_dict={handle_placeholder: handle_int_scalar})) - - self.assertEqual(2, - sess.run( - feedable_int_any.get_next(), - feed_dict={handle_placeholder: handle_int_scalar})) - - with self.assertRaises(errors.InvalidArgumentError): - print(sess.run( - feedable_int_vector.get_next(), - feed_dict={handle_placeholder: handle_int_scalar})) - - with self.assertRaises(errors.InvalidArgumentError): - print(sess.run( - feedable_int_vector.get_next(), - feed_dict={handle_placeholder: handle_float_vector})) - - def testRemoteIteratorUsingRemoteCallOpDirectSession(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 3 - - with ops.device("/job:localhost/replica:0/task:0/cpu:1"): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - @function.Defun(dtypes.string) - def _remote_fn(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, dataset_3.output_types, dataset_3.output_shapes) - return remote_iterator.get_next() - - with ops.device("/job:localhost/replica:0/task:0/cpu:0"): - target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - remote_op = functional_ops.remote_call( - args=[iterator_3_handle], - Tout=[dtypes.int32], - f=_remote_fn, - target=target_placeholder) - - with self.test_session(config=worker_config) as sess: - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - self.assertEqual(elem, [1]) - # Fails when target is cpu:2 where the resource is not located. - with self.assertRaises(errors.InvalidArgumentError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" - }) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - self.assertEqual(elem, [2]) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - self.assertEqual(elem, [3]) - with self.assertRaises(errors.OutOfRangeError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - - def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - with ops.device("/job:localhost/replica:0/task:0/cpu:0"): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - def _encode_raw(byte_array): - return bytes(bytearray(byte_array)) - - @function.Defun(dtypes.uint8) - def _remote_fn(h): - handle = script_ops.py_func(_encode_raw, [h], dtypes.string) - remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, dataset_3.output_types, dataset_3.output_shapes) - return remote_iterator.get_next() - - with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): - target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - iterator_3_handle_uint8 = parsing_ops.decode_raw( - bytes=iterator_3_handle, out_type=dtypes.uint8) - remote_op = functional_ops.remote_call( - args=[iterator_3_handle_uint8], - Tout=[dtypes.int32], - f=_remote_fn, - target=target_placeholder) - - with self.test_session() as sess: - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - self.assertEqual(elem, [1]) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - self.assertEqual(elem, [2]) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - self.assertEqual(elem, [3]) - with self.assertRaises(errors.OutOfRangeError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - - def testIncorrectIteratorRestore(self): - - def _path(): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - _path(), parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(_path()), dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def _build_range_dataset_graph(): - start = 1 - stop = 10 - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = _save_op(iterator._iterator_resource) - restore_op = _restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - def _build_reader_dataset_graph(): - filenames = ["test"] # Does not exist but we don't care in this test. - iterator = readers.FixedLengthRecordDataset( - filenames, 1, 0, 0).make_initializable_iterator() - init_op = iterator.initializer - get_next_op = iterator.get_next() - save_op = _save_op(iterator._iterator_resource) - restore_op = _restore_op(iterator._iterator_resource) - return init_op, get_next_op, save_op, restore_op - - # Saving iterator for RangeDataset graph. - with ops.Graph().as_default() as g: - init_op, _, save_op, _ = _build_range_dataset_graph() - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(save_op) - - # Attempt to restore the saved iterator into an IteratorResource of - # incompatible type. An iterator of RangeDataset has output type int64, - # while an iterator of FixedLengthRecordDataset has output type string. - # So an InvalidArgumentError should be raised by - # IteratorResource::set_iterator. - with ops.Graph().as_default() as g: - _, _, _, restore_op = _build_reader_dataset_graph() - with self.test_session(graph=g) as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(restore_op) - - def testToSingleElement(self): - skip_value = array_ops.placeholder(dtypes.int64, shape=[]) - take_value = array_ops.placeholder_with_default( - constant_op.constant(1, dtype=dtypes.int64), shape=[]) - - dataset = (dataset_ops.Dataset.range(100) - .skip(skip_value) - .map(lambda x: x * x) - .take(take_value)) - - element = dataset_ops.get_single_element(dataset) - - with self.test_session() as sess: - self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) - self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) - self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset was empty."): - sess.run(element, feed_dict={skip_value: 100}) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset had more than one element."): - sess.run(element, feed_dict={skip_value: 0, take_value: 2}) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py deleted file mode 100644 index 27298de65f90c627e5eb638385bfe0478ef74fca..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from os import path -import shutil -import tempfile - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test -from tensorflow.python.util import compat - - -class ListFilesDatasetOpTest(test.TestCase): - - def setUp(self): - self.tmp_dir = tempfile.mkdtemp() - - def tearDown(self): - shutil.rmtree(self.tmp_dir, ignore_errors=True) - - def _touchTempFiles(self, filenames): - for filename in filenames: - open(path.join(self.tmp_dir, filename), 'a').close() - - def testEmptyDirectory(self): - dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) - with self.test_session() as sess: - itr = dataset.make_one_shot_iterator() - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - def testSimpleDirectory(self): - filenames = ['a', 'b', 'c'] - self._touchTempFiles(filenames) - - dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) - with self.test_session() as sess: - itr = dataset.make_one_shot_iterator() - - full_filenames = [] - produced_filenames = [] - for filename in filenames: - full_filenames.append( - compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) - self.assertItemsEqual(full_filenames, produced_filenames) - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - def testEmptyDirectoryInitializer(self): - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - dataset = dataset_ops.Dataset.list_files(filename_placeholder) - - with self.test_session() as sess: - itr = dataset.make_initializable_iterator() - sess.run( - itr.initializer, - feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - def testSimpleDirectoryInitializer(self): - filenames = ['a', 'b', 'c'] - self._touchTempFiles(filenames) - - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - dataset = dataset_ops.Dataset.list_files(filename_placeholder) - - with self.test_session() as sess: - itr = dataset.make_initializable_iterator() - sess.run( - itr.initializer, - feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) - - full_filenames = [] - produced_filenames = [] - for filename in filenames: - full_filenames.append( - compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) - - self.assertItemsEqual(full_filenames, produced_filenames) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - def testFileSuffixes(self): - filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc'] - self._touchTempFiles(filenames) - - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - dataset = dataset_ops.Dataset.list_files(filename_placeholder) - - with self.test_session() as sess: - itr = dataset.make_initializable_iterator() - sess.run( - itr.initializer, - feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')}) - - full_filenames = [] - produced_filenames = [] - for filename in filenames[1:-1]: - full_filenames.append( - compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) - self.assertItemsEqual(full_filenames, produced_filenames) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - def testFileMiddles(self): - filenames = ['a.txt', 'b.py', 'c.pyc'] - self._touchTempFiles(filenames) - - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - dataset = dataset_ops.Dataset.list_files(filename_placeholder) - - with self.test_session() as sess: - itr = dataset.make_initializable_iterator() - sess.run( - itr.initializer, - feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')}) - - full_filenames = [] - produced_filenames = [] - for filename in filenames[1:]: - full_filenames.append( - compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) - - self.assertItemsEqual(full_filenames, produced_filenames) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index dd8247bfd47a9880c7cfe905103702e43b1f2165..8d4042927970cab2f5a518fc0da49b38444dbcdf 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -16,15 +16,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from collections import namedtuple import os -import threading import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops from tensorflow.contrib.data.python.ops import error_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -34,15 +31,9 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import io_ops -from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import script_ops -from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -50,231 +41,11 @@ from tensorflow.python.util import compat class MapDatasetTest(test.TestCase): - def _buildMapDataset(self, components, count): - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return ( - contrib_dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count)) - - def testMapDataset(self): - """Test an dataset that maps a TF function across its input elements.""" - # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> - # RepeatDataset(count). - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - count = array_ops.placeholder(dtypes.int64, shape=[]) - - dataset = self._buildMapDataset(components, count) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - # Test single-threaded access to the iterator. - sess.run(init_op, feed_dict={count: 14}) - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test multi-threaded access to the same iterator. - sess.run(init_op, feed_dict={count: 18}) - results = [] - def iterator_thread(): - while True: - try: - results.append(sess.run(get_next)) - except errors.OutOfRangeError: - return - threads = [self.checkedThread(target=iterator_thread) for _ in range(8)] - for t in threads: - t.start() - for t in threads: - t.join() - - # `results` will contain the same elements components**2 - # repeated 18 times, but in a non-deterministic order. Sort the - # results, and assert that each element of components**2 is - # produced 18 times. - results.sort(key=lambda x: x[0]) - for i in range(7): - for j in range(18): - for component, result_component in zip(components, - results[i * 18 + j]): - self.assertAllEqual(component[i]**2, result_component) - - def _buildParallelMapDataset(self, components, count, num_threads, - output_buffer_size): - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return (contrib_dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_threads=num_threads, output_buffer_size=output_buffer_size) - .repeat(count)) - - def testParallelMapDataset(self): - """Test an dataset that maps a TF function across its input elements.""" - # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) -> - # RepeatDataset(count). - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - count = array_ops.placeholder(dtypes.int64, shape=[]) - num_threads = array_ops.placeholder(dtypes.int32, shape=[]) - output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) - - dataset = self._buildParallelMapDataset(components, count, num_threads, - output_buffer_size) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - def do_test(num_threads_val, output_buffer_size_val): - # Test single-threaded access to the iterator. - sess.run(init_op, feed_dict={ - count: 14, - num_threads: num_threads_val, - output_buffer_size: output_buffer_size_val}) - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test multi-threaded access to the same iterator. - sess.run(init_op, feed_dict={ - count: 18, - num_threads: num_threads_val, - output_buffer_size: output_buffer_size_val}) - results = [] - def iterator_thread(): - while True: - try: - results.append(sess.run(get_next)) - except errors.OutOfRangeError: - return - threads = [self.checkedThread(target=iterator_thread) - for _ in range(64)] - for t in threads: - t.start() - for t in threads: - t.join() - - # `results` will contain the same elements components**2 - # repeated 18 times, but in a non-deterministic order. Sort the - # results, and assert that each element of components**2 is - # produced 18 times. - results.sort(key=lambda x: x[0]) - for i in range(7): - for j in range(18): - for component, result_component in zip(components, - results[i * 18 + j]): - self.assertAllEqual(component[i]**2, result_component) - - for num_threads_val, output_buffer_size_val in [ - (1, 1), (1, 2), (2, 2), (2, 4), (8, 8), (8, 16)]: - do_test(num_threads_val, output_buffer_size_val) - - def testImplicitDisposeParallelMapDataset(self): - # Tests whether a parallel map dataset will be cleaned up correctly when - # the pipeline does not run it until exhaustion. - # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> - # RepeatDataset(1000). - components = (np.arange(1000), - np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], - np.array(37.0) * np.arange(1000)) - - dataset = self._buildParallelMapDataset(components, 1000, 100, 100) - # NOTE(mrry): Also test that the prefetching thread is cancelled correctly. - dataset = dataset.prefetch(100) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(3): - sess.run(get_next) - - def testParallelMapUnspecifiedOutputSize(self): - components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - - dataset = ( - contrib_dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.check_numerics(x, "message"), num_threads=2)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(3): - sess.run(get_next) - - def testParallelMapError(self): - components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - - dataset = ( - contrib_dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.check_numerics(x, "message"), - num_threads=2, - output_buffer_size=2)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(3): - sess.run(get_next) - # The 4th element is NaN, so `array_ops.check_numerics()` should fail. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testPrefetchError(self): - components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - - dataset = ( - contrib_dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message")).prefetch(2)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(3): - sess.run(get_next) - # The 4th element is NaN, so `array_ops.check_numerics()` should fail. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - def testMapIgnoreError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) dataset = ( - contrib_dataset_ops.Dataset.from_tensor_slices(components) + dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.check_numerics(x, "message")).apply( error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() @@ -292,10 +63,9 @@ class MapDatasetTest(test.TestCase): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) dataset = ( - contrib_dataset_ops.Dataset.from_tensor_slices(components).map( + dataset_ops.Dataset.from_tensor_slices(components).map( lambda x: array_ops.check_numerics(x, "message"), - num_threads=2, - output_buffer_size=2).apply(error_ops.ignore_errors())) + num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -317,8 +87,8 @@ class MapDatasetTest(test.TestCase): write_string_to_file(filename, filename) dataset = ( - contrib_dataset_ops.Dataset.from_tensor_slices(filenames).map( - io_ops.read_file, num_threads=2, output_buffer_size=2).apply( + dataset_ops.Dataset.from_tensor_slices(filenames).map( + io_ops.read_file, num_parallel_calls=2).prefetch(2).apply( error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer @@ -343,350 +113,6 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testCaptureHashTable(self): - # NOTE(mrry): We must use the V2 variants of `HashTable` - # etc. because these produce a `tf.resource`-typed output that is - # compatible with the in-graph function implementation. - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - - input_sentences = contrib_dataset_ops.Dataset.from_tensor_slices( - ["brain brain tank salad surgery", "surgery brain"]) - - iterator = (input_sentences - .map(lambda x: string_ops.string_split([x]).values) - .map(table.lookup) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(table.init) - sess.run(init_op) - - print(sess.run(get_next)) - print(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testCaptureQueue(self): - elements = np.random.randint(100, size=[200]) - queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[]) - enqueue_op = queue.enqueue_many(elements) - close_op = queue.close() - iterator = ( - contrib_dataset_ops.Dataset.from_tensors(0).repeat(-1) - .map(lambda _: queue.dequeue()).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(enqueue_op) - sess.run(close_op) - sess.run(init_op) - for element in elements: - self.assertEqual(element, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testCaptureSameResourceMultipleTimes(self): - elements = np.random.randint(100, size=[200]) - queue = data_flow_ops.FIFOQueue( - 200, dtypes.int64, shapes=[], shared_name="shared_queue") - queue_2 = data_flow_ops.FIFOQueue( - 200, dtypes.int64, shapes=[], shared_name="shared_queue") - - enqueue_op = queue.enqueue_many(elements) - close_op = queue.close() - - iterator = ( - contrib_dataset_ops.Dataset.from_tensors(0).repeat(-1) - .map(lambda _: (queue.dequeue(), queue_2.dequeue())) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(enqueue_op) - sess.run(close_op) - sess.run(init_op) - for i in range(100): - self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]), - sorted(sess.run(get_next))) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testCaptureVariable(self): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - iterator = ( - contrib_dataset_ops.Dataset.from_tensors(0).repeat(10) - .map(lambda _: counter_var.assign_add(1)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(counter_var.initializer) - sess.run(init_op) - for i in range(10): - self.assertEqual(i, sess.run(counter_var)) - self.assertEqual(i + 1, sess.run(get_next)) - self.assertEqual(10, sess.run(counter_var)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(10, sess.run(counter_var)) - - def testCaptureUninitializedVariableError(self): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - iterator = ( - contrib_dataset_ops.Dataset.from_tensors(0).repeat(10) - .map(lambda _: counter_var.assign_add(1)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - with self.assertRaises(errors.NotFoundError): - sess.run(get_next) - - def testSeededStatefulOperatorIsProperlyStateful(self): - iterator = ( - contrib_dataset_ops.Dataset.from_tensors(0).repeat(10) - .map(lambda _: random_ops.random_uniform((), seed=11)).batch(2) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - random_values = [] - with self.assertRaises(errors.OutOfRangeError): - while True: - random_values.extend(sess.run(get_next)) - self.assertEqual(10, len(random_values)) - self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6) - sess.run(init_op) - random_values_2 = [] - with self.assertRaises(errors.OutOfRangeError): - while True: - random_values_2.extend(sess.run(get_next)) - - # Randomness is repeatable given same seed - self.assertAllClose(random_values, random_values_2) - - def testMapDict(self): - iterator = (contrib_dataset_ops.Dataset.range(10) - .map(lambda x: {"foo": x * 2, "bar": x ** 2}) - .map(lambda d: d["foo"] + d["bar"]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - self.assertEqual(i * 2 + i ** 2, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testMapNamedtuple(self, count=10): - # construct dataset of tuples - labels = contrib_dataset_ops.Dataset.range(count) - images = labels.map(lambda l: -l) - dataset_tuple = contrib_dataset_ops.Dataset.zip((labels, images)) - - # convert dataset of tuples to dataset of namedtuples - example = namedtuple("Example", ["label", "image"]) - dataset_namedtuple = dataset_tuple.map(example) - - def preprocess_tuple(label, image): - image = 2 * image - return label, image - - def preprocess_namedtuple(example): - return example._replace(image=2 * example.image) - - # preprocess both datasets - dataset_tuple = dataset_tuple.map(preprocess_tuple) - dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple) - - next_tuple = dataset_tuple.make_one_shot_iterator().get_next() - next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next() - - # make sure both datasets contain the same data - with self.test_session() as sess: - for i in range(count): - tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple]) - self.assertEqual(tuple_, namedtuple_) - self.assertEqual(tuple_, (i, -2 * i)) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_namedtuple) - - def testUseStepContainerInMap(self): - row = np.arange(6) - iterator = ( - contrib_dataset_ops.Dataset.from_tensors(row) - .map(lambda elems: functional_ops.map_fn(lambda x: x * x, elems)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - self.assertAllEqual(row ** 2, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testPrefetch(self): - # We will use this event to test that `_map_py_func()` has been - # invoked a certain number of times (6 times, to be exact) after - # consuming fewer elements from the iterator. - ev = threading.Event() - - set_event_during_invocation = 5 - - def _map_py_func(x): - if x == set_event_during_invocation: - ev.set() - return x * x - - def _map_fn(x): - return script_ops.py_func(_map_py_func, [x], x.dtype) - - buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( - contrib_dataset_ops.Dataset.range(100).map(_map_fn) - .prefetch(buffer_size_placeholder).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - # Simple test that prefetch yields the expected values in the - # expected order. - for buffer_size in [1, 10, 100, 1000]: - sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size}) - for i in range(100): - self.assertEqual(i * i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # We can indirectly observe that varying the buffer size has the - # intended effect by observing when `ev` is set (on the 6th - # invocation of `_map_py_func()`). - # NOTE(mrry): We do not test with `buffer_size == - # set_event_during_invocation`, because we must consume at least - # one element to start the prefetching. - for buffer_size in range(1, set_event_during_invocation): - event_will_be_set_after_consuming = ( - set_event_during_invocation - buffer_size + 1) - - ev.clear() - sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size}) - for i in range(event_will_be_set_after_consuming): - self.assertFalse(ev.is_set()) - self.assertEqual(i * i, sess.run(get_next)) - ev.wait() - for i in range(event_will_be_set_after_consuming, 100): - self.assertEqual(i * i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testReturnList(self): - iterator = ( - contrib_dataset_ops.Dataset.range(10) - .map(lambda x: [x, constant_op.constant(37.0)]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - self.assertEqual((i, 37.0), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testMultiOutputPyFunc(self): - # The `tf.py_func()` op returns a list of tensors for its outputs. - def _map_fn(x_tensor): - def _map_py_func(x): - return x, np.array(37.0, dtype=np.float64) - return script_ops.py_func( - _map_py_func, [x_tensor], [dtypes.int64, dtypes.float64]) - - iterator = ( - contrib_dataset_ops.Dataset.range(10).map(_map_fn) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - self.assertEqual((i, 37.0), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - - def testSparse(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1])), - dense_shape=np.array([1, 1])) - - iterator = ( - contrib_dataset_ops.Dataset.range(10).map(_sparse) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) - self.assertSparseValuesEqual(actual, _sparse(i)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSparseChain(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1])), - dense_shape=np.array([1, 1])) - - def _check(i): - self.assertTrue(sparse_tensor.is_sparse(i)) - return sparse_ops.sparse_concat(0, [i, i]) - - iterator = ( - contrib_dataset_ops.Dataset.range(10).map(_sparse).map(_check) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) - self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval()) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - def testCaptureResourceInMapFn(self): def _build_ds(iterator): @@ -695,10 +121,10 @@ class MapDatasetTest(test.TestCase): get_next = iterator.get_next() return x * get_next - return contrib_dataset_ops.Dataset.range(10).map(_map_fn) + return dataset_ops.Dataset.range(10).map(_map_fn) def _build_graph(): - captured_iterator = contrib_dataset_ops.Dataset.range( + captured_iterator = dataset_ops.Dataset.range( 10).make_initializable_iterator() ds = _build_ds(captured_iterator) iterator = ds.make_initializable_iterator() @@ -734,7 +160,7 @@ class MapDatasetSerializationTest( return math_ops.square(x), math_ops.square(y), math_ops.square(z) return ( - contrib_dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(self._num_epochs)) def testSaveRestoreCore(self): @@ -751,7 +177,7 @@ class MapDatasetSerializationTest( return random_ops.random_uniform( (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - return contrib_dataset_ops.Dataset.range(100).map(_map_fn) + return dataset_ops.Dataset.range(100).map(_map_fn) self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) @@ -760,7 +186,7 @@ class MapDatasetSerializationTest( def _build_ds(): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) - return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map( + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( lambda _: counter_var.assign_add(1))) self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) @@ -769,7 +195,7 @@ class MapDatasetSerializationTest( def _build_ds(): constant_var = constant_op.constant(5) - return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map( + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( lambda x: x + constant_var)) self.run_core_tests(_build_ds, None, 10) @@ -783,7 +209,7 @@ class MapDatasetSerializationTest( def defun_fn(x): return constant_op.constant(1000) + math_ops.to_int32(x) - return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn) + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) self.run_core_tests(_build_ds, None, num_outputs) @@ -801,10 +227,25 @@ class MapDatasetSerializationTest( return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn) + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) self.run_core_tests(_build_ds, None, num_outputs) + def testSparseCore(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def _build_ds(num_outputs): + return dataset_ops.Dataset.range(num_outputs).map(_sparse) + + num_outputs = 10 + self.run_core_tests(lambda: _build_ds(num_outputs), + lambda: _build_ds(int(num_outputs / 2)), num_outputs) + class ParallelMapDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): @@ -851,7 +292,8 @@ class ParallelMapDatasetSerializationTest( return random_ops.random_uniform( (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - return contrib_dataset_ops.Dataset.range(100).map(_map_fn) + return dataset_ops.Dataset.range(100).map( + _map_fn, num_parallel_calls=2).prefetch(2) self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) @@ -860,8 +302,9 @@ class ParallelMapDatasetSerializationTest( def _build_ds(): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) - return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1))) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1), + num_parallel_calls=2).prefetch(2)) self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) @@ -869,8 +312,8 @@ class ParallelMapDatasetSerializationTest( def _build_ds(): constant_var = constant_op.constant(5) - return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda x: x + constant_var)) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) self.run_core_tests(_build_ds, None, 10) @@ -883,7 +326,8 @@ class ParallelMapDatasetSerializationTest( def defun_fn(x): return constant_op.constant(1000) + math_ops.to_int32(x) - return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn) + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) self.run_core_tests(_build_ds, None, num_outputs) @@ -901,7 +345,8 @@ class ParallelMapDatasetSerializationTest( return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn) + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) self.run_core_tests(_build_ds, None, num_outputs) @@ -910,7 +355,7 @@ class IgnoreErrorsSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_ds(self, components): - return contrib_dataset_ops.Dataset.from_tensor_slices(components).map( + return dataset_ops.Dataset.from_tensor_slices(components).map( lambda x: array_ops.check_numerics(x, "message")).apply( error_ops.ignore_errors()) diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index a431670829ed1d66f1719985af73eafa1fe45982..80e1cb0041024b68bd5268b5de5d69c88c839896 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -21,14 +21,13 @@ import os from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import counter -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops @@ -38,131 +37,6 @@ from tensorflow.python.platform import test class RangeDatasetTest(test.TestCase): - def testStop(self): - stop = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={stop: 5}) - for i in range(5): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testStartStop(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 2, stop: 5}) - for i in range(2, 5): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testStartStopStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2}) - for i in range(2, 10, 2): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testZeroStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - - with self.test_session() as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0}) - - def testNegativeStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1}) - # This for loop is a no-op but will ensure that the implementation is - # consistent with range if it ever changes. - for i in range(2, 10, -1): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testStopLessThanStart(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 10, stop: 2}) - # This for loop is a no-op but will ensure that the implementation is - # consistent with range if it ever changes. - for i in range(10, 2): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testStopLessThanStartWithPositiveStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2}) - # This for loop is a no-op but will ensure that the implementation is - # consistent with range if it ever changes. - for i in range(10, 2, 2): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testStopLessThanStartWithNegativeStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1}) - for i in range(10, 2, -1): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - def testEnumerateDataset(self): components = (["a", "b"], [1, 2], [37.0, 38]) start = constant_op.constant(20, dtype=dtypes.int64) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 1c42a3d855bc16c21e385d7108c3106884ae4f5e..6efe97444a375febc550ff3a3ea04bcd9330a3a5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -76,101 +77,12 @@ class TextLineDatasetTestBase(test.TestCase): return filenames -class TextLineDatasetTest(TextLineDatasetTestBase): - - def _testTextLineDataset(self, compression_type=None): - test_filenames = self._createFiles( - 2, 5, crlf=True, compression_type=compression_type) - filenames = array_ops.placeholder(dtypes.string, shape=[None]) - num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = readers.TextLineDataset( - filenames, compression_type=compression_type).repeat(num_epochs) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - init_batch_op = iterator.make_initializer(batch_dataset) - get_next = iterator.get_next() - - with self.test_session() as sess: - # Basic test: read from file 0. - sess.run( - init_op, feed_dict={filenames: [test_filenames[0]], - num_epochs: 1}) - for i in range(5): - self.assertEqual(self._lineText(0, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from file 1. - sess.run( - init_op, feed_dict={filenames: [test_filenames[1]], - num_epochs: 1}) - for i in range(5): - self.assertEqual(self._lineText(1, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from both files. - sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1}) - for j in range(2): - for i in range(5): - self.assertEqual(self._lineText(j, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test repeated iteration through both files. - sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10}) - for _ in range(10): - for j in range(2): - for i in range(5): - self.assertEqual(self._lineText(j, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test batched and repeated iteration through both files. - sess.run( - init_batch_op, - feed_dict={filenames: test_filenames, - num_epochs: 10, - batch_size: 5}) - for _ in range(10): - self.assertAllEqual([self._lineText(0, i) for i in range(5)], - sess.run(get_next)) - self.assertAllEqual([self._lineText(1, i) for i in range(5)], - sess.run(get_next)) - - def testTextLineDatasetNoCompression(self): - self._testTextLineDataset() - - def testTextLineDatasetGzipCompression(self): - self._testTextLineDataset(compression_type="GZIP") - - def testTextLineDatasetZlibCompression(self): - self._testTextLineDataset(compression_type="ZLIB") - - def testTextLineDatasetBuffering(self): - test_filenames = self._createFiles(2, 5, crlf=True) - - repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10) - iterator = repeat_dataset.make_one_shot_iterator() - - with self.test_session() as sess: - for j in range(2): - for i in range(5): - self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - class TextLineDatasetSerializationTest( TextLineDatasetTestBase, dataset_serialization_test_base.DatasetSerializationTestBase): def _build_iterator_graph(self, test_filenames, compression_type=None): - return readers.TextLineDataset( + return core_readers.TextLineDataset( test_filenames, compression_type=compression_type, buffer_size=10) def testTextLineCore(self): @@ -217,101 +129,13 @@ class FixedLengthRecordReaderTestBase(test.TestCase): return filenames -class FixedLengthRecordReaderTest(FixedLengthRecordReaderTestBase): - - def testFixedLengthRecordDataset(self): - test_filenames = self._createFiles() - filenames = array_ops.placeholder(dtypes.string, shape=[None]) - num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = (readers.FixedLengthRecordDataset( - filenames, self._record_bytes, self._header_bytes, self._footer_bytes) - .repeat(num_epochs)) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - init_batch_op = iterator.make_initializer(batch_dataset) - get_next = iterator.get_next() - - with self.test_session() as sess: - # Basic test: read from file 0. - sess.run( - init_op, feed_dict={filenames: [test_filenames[0]], - num_epochs: 1}) - for i in range(self._num_records): - self.assertEqual(self._record(0, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from file 1. - sess.run( - init_op, feed_dict={filenames: [test_filenames[1]], - num_epochs: 1}) - for i in range(self._num_records): - self.assertEqual(self._record(1, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from both files. - sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1}) - for j in range(self._num_files): - for i in range(self._num_records): - self.assertEqual(self._record(j, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test repeated iteration through both files. - sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10}) - for _ in range(10): - for j in range(self._num_files): - for i in range(self._num_records): - self.assertEqual(self._record(j, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test batched and repeated iteration through both files. - sess.run( - init_batch_op, - feed_dict={ - filenames: test_filenames, - num_epochs: 10, - batch_size: self._num_records - }) - for _ in range(10): - for j in range(self._num_files): - self.assertAllEqual( - [self._record(j, i) for i in range(self._num_records)], - sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFixedLengthRecordDatasetBuffering(self): - test_filenames = self._createFiles() - dataset = readers.FixedLengthRecordDataset( - test_filenames, - self._record_bytes, - self._header_bytes, - self._footer_bytes, - buffer_size=10) - iterator = dataset.make_one_shot_iterator() - - with self.test_session() as sess: - for j in range(self._num_files): - for i in range(self._num_records): - self.assertEqual(self._record(j, i), sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - class FixedLengthRecordDatasetSerializationTest( FixedLengthRecordReaderTestBase, dataset_serialization_test_base.DatasetSerializationTestBase): def _build_iterator_graph(self, num_epochs, compression_type=None): filenames = self._createFiles() - return readers.FixedLengthRecordDataset( + return core_readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes).repeat(num_epochs) @@ -338,9 +162,8 @@ class TFRecordDatasetTestBase(test.TestCase): self.compression_type = array_ops.placeholder_with_default("", shape=[]) self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - repeat_dataset = readers.TFRecordDataset(self.filenames, - self.compression_type).repeat( - self.num_epochs) + repeat_dataset = core_readers.TFRecordDataset( + self.filenames, self.compression_type).repeat(self.num_epochs) batch_dataset = repeat_dataset.batch(self.batch_size) iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) @@ -363,129 +186,6 @@ class TFRecordDatasetTestBase(test.TestCase): return filenames -class TFRecordDatasetTest(TFRecordDatasetTestBase): - - def testReadOneEpoch(self): - with self.test_session() as sess: - # Basic test: read from file 0. - sess.run( - self.init_op, - feed_dict={ - self.filenames: [self.test_filenames[0]], - self.num_epochs: 1 - }) - for i in range(self._num_records): - self.assertAllEqual(self._record(0, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - # Basic test: read from file 1. - sess.run( - self.init_op, - feed_dict={ - self.filenames: [self.test_filenames[1]], - self.num_epochs: 1 - }) - for i in range(self._num_records): - self.assertAllEqual(self._record(1, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - # Basic test: read from both files. - sess.run( - self.init_op, - feed_dict={self.filenames: self.test_filenames, - self.num_epochs: 1}) - for j in range(self._num_files): - for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - def testReadTenEpochs(self): - with self.test_session() as sess: - sess.run( - self.init_op, - feed_dict={self.filenames: self.test_filenames, - self.num_epochs: 10}) - for _ in range(10): - for j in range(self._num_files): - for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - def testReadTenEpochsOfBatches(self): - with self.test_session() as sess: - sess.run( - self.init_batch_op, - feed_dict={ - self.filenames: self.test_filenames, - self.num_epochs: 10, - self.batch_size: self._num_records - }) - for _ in range(10): - for j in range(self._num_files): - values = sess.run(self.get_next) - self.assertAllEqual( - [self._record(j, i) for i in range(self._num_records)], values) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - def testReadZlibFiles(self): - zlib_files = [] - for i, fn in enumerate(self.test_filenames): - with open(fn, "rb") as f: - cdata = zlib.compress(f.read()) - - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) - with open(zfn, "wb") as f: - f.write(cdata) - zlib_files.append(zfn) - - with self.test_session() as sess: - sess.run( - self.init_op, - feed_dict={self.filenames: zlib_files, - self.compression_type: "ZLIB"}) - for j in range(self._num_files): - for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - def testReadGzipFiles(self): - gzip_files = [] - for i, fn in enumerate(self.test_filenames): - with open(fn, "rb") as f: - gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) - with gzip.GzipFile(gzfn, "wb") as gzf: - gzf.write(f.read()) - gzip_files.append(gzfn) - - with self.test_session() as sess: - sess.run( - self.init_op, - feed_dict={self.filenames: gzip_files, - self.compression_type: "GZIP"}) - for j in range(self._num_files): - for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - def testReadWithBuffer(self): - one_mebibyte = 2**20 - d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte) - iterator = d.make_one_shot_iterator() - with self.test_session() as sess: - for j in range(self._num_files): - for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - class TFRecordDatasetSerializationTest( TFRecordDatasetTestBase, dataset_serialization_test_base.DatasetSerializationTestBase): @@ -517,7 +217,7 @@ class TFRecordDatasetSerializationTest( gzip_files.append(gzfn) filenames = gzip_files - return readers.TFRecordDataset( + return core_readers.TFRecordDataset( filenames, compression_type, buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) @@ -575,7 +275,7 @@ class ReadBatchFeaturesTest(test.TestCase): "record": parsing_ops.FixedLenFeature([], dtypes.int64), "keywords": parsing_ops.VarLenFeature(dtypes.string) }, - reader=readers.TFRecordDataset, + reader=core_readers.TFRecordDataset, randomize_input=False, num_epochs=self.num_epochs) @@ -714,12 +414,11 @@ class ReadBatchFeaturesTest(test.TestCase): self._next_actual_batch(sess) def testReadWithEquivalentDataset(self): - # TODO(mrry): Add support for tf.SparseTensor as a Dataset component. features = { "file": parsing_ops.FixedLenFeature([], dtypes.int64), "record": parsing_ops.FixedLenFeature([], dtypes.int64), } - dataset = (readers.TFRecordDataset(self.test_filenames) + dataset = (core_readers.TFRecordDataset(self.test_filenames) .map(lambda x: parsing_ops.parse_single_example(x, features)) .repeat(10).batch(2)) iterator = dataset.make_initializable_iterator() diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 0ac8d7359f7234d98167277724780bf31555e6fb..3c7b46629edb13459766b5ef3f392e8d00ad4db8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -19,8 +19,8 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import resampling +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.ops import string_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index 1a26da82e533ec01106ea10525c1cd96627c34fb..36ddf3004237ed042f21d691d83eafbaa20621e6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -20,194 +20,10 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -class SequenceDatasetTest(test.TestCase): - - def testRepeatTensorDataset(self): - """Test a dataset that repeats its input multiple times.""" - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - # This placeholder can be fed when dataset-definition subgraph - # runs (i.e. `init_op` below) to configure the number of - # repetitions used in a particular iterator. - count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - - iterator = (dataset_ops.Dataset.from_tensors(components) - .repeat(count_placeholder).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - # Test a finite repetition. - sess.run(init_op, feed_dict={count_placeholder: 3}) - for _ in range(3): - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test a different finite repetition. - sess.run(init_op, feed_dict={count_placeholder: 7}) - for _ in range(7): - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test an empty repetition. - sess.run(init_op, feed_dict={count_placeholder: 0}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test an infinite repetition. - # NOTE(mrry): There's not a good way to test that the sequence - # actually is infinite. - sess.run(init_op, feed_dict={count_placeholder: -1}) - for _ in range(17): - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - - def testTakeTensorDataset(self): - components = (np.arange(10),) - count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .take(count_placeholder).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - # Take fewer than input size - sess.run(init_op, feed_dict={count_placeholder: 4}) - for i in range(4): - results = sess.run(get_next) - self.assertAllEqual(results, components[0][i:i+1]) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Take more than input size - sess.run(init_op, feed_dict={count_placeholder: 25}) - for i in range(10): - results = sess.run(get_next) - self.assertAllEqual(results, components[0][i:i+1]) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Take all of input - sess.run(init_op, feed_dict={count_placeholder: -1}) - for i in range(10): - results = sess.run(get_next) - self.assertAllEqual(results, components[0][i:i+1]) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Take nothing - sess.run(init_op, feed_dict={count_placeholder: 0}) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSkipTensorDataset(self): - components = (np.arange(10),) - count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .skip(count_placeholder).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - # Skip fewer than input size, we should skip - # the first 4 elements and then read the rest. - sess.run(init_op, feed_dict={count_placeholder: 4}) - for i in range(4, 10): - results = sess.run(get_next) - self.assertAllEqual(results, components[0][i:i+1]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Skip more than input size: get nothing. - sess.run(init_op, feed_dict={count_placeholder: 25}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Skip exactly input size. - sess.run(init_op, feed_dict={count_placeholder: 10}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Set -1 for 'count': skip the entire dataset. - sess.run(init_op, feed_dict={count_placeholder: -1}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Skip nothing - sess.run(init_op, feed_dict={count_placeholder: 0}) - for i in range(0, 10): - results = sess.run(get_next) - self.assertAllEqual(results, components[0][i:i+1]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRepeatRepeatTensorDataset(self): - """Test the composition of repeat datasets.""" - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - inner_count = array_ops.placeholder(dtypes.int64, shape=[]) - outer_count = array_ops.placeholder(dtypes.int64, shape=[]) - - iterator = (dataset_ops.Dataset.from_tensors(components).repeat(inner_count) - .repeat(outer_count).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14}) - for _ in range(7 * 14): - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRepeatEmptyDataset(self): - """Test that repeating an empty dataset does not hang.""" - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10) - .repeat(-1).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - with self.assertRaisesRegexp( - errors.OutOfRangeError, - "Attempted to repeat an empty dataset infinitely."): - sess.run(get_next) - - class SequenceDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6b74dc3eb80a6168117beed06935737198cecb --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py @@ -0,0 +1,85 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration test for input pipeline serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class MultipleInputPipelinesTest(test.TestCase): + + def _build_input_pipeline(self, name, num_outputs): + with ops.name_scope(name): + ds = dataset_ops.Dataset.range(num_outputs).shuffle( + 10, reshuffle_each_iteration=False).prefetch(10) + iterator = ds.make_initializable_iterator() + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + return iterator.initializer, iterator.get_next() + + def _build_graph(self, num_pipelines, num_outputs): + init_ops = [] + get_next_ops = [] + for i in range(num_pipelines): + name = "input_pipeline_%d" % i + init_op, get_next_op = self._build_input_pipeline(name, num_outputs) + init_ops.append(init_op) + get_next_ops.append(get_next_op) + saver = saver_lib.Saver() + return init_ops, get_next_ops, saver + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def testConcurrentSaves(self): + num_pipelines = 100 + num_outputs = 100 + break_point = 10 + all_outputs = [[] for _ in range(num_pipelines)] + with ops.Graph().as_default() as g: + init_ops, get_next_ops, saver = self._build_graph(num_pipelines, + num_outputs) + with self.test_session(graph=g) as sess: + sess.run(init_ops) + for _ in range(break_point): + output = sess.run(get_next_ops) + for i in range(num_pipelines): + all_outputs[i].append(output[i]) + saver.save(sess, self._ckpt_path()) + + with ops.Graph().as_default() as g: + init_ops, get_next_ops, saver = self._build_graph(num_pipelines, + num_outputs) + with self.test_session(graph=g) as sess: + saver.restore(sess, self._ckpt_path()) + for _ in range(num_outputs - break_point): + output = sess.run(get_next_ops) + for i in range(num_pipelines): + all_outputs[i].append(output[i]) + + for output in all_outputs: + self.assertSequenceEqual(sorted(output), range(num_outputs)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shard_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shard_dataset_op_test.py deleted file mode 100644 index 0b3c32c06eb1d69244c9a02ca4ba571769f13f40..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/shard_dataset_op_test.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.framework import errors -from tensorflow.python.platform import test - - -class ShardDatasetOpTest(test.TestCase): - - def testSimpleCase(self): - dataset = dataset_ops.Dataset.range(10).shard(5, 2) - iterator = dataset.make_one_shot_iterator() - - with self.test_session() as sess: - self.assertEqual(2, sess.run(iterator.get_next())) - self.assertEqual(7, sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testNestedData(self): - dataset_a = dataset_ops.Dataset.range(10) - dataset_b = dataset_ops.Dataset.range(10, 0, -1) - dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2) - iterator = dataset.make_one_shot_iterator() - - with self.test_session() as sess: - self.assertEqual((2, 8), sess.run(iterator.get_next())) - self.assertEqual((7, 3), sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testOffsetZero(self): - dataset = dataset_ops.Dataset.range(10).shard(5, 0) - iterator = dataset.make_one_shot_iterator() - - with self.test_session() as sess: - self.assertEqual(0, sess.run(iterator.get_next())) - self.assertEqual(5, sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testOffsetGreaterNumShards(self): - with self.assertRaises(ValueError): - dataset_ops.Dataset.range(10).shard(5, 7) - - def testNegativeOffset(self): - with self.assertRaises(ValueError): - dataset_ops.Dataset.range(10).shard(5, -3) - - def testNegativeNumShards(self): - with self.assertRaises(ValueError): - dataset_ops.Dataset.range(10).shard(-3, 1) - - def testZeroNumShards(self): - with self.assertRaises(ValueError): - dataset_ops.Dataset.range(10).shard(0, 1) - - def testIteratorEndsBeforeFirstElem(self): - dataset = dataset_ops.Dataset.range(1).shard(5, 2) - iterator = dataset.make_one_shot_iterator() - - with self.test_session() as sess: - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testLargerWorkerPool(self): - dataset = dataset_ops.Dataset.range(10).shard(7, 5) - iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: - self.assertEqual(5, sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testIndexEqualsNumShards(self): - dataset = dataset_ops.Dataset.range(10).shard(5, 4) - iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: - self.assertEqual(4, sess.run(iterator.get_next())) - self.assertEqual(9, sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testIndexEqualsNumShards2(self): - dataset = dataset_ops.Dataset.range(10).shard(4, 3) - iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: - self.assertEqual(3, sess.run(iterator.get_next())) - self.assertEqual(7, sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 45943d56ecb4bc18a6221157d0eeeae4efdf23cc..bcc644c0971854d948025009dc7add2fea214048 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -17,144 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections - import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ShuffleDatasetTest(test.TestCase): - - def testShuffleDataset(self): - components = ( - np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), - np.array([9.0, 10.0, 11.0, 12.0]) - ) - count_placeholder = array_ops.placeholder_with_default( - constant_op.constant(5, dtypes.int64), shape=[]) - buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = ( - contrib_dataset_ops.Dataset.from_tensor_slices(components) - .repeat(count_placeholder)) - - shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder, - seed_placeholder) - - self.assertEqual(tuple([c.shape[1:] for c in components]), - shuffle_dataset.output_shapes) - - # Create initialization ops for iterators without and with - # shuffling, respectively. - iterator = iterator_ops.Iterator.from_structure( - shuffle_dataset.output_types, shuffle_dataset.output_shapes) - init_fifo_op = iterator.make_initializer(repeat_dataset) - init_shuffle_op = iterator.make_initializer(shuffle_dataset) - - get_next = iterator.get_next() - - with self.test_session() as sess: - # First run without shuffling to collect the "ground truth". - sess.run(init_fifo_op) - unshuffled_elements = [] - for _ in range(20): - unshuffled_elements.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Assert that the shuffled dataset has the same elements as the - # "ground truth". - sess.run( - init_shuffle_op, - feed_dict={buffer_size_placeholder: 100, - seed_placeholder: 37}) - shuffled_elements = [] - for _ in range(20): - shuffled_elements.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertAllEqual( - sorted(unshuffled_elements), sorted(shuffled_elements)) - - # Assert that shuffling twice with the same seeds gives the same sequence. - sess.run( - init_shuffle_op, - feed_dict={buffer_size_placeholder: 100, - seed_placeholder: 37}) - reshuffled_elements_same_seed = [] - for _ in range(20): - reshuffled_elements_same_seed.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(shuffled_elements, reshuffled_elements_same_seed) - - # Assert that shuffling twice with a different seed gives a different - # permutation of the same elements. - sess.run( - init_shuffle_op, - feed_dict={buffer_size_placeholder: 100, - seed_placeholder: 1037}) - reshuffled_elements_different_seed = [] - for _ in range(20): - reshuffled_elements_different_seed.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed) - self.assertAllEqual( - sorted(shuffled_elements), sorted(reshuffled_elements_different_seed)) - - # Assert that the shuffled dataset has the same elements as the - # "ground truth" when the buffer size is smaller than the input - # dataset. - sess.run( - init_shuffle_op, - feed_dict={buffer_size_placeholder: 2, - seed_placeholder: 37}) - reshuffled_elements_small_buffer = [] - for _ in range(20): - reshuffled_elements_small_buffer.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertAllEqual( - sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer)) - - # Test the case of shuffling an empty dataset. - sess.run(init_shuffle_op, feed_dict={buffer_size_placeholder: 2, - seed_placeholder: 37, - count_placeholder: 0}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testDefaultArguments(self): - components = [0, 1, 2, 3, 4] - iterator = ( - contrib_dataset_ops.Dataset.from_tensor_slices(components).shuffle(5) - .repeat().make_one_shot_iterator()) - - get_next = iterator.get_next() - - with self.test_session() as sess: - counts = collections.defaultdict(lambda: 0) - for _ in range(10): - for _ in range(5): - counts[sess.run(get_next)] += 1 - - for i in range(5): - self.assertEqual(10, counts[i]) - - class ShuffleDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index efd864f866611bfd3bac1edcf98d84be852410fd..e26cef8ec522c7e69a0c19b2b30a969bbfc0ad78 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os + import sqlite3 from tensorflow.contrib.data.python.ops import readers diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index 55296d5710e7f66408bb7464cf790149d6df9fa1..3c436f7a0b45a13109960e87dd97ca56b10bb871 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py index 5d34b0024c472d0393544ff3dad8acea7964345f..e39fa957f0bbb9d3671274d5f58b993e8399814b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py @@ -20,97 +20,10 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -class ZipDatasetTest(test.TestCase): - - def testZipDataset(self): - component_placeholders = [ - array_ops.placeholder(dtypes.int64), - array_ops.placeholder(dtypes.int64), - array_ops.placeholder(dtypes.float64) - ] - - datasets = tuple([ - dataset_ops.Dataset.from_tensor_slices(component_placeholder) - for component_placeholder in component_placeholders - ]) - zipped = dataset_ops.Dataset.zip(datasets) - - iterator = zipped.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - equal_length_components = [ - np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 22), - np.array([37.0, 38.0, 39.0, 40.0]) - ] - sess.run(init_op, feed_dict={ph: value for ph, value in zip( - component_placeholders, equal_length_components)}) - for i in range(4): - results = sess.run(get_next) - for component, result_component in zip( - equal_length_components, results): - self.assertAllEqual(component[i], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]] - sess.run(init_op, feed_dict={ph: value for ph, value in zip( - component_placeholders, variable_length_components)}) - for i in range(2): - results = sess.run(get_next) - for component, result_component in zip( - variable_length_components, results): - self.assertAllEqual(component[i], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testNestedZipDataset(self): - component_placeholders = [ - array_ops.placeholder(dtypes.int64, shape=[4, 20]), - array_ops.placeholder(dtypes.int64, shape=[4, 22]), - array_ops.placeholder(dtypes.float64, shape=[4]) - ] - - datasets = [ - dataset_ops.Dataset.from_tensor_slices(component_placeholder) - for component_placeholder in component_placeholders - ] - zipped = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2]))) - - iterator = zipped.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([20], get_next[0].shape) - self.assertEqual([22], get_next[1][0].shape) - self.assertEqual([], get_next[1][1].shape) - - with self.test_session() as sess: - equal_length_components = [ - np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 22), - np.array([37.0, 38.0, 39.0, 40.0]) - ] - sess.run(init_op, feed_dict={ph: value for ph, value in zip( - component_placeholders, equal_length_components)}) - for i in range(4): - result1, (result2, result3) = sess.run(get_next) - self.assertAllEqual(equal_length_components[0][i], result1) - self.assertAllEqual(equal_length_components[1][i], result2) - self.assertAllEqual(equal_length_components[2][i], result3) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - class ZipDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 4349085a10135b4dee842a29916aeb5febe9ddd4..b488357f226d0922bba3799cc1f4b5c75e2e8328 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -15,7 +15,7 @@ py_library( name = "dataset_ops", srcs = [ "counter.py", - "dataset_ops.py", + "get_single_element.py", ], srcs_version = "PY2AND3", deps = [ @@ -109,6 +109,8 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", @@ -130,36 +132,44 @@ py_library( ) tf_gen_op_wrapper_py( - name = "prefetching_ops", - out = "gen_prefetching_ops.py", - deps = ["//tensorflow/contrib/data:prefetching_ops_op_lib"], + name = "gen_dataset_ops", + out = "gen_dataset_ops.py", + deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"], ) tf_kernel_library( - name = "prefetching_ops_kernels", + name = "dataset_ops_kernels", deps = [ - "//tensorflow/contrib/data/kernels:prefetching_kernels", + "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/core:framework", ], alwayslink = 1, ) tf_custom_op_py_library( - name = "prefetching_py", - srcs = ["prefetching_ops.py"], - dso = ["//tensorflow/contrib/data:_prefetching_ops.so"], + name = "contrib_op_loader", + srcs = ["contrib_op_loader.py"], + dso = ["//tensorflow/contrib/data:_dataset_ops.so"], kernels = [ - ":prefetching_ops_kernels", - "//tensorflow/contrib/data:prefetching_ops_op_lib", + ":dataset_ops_kernels", + "//tensorflow/contrib/data:dataset_ops_op_lib", ], srcs_version = "PY2AND3", deps = [ - ":prefetching_ops", + ":gen_dataset_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:platform", ], ) +py_library( + name = "prefetching_ops", + srcs = ["prefetching_ops.py"], + deps = [ + ":contrib_op_loader", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 76c07b2c999e1424e8efe4af515fddee73922c9c..6eb512dec67cb7b9c8c4518d03aee0b436205f9a 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -403,7 +403,7 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the number of batches to create in parallel. On one hand, higher values can help mitigate the effect of stragglers. On the other hand, higher values - can increasing contention if CPU is scarce. + can increase contention if CPU is scarce. Returns: A `Dataset` transformation function, which can be passed to diff --git a/tensorflow/contrib/data/python/ops/contrib_op_loader.py b/tensorflow/contrib/data/python/ops/contrib_op_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8f495a9dc9c82311435e71d2ac9ed35fd9aea794 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/contrib_op_loader.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python helper for loading contrib ops and kernels.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index fafd231061a9108b2585f4fc9256b6f069b7c37a..214641bb9a62e6cbdece07b511864a5cff20944d 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -475,7 +475,6 @@ class Dataset(dataset_ops.Dataset): @deprecation.deprecated_args( None, - "Replace `num_threads=T` with `num_parallel_calls=T`. Replace " "`output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.", "num_threads", "output_buffer_size") def map(self, @@ -483,7 +482,7 @@ class Dataset(dataset_ops.Dataset): num_threads=None, output_buffer_size=None, num_parallel_calls=None): - """Maps `map_func` across this datset. + """Maps `map_func` across this dataset. Args: map_func: A function mapping a nested structure of tensors (having diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index aa629cba479102ee4244884e7c546615b28cf4e5..6c21e489f7c35484ebacd465e3b46d6920df5933 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -17,10 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse -from tensorflow.python.ops import gen_dataset_ops def ignore_errors(): diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py new file mode 100644 index 0000000000000000000000000000000000000000..a817b45b71b608810a9d7536ec123ab84f7cdc3b --- /dev/null +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -0,0 +1,67 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for Datasets and Iterators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.ops import gen_dataset_ops + + +def get_single_element(dataset): + """Returns the single element in `dataset` as a nested structure of tensors. + + This function enables you to use a @{tf.data.Dataset} in a stateless + "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. + This can be useful when your preprocessing transformations are expressed + as a `Dataset`, and you want to use the transformation at serving time. + For example: + + ```python + input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) + + def preprocessing_fn(input_str): + # ... + return image, label + + dataset = (tf.data.Dataset.from_tensor_slices(input_batch) + .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) + .batch(BATCH_SIZE)) + + image_batch, label_batch = tf.contrib.data.get_single_element(dataset) + ``` + + Args: + dataset: A @{tf.data.Dataset} object containing a single element. + + Returns: + A nested structure of @{tf.Tensor} objects, corresponding to the single + element of `dataset`. + + Raises: + TypeError: if `dataset` is not a `tf.data.Dataset` object. + InvalidArgumentError (at runtime): if `dataset` does not contain exactly + one element. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + return nest.pack_sequence_as( + dataset.output_types, + gen_dataset_ops.dataset_to_single_element( + dataset._as_variant_tensor(), # pylint: disable=protected-access + output_types=nest.flatten(dataset.output_types), + output_shapes=nest.flatten(dataset.output_shapes))) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index ef91c56726e969053fdad667dda3e89430045652..67b085002aa7797d858837fea4646fb968ad5d97 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -45,7 +45,7 @@ def group_by_window(key_func, key_func: A function mapping a nested structure of tensors (having shapes and types defined by `self.output_shapes` and `self.output_types`) to a scalar `tf.int64` tensor. - reduce_func: A function mapping a key and a dataset of up to `batch_size` + reduce_func: A function mapping a key and a dataset of up to `window_size` consecutive elements matching that key to another dataset. window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements matching the same key to combine in a single diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index cfe8012b5657995b78d701528ea35cbb3748adb9..96a9e9ed6649444dac5e56d7dd2fcdb62fc56459 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -17,12 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_prefetching_ops -from tensorflow.contrib.util import loader -from tensorflow.python.platform import resource_loader - -_prefetching_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("../../_prefetching_ops.so")) +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops # TODO(rohanj): Add a python class that constructs resource in the __init__ @@ -35,7 +31,7 @@ def function_buffering_resource(string_arg, thread_pool_size=1, container="", name=None): - return gen_prefetching_ops.function_buffering_resource( + return gen_dataset_ops.function_buffering_resource( string_arg=string_arg, target_device=target_device, shared_name=shared_name, @@ -49,7 +45,7 @@ def function_buffering_resource(string_arg, def function_buffering_resource_get_next(function_buffer_resource, output_types, name=None): - return gen_prefetching_ops.function_buffering_resource_get_next( + return gen_dataset_ops.function_buffering_resource_get_next( function_buffer_resource=function_buffer_resource, output_types=output_types, name=name) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 347e5edc7b0d479dfa260e8cec500ffaaba375be..57f30102778f3bac47580f9bdf94e411dfe1b621 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,9 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -27,74 +25,6 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile -from tensorflow.python.util import deprecation - - -class TextLineDataset(contrib_dataset_ops.Dataset): - """A `Dataset` comprising lines from one or more text files.""" - - @deprecation.deprecated(None, "Use `tf.data.TextLineDataset`.") - def __init__(self, filenames, compression_type=None, buffer_size=None): - """Creates a `TextLineDataset`. - - Args: - filenames: A `tf.string` tensor containing one or more filenames. - compression_type: (Optional.) A `tf.string` scalar evaluating to one of - `""` (no compression), `"ZLIB"`, or `"GZIP"`. - buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes - to buffer. A value of 0 results in the default buffering values chosen - based on the compression type. - """ - dataset = readers.TextLineDataset(filenames, compression_type, - buffer_size) - super(TextLineDataset, self).__init__(dataset) - - -class TFRecordDataset(contrib_dataset_ops.Dataset): - """A `Dataset` comprising records from one or more TFRecord files.""" - - @deprecation.deprecated(None, "Use `tf.data.TFRecordDataset`.") - def __init__(self, filenames, compression_type=None, buffer_size=None): - """Creates a `TFRecordDataset`. - - Args: - filenames: A `tf.string` tensor containing one or more filenames. - compression_type: (Optional.) A `tf.string` scalar evaluating to one of - `""` (no compression), `"ZLIB"`, or `"GZIP"`. - buffer_size: (Optional.) A `tf.int64` scalar representing the number of - bytes in the read buffer. 0 means no buffering. - """ - dataset = readers.TFRecordDataset(filenames, compression_type, - buffer_size) - super(TFRecordDataset, self).__init__(dataset) - - -class FixedLengthRecordDataset(contrib_dataset_ops.Dataset): - """A `Dataset` of fixed-length records from one or more binary files.""" - - @deprecation.deprecated(None, "Use `tf.data.FixedLengthRecordDataset`.") - def __init__(self, - filenames, - record_bytes, - header_bytes=None, - footer_bytes=None, - buffer_size=None): - """Creates a `FixedLengthRecordDataset`. - - Args: - filenames: A `tf.string` tensor containing one or more filenames. - record_bytes: A `tf.int64` scalar representing the number of bytes in - each record. - header_bytes: (Optional.) A `tf.int64` scalar representing the number of - bytes to skip at the start of a file. - footer_bytes: (Optional.) A `tf.int64` scalar representing the number of - bytes to ignore at the end of a file. - buffer_size: (Optional.) A `tf.int64` scalar representing the number of - bytes to buffer when reading. - """ - dataset = readers.FixedLengthRecordDataset( - filenames, record_bytes, header_bytes, footer_bytes, buffer_size) - super(FixedLengthRecordDataset, self).__init__(dataset) def read_batch_features(file_pattern, @@ -216,14 +146,7 @@ def _get_file_names(file_pattern, randomize_input): return file_names -class SqlDataset(contrib_dataset_ops.Dataset): - - def __init__(self, driver_name, data_source_name, query, output_types): - dataset = _SqlDataset(driver_name, data_source_name, query, output_types) - super(SqlDataset, self).__init__(dataset) - - -class _SqlDataset(dataset_ops.Dataset): +class SqlDataset(dataset_ops.Dataset): """A `Dataset` consisting of the results from a SQL query.""" def __init__(self, driver_name, data_source_name, query, output_types): @@ -255,7 +178,7 @@ class _SqlDataset(dataset_ops.Dataset): output_types: A tuple of `tf.DType` objects representing the types of the columns returned by `query`. """ - super(_SqlDataset, self).__init__() + super(SqlDataset, self).__init__() self._driver_name = ops.convert_to_tensor( driver_name, dtype=dtypes.string, name="driver_name") self._data_source_name = ops.convert_to_tensor( diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 1dd0729513c0d46db25226178eb17b41efaae0ae..9cd1701c397b5a0bf5cc47c1bcab033704794d80 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops @@ -161,8 +162,10 @@ class _StatsDataset(dataset_ops.Dataset): return self._op_function( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._tag, - output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) @property def output_shapes(self): diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 60a187e541df4a794ae3944c30c427944915f7d0..61c411271d0bb8d7b4cc3b14992b82ec1e5674ed 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -40,6 +40,7 @@ from tensorflow.contrib.distributions.python.ops.geometric import * from tensorflow.contrib.distributions.python.ops.half_normal import * from tensorflow.contrib.distributions.python.ops.independent import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * +from tensorflow.contrib.distributions.python.ops.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.logistic import * from tensorflow.contrib.distributions.python.ops.mixture import * from tensorflow.contrib.distributions.python.ops.mixture_same_family import * @@ -97,7 +98,6 @@ _allowed_symbols = [ 'Autoregressive', 'Binomial', 'Bernoulli', - 'BernoulliWithSigmoidProbs', 'Beta', 'BetaWithSoftplusConcentration', 'Categorical', @@ -115,6 +115,7 @@ _allowed_symbols = [ 'Independent', 'InverseGamma', 'InverseGammaWithSoftplusConcentrationRate', + 'Kumaraswamy', 'Laplace', 'LaplaceWithSoftplusScale', 'Logistic', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index a255d4fc890e67180532e342332a8e3f63a869cd..31d24aa9ea09007b8db40e4869371b1f62639ac7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -23,10 +23,15 @@ import itertools import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import mixture +from tensorflow.contrib.distributions.python.ops import mixture_same_family +from tensorflow.contrib.distributions.python.ops import mvn_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import categorical +from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.linalg import linear_operator_diag import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -395,6 +400,41 @@ class MixtureStddevTest(test.TestCase): self.assertAllClose(actual_devs, expected_devs) +class PadMixtureDimensionsTest(test.TestCase): + + def test_pad_mixture_dimensions_mixture(self): + with self.test_session() as sess: + gm = mixture.Mixture( + cat=categorical.Categorical(probs=[[0.3, 0.7]]), + components=[ + normal.Normal(loc=[-1.0], scale=[1.0]), + normal.Normal(loc=[1.0], scale=[0.5]) + ]) + + x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) + x_pad = distribution_util.pad_mixture_dimensions( + x, gm, gm.cat, gm.event_shape.ndims) + x_out, x_pad_out = sess.run([x, x_pad]) + + self.assertAllEqual(x_pad_out.shape, [2, 2]) + self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1])) + + def test_pad_mixture_dimensions_mixture_same_family(self): + with self.test_session() as sess: + gm = mixture_same_family.MixtureSameFamily( + mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag.MultivariateNormalDiag( + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5])) + + x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) + x_pad = distribution_util.pad_mixture_dimensions( + x, gm, gm.mixture_distribution, gm.event_shape.ndims) + x_out, x_pad_out = sess.run([x, x_pad]) + + self.assertAllEqual(x_pad_out.shape, [2, 2, 1]) + self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1])) + + class _PadTest(object): def testNegAxisCorrectness(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3c86b5c0f42b64fc6e4e362cbcc162bccf74a2 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py @@ -0,0 +1,388 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import kumaraswamy as kumaraswamy_lib +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + + +special = try_import("scipy.special") +stats = try_import("scipy.stats") + + +def _kumaraswamy_mode(a, b): + a = np.asarray(a) + b = np.asarray(b) + return ((a - 1) / (a * b - 1))**(1 / a) + + +def _kumaraswamy_moment(a, b, n): + a = np.asarray(a) + b = np.asarray(b) + return b * special.beta(1.0 + n / a, b) + + +def _harmonic_number(b): + b = np.asarray(b) + return special.psi(b + 1) - special.psi(1) + + +def _kumaraswamy_cdf(a, b, x): + a = np.asarray(a) + b = np.asarray(b) + x = np.asarray(x) + return 1 - (1 - x**a)**b + + +def _kumaraswamy_pdf(a, b, x): + a = np.asarray(a) + b = np.asarray(b) + x = np.asarray(x) + return a * b * x ** (a - 1) * (1 - x ** a) ** (b - 1) + + +class KumaraswamyTest(test.TestCase): + + def testSimpleShapes(self): + with self.test_session(): + a = np.random.rand(3) + b = np.random.rand(3) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertAllEqual([], dist.event_shape_tensor().eval()) + self.assertAllEqual([3], dist.batch_shape_tensor().eval()) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) + + def testComplexShapes(self): + with self.test_session(): + a = np.random.rand(3, 2, 2) + b = np.random.rand(3, 2, 2) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertAllEqual([], dist.event_shape_tensor().eval()) + self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval()) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) + + def testComplexShapesBroadcast(self): + with self.test_session(): + a = np.random.rand(3, 2, 2) + b = np.random.rand(2, 2) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertAllEqual([], dist.event_shape_tensor().eval()) + self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval()) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) + + def testAProperty(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual([1, 3], dist.concentration1.get_shape()) + self.assertAllClose(a, dist.concentration1.eval()) + + def testBProperty(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual([1, 3], dist.concentration0.get_shape()) + self.assertAllClose(b, dist.concentration0.eval()) + + def testPdfXProper(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = kumaraswamy_lib.Kumaraswamy(a, b, validate_args=True) + dist.prob([.1, .3, .6]).eval() + dist.prob([.2, .3, .5]).eval() + # Either condition can trigger. + with self.assertRaisesOpError("sample must be positive"): + dist.prob([-1., 0.1, 0.5]).eval() + with self.assertRaisesOpError("sample must be positive"): + dist.prob([0., 0.1, 0.5]).eval() + with self.assertRaisesOpError("sample must be no larger than `1`"): + dist.prob([.1, .2, 1.2]).eval() + + def testPdfTwoBatches(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [.5, .5] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2,), pdf.get_shape()) + + def testPdfTwoBatchesNontrivialX(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [.3, .7] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2,), pdf.get_shape()) + + def testPdfUniformZeroBatch(self): + with self.test_session(): + # This is equivalent to a uniform distribution + a = 1. + b = 1. + x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((5,), pdf.get_shape()) + + def testPdfAStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + a = [[1., 2]] + b = [[1., 2]] + x = [[.5, .5], [.3, .7]] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfAStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [[.5, .5], [.2, .8]] + pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfXStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [[.5, .5]] + pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfXStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [.5, .5] + pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testKumaraswamyMean(self): + with session.Session(): + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.mean().get_shape(), (3,)) + if not stats: + return + expected_mean = _kumaraswamy_moment(a, b, 1) + self.assertAllClose(expected_mean, dist.mean().eval()) + + def testKumaraswamyVariance(self): + with session.Session(): + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.variance().get_shape(), (3,)) + if not stats: + return + expected_variance = _kumaraswamy_moment(a, b, 2) - _kumaraswamy_moment( + a, b, 1)**2 + self.assertAllClose(expected_variance, dist.variance().eval()) + + def testKumaraswamyMode(self): + with session.Session(): + a = np.array([1.1, 2, 3]) + b = np.array([2., 4, 1.2]) + expected_mode = _kumaraswamy_mode(a, b) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.mode().get_shape(), (3,)) + self.assertAllClose(expected_mode, dist.mode().eval()) + + def testKumaraswamyModeInvalid(self): + with session.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + dist.mode().eval() + + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + dist.mode().eval() + + def testKumaraswamyModeEnableAllowNanStats(self): + with session.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=True) + + expected_mode = _kumaraswamy_mode(a, b) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, dist.mode().eval()) + + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=True) + + expected_mode = _kumaraswamy_mode(a, b) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, dist.mode().eval()) + + def testKumaraswamyEntropy(self): + with session.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = (1 - 1. / a) + ( + 1 - 1. / b) * _harmonic_number(b) + np.log(a * b) + self.assertAllClose(expected_entropy, dist.entropy().eval()) + + def testKumaraswamySample(self): + with self.test_session(): + a = 1. + b = 2. + kumaraswamy = kumaraswamy_lib.Kumaraswamy(a, b) + n = constant_op.constant(100000) + samples = kumaraswamy.sample(n) + sample_values = samples.eval() + self.assertEqual(sample_values.shape, (100000,)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + self.assertLess( + stats.kstest( + # Kumaraswamy is a univariate distribution. + sample_values, + lambda x: _kumaraswamy_cdf(1., 2., x))[0], + 0.01) + # The standard error of the sample mean is 1 / (sqrt(18 * n)) + expected_mean = _kumaraswamy_moment(a, b, 1) + self.assertAllClose(sample_values.mean(axis=0), expected_mean, atol=1e-2) + expected_variance = _kumaraswamy_moment(a, b, 2) - _kumaraswamy_moment( + a, b, 1)**2 + self.assertAllClose( + np.cov(sample_values, rowvar=0), expected_variance, atol=1e-1) + + # Test that sampling with the same seed twice gives the same results. + def testKumaraswamySampleMultipleTimes(self): + with self.test_session(): + a_val = 1. + b_val = 2. + n_val = 100 + + random_seed.set_random_seed(654321) + kumaraswamy1 = kumaraswamy_lib.Kumaraswamy( + concentration1=a_val, concentration0=b_val, name="kumaraswamy1") + samples1 = kumaraswamy1.sample(n_val, seed=123456).eval() + + random_seed.set_random_seed(654321) + kumaraswamy2 = kumaraswamy_lib.Kumaraswamy( + concentration1=a_val, concentration0=b_val, name="kumaraswamy2") + samples2 = kumaraswamy2.sample(n_val, seed=123456).eval() + + self.assertAllClose(samples1, samples2) + + def testKumaraswamySampleMultidimensional(self): + with self.test_session(): + a = np.random.rand(3, 2, 2).astype(np.float32) + b = np.random.rand(3, 2, 2).astype(np.float32) + kumaraswamy = kumaraswamy_lib.Kumaraswamy(a, b) + n = constant_op.constant(100000) + samples = kumaraswamy.sample(n) + sample_values = samples.eval() + self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + self.assertAllClose( + sample_values[:, 1, :].mean(axis=0), + _kumaraswamy_moment(a, b, 1)[1, :], + atol=1e-1) + + def testKumaraswamyCdf(self): + with self.test_session(): + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = kumaraswamy_lib.Kumaraswamy(a, b).cdf(x).eval() + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + if not stats: + return + self.assertAllClose( + _kumaraswamy_cdf(a, b, x), actual, rtol=1e-4, atol=0) + + def testKumaraswamyLogCdf(self): + with self.test_session(): + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = math_ops.exp(kumaraswamy_lib.Kumaraswamy(a, + b).log_cdf(x)).eval() + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + if not stats: + return + self.assertAllClose( + _kumaraswamy_cdf(a, b, x), actual, rtol=1e-4, atol=0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index 1e514fe0ff21cd53c8c235da417890773db50c37..02064891758a86c5108e11da6a3666f2d5c56c64 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -107,7 +107,7 @@ def _test_capture_normal_sample_outputs(): ds.Normal._call_sample_n = true_normal_call_sample_n -def make_univariate_mixture(batch_shape, num_components): +def make_univariate_mixture(batch_shape, num_components, use_static_graph): batch_shape = ops.convert_to_tensor(batch_shape, dtypes.int32) logits = random_ops.random_uniform( array_ops.concat((batch_shape, [num_components]), axis=0), @@ -119,11 +119,11 @@ def make_univariate_mixture(batch_shape, num_components): for _ in range(num_components) ] cat = ds.Categorical(logits, dtype=dtypes.int32) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=use_static_graph) def make_multivariate_mixture(batch_shape, num_components, event_shape, - batch_shape_tensor=None): + use_static_graph, batch_shape_tensor=None): if batch_shape_tensor is None: batch_shape_tensor = batch_shape batch_shape_tensor = ops.convert_to_tensor(batch_shape_tensor, dtypes.int32) @@ -145,15 +145,17 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape, loc=loc, scale_diag=scale_diag) components = [create_component() for _ in range(num_components)] cat = ds.Categorical(logits, dtype=dtypes.int32) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=use_static_graph) class MixtureTest(test.TestCase): + use_static_graph = False def testShapes(self): with self.test_session(): for batch_shape in ([], [1], [2, 3, 4]): - dist = make_univariate_mixture(batch_shape, num_components=10) + dist = make_univariate_mixture(batch_shape, num_components=10, + use_static_graph=self.use_static_graph) self.assertAllEqual(batch_shape, dist.batch_shape) self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) self.assertAllEqual([], dist.event_shape) @@ -161,7 +163,8 @@ class MixtureTest(test.TestCase): for event_shape in ([1], [2]): dist = make_multivariate_mixture( - batch_shape, num_components=10, event_shape=event_shape) + batch_shape, num_components=10, event_shape=event_shape, + use_static_graph=self.use_static_graph) self.assertAllEqual(batch_shape, dist.batch_shape) self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) self.assertAllEqual(event_shape, dist.event_shape) @@ -172,7 +175,8 @@ class MixtureTest(test.TestCase): r"cat.num_classes != len"): ds.Mixture( ds.Categorical([0.1, 0.5]), # 2 classes - [ds.Normal(loc=1.0, scale=2.0)]) + [ds.Normal(loc=1.0, scale=2.0)], + use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch( ValueError, r"\(\) and \(2,\) are not compatible"): # The value error is raised because the batch shapes of the @@ -185,13 +189,15 @@ class MixtureTest(test.TestCase): loc=1.0, scale=2.0), # scalar dist ds.Normal( loc=[1.0, 1.0], scale=[2.0, 2.0]) - ]) + ], + use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"): cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32) ds.Mixture( ds.Categorical(cat_logits), [ds.Normal( - loc=[1.0], scale=[2.0])]) + loc=[1.0], scale=[2.0])], + use_static_graph=self.use_static_graph) def testBrokenShapesDynamic(self): with self.test_session(): @@ -203,29 +209,37 @@ class MixtureTest(test.TestCase): loc=d0_param, scale=d0_param), ds.Normal( loc=d1_param, scale=d1_param) ], - validate_args=True) - with self.assertRaisesOpError(r"batch shape must match"): + validate_args=True, + use_static_graph=self.use_static_graph) + + if self.use_static_graph: + error_string = r"Shapes of all inputs must match" + else: + error_string = r"batch shape must match" + + with self.assertRaisesOpError(error_string): d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: [1.0]}) - with self.assertRaisesOpError(r"batch shape must match"): + with self.assertRaisesOpError(error_string): d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: 1.0}) def testBrokenTypes(self): with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"): - ds.Mixture(None, []) + ds.Mixture(None, [], use_static_graph=self.use_static_graph) cat = ds.Categorical([0.3, 0.2]) # components must be a list of distributions with self.assertRaisesWithPredicateMatch( TypeError, "all .* must be Distribution instances"): - ds.Mixture(cat, [None]) + ds.Mixture(cat, [None], use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"): ds.Mixture( cat, [ ds.Normal(loc=[1.0], scale=[2.0]), ds.Normal(loc=[np.float16(1.0)], scale=[np.float16(2.0)]), - ]) + ], use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"): - ds.Mixture(ds.Categorical([0.3, 0.2]), None) + ds.Mixture(ds.Categorical([0.3, 0.2]), None, + use_static_graph=self.use_static_graph) # TODO(ebrevdo): once distribution Domains have been added, add a # test to ensure that the domains of the distributions in a @@ -235,7 +249,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=2) + batch_shape=batch_shape, num_components=2, + use_static_graph=self.use_static_graph) mean = dist.mean() self.assertEqual(batch_shape, mean.get_shape()) @@ -256,7 +271,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_multivariate_mixture( - batch_shape=batch_shape, num_components=2, event_shape=(4,)) + batch_shape=batch_shape, num_components=2, event_shape=(4,), + use_static_graph=self.use_static_graph) mean = dist.mean() self.assertEqual(batch_shape + (4,), mean.get_shape()) @@ -283,7 +299,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=num_components) + batch_shape=batch_shape, num_components=num_components, + use_static_graph=self.use_static_graph) dev = dist.stddev() self.assertEqual(batch_shape, dev.get_shape()) @@ -325,7 +342,8 @@ class MixtureTest(test.TestCase): dist = make_multivariate_mixture( batch_shape=batch_shape, num_components=num_components, - event_shape=(4,)) + event_shape=(4,), + use_static_graph=self.use_static_graph) dev = dist.stddev() self.assertEqual(batch_shape + (4,), dev.get_shape()) @@ -371,7 +389,8 @@ class MixtureTest(test.TestCase): scale=component_devs[0]), ds.Normal(loc=component_means[1], scale=component_devs[1]), - ]) + ], + use_static_graph=self.use_static_graph) mix_dev = mixture_dist.stddev() with self.test_session() as sess: actual_stddev = sess.run(mix_dev) @@ -379,7 +398,8 @@ class MixtureTest(test.TestCase): def testProbScalarUnivariate(self): with self.test_session() as sess: - dist = make_univariate_mixture(batch_shape=[], num_components=2) + dist = make_univariate_mixture(batch_shape=[], num_components=2, + use_static_graph=self.use_static_graph) for x in [ np.array( [1.0, 2.0], dtype=np.float32), np.array( @@ -405,7 +425,8 @@ class MixtureTest(test.TestCase): def testProbScalarMultivariate(self): with self.test_session() as sess: dist = make_multivariate_mixture( - batch_shape=[], num_components=2, event_shape=[3]) + batch_shape=[], num_components=2, event_shape=[3], + use_static_graph=self.use_static_graph) for x in [ np.array( [[-1.0, 0.0, 1.0], [0.5, 1.0, -0.3]], dtype=np.float32), np.array( @@ -432,7 +453,8 @@ class MixtureTest(test.TestCase): def testProbBatchUnivariate(self): with self.test_session() as sess: - dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2) + dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2, + use_static_graph=self.use_static_graph) for x in [ np.random.randn(2, 3).astype(np.float32), @@ -459,7 +481,8 @@ class MixtureTest(test.TestCase): def testProbBatchMultivariate(self): with self.test_session() as sess: dist = make_multivariate_mixture( - batch_shape=[2, 3], num_components=2, event_shape=[4]) + batch_shape=[2, 3], num_components=2, event_shape=[4], + use_static_graph=self.use_static_graph) for x in [ np.random.randn(2, 3, 4).astype(np.float32), @@ -487,7 +510,8 @@ class MixtureTest(test.TestCase): num_components = 3 batch_shape = [] dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=num_components) + batch_shape=batch_shape, num_components=num_components, + use_static_graph=self.use_static_graph) n = 4 with _test_capture_normal_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -502,7 +526,10 @@ class MixtureTest(test.TestCase): which_c = np.where(cat_sample_values == c)[0] size_c = which_c.size # Scalar Batch univariate case: batch_size == 1, rank 1 - which_dist_samples = dist_sample_values[c][:size_c] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c] + else: + which_dist_samples = dist_sample_values[c][:size_c] self.assertAllClose(which_dist_samples, sample_values[which_c]) # Test that sampling with the same seed twice gives the same results. @@ -522,7 +549,8 @@ class MixtureTest(test.TestCase): ] cat = ds.Categorical( logits, dtype=dtypes.int32, name="cat1") - dist1 = ds.Mixture(cat, components, name="mixture1") + dist1 = ds.Mixture(cat, components, name="mixture1", + use_static_graph=self.use_static_graph) samples1 = dist1.sample(n, seed=123456).eval() random_seed.set_random_seed(654321) @@ -532,7 +560,8 @@ class MixtureTest(test.TestCase): ] cat2 = ds.Categorical( logits, dtype=dtypes.int32, name="cat2") - dist2 = ds.Mixture(cat2, components2, name="mixture2") + dist2 = ds.Mixture(cat2, components2, name="mixture2", + use_static_graph=self.use_static_graph) samples2 = dist2.sample(n, seed=123456).eval() self.assertAllClose(samples1, samples2) @@ -541,7 +570,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: num_components = 3 dist = make_multivariate_mixture( - batch_shape=[], num_components=num_components, event_shape=[2]) + batch_shape=[], num_components=num_components, event_shape=[2], + use_static_graph=self.use_static_graph) n = 4 with _test_capture_mvndiag_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -555,14 +585,18 @@ class MixtureTest(test.TestCase): which_c = np.where(cat_sample_values == c)[0] size_c = which_c.size # Scalar Batch multivariate case: batch_size == 1, rank 2 - which_dist_samples = dist_sample_values[c][:size_c, :] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c, :] + else: + which_dist_samples = dist_sample_values[c][:size_c, :] self.assertAllClose(which_dist_samples, sample_values[which_c, :]) def testSampleBatchUnivariate(self): with self.test_session() as sess: num_components = 3 dist = make_univariate_mixture( - batch_shape=[2, 3], num_components=num_components) + batch_shape=[2, 3], num_components=num_components, + use_static_graph=self.use_static_graph) n = 4 with _test_capture_normal_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -576,8 +610,12 @@ class MixtureTest(test.TestCase): which_c_s, which_c_b0, which_c_b1 = np.where(cat_sample_values == c) size_c = which_c_s.size # Batch univariate case: batch_size == [2, 3], rank 3 - which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, - which_c_b1] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c_s, which_c_b0, + which_c_b1] + else: + which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, + which_c_b1] self.assertAllClose(which_dist_samples, sample_values[which_c_s, which_c_b0, which_c_b1]) @@ -594,7 +632,8 @@ class MixtureTest(test.TestCase): dist = make_multivariate_mixture( batch_shape=batch_shape, num_components=num_components, event_shape=[4], - batch_shape_tensor=batch_shape_tensor) + batch_shape_tensor=batch_shape_tensor, + use_static_graph=self.use_static_graph) n = 5 with _test_capture_mvndiag_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -617,8 +656,12 @@ class MixtureTest(test.TestCase): which_c_s, which_c_b0, which_c_b1 = np.where(cat_sample_values == c) size_c = which_c_s.size # Batch univariate case: batch_size == [2, 3], rank 4 (multivariate) - which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, - which_c_b1, :] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c_s, which_c_b0, + which_c_b1, :] + else: + which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, + which_c_b1, :] self.assertAllClose(which_dist_samples, sample_values[which_c_s, which_c_b0, which_c_b1, :]) @@ -632,7 +675,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_multivariate_mixture( - batch_shape=batch_shape, num_components=2, event_shape=(4,)) + batch_shape=batch_shape, num_components=2, event_shape=(4,), + use_static_graph=self.use_static_graph) entropy_lower_bound = dist.entropy_lower_bound() self.assertEqual(batch_shape, entropy_lower_bound.get_shape()) @@ -673,7 +717,8 @@ class MixtureTest(test.TestCase): cat_tf = ds.Categorical(probs=mixture_weights) components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf, + use_static_graph=self.use_static_graph) x_tensor = array_ops.placeholder(shape=(), dtype=dtypes.float32) @@ -721,7 +766,8 @@ class MixtureTest(test.TestCase): cat_tf = ds.Categorical(probs=mixture_weights) components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf, + use_static_graph=self.use_static_graph) x_tensor = array_ops.placeholder(shape=psize, dtype=dtypes.float32) xs_to_check = [ @@ -760,12 +806,18 @@ class MixtureTest(test.TestCase): gm = ds.Mixture( cat=ds.Categorical(probs=[.3, .7]), components=[ds.Gamma(1., 2.), - ds.Gamma(2., 1.)]) + ds.Gamma(2., 1.)], + use_static_graph=self.use_static_graph) x_ = gm.sample().eval() self.assertAllEqual([], x_.shape) +class MixtureStaticSampleTest(MixtureTest): + use_static_graph = True + + class MixtureBenchmark(test.Benchmark): + use_static_graph = False def _runSamplingBenchmark(self, name, create_distribution, use_gpu, num_components, batch_size, num_features, @@ -811,7 +863,7 @@ class MixtureBenchmark(test.Benchmark): components = list( ds.MultivariateNormalDiag( loc=mu, scale_diag=sigma) for (mu, sigma) in zip(mus, sigmas)) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=self.use_static_graph) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -853,7 +905,7 @@ class MixtureBenchmark(test.Benchmark): ds.MultivariateNormalTriL( loc=mu, scale_tril=linalg_ops.cholesky(sigma)) for (mu, sigma) in zip(mus, sigmas)) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=self.use_static_graph) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -872,5 +924,9 @@ class MixtureBenchmark(test.Benchmark): sample_size=sample_size) +class MixtureStaticSampleBenchmark(MixtureBenchmark): + use_static_graph = True + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index dc8ae1eed19eda772219287d8661f534ac242d10..5251dbcb5748f75688aa43ce6e4e9dbd76be78bb 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -237,6 +237,11 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): return y event_size = array_ops.shape(x)[-1] + # If the event size is available at graph construction time, we can inform + # the graph compiler of the maximum number of steps. If not, + # static_event_size will be None, and the maximum_iterations argument will + # have no effect. + static_event_size = x.shape.with_rank_at_least(1)[-1].value y0 = array_ops.zeros_like(x, name="y0") # call the template once to ensure creation _ = self._shift_and_log_scale_fn(y0) @@ -258,7 +263,8 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): _, y = control_flow_ops.while_loop( cond=lambda index, _: index < event_size, body=_loop_body, - loop_vars=[0, y0]) + loop_vars=(0, y0), + maximum_iterations=static_event_size) return y def _inverse(self, y): diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index a4d249d41ec9733721a3583d3708e0da56db1733..289e1d50e1146a641c0cc433ece3465aed73b1c2 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import linalg +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -442,6 +443,44 @@ def maybe_check_scalar_distribution( return assertions +def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, + event_ndims): + """Pad dimensions of event tensors for mixture distributions. + + See `Mixture._sample_n` and `MixtureSameFamily._sample_n` for usage examples. + + Args: + x: event tensor to pad. + mixture_distribution: Base distribution of the mixture. + categorical_distribution: `Categorical` distribution that mixes the base + distribution. + event_ndims: Integer specifying the number of event dimensions in the event + tensor. + + Returns: + A padded version of `x` that can broadcast with `categorical_distribution`. + """ + with ops.name_scope("pad_mix_dims", values=[x]): + def _get_ndims(d): + if d.batch_shape.ndims is not None: + return d.batch_shape.ndims + return array_ops.shape(d.batch_shape_tensor())[0] + dist_batch_ndims = _get_ndims(mixture_distribution) + cat_batch_ndims = _get_ndims(categorical_distribution) + pad_ndims = array_ops.where( + categorical_distribution.is_scalar_batch(), + dist_batch_ndims, + dist_batch_ndims - cat_batch_ndims) + s = array_ops.shape(x) + x = array_ops.reshape(x, shape=array_ops.concat([ + s[:-1], + array_ops.ones([pad_ndims], dtype=dtypes.int32), + s[-1:], + array_ops.ones([event_ndims], dtype=dtypes.int32), + ], axis=0)) + return x + + def static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tensor_util.constant_value(ops.convert_to_tensor(x)) diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py new file mode 100644 index 0000000000000000000000000000000000000000..74d5d8773cf3e69a52554c87d656fea2835c8354 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -0,0 +1,258 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Kumaraswamy distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops.distributions import beta +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export + +__all__ = [ + "Kumaraswamy", +] + +_kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in +`[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" + + +def _harmonic_number(x): + """Compute the harmonic number from its analytic continuation. + + Derivation from [1] and Euler's constant [2]. + [1] - + https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers + [2] - https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant + + + Args: + x: input float. + + Returns: + z: The analytic continuation of the harmonic number for the input. + + """ + one = array_ops.ones([], dtype=x.dtype) + return math_ops.digamma(x + one) - math_ops.digamma(one) + + +@tf_export("distributions.Kumaraswamy") +class Kumaraswamy(beta.Beta): + """Kumaraswamy distribution. + + The Kumaraswamy distribution is defined over the `(0, 1)` interval using + parameters + `concentration1` (aka "alpha") and `concentration0` (aka "beta"). It has a + shape similar to the Beta distribution, but is reparameterizeable. + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; alpha, beta) = alpha * beta * x**(alpha - 1) * (1 - x**alpha)**(beta - + 1) + ``` + + where: + + * `concentration1 = alpha`, + * `concentration0 = beta`, + + Distribution parameters are automatically broadcast in all functions; see + examples for details. + + #### Examples + + ```python + # Create a batch of three Kumaraswamy distributions. + alpha = [1, 2, 3] + beta = [1, 2, 3] + dist = Kumaraswamy(alpha, beta) + + dist.sample([4, 5]) # Shape [4, 5, 3] + + # `x` has three batch entries, each with two samples. + x = [[.1, .4, .5], + [.2, .3, .5]] + # Calculate the probability of each pair of samples under the corresponding + # distribution in `dist`. + dist.prob(x) # Shape [2, 3] + ``` + + ```python + # Create batch_shape=[2, 3] via parameter broadcast: + alpha = [[1.], [2]] # Shape [2, 1] + beta = [3., 4, 5] # Shape [3] + dist = Kumaraswamy(alpha, beta) + + # alpha broadcast as: [[1., 1, 1,], + # [2, 2, 2]] + # beta broadcast as: [[3., 4, 5], + # [3, 4, 5]] + # batch_Shape [2, 3] + dist.sample([4, 5]) # Shape [4, 5, 2, 3] + + x = [.2, .3, .5] + # x will be broadcast as [[.2, .3, .5], + # [.2, .3, .5]], + # thus matching batch_shape [2, 3]. + dist.prob(x) # Shape [2, 3] + ``` + + """ + + def __init__(self, + concentration1=None, + concentration0=None, + validate_args=False, + allow_nan_stats=True, + name="Kumaraswamy"): + """Initialize a batch of Kumaraswamy distributions. + + Args: + concentration1: Positive floating-point `Tensor` indicating mean + number of successes; aka "alpha". Implies `self.dtype` and + `self.batch_shape`, i.e., + `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. + concentration0: Positive floating-point `Tensor` indicating mean + number of failures; aka "beta". Otherwise has same semantics as + `concentration1`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + super(Kumaraswamy, self).__init__( + concentration1=concentration1, + concentration0=concentration0, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=name) + self._reparameterization_type = distribution.FULLY_REPARAMETERIZED + + def _sample_n(self, n, seed=None): + expanded_concentration1 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration1 + expanded_concentration0 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration0 + shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + uniform_sample = random_ops.random_uniform( + shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed) + + kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**( + 1. / expanded_concentration1) + return kumaraswamy_sample + + @distribution_util.AppendDocstring(_kumaraswamy_sample_note) + def _log_cdf(self, x): + a = self.concentration1 + b = self.concentration0 + return math_ops.log1p(-(1 - x**a)**b) + + @distribution_util.AppendDocstring(_kumaraswamy_sample_note) + def _cdf(self, x): + a = self.concentration1 + b = self.concentration0 + return 1 - (1 - x**a)**b + + def _survival_function(self, x): + a = self.concentration1 + b = self.concentration0 + return (1 - x**a)**b + + def _log_survival_function(self, x): + a = self.concentration1 + b = self.concentration0 + return b * math_ops.log1p(-x**a) + + def _log_unnormalized_prob(self, x): + x = self._maybe_assert_valid_sample(x) + a = self.concentration1 + b = self.concentration0 + return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a) + + def _log_normalization(self): + a = self.concentration1 + b = self.concentration0 + return -(math_ops.log(a) + math_ops.log(b)) + + def _entropy(self): + a = self.concentration1 + b = self.concentration0 + return (1 - 1. / a) + ( + 1 - 1. / b) * _harmonic_number(b) + math_ops.log(a) + math_ops.log(b) + + def _moment(self, n): + """Compute the n'th (uncentered) moment.""" + expanded_concentration1 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration1 + expanded_concentration0 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration0 + beta_arg0 = 1 + n / expanded_concentration1 + beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1) + log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta( + beta_arg) + return math_ops.exp(log_moment) + + def _mean(self): + return self._moment(1) + + def _variance(self): + # TODO(b/72696533): Investigate a more numerically stable version. + return self._moment(2) - math_ops.square(self._moment(1)) + + @distribution_util.AppendDocstring( + """Note: The mode is undefined when `concentration1 <= 1` or + `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN` + is used for undefined modes. If `self.allow_nan_stats` is `False` an + exception is raised when one or more modes are undefined.""") + def _mode(self): + a = self.concentration1 + b = self.concentration0 + mode = ((a - 1) / (a * b - 1))**(1. / a) + if self.allow_nan_stats: + nan = array_ops.fill( + self.batch_shape_tensor(), + np.array(np.nan, dtype=self.dtype.as_numpy_dtype), + name="nan") + is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.) + return array_ops.where(is_defined, mode, nan) + return control_flow_ops.with_dependencies([ + check_ops.assert_less( + array_ops.ones([], dtype=self.dtype), + self.concentration1, + message="Mode undefined for concentration1 <= 1."), + check_ops.assert_less( + array_ops.ones([], dtype=self.dtype), + self.concentration0, + message="Mode undefined for concentration0 <= 1.") + ], mode) diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index f2d492f5489a197157558ae727416b51db04793e..cef6a143fc615901315a3780bf4ed53b8c7cd177 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -71,6 +71,7 @@ class Mixture(distribution.Distribution): components, validate_args=False, allow_nan_stats=True, + use_static_graph=False, name="Mixture"): """Initialize a Mixture distribution. @@ -96,6 +97,11 @@ class Mixture(distribution.Distribution): exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. + use_static_graph: Calls to `sample` will not rely on dynamic tensor + indexing, allowing for some static graph compilation optimizations, but + at the expense of sampling all underlying distributions in the mixture. + (Possibly useful when running on TPUs). + Default value: `False` (i.e., use dynamic indexing). name: A name for this distribution (optional). Raises: @@ -178,6 +184,10 @@ class Mixture(distribution.Distribution): self._static_event_shape = static_event_shape self._static_batch_shape = static_batch_shape + self._use_static_graph = use_static_graph + if use_static_graph and static_num_components is None: + raise ValueError("Number of categories must be known statically when " + "`static_sample=True`.") # We let the Mixture distribution access _graph_parents since its arguably # more like a baseclass. graph_parents = self._cat._graph_parents # pylint: disable=protected-access @@ -292,6 +302,31 @@ class Mixture(distribution.Distribution): return mixture_log_cdf def _sample_n(self, n, seed=None): + if self._use_static_graph: + # This sampling approach is almost the same as the approach used by + # `MixtureSameFamily`. The differences are due to having a list of + # `Distribution` objects rather than a single object, and maintaining + # random seed management that is consistent with the non-static code path. + samples = [] + cat_samples = self.cat.sample(n, seed=seed) + for c in range(self.num_components): + seed = distribution_util.gen_new_seed(seed, "mixture") + samples.append(self.components[c].sample(n, seed=seed)) + x = array_ops.stack( + samples, -self._static_event_shape.ndims - 1) # [n, B, k, E] + npdt = x.dtype.as_numpy_dtype + mask = array_ops.one_hot( + indices=cat_samples, # [n, B] + depth=self._num_components, # == k + on_value=np.ones([], dtype=npdt), + off_value=np.zeros([], dtype=npdt)) # [n, B, k] + mask = distribution_utils.pad_mixture_dimensions( + mask, self, self._cat, + self._static_event_shape.ndims) # [n, B, k, [1]*e] + return math_ops.reduce_sum( + x * mask, + axis=-1 - self._static_event_shape.ndims) # [n, B, E] + with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 49afbea7f05136674aa0c1441bd46548b7b55c8f..b93bdc5ab4010663baddda1410b302644853648b 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.framework import dtypes +from tensorflow.contrib.distributions.python.ops import distribution_util as distribution_utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -239,7 +239,9 @@ class MixtureSameFamily(distribution.Distribution): depth=self._num_components, # == k on_value=np.ones([], dtype=npdt), off_value=np.zeros([], dtype=npdt)) # [n, B, k] - mask = self._pad_mix_dims(mask) # [n, B, k, [1]*e] + mask = distribution_utils.pad_mixture_dimensions( + mask, self, self.mixture_distribution, + self._event_shape().ndims) # [n, B, k, [1]*e] return math_ops.reduce_sum( x * mask, axis=-1 - self._event_ndims) # [n, B, E] @@ -254,8 +256,9 @@ class MixtureSameFamily(distribution.Distribution): def _mean(self): with ops.control_dependencies(self._runtime_assertions): - probs = self._pad_mix_dims( - self.mixture_distribution.probs) # [B, k, [1]*e] + probs = distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, [1]*e] return math_ops.reduce_sum( probs * self.components_distribution.mean(), axis=-1 - self._event_ndims) # [B, E] @@ -271,8 +274,9 @@ class MixtureSameFamily(distribution.Distribution): def _variance(self): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) - probs = self._pad_mix_dims( - self.mixture_distribution.probs) # [B, k, [1]*e] + probs = distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, [1]*e] mean_cond_var = math_ops.reduce_sum( probs * self.components_distribution.variance(), axis=-1 - self._event_ndims) # [B, E] @@ -291,8 +295,12 @@ class MixtureSameFamily(distribution.Distribution): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) - probs = self._pad_mix_dims(self._pad_mix_dims( - self.mixture_distribution.probs)) # [B, k, 1, 1] + probs = distribution_utils.pad_mixture_dimensions( + distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims), + self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, 1, 1] mean_cond_var = math_ops.reduce_sum( probs * self.components_distribution.covariance(), axis=-3) # [B, e, e] @@ -312,27 +320,6 @@ class MixtureSameFamily(distribution.Distribution): shape[:d], [1], shape[d:]], axis=0)) return x - def _pad_mix_dims(self, x): - with ops.name_scope("pad_mix_dims", values=[x]): - def _get_ndims(d): - if d.batch_shape.ndims is not None: - return d.batch_shape.ndims - return array_ops.shape(d.batch_shape_tensor())[0] - dist_batch_ndims = _get_ndims(self) - cat_batch_ndims = _get_ndims(self.mixture_distribution) - pad_ndims = array_ops.where( - self.mixture_distribution.is_scalar_batch(), - dist_batch_ndims, - dist_batch_ndims - cat_batch_ndims) - s = array_ops.shape(x) - x = array_ops.reshape(x, shape=array_ops.concat([ - s[:-1], - array_ops.ones([pad_ndims], dtype=dtypes.int32), - s[-1:], - array_ops.ones([self._event_ndims], dtype=dtypes.int32), - ], axis=0)) - return x - def _outer_squared_difference(x, y): """Convenience function analogous to tf.squared_difference.""" diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index b6becfa9fc93f189a1a7bf7b2a7af8dc1f2e9720..2aa771a71efe52c8d86d459f090ea8ee137c4487 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -278,7 +278,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): * math_ops.log(self.temperature)) # compute the unnormalized density log_softmax = nn_ops.log_softmax(logits_2d - x_2d * self._temperature_2d) - log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keep_dims=False) + log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keepdims=False) # combine unnormalized density with normalization constant log_prob = log_norm_const + log_unnorm_prob # Reshapes log_prob to be consistent with shape of user-supplied logits diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index 09242ee47ddd044dfc99e22d5b7751a989c86485..9d2ca07c3a25fa7acb9b0f5806b763d9a57b51fa 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -41,28 +41,8 @@ support for distributed and multi-GPU training and CPU performance. ## Installation -Since eager execution is not yet part of a TensorFlow release, using it requires -either [building from source](https://www.tensorflow.org/install/install_sources) -or the latest nightly builds. The nightly builds are available as: - -- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and - -- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images. - -For example, to run the latest nightly docker image: - -```sh -# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker -nvidia-docker pull tensorflow/tensorflow:nightly-gpu -nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu - -# If you do not have a GPU, use the CPU-only image -docker pull tensorflow/tensorflow:nightly -docker run -it -p 8888:8888 tensorflow/tensorflow:nightly -``` - -And then visit http://localhost:8888 in your browser for a Jupyter notebook -environment. Try out the notebooks below. +Eager execution is included in TensorFlow versions 1.5 and above. +Installation instructions at https://www.tensorflow.org/install/ ## Documentation diff --git a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto index c962638aa11c06dcd5be6a794314e029ae84e572..024765acb28726fd102dfbf167f4e780072ce6e7 100644 --- a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto +++ b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto @@ -4,9 +4,9 @@ option cc_enable_arenas = true; package tensorflow.contrib.eager; -// Prototype for an addition to BundleHeaderProto which saves extra information -// about the objects which own variables, allowing for more robust checkpoint -// loading into modified programs. +// Prototype format which saves extra information about the objects which own +// variables, allowing for more robust checkpoint loading into modified +// programs. Currently stored in its own entry in a TensorBundle. message CheckpointableObjectGraph { message Object { @@ -14,40 +14,39 @@ message CheckpointableObjectGraph { // An index into `CheckpointableObjectGraph.nodes`, indicating the object // being referenced. int32 node_id = 1; - // A numeric identifier for this object within its parent. - int32 local_uid = 2; - // A user-provided name for the edge. May be blank/omitted, in which case - // there is no explicitly provided local name; fall back on local_uid. - string local_name = 3; + // A user-provided name for the edge. + string local_name = 2; } - message VariableReference { - // A name for the variable which is unique within the object which owns - // it. Does not include a name_scope or variable_scope prefix. - string local_name = 1; - // The full name of the variable. Used to allow name-based loading of - // checkpoints which were saved using an object-based API. + message SerializedTensor { + // A name for the Tensor. Simple variables have only one + // `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may + // be restored on object creation as an optimization. + string name = 1; + // The full name of the variable/tensor, if applicable. Used to allow + // name-based loading of checkpoints which were saved using an + // object-based API. Should match the checkpoint key which would have been + // assigned by tf.train.Saver. string full_name = 2; + // The generated name of the Tensor in the checkpoint. + string checkpoint_key = 3; } message SlotVariableReference { - // An index into `CheckpointableObjectGraph.nodes`, indicating the object - // which created the variable that this variable is slotting for. + // An index into `CheckpointableObjectGraph.nodes`, indicating the + // variable object this slot was created for. int32 original_variable_node_id = 1; - // The local name of the variable being slotted for within the object that - // owns it. - string original_variable_local_name = 2; // The name of the slot (e.g. "m"/"v"). - string slot_name = 3; - // The full name of the slot variable. Used to allow name-based loading of - // checkpoints which were saved using an object-based API. - string full_name = 4; + string slot_name = 2; + // An index into `CheckpointableObjectGraph.nodes`, indicating the + // `Object` with the value of the slot variable. + int32 slot_variable_node_id = 3; } // Objects which this object depends on. repeated ObjectReference children = 1; - // Non-slot variables owned by this object. - repeated VariableReference variables = 2; + // Serialized data specific to this object. + repeated SerializedTensor attributes = 2; // Slot variables owned by this object. repeated SlotVariableReference slot_variables = 3; } diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index e984c63af7ce2b32ab30121bf34bb2de4dfeb218..ad40e55cb48aac08eca7022846a0bd07b8accb3f 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -52,7 +52,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_py", + "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:errors", @@ -220,29 +220,42 @@ py_test( ) py_library( - name = "checkpointable", - srcs = ["checkpointable.py"], + name = "checkpointable_utils", + srcs = ["checkpointable_utils.py"], srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/contrib/eager/proto:checkpointable_object_graph_proto_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:tensor_shape", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", ], ) py_test( - name = "checkpointable_test", - srcs = ["checkpointable_test.py"], + name = "checkpointable_utils_test", + srcs = ["checkpointable_utils_test.py"], srcs_version = "PY2AND3", deps = [ - ":checkpointable", + ":checkpointable_utils", ":network", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", "//tensorflow/python:layers", + "//tensorflow/python:layers_base", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py deleted file mode 100644 index b141ffb2bc03b8e38f8481bc044c3aae7e156c15..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/checkpointable.py +++ /dev/null @@ -1,392 +0,0 @@ -"""An object-local variable management scheme.""" -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import re - -from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 -from tensorflow.python.ops import variable_scope -from tensorflow.python.training import optimizer as optimizer_lib -from tensorflow.python.training import saver as saver_lib - -_CheckpointableReference = collections.namedtuple( - "_CheckpointableReference", - [ - "name", # The local name if explicitly specified, else None. - "local_uid", # 0 for the first dependency, 1 for the next, ... Used for - # routing checkpointed variables to their correct - # Checkpointables when "name" is not set (see docstring of - # `track_checkpointable`). - "ref" # The Checkpointable object being referenced. - ]) - -_OwnedVariable = collections.namedtuple( - "_OwnedVariable", - [ - "name", # The variable's (local) name. - "variable" # The owned variable object. - ]) - -# Validation regular expression for the local names of Checkpointable -# objects. In particular, disallows "/" in names, and reserves -# underscore-prefixed names. -_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.-]*$") - -# Keyword for identifying that the next bit of a checkpoint variable name is a -# slot name. May not be the local name of a checkpointable. Checkpoint names for -# slot variables look like: -# -# /<_OPTIMIZER_SLOTS_NAME>// -# -# Where is a full path from the checkpoint root to the -# variable being slotted for. -_OPTIMIZER_SLOTS_NAME = "_OPTIMIZER_SLOT" - - -class Checkpointable(object): - """Manages variables and dependencies on other objects. - - To make reliable checkpoints, all `Checkpointable`s on which this object - depends must be registered in the constructor using `track_checkpointable` in - a deterministic order, and if possible they should be named. Variables may be - created using `add_variable` outside of the constructor and in any order, but - only these variables will be saved. - """ - - def __init__(self): - # Basically less useful OrderedDicts but without the reference cycles. - # TODO(allenl): Switch these to OrderedDict once TensorFlow supports only - # Python 3.6+. - self._checkpoint_dependencies = [] # A list of _CheckpointableReference - # objects. - self._dependency_names = set() - self._owned_variables = [] # A list of _OwnedVariable objects. - self._owned_variable_names = set() - - def add_variable(self, name, shape, dtype=None, initializer=None, **kwargs): - """Create a new variable object to be saved with this `Checkpointable`. - - If the user has requested that this object or another `Checkpointable` which - depends on this object be restored from a checkpoint (deferred loading - before variable object creation), `initializer` may be ignored and the value - from the checkpoint used instead. - - Args: - name: A name for the variable. Must be unique within this object. - shape: The shape of the variable. - dtype: The data type of the variable. - initializer: The initializer to use. Ignored if deferred loading has been - requested. - **kwargs: Passed to get_variable. - - Returns: - The new variable object. - - Raises: - ValueError: If the variable name is not unique. - """ - if name in self._owned_variable_names: - raise ValueError( - ("A variable named '%s' already exists in this Checkpointable, but " - "Checkpointable.add_variable called to create another with " - "that name. Variable names must be unique within a Checkpointable " - "object.") % (name,)) - if "getter" in kwargs: - # Allow the getter to be overridden, typically because there is a need for - # compatibility with some other variable creation mechanism. This should - # be relatively uncommon in user code. - getter = kwargs.pop("getter") - else: - getter = variable_scope.get_variable - # TODO(allenl): handle deferred loading - new_variable = getter( - name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs) - self._owned_variables.append( - _OwnedVariable(name=name, variable=new_variable)) - self._owned_variable_names.add(name) - return new_variable - - def track_checkpointable(self, checkpointable, name=None): - """Declare a dependency on another `Checkpointable` object. - - Indicates that checkpoints for this object should include variables from - `checkpointable`. - - Variables in a checkpoint are mapped to `Checkpointable`s based on names if - provided when the checkpoint was written, but otherwise use the order those - `Checkpointable`s were declared as dependencies. Both `name` arguments and - the dependency declaration order should be deterministic. - - There are two sufficient conditions to avoid breaking existing checkpoints - when modifying a class: (1) New dependencies must be declared after existing - dependencies, and (2) dependencies which were previously declared may never - be removed (a trivial placeholder with the same name may be used instead). - - Args: - checkpointable: A `Checkpointable` which this object depends on. - name: A local name for `checkpointable`, used for loading checkpoints into - the correct objects. If provided, it must be unique within this - `Checkpointable`. If None, dependency declaration order is used instead. - - Returns: - `checkpointable`, for convenience when declaring a dependency and - assigning to a member variable in one statement. - - Raises: - RuntimeError: If __init__ was not called. - TypeError: If `checkpointable` does not inherit from `Checkpointable`. - ValueError: For invalid names. - """ - if not hasattr(self, "_checkpoint_dependencies"): - raise RuntimeError("Need to call Checkpointable.__init__ before calling " - "Checkpointable.track_checkpointable().") - if not isinstance(checkpointable, Checkpointable): - raise TypeError( - ("Checkpointable.track_checkpointable() passed type %s, not a " - "Checkpointable.") % (type(checkpointable),)) - if name is not None: - if not _VALID_LOCAL_NAME.match(name): - raise ValueError( - ("Checkpointable names must match the regular expression '%s', but " - "got an invalid name '%s' instead.") % (_VALID_LOCAL_NAME.pattern, - name)) - if name in self._dependency_names: - raise ValueError( - ("Called Checkpointable.track_checkpointable() with name='%s', but " - "a Checkpointable with this name is already declared as a " - "dependency. If provided, names must be unique.") % (name,)) - self._dependency_names.add(name) - self._checkpoint_dependencies.append( - _CheckpointableReference( - name=name, - ref=checkpointable, - # TODO(allenl): Should this be exposed to allow users to stop - # depending on things and still load checkpoints when not using - # names? - local_uid=len(self._checkpoint_dependencies))) - return checkpointable - - @property - def checkpoint_dependencies(self): - """Other `Checkpointable` objects on which this object depends.""" - return self._checkpoint_dependencies - - -def _breadth_first_checkpointable_traversal(root_checkpointable): - """Find shortest paths to all variables owned by dependencies of root.""" - bfs_sorted = [] - root_checkpointable_reference = _CheckpointableReference( - name=None, local_uid=0, ref=root_checkpointable) - to_visit = collections.deque([root_checkpointable_reference]) - path_to_root = {root_checkpointable_reference: ()} - while to_visit: - current_checkpointable = to_visit.popleft() - bfs_sorted.append(current_checkpointable) - for child_checkpointable in ( - current_checkpointable.ref.checkpoint_dependencies): - if child_checkpointable not in path_to_root: - path_to_root[child_checkpointable] = ( - path_to_root[current_checkpointable] + (child_checkpointable,)) - to_visit.append(child_checkpointable) - return bfs_sorted, path_to_root - - -def _object_prefix_from_path(path_to_root): - return "/".join((checkpointable.name if checkpointable.name else "_%d" % ( - checkpointable.local_uid,)) for checkpointable in path_to_root) - - -def _escape_variable_name(variable_name): - # We need to support slashes in variable names for compatibility, since this - # naming scheme is being patched in to things like Layer.add_variable where - # slashes were previously accepted. We also want to use slashes to indicate - # edges traversed to reach the variable, so we escape forward slashes in - # variable names. - return variable_name.replace("_S_", "_S_.").replace(r"/", r"_S__") - - -def _variable_naming_for_object(path_to_root): - """Make a function for naming variables in an object.""" - # Name non-slot variables: - # - # / - # - # is not necessarily unique, but this is fine since we also - # save the graph of `Checkpointable`s with the checkpoint. Even if this path - # no longer exists because of a change in the Python program, we can look up - # the `Checkpointable` which owns the variable in the checkpoint's graph and - # use another path if one still exists. - - object_prefix = _object_prefix_from_path(path_to_root) - if object_prefix: - object_prefix += "/" - - def _name_single_variable(owned_variable): - """Names a variable within an object.""" - return object_prefix + _escape_variable_name(owned_variable.name) - - return _name_single_variable - - -def _slot_variable_naming_for_optimizer(optimizer, path_to_root): - """Make a function for naming slot variables in an optimizer.""" - # Name slot variables: - # - # /<_OPTIMIZER_SLOTS_NAME>// - # - # where is exactly the checkpoint name used for the original - # variable, including the path from the checkpoint root and the local name in - # the object which owns it. Note that we only save slot variables if the - # variable it's slotting for is also being saved. - - optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, - _object_prefix_from_path(path_to_root)) - - def _name_slot_variable(variable_path, slot_name): - """With an optimizer specified, name a slot variable.""" - - if not _VALID_LOCAL_NAME.match(slot_name): - # Slot variable names include the name of the slot. We need to - # validate that part of the name to be sure that the checkpoint name - # is a valid name scope name. - raise ValueError( - ("Could not save slot variables for optimizer %s, because its " - "slot name has invalid characters (got '%s', was expecting it " - "to match the regular expression '%s').") % - (optimizer, slot_name, _VALID_LOCAL_NAME.pattern)) - - return variable_path + optimizer_identifier + slot_name - - return _name_slot_variable - - -def _serialize_non_slot_variables(checkpointable_objects, path_to_root, - object_graph_proto): - """Name non-slot variables and add them to `object_graph_proto`.""" - named_variables = {} - non_slot_variables = [] - checkpoint_node_ids = {} - - for checkpoint_id, checkpointable in enumerate(checkpointable_objects): - checkpoint_node_ids[checkpointable] = checkpoint_id - - for checkpoint_id, checkpointable in enumerate(checkpointable_objects): - naming_scheme = _variable_naming_for_object(path_to_root[checkpointable]) - object_proto = object_graph_proto.nodes.add() - for owned_variable in checkpointable.ref._owned_variables: # pylint: disable=protected-access - variable_name = naming_scheme(owned_variable) - named_variables[variable_name] = owned_variable.variable - non_slot_variables.append(( - variable_name, # The variable's full checkpoint name - owned_variable, # The variable's _OwnedVariable object - checkpoint_id)) # The checkpoint ID of the node which owns this - # variable. - variable_proto = object_proto.variables.add() - variable_proto.local_name = owned_variable.name - # Figure out the name-based Saver's name for this variable. - saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( - [owned_variable.variable], convert_variable_to_tensor=False) - variable_full_name, = saver_dict.keys() - variable_proto.full_name = variable_full_name - - for child in checkpointable.ref.checkpoint_dependencies: - child_proto = object_proto.children.add() - child_proto.node_id = checkpoint_node_ids[child] - child_proto.local_uid = child.local_uid - if child.name is not None: - child_proto.local_name = child.name - return named_variables, non_slot_variables - - -def _serialize_slot_variables(checkpointable_objects, path_to_root, - non_slot_variables, object_graph_proto): - """Name slot variables and add them to `object_graph_proto`.""" - named_slot_variables = {} - for optimizer_checkpoint_id, checkpointable_ref in enumerate( - checkpointable_objects): - if isinstance(checkpointable_ref.ref, optimizer_lib.Optimizer): - optimizer_object_proto = object_graph_proto.nodes[optimizer_checkpoint_id] - naming_scheme = _slot_variable_naming_for_optimizer( - optimizer=checkpointable_ref.ref, - path_to_root=path_to_root[checkpointable_ref]) - slot_names = checkpointable_ref.ref.get_slot_names() - for (variable_path, owned_variable, - original_node_checkpoint_id) in non_slot_variables: - for slot_name in slot_names: - slot_variable = checkpointable_ref.ref.get_slot( - owned_variable.variable, slot_name) - if slot_variable is not None: - checkpoint_name = naming_scheme( - variable_path=variable_path, slot_name=slot_name) - named_slot_variables[checkpoint_name] = slot_variable - slot_variable_proto = optimizer_object_proto.slot_variables.add() - slot_variable_proto.slot_name = slot_name - # Figure out the name-based Saver's name for this variable. - saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( - [slot_variable], convert_variable_to_tensor=False) - slot_variable_full_name, = saver_dict.keys() - slot_variable_proto.full_name = slot_variable_full_name - slot_variable_proto.original_variable_local_name = ( - owned_variable.name) - slot_variable_proto.original_variable_node_id = ( - original_node_checkpoint_id) - return named_slot_variables - - -# TODO(allenl): Convenience utility for saving multiple objects (i.e. construct -# a root Checkpointable if passed a list of Checkpointables). -def _serialize_object_graph(root_checkpointable): - """Determine checkpoint keys for variables and build a serialized graph. - - Non-slot variables are keyed based on a shortest path from the root saveable - to the object which owns the variable (i.e. the one which called - `Checkpointable.add_variable` to create it). - - Slot variables are keyed based on a shortest path to the variable being - slotted for, a shortest path to their optimizer, and the slot name. - - Args: - root_checkpointable: A `Checkpointable` object whose variables (including - the variables of dependencies, recursively) should be saved. - - Returns: - A tuple of (named_variables, object_graph_proto): - named_variables: A dictionary mapping names to variable objects. - object_graph_proto: A CheckpointableObjectGraph protocol buffer containing - the serialized object graph and variable references. - - Raises: - ValueError: If there are invalid characters in an optimizer's slot names. - """ - checkpointable_objects, path_to_root = ( - _breadth_first_checkpointable_traversal(root_checkpointable)) - object_graph_proto = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - - # Gather non-slot variables. - named_variables, non_slot_variables = _serialize_non_slot_variables( - checkpointable_objects, path_to_root, object_graph_proto) - - # Gather slot variables which are associated with variables gathered above. - named_slot_variables = _serialize_slot_variables( - checkpointable_objects, path_to_root, non_slot_variables, - object_graph_proto) - - named_variables.update(named_slot_variables) - return named_variables, object_graph_proto diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py deleted file mode 100644 index ff419614f580d3bace9d99648478cc2204d7801d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/checkpointable_test.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import six - -from tensorflow.contrib.eager.python import checkpointable -from tensorflow.contrib.eager.python import network as network_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.layers import core -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.training import adam -from tensorflow.python.training import training_util - - -class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): - - def __init__(self, *args, **kwargs): - checkpointable.Checkpointable.__init__(self) - core.Dense.__init__(self, *args, **kwargs) - - def add_variable(self, name, shape, **kwargs): - # Calls both Checkpointable.add_variable and Layer.add_variable. Eventually - # Layer.add_variable should inherit from Checkpointable and simply call - # super and then do post-processing. - return checkpointable.Checkpointable.add_variable( - self, - name=name, - shape=shape, - getter=functools.partial(core.Dense.add_variable, self), - **kwargs) - - -# pylint: disable=not-callable -class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): - - def __init__(self): - network_lib.Network.__init__(self) - checkpointable.Checkpointable.__init__(self) - - def track_layer(self, layer, name=None): - self.track_checkpointable(layer, name=name) - return super(CheckpointableNetwork, self).track_layer(layer) - - -class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): - - def __init__(self, *args, **kwargs): - checkpointable.Checkpointable.__init__(self) - adam.AdamOptimizer.__init__(self, *args, **kwargs) - - # NOTE: Copied from Optimizer with modifications to use add_variable - # for non-slot variables. These contortions are necessary to maintain - # checkpoint compatibility with variable.name based saving. - # TODO(allenl): Make this cleaner. - def _create_non_slot_variable(self, initial_value, name, colocate_with): - """Add an extra variable, not associated with a slot.""" - if context.in_graph_mode(): - graph = colocate_with.graph - else: - graph = None - - key = (name, graph) - v = self._non_slot_dict.get(key, None) - if v is None: - with ops.colocate_with(colocate_with): - def _variable_getter(name, shape, dtype, initializer): - del shape, dtype # not used, but there for compatibility - return variable_scope.variable( - name=name, initial_value=initializer, trainable=False) - - initial_value = ops.convert_to_tensor(initial_value) - v = self.add_variable( - name=name, - shape=initial_value.get_shape(), - initializer=initial_value, - getter=_variable_getter) - - self._non_slot_dict[key] = v - - return v - - # TODO(allenl): Override slot variable creation (_get_or_make_slot, - # _get_or_make_slot_with_initializer, _zeros_slot) to allow deferred - # loading. Likely no need to run this through add_variable, since gathering - # slot variables is special cased anyway. - - -class MyNetwork(CheckpointableNetwork): - """A concrete Network for testing.""" - - def __init__(self): - super(MyNetwork, self).__init__() - self._named = self.track_layer( - CheckpointableDenseLayer(1, use_bias=True), name="named_dense") - self._unnamed = self.track_layer( - CheckpointableDenseLayer(1, use_bias=False)) - - def call(self, values): - return self._unnamed(self._named(values)) - - -class Root(checkpointable.Checkpointable): - """A stand-in for a Trainer class.""" - - def __init__(self, optimizer, network): - super(Root, self).__init__() - self.track_checkpointable(optimizer, name="optimizer") - self.track_checkpointable(network, name="network") - self._global_step = None - - @property - def global_step(self): - if self._global_step is None: - # Get the default create_global_step utility to actually call - # self.add_variable, by setting a custom getter. - def _owned_variable_as_custom_getter(getter, *args, **kwargs): - return self.add_variable(*args, getter=getter, **kwargs) - - with variable_scope.variable_scope( - "", custom_getter=_owned_variable_as_custom_getter): - self._global_step = training_util.create_global_step() - return self._global_step - - -class CheckpointNamingTests(test.TestCase): - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testNamingWithOptimizer(self): - input_value = constant_op.constant([[3.]]) - network = MyNetwork() - # A nuisance Network using the same optimizer. Its slot variables should not - # go in the checkpoint, since it is never depended on. - other_network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root_checkpointable = Root(optimizer=optimizer, network=network) - if context.in_eager_mode(): - optimizer.minimize( - lambda: network(input_value), - global_step=root_checkpointable.global_step) - optimizer.minimize( - lambda: other_network(input_value), - global_step=root_checkpointable.global_step) - else: - train_op = optimizer.minimize( - network(input_value), global_step=root_checkpointable.global_step) - optimizer.minimize( - other_network(input_value), - global_step=root_checkpointable.global_step) - self.evaluate(variables.global_variables_initializer()) - self.evaluate(train_op) - named_variables, serialized_graph = checkpointable._serialize_object_graph( - root_checkpointable) - expected_checkpoint_names = ( - # Created in the root node, so no prefix. - "global_step", - # No name provided to track_checkpointable(), so the position (1, after - # the named track_checkpointable() which is 0) is used instead. - "network/_1/kernel", - # track_checkpointable() with a name provided, so that's used - "network/named_dense/kernel", - "network/named_dense/bias", - # The optimizer creates two non-slot variables - "optimizer/beta1_power", - "optimizer/beta2_power", - # Slot variables - "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/m", - "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/v", - "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/m", - "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/v", - "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m", - "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/v", - ) - six.assertCountEqual(self, expected_checkpoint_names, - named_variables.keys()) - # Check that we've mapped to the right variable objects (not exhaustive) - self.assertEqual("global_step:0", named_variables["global_step"].name) - self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0", - named_variables["network/_1/kernel"].name) - self.assertEqual("my_network/checkpointable_dense_layer/kernel:0", - named_variables["network/named_dense/kernel"].name) - self.assertEqual("beta1_power:0", - named_variables["optimizer/beta1_power"].name) - self.assertEqual("beta2_power:0", - named_variables["optimizer/beta2_power"].name) - # Spot check the generated protocol buffers. - self.assertEqual(0, serialized_graph.nodes[0].children[0].local_uid) - self.assertEqual("optimizer", - serialized_graph.nodes[0].children[0].local_name) - optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ - 0].node_id] - self.assertEqual("beta1_power", optimizer_node.variables[0].local_name) - self.assertEqual("beta1_power", optimizer_node.variables[0].full_name) - self.assertEqual( - "kernel", optimizer_node.slot_variables[0].original_variable_local_name) - original_variable_owner = serialized_graph.nodes[ - optimizer_node.slot_variables[0].original_variable_node_id] - self.assertEqual("kernel", original_variable_owner.variables[0].local_name) - self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) - # We strip off the :0 suffix, as variable.name-based saving does. - self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam", - optimizer_node.slot_variables[0].full_name) - self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam:0", - optimizer.get_slot( - var=named_variables["network/named_dense/kernel"], - name="m").name) - - def _get_checkpoint_name(self, name): - root = checkpointable.Checkpointable() - with variable_scope.variable_scope("get_checkpoint_name"): - # Create the variable in a variable scope so that we get more relaxed - # naming rules (variables outside a scope may not start with "_", "/" or - # "-"). Since we don't use the scope part of the name, these cases are - # somewhat annoying. - root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) - named_variables, _ = checkpointable._serialize_object_graph(root) - checkpoint_name, = named_variables.keys() - with ops.name_scope("root/" + checkpoint_name): - pass # Make sure we can use this as an op name if we prefix it. - return checkpoint_name - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testVariableNameEscaping(self): - self.assertEqual(r"a_S__b_S__c", self._get_checkpoint_name(r"a/b/c")) - self.assertEqual(r"", self._get_checkpoint_name(r"")) - self.assertEqual(r"_S__", self._get_checkpoint_name(r"/")) - self.assertEqual(r"_S___S_._", self._get_checkpoint_name(r"/_S__")) - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testNumberedPath(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() - root.track_checkpointable(leaf) - leaf.add_variable(name="v", shape=[]) - named_variables, _ = checkpointable._serialize_object_graph(root) - variable_name, = named_variables.keys() - self.assertEqual(r"_0/v", variable_name) - - @test_util.run_in_graph_and_eager_modes() - def testLocalNameValidation(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() - with self.assertRaisesRegexp(ValueError, "invalid name"): - # Leading underscores are reserved, which avoids conflicts with - # un-named edges in paths and the optimizer slots identifier. - root.track_checkpointable(leaf, name="_12") - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0506af391cf22ec89d5174f17208c8eb393ddc54 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable_utils.py @@ -0,0 +1,435 @@ +"""Utilities for working with Checkpointable objects.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import checkpointable as core_checkpointable +from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.training import saver as saver_lib + + +_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. + +# Keyword for identifying that the next bit of a checkpoint variable name is a +# slot name. Checkpoint names for slot variables look like: +# +# /<_OPTIMIZER_SLOTS_NAME>// +# +# Where is a full path from the checkpoint root to the +# variable being slotted for. +_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" +# Keyword for separating the path to an object from the name of an +# attribute in checkpoint names. Used like: +# /<_OBJECT_ATTRIBUTES_NAME>/ +_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" +# Key where the object graph proto is saved in a TensorBundle +_OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" + + +# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange +# or consolidating the implementation with get_variable. +def _default_getter(name, shape, dtype, initializer=None, + partition_info=None, **kwargs): + """A pared-down version of get_variable which does not reuse variables.""" + dtype = dtypes.as_dtype(dtype) + shape_object = tensor_shape.as_shape(shape) + with ops.init_scope(): + if initializer is None: + initializer, initializing_from_value = ( + variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access + name=name, shape=shape_object, dtype=dtype)) + else: + initializing_from_value = not callable(initializer) + # Same logic as get_variable + variable_dtype = dtype.base_dtype + if initializing_from_value: + if shape is not None: + raise ValueError("If initializer is a constant, do not specify shape.") + initial_value = initializer + else: + # Instantiate initializer if provided initializer is a type object. + if isinstance(initializer, type(init_ops.Initializer)): + initializer = initializer(dtype=dtype) + def initial_value(): + return initializer( + shape_object.as_list(), dtype=dtype, partition_info=partition_info) + return resource_variable_ops.ResourceVariable( + initial_value=initial_value, + name=name, + dtype=variable_dtype, + **kwargs + ) + + +def add_variable(checkpointable, name, shape=None, dtype=dtypes.float32, + initializer=None): + """Add a variable to a Checkpointable with no scope influence.""" + return checkpointable._add_variable_with_custom_getter( # pylint: disable=protected-access + name=name, shape=shape, dtype=dtype, + initializer=initializer, getter=_default_getter) + + +def _breadth_first_checkpointable_traversal(root_checkpointable): + """Find shortest paths to all variables owned by dependencies of root.""" + bfs_sorted = [] + to_visit = collections.deque([root_checkpointable]) + path_to_root = {root_checkpointable: ()} + while to_visit: + current_checkpointable = to_visit.popleft() + current_checkpointable._maybe_initialize_checkpointable() # pylint: disable=protected-access + bfs_sorted.append(current_checkpointable) + for child_checkpointable in ( + current_checkpointable._checkpoint_dependencies): # pylint: disable=protected-access + if child_checkpointable.ref not in path_to_root: + path_to_root[child_checkpointable.ref] = ( + path_to_root[current_checkpointable] + (child_checkpointable,)) + to_visit.append(child_checkpointable.ref) + return bfs_sorted, path_to_root + + +def _escape_local_name(name): + # We need to support slashes in local names for compatibility, since this + # naming scheme is being patched in to things like Layer.add_variable where + # slashes were previously accepted. We also want to use slashes to indicate + # edges traversed to reach the variable, so we escape forward slashes in + # names. + return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR) + .replace(r"/", _ESCAPE_CHAR + "S")) + + +def _object_prefix_from_path(path_to_root): + return "/".join( + (_escape_local_name(checkpointable.name) + for checkpointable in path_to_root)) + + +def _slot_variable_naming_for_optimizer(optimizer_path): + """Make a function for naming slot variables in an optimizer.""" + # Name slot variables: + # + # /<_OPTIMIZER_SLOTS_NAME>// + # + # where is exactly the checkpoint name used for the original + # variable, including the path from the checkpoint root and the local name in + # the object which owns it. Note that we only save slot variables if the + # variable it's slotting for is also being saved. + + optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path) + + def _name_slot_variable(variable_path, slot_name): + """With an optimizer specified, name a slot variable.""" + return (variable_path + + optimizer_identifier + + _escape_local_name(slot_name)) + + return _name_slot_variable + + +def _serialize_slot_variables(checkpointable_objects, node_ids, object_names): + """Gather and name slot variables.""" + non_slot_objects = list(checkpointable_objects) + slot_variables = {} + for checkpointable in non_slot_objects: + if isinstance(checkpointable, optimizer_lib.Optimizer): + naming_scheme = _slot_variable_naming_for_optimizer( + optimizer_path=object_names[checkpointable]) + slot_names = checkpointable.get_slot_names() + for slot_name in slot_names: + for original_variable_node_id, original_variable in enumerate( + non_slot_objects): + try: + slot_variable = checkpointable.get_slot( + original_variable, slot_name) + except AttributeError: + slot_variable = None + if slot_variable is None: + continue + slot_variable._maybe_initialize_checkpointable() # pylint: disable=protected-access + if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access + # TODO(allenl): Gather dependencies of slot variables. + raise NotImplementedError( + "Currently only variables with no dependencies can be saved as " + "slot variables. File a feature request if this limitation " + "bothers you.") + if slot_variable in node_ids: + raise NotImplementedError( + "A slot variable was re-used as a dependency of a " + "Checkpointable object. This is not currently allowed. File a " + "feature request if this limitation bothers you.") + checkpoint_name = naming_scheme( + variable_path=object_names[original_variable], + slot_name=slot_name) + object_names[slot_variable] = checkpoint_name + slot_variable_node_id = len(checkpointable_objects) + node_ids[slot_variable] = slot_variable_node_id + checkpointable_objects.append(slot_variable) + slot_variable_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph + .Object.SlotVariableReference( + slot_name=slot_name, + original_variable_node_id=original_variable_node_id, + slot_variable_node_id=slot_variable_node_id)) + slot_variables.setdefault(checkpointable, []).append( + slot_variable_proto) + return slot_variables + + +def _serialize_checkpointables( + checkpointable_objects, node_ids, object_names, slot_variables): + """Name non-slot `Checkpointable`s and add them to `object_graph_proto`.""" + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + named_saveables = {} + + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + assert node_ids[checkpointable] == checkpoint_id + object_proto = object_graph_proto.nodes.add() + object_proto.slot_variables.extend(slot_variables.get(checkpointable, ())) + object_name = object_names[checkpointable] + for name, saveable in ( + checkpointable._gather_tensors_for_checkpoint().items()): # pylint: disable=protected-access + attribute = object_proto.attributes.add() + attribute.name = name + attribute.checkpoint_key = "%s/%s/%s" % ( + object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) + # Figure out the name-based Saver's name for this variable. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [saveable], convert_variable_to_tensor=False) + attribute.full_name, = saver_dict.keys() + named_saveables[attribute.checkpoint_key] = saveable + + for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access + child_proto = object_proto.children.add() + child_proto.node_id = node_ids[child.ref] + child_proto.local_name = child.name + + return named_saveables, object_graph_proto + + +def _serialize_object_graph(root_checkpointable): + """Determine checkpoint keys for variables and build a serialized graph. + + Non-slot variables are keyed based on a shortest path from the root saveable + to the object which owns the variable (i.e. the one which called + `Checkpointable._add_variable` to create it). + + Slot variables are keyed based on a shortest path to the variable being + slotted for, a shortest path to their optimizer, and the slot name. + + Args: + root_checkpointable: A `Checkpointable` object whose variables (including + the variables of dependencies, recursively) should be saved. + + Returns: + A tuple of (named_variables, object_graph_proto): + named_variables: A dictionary mapping names to variable objects. + object_graph_proto: A CheckpointableObjectGraph protocol buffer containing + the serialized object graph and variable references. + + Raises: + ValueError: If there are invalid characters in an optimizer's slot names. + """ + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(root_checkpointable)) + object_names = { + obj: _object_prefix_from_path(path) + for obj, path in path_to_root.items()} + node_ids = {node: node_id for node_id, node + in enumerate(checkpointable_objects)} + slot_variables = _serialize_slot_variables( + checkpointable_objects=checkpointable_objects, + node_ids=node_ids, + object_names=object_names) + return _serialize_checkpointables( + checkpointable_objects=checkpointable_objects, + node_ids=node_ids, + object_names=object_names, + slot_variables=slot_variables) + + +class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): + + def __init__(self, tensor, name): + spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name) + super(_NoRestoreSaveable, self).__init__(tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + return control_flow_ops.no_op() + + +def save(file_prefix, root_checkpointable, checkpoint_number=None, + session=None): + """Save a training checkpoint. + + Args: + file_prefix: A prefix to use for the checkpoint filenames + (/path/to/directory/and_a_prefix). Names are generated based on this + prefix and the global step, if provided. + root_checkpointable: A Checkpointable object to save. The checkpoint + includes variables created by this object and any Checkpointable objects + it depends on. + checkpoint_number: An integer variable or Tensor, used to number + checkpoints. Typically this value is saved along with other variables in + training checkpoints, which will happen automatically if it was created by + `root_checkpointable` or one of its dependencies (via + `Checkpointable._add_variable`). + session: The session to evaluate variables in. Ignored when executing + eagerly. If not provided when graph building, the default session is used. + + Returns: + The full path to the checkpoint. + """ + named_variables, serialized_graph = _serialize_object_graph( + root_checkpointable) + if context.in_graph_mode(): + if session is None: + session = ops.get_default_session() + else: + session = None + assert _OBJECT_GRAPH_PROTO_KEY not in named_variables + # TODO(allenl): Feed rather than embedding a constant. + named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable( + tensor=constant_op.constant( + serialized_graph.SerializeToString(), dtype=dtypes.string), + name=_OBJECT_GRAPH_PROTO_KEY) + with ops.device("/device:CPU:0"): + save_path = saver_lib.Saver(var_list=named_variables).save( + sess=session, + save_path=file_prefix, + write_meta_graph=False, + global_step=checkpoint_number) + return save_path + + +class CheckpointLoadStatus(object): + """Checks the status of checkpoint loading.""" + + def __init__(self, checkpoint): + self._checkpoint = checkpoint + + def assert_consumed(self): + """Asserts that all objects in the checkpoint have been created/matched.""" + for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): + checkpointable = self._checkpoint.object_by_proto_id.get(node_id, None) + if checkpointable is None: + raise AssertionError("Unresolved object in checkpoint: %s" % (node,)) + if checkpointable._update_uid < self._checkpoint.restore_uid: # pylint: disable=protected-access + raise AssertionError( + "Object not assigned a value from checkpoint: %s" % (node,)) + if self._checkpoint.slot_restorations: + # Sanity check; this collection should be clear if everything has been + # restored. + raise AssertionError("Unresolved slot restorations: %s" % ( + self._checkpoint.slot_restorations,)) + return self + + @property + def restore_ops(self): + """Operations to restore objects in the dependency graph.""" + return self._checkpoint.restore_ops + + +def restore(save_path, root_checkpointable, session=None): + """Restore a training checkpoint. + + Restores the values of variables created with `Checkpointable._add_variable` + in `root_checkpointable` and any objects that it tracks (transitive). Either + assigns values immediately if variables to restore have been created already, + or defers restoration until the variables are created. Dependencies added to + `root_checkpointable` after this call will be matched if they have a + corresponding object in the checkpoint. + + When building a graph, restorations are added to the graph but not run. A + session is required to retrieve checkpoint metadata. + + To disallow deferred loading, assert immediately that all checkpointed + variables have been matched to variable objects: + + ```python + restore(path, root).assert_consumed() + ``` + + An exception will be raised unless every object was matched and its variables + already exist. + + When graph building, `assert_consumed()` indicates that all of the restore ops + which will be created for this checkpoint have been created. They are + available in the `restore_ops` property of the status object: + + ```python + session.run(restore(path, root).assert_consumed().restore_ops) + ``` + + If the checkpoint has not been consumed completely, then the list of + `restore_ops` will grow as more objects are added to the dependency graph. + + Args: + save_path: The path to the checkpoint, as returned by `save` or + `tf.train.latest_checkpoint`. If None (as when there is no latest + checkpoint for `tf.train.latest_checkpoint` to return), does nothing. + root_checkpointable: The root of the object graph to restore. Variables to + restore need not have been created yet, but all dependencies on other + `Checkpointable` objects should already be declared. Objects in the + dependency graph are matched to objects in the checkpointed graph, and + matching objects have their variables restored (or the checkpointed values + saved for eventual restoration when the variable is created). + session: The session to retrieve metadata with. Ignored when executing + eagerly. If not provided when graph building, the default session is used. + Returns: + A `CheckpointLoadStatus` object, which can be used to make assertions about + the status of checkpoint restoration and fetch restore ops. + """ + if save_path is None: + return + if context.in_graph_mode(): + if session is None: + session = ops.get_default_session() + else: + session = None + object_graph_string, = io_ops.restore_v2( + prefix=save_path, + tensor_names=[_OBJECT_GRAPH_PROTO_KEY], + shape_and_slices=[""], + dtypes=[dtypes.string], + name="object_graph_proto_read") + if session is not None: + object_graph_string = session.run(object_graph_string) + else: + object_graph_string = object_graph_string.numpy() + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + object_graph_proto.ParseFromString(object_graph_string) + checkpoint = core_checkpointable._Checkpoint( # pylint: disable=protected-access + object_graph_proto=object_graph_proto, + save_path=save_path) + core_checkpointable._CheckpointPosition( # pylint: disable=protected-access + checkpoint=checkpoint, proto_id=0).restore(root_checkpointable) + load_status = CheckpointLoadStatus(checkpoint) + return load_status diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..21ba6adc6a26a11783264be2e217373453224e79 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py @@ -0,0 +1,886 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os +import unittest + +import six + +from tensorflow.contrib.eager.python import checkpointable_utils +from tensorflow.contrib.eager.python import network as network_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.layers import base +from tensorflow.python.layers import core +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import adam +from tensorflow.python.training import checkpointable +from tensorflow.python.training import saver as core_saver +from tensorflow.python.training import training_util + + +class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): + + def __init__(self, *args, **kwargs): + checkpointable.Checkpointable.__init__(self) + core.Dense.__init__(self, *args, **kwargs) + + def add_variable(self, name, shape, **kwargs): + # Calls both Checkpointable._add_variable and Layer.add_variable. Eventually + # Layer.add_variable should inherit from Checkpointable and simply call + # super and then do post-processing. + return checkpointable.Checkpointable._add_variable_with_custom_getter( + self, + name=name, + shape=shape, + getter=functools.partial(core.Dense.add_variable, self), + **kwargs) + + +# pylint: disable=not-callable +class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): + + def __setattr__(self, name, value): + if isinstance(value, base.Layer): + self.track_layer(value, name=name) + # Checkpointable is next in the method resolution order, so this will catch + # Checkpointable objects which aren't Layers. + super(CheckpointableNetwork, self).__setattr__(name, value) + + def track_layer(self, layer, name): + self._track_checkpointable(layer, name=name) + return super(CheckpointableNetwork, self).track_layer(layer) + + +class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): + + # NOTE: Copied from Optimizer with modifications to use add_variable + # for non-slot variables. These contortions are necessary to maintain + # checkpoint compatibility with variable.name based saving. + # TODO(allenl): Make this cleaner. + def _create_non_slot_variable(self, initial_value, name, colocate_with): + """Add an extra variable, not associated with a slot.""" + if context.in_graph_mode(): + graph = colocate_with.graph + else: + graph = None + + key = (name, graph) + v = self._non_slot_dict.get(key, None) + if v is None: + with ops.colocate_with(colocate_with): + def _variable_getter(name, shape, dtype, initializer): + del shape, dtype # not used, but there for compatibility + return variable_scope.variable( + name=name, initial_value=initializer, trainable=False) + + initial_value = ops.convert_to_tensor(initial_value) + v = self._add_variable_with_custom_getter( + name=name, + shape=initial_value.get_shape(), + initializer=initial_value, + getter=_variable_getter) + + self._non_slot_dict[key] = v + + return v + + +class NonLayerCheckpointable(checkpointable.Checkpointable): + + def __init__(self): + super(NonLayerCheckpointable, self).__init__() + self.a_variable = checkpointable_utils.add_variable( + self, name="a_variable", shape=[]) + + +class MyNetwork(CheckpointableNetwork): + """A concrete Network for testing.""" + + def __init__(self): + super(MyNetwork, self).__init__() + self._named_dense = CheckpointableDenseLayer(1, use_bias=True) + self._via_track_layer = self.track_layer( + CheckpointableDenseLayer(1, use_bias=False), name="via_track_layer") + # We can still track Checkpointables which aren't Layers. + self._non_layer = NonLayerCheckpointable() + + def call(self, values): + return self._via_track_layer(self._named_dense(values)) + + +class Checkpoint(checkpointable.Checkpointable): + """A utility class which groups `Checkpointable` objects.""" + + def __init__(self, **kwargs): + super(Checkpoint, self).__init__() + for k, v in sorted(kwargs.items(), key=lambda item: item[0]): + setattr(self, k, v) + self._save_counter = None + + @property + def save_counter(self): + """An integer variable which starts at zero and is incremented on save. + + Used to number checkpoints. + + Returns: + The save counter variable. + """ + if self._save_counter is None: + # Initialized to 0 and incremented before saving. + self._save_counter = checkpointable_utils.add_variable( + self, name="save_counter", initializer=0, dtype=dtypes.int64) + return self._save_counter + + def save(self, file_prefix, session=None): + assign_op = self.save_counter.assign_add(1) + if context.in_graph_mode(): + if session is None: + session = ops.get_default_session() + session.run(assign_op) + return checkpointable_utils.save( + file_prefix=file_prefix, + root_checkpointable=self, + checkpoint_number=self.save_counter, + session=session) + + def restore(self, save_path): + return checkpointable_utils.restore( + save_path=save_path, + root_checkpointable=self) + + +class InterfaceTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testAddVariable(self): + obj = NonLayerCheckpointable() + with self.assertRaisesRegexp(ValueError, "do not specify shape"): + checkpointable_utils.add_variable( + obj, name="shape_specified_twice", shape=[], initializer=1) + constant_initializer = checkpointable_utils.add_variable( + obj, name="constant_initializer", initializer=1) + with variable_scope.variable_scope("some_variable_scope"): + ones_initializer = checkpointable_utils.add_variable( + obj, + name="ones_initializer", + shape=[2], + initializer=init_ops.ones_initializer(dtype=dtypes.float32)) + bare_initializer = checkpointable_utils.add_variable( + obj, + name="bare_initializer", + shape=[2, 2], + dtype=dtypes.float64, + initializer=init_ops.zeros_initializer) + + # Even in graph mode, there are no naming conflicts between objects, only + # naming conflicts within an object. + other_duplicate = resource_variable_ops.ResourceVariable( + name="duplicate", initial_value=1.) + duplicate = checkpointable_utils.add_variable( + obj, name="duplicate", shape=[]) + with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): + checkpointable_utils.add_variable(obj, name="duplicate", shape=[]) + + if context.in_graph_mode(): + self.evaluate(variables.global_variables_initializer()) + self.assertEqual("constant_initializer:0", constant_initializer.name) + self.assertEqual(1, self.evaluate(constant_initializer)) + self.assertEqual("some_variable_scope/ones_initializer:0", + ones_initializer.name) + self.assertAllEqual([1, 1], self.evaluate(ones_initializer)) + self.assertAllEqual([[0., 0.], + [0., 0.]], self.evaluate(bare_initializer)) + self.assertEqual("a_variable:0", obj.a_variable.name) + self.assertEqual("duplicate:0", other_duplicate.name) + if context.in_graph_mode(): + # The .name attribute may be globally influenced, but the checkpoint name + # won't be (tested below). + self.assertEqual("duplicate_1:0", duplicate.name) + else: + # When executing eagerly, there's no uniquification of variable names. The + # checkpoint name will be the same. + self.assertEqual("duplicate:0", duplicate.name) + named_variables, _ = checkpointable_utils._serialize_object_graph(obj) + expected_checkpoint_names = ( + "a_variable/.ATTRIBUTES/VARIABLE_VALUE", + "bare_initializer/.ATTRIBUTES/VARIABLE_VALUE", + "constant_initializer/.ATTRIBUTES/VARIABLE_VALUE", + "duplicate/.ATTRIBUTES/VARIABLE_VALUE", + "ones_initializer/.ATTRIBUTES/VARIABLE_VALUE", + ) + six.assertCountEqual( + self, expected_checkpoint_names, named_variables.keys()) + + def testInitNotCalled(self): + + class NoInit(checkpointable.Checkpointable): + + def __init__(self): + pass + + # __init__ for Checkpointable will be called implicitly. + checkpointable_utils.add_variable(NoInit(), "var", shape=[]) + + def testShapeDtype(self): + root = checkpointable.Checkpointable() + v1 = checkpointable_utils.add_variable( + root, name="v1", initializer=3., dtype=dtypes.float64) + self.assertEqual(dtypes.float64, v1.dtype) + v2 = checkpointable_utils.add_variable( + root, + name="v2", + shape=[3], + initializer=init_ops.ones_initializer, + dtype=dtypes.float64) + self.assertEqual(dtypes.float64, v2.dtype) + self.assertAllEqual([1., 1., 1.], self.evaluate(v2)) + + +class CheckpointingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNamingWithOptimizer(self): + input_value = constant_op.constant([[3.]]) + network = MyNetwork() + # A nuisance Network using the same optimizer. Its slot variables should not + # go in the checkpoint, since it is never depended on. + other_network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + optimizer_step = training_util.get_or_create_global_step() + root_checkpointable = Checkpoint( + optimizer=optimizer, network=network, optimizer_step=optimizer_step) + if context.in_eager_mode(): + optimizer.minimize( + lambda: network(input_value), + global_step=optimizer_step) + optimizer.minimize( + lambda: other_network(input_value), + global_step=optimizer_step) + else: + train_op = optimizer.minimize( + network(input_value), global_step=optimizer_step) + optimizer.minimize( + other_network(input_value), + global_step=optimizer_step) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(train_op) + named_variables, serialized_graph = ( + checkpointable_utils._serialize_object_graph(root_checkpointable)) + expected_checkpoint_names = ( + # Created in the root node, so no prefix. + "optimizer_step", + # No name provided to track_checkpointable(), so the position is used + # instead (one-based). + "network/via_track_layer/kernel", + # track_checkpointable() with a name provided, so that's used + "network/_named_dense/kernel", + "network/_named_dense/bias", + # non-Layer dependency of the network + "network/_non_layer/a_variable", + # The optimizer creates two non-slot variables + "optimizer/beta1_power", + "optimizer/beta2_power", + # Slot variables + "network/via_track_layer/kernel/.OPTIMIZER_SLOT/optimizer/m", + "network/via_track_layer/kernel/.OPTIMIZER_SLOT/optimizer/v", + "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m", + "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v", + "network/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m", + "network/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v", + ) + suffix = "/.ATTRIBUTES/VARIABLE_VALUE" + expected_checkpoint_names = [ + name + suffix for name in expected_checkpoint_names] + six.assertCountEqual(self, expected_checkpoint_names, + named_variables.keys()) + # Check that we've mapped to the right variable objects (not exhaustive) + self.assertEqual( + "global_step:0", + named_variables["optimizer_step" + suffix].name) + self.assertEqual( + "my_network/checkpointable_dense_layer_1/kernel:0", + named_variables["network/via_track_layer/kernel" + suffix].name) + self.assertEqual( + "my_network/checkpointable_dense_layer/kernel:0", + named_variables["network/_named_dense/kernel" + suffix].name) + self.assertEqual( + "beta1_power:0", + named_variables["optimizer/beta1_power" + suffix].name) + self.assertEqual( + "beta2_power:0", + named_variables["optimizer/beta2_power" + suffix].name) + # Spot check the generated protocol buffers. + self.assertEqual("optimizer", + serialized_graph.nodes[0].children[1].local_name) + optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ + 1].node_id] + self.assertEqual("beta1_power", + optimizer_node.children[0].local_name) + self.assertEqual("beta1_power", + serialized_graph.nodes[optimizer_node.children[0].node_id] + .attributes[0].full_name) + self.assertEqual( + "my_network/checkpointable_dense_layer/kernel", + serialized_graph.nodes[optimizer_node.slot_variables[0] + .original_variable_node_id] + .attributes[0].full_name) + # We strip off the :0 suffix, as variable.name-based saving does. + self.assertEqual( + "my_network/checkpointable_dense_layer/kernel/Adam", + serialized_graph.nodes[optimizer_node.slot_variables[0] + .slot_variable_node_id] + .attributes[0].full_name) + self.assertEqual( + "my_network/checkpointable_dense_layer/kernel/Adam:0", + optimizer.get_slot( + var=named_variables["network/_named_dense/kernel" + suffix], + name="m").name) + self.assertEqual( + "network/_named_dense/kernel" + suffix, + serialized_graph.nodes[ + optimizer_node.slot_variables[0] + .original_variable_node_id].attributes[0].checkpoint_key) + self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) + self.assertEqual( + "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix, + serialized_graph.nodes[ + optimizer_node.slot_variables[0] + .slot_variable_node_id].attributes[0].checkpoint_key) + + @test_util.run_in_graph_and_eager_modes() + def testSaveRestore(self): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root_checkpointable = Checkpoint(optimizer=optimizer, network=network) + input_value = constant_op.constant([[3.]]) + if context.in_eager_mode(): + optimizer.minimize( + lambda: network(input_value)) + else: + train_op = optimizer.minimize(network(input_value)) + # TODO(allenl): Make initialization more pleasant when graph building. + root_checkpointable.save_counter # pylint: disable=pointless-statement + self.evaluate(variables.global_variables_initializer()) + self.evaluate(train_op) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.])) + m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m") + self.evaluate(state_ops.assign(m_bias_slot, [1.5])) + save_path = root_checkpointable.save(file_prefix=prefix) + self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.])) + self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) + optimizer_variables = self.evaluate(optimizer.variables()) + self.evaluate(state_ops.assign(m_bias_slot, [-2.])) + # Immediate restoration + status = root_checkpointable.restore(save_path=save_path).assert_consumed() + self.evaluate(status.restore_ops) + self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1])) + self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) + self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) + if context.in_graph_mode(): + return # Restore-on-create is only supported when executing eagerly + on_create_network = MyNetwork() + on_create_optimizer = CheckpointableAdam(0.001) + on_create_root = Checkpoint( + optimizer=on_create_optimizer, network=on_create_network) + # Deferred restoration + status = on_create_root.restore(save_path=save_path) + on_create_network(constant_op.constant([[3.]])) # create variables + self.assertAllEqual(1, self.evaluate(on_create_root.save_counter)) + self.assertAllEqual([42.], + self.evaluate( + on_create_network._named_dense.variables[1])) + on_create_m_bias_slot = on_create_optimizer.get_slot( + on_create_network._named_dense.variables[1], "m") + # Optimizer slot variables are created when the original variable is + # restored. + self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) + self.assertAllEqual(optimizer_variables[2:], + self.evaluate(on_create_optimizer.variables())) + on_create_optimizer._create_slots( + [resource_variable_ops.ResourceVariable([1.])]) + status.assert_consumed() + beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() + self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) + self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power)) + + def testDeferredRestorationUsageEager(self): + """An idiomatic eager execution example.""" + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + for training_continuation in range(3): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root = Checkpoint( + optimizer=optimizer, network=network, + optimizer_step=training_util.get_or_create_global_step()) + root.restore(core_saver.latest_checkpoint(checkpoint_directory)) + for _ in range(num_training_steps): + # TODO(allenl): Use a Dataset and serialize/checkpoint it. + input_value = constant_op.constant([[3.]]) + optimizer.minimize( + lambda: network(input_value), # pylint: disable=cell-var-from-loop + global_step=root.optimizer_step) + root.save(file_prefix=checkpoint_prefix) + self.assertEqual((training_continuation + 1) * num_training_steps, + root.optimizer_step.numpy()) + + def testUsageGraph(self): + """Expected usage when graph building.""" + with context.graph_mode(): + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + for training_continuation in range(3): + with ops.Graph().as_default(): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root = Checkpoint( + optimizer=optimizer, network=network, + global_step=training_util.get_or_create_global_step()) + input_value = constant_op.constant([[3.]]) + train_op = optimizer.minimize( + network(input_value), + global_step=root.global_step) + root.save_counter # pylint: disable=pointless-statement + init_op = variables.global_variables_initializer() + checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + with self.test_session(graph=ops.get_default_graph()) as session: + if checkpoint_path is None: + self.assertEqual(0, training_continuation) + session.run(init_op) + # Another alternative would be to run initializers automatically + # if no checkpoint is being loaded. This would make deferred + # loading a bit more useful with graph execution. + else: + status = checkpointable_utils.restore( + save_path=checkpoint_path, + root_checkpointable=root, + session=session).assert_consumed() + session.run(status.restore_ops) + for _ in range(num_training_steps): + session.run(train_op) + root.save(file_prefix=checkpoint_prefix, + session=session) + self.assertEqual((training_continuation + 1) * num_training_steps, + session.run(root.global_step)) + self.assertEqual(training_continuation + 1, + session.run(root.save_counter)) + + def _get_checkpoint_name(self, name): + root = checkpointable.Checkpointable() + checkpointable_utils.add_variable( + root, name=name, shape=[1, 2], dtype=dtypes.float64) + named_variables, _ = checkpointable_utils._serialize_object_graph(root) + checkpoint_name, = named_variables.keys() + with ops.name_scope("root/" + checkpoint_name): + pass # Make sure we can use this as an op name if we prefix it. + return checkpoint_name + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableNameEscaping(self): + suffix = "/.ATTRIBUTES/VARIABLE_VALUE" + self.assertEqual(r"a.Sb.Sc" + suffix, self._get_checkpoint_name(r"a/b/c")) + self.assertEqual(r"b" + suffix, self._get_checkpoint_name(r"b")) + self.assertEqual(r"c.S" + suffix, self._get_checkpoint_name(r"c/")) + self.assertEqual(r"d.S..S" + suffix, self._get_checkpoint_name(r"d/.S")) + self.assertEqual(r"d.S..ATTRIBUTES.Sf" + suffix, + self._get_checkpoint_name(r"d/.ATTRIBUTES/f")) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNumberedPath(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + root.leaf = leaf + checkpointable_utils.add_variable(leaf, name="v", shape=[]) + named_variables, _ = checkpointable_utils._serialize_object_graph(root) + variable_name, = named_variables.keys() + self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", variable_name) + + @test_util.run_in_graph_and_eager_modes() + def testLocalNameValidation(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + # Dots are escaped, which avoids conflicts with reserved names. + root._track_checkpointable(leaf, name=".ATTRIBUTES") + checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[]) + named_variables, _ = checkpointable_utils._serialize_object_graph(root) + name, = named_variables.keys() + self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE") + + @test_util.run_in_graph_and_eager_modes() + def testLateDependencyTracking(self): + + class Dependency(checkpointable.Checkpointable): + + def build(self): + self.var = checkpointable_utils.add_variable( + self, "var", initializer=0.) + + class LateDependencies(checkpointable.Checkpointable): + + def add_dep(self): + self.dep = Dependency() + self.dep.build() + + original = LateDependencies() + original.add_dep() + self.evaluate(state_ops.assign(original.dep.var, 123.)) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpointable_utils.save(checkpoint_prefix, original) + load_into = LateDependencies() + status = checkpointable_utils.restore(save_path, load_into) + with self.assertRaises(AssertionError): + status.assert_consumed() + load_into.add_dep() + status.assert_consumed() + self.evaluate(status.restore_ops) + self.assertEqual(123., self.evaluate(load_into.dep.var)) + + @test_util.run_in_graph_and_eager_modes() + def testDepAfterVar(self): + + class Dependency(checkpointable.Checkpointable): + + def build(self): + self.var = checkpointable_utils.add_variable( + self, "var", initializer=0.) + + class DepAfterVar(checkpointable.Checkpointable): + + def add_dep(self): + dep = Dependency() + dep.build() + self.dep = dep + + dep_after_var = DepAfterVar() + dep_after_var.add_dep() + self.evaluate(state_ops.assign(dep_after_var.dep.var, -14.)) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpointable_utils.save( + checkpoint_prefix, dep_after_var) + + loaded_dep_after_var = DepAfterVar() + status = checkpointable_utils.restore( + save_path, loaded_dep_after_var) + loaded_dep_after_var.add_dep() + status.assert_consumed() + self.evaluate(status.restore_ops) + self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var)) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testDeferredSlotRestoration(self): + checkpoint_directory = self.get_temp_dir() + + root = checkpointable.Checkpointable() + root.var = checkpointable_utils.add_variable( + root, name="var", initializer=0.) + optimizer = CheckpointableAdam(0.1) + if context.in_graph_mode(): + train_op = optimizer.minimize(root.var) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(train_op) + else: + optimizer.minimize(root.var.read_value) + self.evaluate(state_ops.assign(root.var, 12.)) + no_slots_path = checkpointable_utils.save( + os.path.join(checkpoint_directory, "no_slots"), root) + root.optimizer = optimizer + self.evaluate(state_ops.assign(root.var, 13.)) + self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), + 14.)) + slots_path = checkpointable_utils.save( + os.path.join(checkpoint_directory, "with_slots"), root) + new_root = checkpointable.Checkpointable() + # Load the slot-containing checkpoint (deferred), then immediately overwrite + # the non-slot variable (also deferred). + slot_status = checkpointable_utils.restore( + slots_path, new_root) + no_slot_status = checkpointable_utils.restore( + no_slots_path, new_root) + with self.assertRaises(AssertionError): + no_slot_status.assert_consumed() + new_root.var = checkpointable_utils.add_variable( + new_root, name="var", shape=[]) + no_slot_status.assert_consumed() + self.evaluate(no_slot_status.restore_ops) + self.assertEqual(12., self.evaluate(new_root.var)) + new_root.optimizer = CheckpointableAdam(0.1) + with self.assertRaisesRegexp(AssertionError, "beta1_power"): + slot_status.assert_consumed() + self.assertEqual(12., self.evaluate(new_root.var)) + if context.in_eager_mode(): + # Slot variables are only created with restoring initializers when + # executing eagerly. + self.assertEqual(14., self.evaluate( + new_root.optimizer.get_slot(name="m", var=new_root.var))) + else: + self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), + None) + if context.in_graph_mode(): + train_op = new_root.optimizer.minimize(new_root.var) + # The slot variable now exists; restore() didn't create it, but we should + # now have a restore op for it. + self.evaluate(slot_status.restore_ops) + self.assertEqual(14., self.evaluate( + new_root.optimizer.get_slot(name="m", var=new_root.var))) + self.evaluate(train_op) + else: + new_root.optimizer.minimize(new_root.var.read_value) + slot_status.assert_consumed() + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testOverlappingRestores(self): + checkpoint_directory = self.get_temp_dir() + save_root = checkpointable.Checkpointable() + save_root.dep = checkpointable.Checkpointable() + save_root.dep.var = checkpointable_utils.add_variable( + save_root.dep, name="var", initializer=0.) + self.evaluate(state_ops.assign(save_root.dep.var, 12.)) + first_path = checkpointable_utils.save( + os.path.join(checkpoint_directory, "first"), save_root) + self.evaluate(state_ops.assign(save_root.dep.var, 13.)) + second_path = checkpointable_utils.save( + os.path.join(checkpoint_directory, "second"), save_root) + + first_root = checkpointable.Checkpointable() + second_root = checkpointable.Checkpointable() + first_status = checkpointable_utils.restore( + first_path, first_root) + second_status = checkpointable_utils.restore( + second_path, second_root) + load_dep = checkpointable.Checkpointable() + load_dep.var = checkpointable_utils.add_variable( + load_dep, name="var", shape=[]) + first_root.dep = load_dep + first_status.assert_consumed() + self.evaluate(first_status.restore_ops) + self.assertEqual([], second_status.restore_ops) + self.assertEqual(12., self.evaluate(load_dep.var)) + second_root.dep = load_dep + second_status.assert_consumed() + self.evaluate(second_status.restore_ops) + self.assertEqual(13., self.evaluate(load_dep.var)) + + # Try again with the order of the restore() reversed. The last restore + # determines the final value. + first_root = checkpointable.Checkpointable() + second_root = checkpointable.Checkpointable() + second_status = checkpointable_utils.restore( + second_path, second_root) + first_status = checkpointable_utils.restore( + first_path, first_root) + load_dep = checkpointable.Checkpointable() + load_dep.var = checkpointable_utils.add_variable( + load_dep, name="var", shape=[]) + first_root.dep = load_dep + first_status.assert_consumed() + self.assertEqual([], second_status.restore_ops) + self.evaluate(first_status.restore_ops) + self.assertEqual(12., self.evaluate(load_dep.var)) + second_root.dep = load_dep + second_status.assert_consumed() + self.evaluate(second_status.restore_ops) + self.assertEqual(12., self.evaluate(load_dep.var)) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testAmbiguousLoad(self): + # Not OK to split one checkpoint object into two + checkpoint_directory = self.get_temp_dir() + save_root = checkpointable.Checkpointable() + save_root.dep_one = checkpointable.Checkpointable() + save_root.dep_two = checkpointable.Checkpointable() + dep_three = checkpointable.Checkpointable() + save_root.dep_one.dep_three = dep_three + save_root.dep_two.dep_three = dep_three + checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) + self.evaluate(variables.global_variables_initializer()) + save_path = checkpointable_utils.save( + os.path.join(checkpoint_directory, "ckpt"), save_root) + load_root = checkpointable.Checkpointable() + checkpointable_utils.restore(save_path, load_root) + load_root.dep_one = checkpointable.Checkpointable() + load_root.dep_two = checkpointable.Checkpointable() + load_root.dep_one.dep_three = checkpointable.Checkpointable() + with self.assertRaisesRegexp(AssertionError, + "resolved to different objects"): + load_root.dep_two.dep_three = checkpointable.Checkpointable() + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testObjectsCombined(self): + # Currently fine to load two checkpoint objects into one Python object + checkpoint_directory = self.get_temp_dir() + save_root = checkpointable.Checkpointable() + save_root.dep_one = checkpointable.Checkpointable() + save_root.dep_two = checkpointable.Checkpointable() + checkpointable_utils.add_variable( + save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) + checkpointable_utils.add_variable( + save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64) + self.evaluate(variables.global_variables_initializer()) + save_path = checkpointable_utils.save( + os.path.join(checkpoint_directory, "ckpt"), save_root) + load_root = checkpointable.Checkpointable() + load_root.dep_one = checkpointable.Checkpointable() + load_root.dep_two = load_root.dep_one + v1 = checkpointable_utils.add_variable( + load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) + v2 = checkpointable_utils.add_variable( + load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64) + status = checkpointable_utils.restore( + save_path, load_root).assert_consumed() + self.evaluate(status.restore_ops) + self.assertEqual(32., self.evaluate(v1)) + self.assertEqual(64., self.evaluate(v2)) + + @test_util.run_in_graph_and_eager_modes() + def testDependencyLoop(self): + # Note: this test creates garbage during eager execution because it + # purposefully creates a reference cycle. + first = checkpointable.Checkpointable() + second = checkpointable.Checkpointable() + first.second = second + second.first = first + first.v = checkpointable_utils.add_variable( + first, "v1", initializer=[3., 1., 4.]) + second.v = checkpointable_utils.add_variable( + second, "v2", initializer=[1., 1., 2., 3.]) + self.evaluate(variables.global_variables_initializer()) + checkpoint_directory = self.get_temp_dir() + save_path = checkpointable_utils.save( + os.path.join(checkpoint_directory, "ckpt"), first) + + # Test deferred loading + first_load = checkpointable.Checkpointable() + status = checkpointable_utils.restore(save_path, first_load) + second_load = checkpointable.Checkpointable() + first_load.second = second_load + second_load.first = first_load + with self.assertRaises(AssertionError): + status.assert_consumed() + first_load.v = checkpointable_utils.add_variable( + first_load, "v1", shape=[3]) + second_load.v = checkpointable_utils.add_variable( + second_load, "v2", shape=[4]) + status.assert_consumed() + self.evaluate(status.restore_ops) + self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) + self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) + + # Test loading when variables have already been created + self.evaluate(first_load.v.assign([2., 7., 1.])) + self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v)) + self.evaluate(second_load.v.assign([2., 7., 1., 8.])) + self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v)) + status = checkpointable_utils.restore( + save_path, first_load).assert_consumed() + self.evaluate(status.restore_ops) + self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) + self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) + + @test_util.run_in_graph_and_eager_modes() + def testRestoreOnAssign(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session(save_graph): + first = checkpointable.Checkpointable() + first.var1 = variable_scope.get_variable( + name="outside_var", initializer=0.) + first.var2 = variable_scope.get_variable( + name="blah", initializer=0.) + self.evaluate(first.var1.assign(4.)) + self.evaluate(first.var2.assign(8.)) + save_path = checkpointable_utils.save( + checkpoint_prefix, root_checkpointable=first) + restore_graph = ops.Graph() + with restore_graph.as_default(), self.test_session(restore_graph): + second = checkpointable.Checkpointable() + second.var2 = variable_scope.get_variable( + name="blah", initializer=0.) + status = checkpointable_utils.restore( + save_path, root_checkpointable=second) + recreated_var1 = variable_scope.get_variable( + name="outside_var", initializer=0.) + self.evaluate(status.restore_ops) + self.assertEqual(8., self.evaluate(second.var2)) + self.evaluate(recreated_var1.assign(-2.)) + self.assertEqual(-2., self.evaluate(recreated_var1)) + second.var1 = recreated_var1 + self.evaluate(status.restore_ops) + self.assertEqual(4., self.evaluate(recreated_var1)) + + # TODO(allenl): Saver class that doesn't pollute the graph with constants. + @unittest.skip("todo") + def testManySavesGraph(self): + """Saves after the first should not modify the graph.""" + with context.graph_mode(): + graph = ops.Graph() + with graph.as_default(), self.test_session(graph): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + obj = checkpointable.Checkpointable() + obj.var = variable_scope.get_variable(name="v", initializer=0.) + obj.opt = CheckpointableAdam(0.1) + obj.opt.minimize(obj.var.read_value()) + self.evaluate(variables.global_variables_initializer()) + checkpointable_utils.save( + checkpoint_prefix, root_checkpointable=obj) + before_ops = graph.get_operations() + checkpointable_utils.save( + checkpoint_prefix, root_checkpointable=obj) + self.assertEqual(before_ops, graph.get_operations()) + + @unittest.skip("todo") + def testManyRestoresGraph(self): + """Restores after the first should not modify the graph.""" + with context.graph_mode(): + graph = ops.Graph() + with graph.as_default(), self.test_session(graph): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + obj = checkpointable.Checkpointable() + obj.var = variable_scope.get_variable(name="v", initializer=0.) + obj.opt = CheckpointableAdam(0.1) + obj.opt.minimize(obj.var.read_value()) + self.evaluate(variables.global_variables_initializer()) + save_path = checkpointable_utils.save( + checkpoint_prefix, root_checkpointable=obj) + checkpointable_utils.restore( + save_path, root_checkpointable=obj) + before_ops = graph.get_operations() + checkpointable_utils.restore( + save_path, root_checkpointable=obj) + self.assertEqual(before_ops, graph.get_operations()) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 544a3eafc08f892f6e3315f0656c97b9877cfa0e..d177bfeab2d1fdc05d7ced54df8723fae2c77fdb 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -112,7 +112,7 @@ class Iterator(object): remote_fn.add_to_graph(None) target = constant_op.constant("/device:CPU:0") with ops.device(self._device): - self._buffer_resource_handle = prefetching_ops.function_buffering_resource( + self._buffer_resource_handle = prefetching_ops.function_buffering_resource( # pylint: disable=line-too-long string_arg=iter_string_handle, f=remote_fn, target_device=target, @@ -120,8 +120,9 @@ class Iterator(object): thread_pool_size=1, container="", shared_name=_generate_shared_name("function_buffer_resource")) - self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._buffer_resource_handle, handle_device=self._device) + self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long + handle=self._buffer_resource_handle, + handle_device=self._device) def __iter__(self): return self diff --git a/tensorflow/contrib/eager/python/examples/gan/README.md b/tensorflow/contrib/eager/python/examples/gan/README.md index e8c9db1a1e2eb5881b08a4d3866c82b24d64be12..208a64b05d47eea10b49a1bf967a5453677bfd21 100644 --- a/tensorflow/contrib/eager/python/examples/gan/README.md +++ b/tensorflow/contrib/eager/python/examples/gan/README.md @@ -11,7 +11,7 @@ Other eager execution examples can be found under the parent directory. - `mnist.py`: Model definitions and training routines. - `mnist_test.py`: Benchmarks for training and using the models using eager execution. -- `mnist_graph_test.py`: Benchmarks for trainig and using the models using +- `mnist_graph_test.py`: Benchmarks for training and using the models using graph execution. The same model definitions and loss functions are used in all benchmarks. diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py index 2a7be95811f6fff06e2c489890703561ed879c42..58b1e89d15895cf38331e6f7bd5a311a2f5f6467 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -35,11 +35,11 @@ from tensorflow.examples.tutorials.mnist import input_data FLAGS = None -class MNISTModel(tfe.Network): +class MNISTModel(tf.keras.Model): """MNIST Network. Network structure is equivalent to: - https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py + https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/mnist/mnist_deep.py and https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py @@ -61,18 +61,17 @@ class MNISTModel(tfe.Network): else: assert data_format == 'channels_last' self._input_shape = [-1, 28, 28, 1] - self.conv1 = self.track_layer( - tf.layers.Conv2D(32, 5, data_format=data_format, activation=tf.nn.relu)) - self.conv2 = self.track_layer( - tf.layers.Conv2D(64, 5, data_format=data_format, activation=tf.nn.relu)) - self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.nn.relu)) - self.fc2 = self.track_layer(tf.layers.Dense(10)) - self.dropout = self.track_layer(tf.layers.Dropout(0.5)) - self.max_pool2d = self.track_layer( - tf.layers.MaxPooling2D( - (2, 2), (2, 2), padding='SAME', data_format=data_format)) - - def call(self, inputs, training): + self.conv1 = tf.layers.Conv2D( + 32, 5, data_format=data_format, activation=tf.nn.relu) + self.conv2 = tf.layers.Conv2D( + 64, 5, data_format=data_format, activation=tf.nn.relu) + self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu) + self.fc2 = tf.layers.Dense(10) + self.dropout = tf.layers.Dropout(0.5) + self.max_pool2d = tf.layers.MaxPooling2D( + (2, 2), (2, 2), padding='SAME', data_format=data_format) + + def call(self, inputs, training=False): """Computes labels from inputs. Users should invoke __call__ to run the network, which delegates to this @@ -95,8 +94,7 @@ class MNISTModel(tfe.Network): x = self.max_pool2d(x) x = tf.layers.flatten(x) x = self.fc1(x) - if training: - x = self.dropout(x) + x = self.dropout(x, training=training) x = self.fc2(x) return x diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 76e06269b6bbeb3386a6346244d294b1c5167b6e..0ff8746884c288f824f5f22ab4c550370d0e0302 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -22,6 +22,7 @@ import gc import tempfile import time +from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf import tensorflow.contrib.eager as tfe diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 40919f2d4cf511eb35fac954719286366aef6c7c..aa87b94e7b0876e65405f6bcb2d6aabde36582bf 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -65,7 +65,6 @@ import six import tensorflow as tf from tensorflow.contrib.eager.python import tfe -from tensorflow.python.eager import context try: import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index d34e9ea68b76373d4b5a9ee9e3852c60a7c81525..5c5c59c87744f4ffa6db90e5d8d3aa3bc8132756 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -339,8 +339,7 @@ if __name__ == "__main__": "http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz") parser.add_argument( "--logdir", type=str, default="", help="Directory for checkpoint.") - parser.add_argument( - "--epoch", type=int, default=20, help="Number of epochs.") + parser.add_argument("--epoch", type=int, default=20, help="Number of epochs.") parser.add_argument("--batch-size", type=int, default=20, help="Batch size.") parser.add_argument( "--seq-len", type=int, default=35, help="Sequence length.") diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py index fcaae0a4f8c0bad916d74bd9b80fcfa55a63d84a..3bc3bb49bcbbc26f7a3134a8bfc385ec080dde1e 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/data.py +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -227,6 +227,29 @@ def calculate_bins(length2count, min_bin_size): return bounds +def encode_sentence(sentence, word2index): + """Encode a single sentence as word indices and shift-reduce code. + + Args: + sentence: The sentence with added binary parse information, represented as + a string, with all the word items and parentheses separated by spaces. + E.g., '( ( The dog ) ( ( is ( playing toys ) ) . ) )'. + word2index: A `dict` mapping words to their word indices. + + Returns: + 1. Word indices as a numpy array, with shape `(sequence_len, 1)`. + 2. Shift-reduce sequence as a numpy array, with shape + `(sequence_len * 2 - 3, 1)`. + """ + items = [w for w in sentence.split(" ") if w] + words = get_non_parenthesis_words(items) + shift_reduce = get_shift_reduce(items) + word_indices = pad_and_reverse_word_ids( + [[word2index.get(word, UNK_CODE) for word in words]]).T + return (word_indices, + np.expand_dims(np.array(shift_reduce, dtype=np.int64), -1)) + + class SnliData(object): """A split of SNLI data.""" diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py index e4f0b37c5099e45b7e3b258b258c0a203c36b3b7..54fef2c3fe4111cd2d93ac109a5b8fffad0c2fad 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/data_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py @@ -22,6 +22,7 @@ import os import shutil import tempfile +import numpy as np import tensorflow as tf from tensorflow.contrib.eager.python.examples.spinn import data @@ -173,14 +174,9 @@ class DataTest(tf.test.TestCase): ValueError, "Cannot find GloVe embedding file at"): data.load_word_vectors(self._temp_data_dir, vocab) - def testSnliData(self): - """Unit test for SnliData objects.""" - snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") - fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") - os.makedirs(snli_1_0_dir) - + def _createFakeSnliData(self, fake_snli_file): # Four sentences in total. - with open(fake_train_file, "wt") as f: + with open(fake_snli_file, "wt") as f: f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") @@ -205,10 +201,7 @@ class DataTest(tf.test.TestCase): "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") - glove_dir = os.path.join(self._temp_data_dir, "glove") - os.makedirs(glove_dir) - glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") - + def _createFakeGloveData(self, glove_file): words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] with open(glove_file, "wt") as f: for i, word in enumerate(words): @@ -220,6 +213,40 @@ class DataTest(tf.test.TestCase): else: f.write("\n") + def testEncodeSingleSentence(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + self._createFakeSnliData(fake_train_file) + vocab = data.load_vocabulary(self._temp_data_dir) + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + self._createFakeGloveData(glove_file) + word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) + + sentence_variants = [ + "( Foo ( ( bar baz ) . ) )", + " ( Foo ( ( bar baz ) . ) ) ", + "( Foo ( ( bar baz ) . ) )"] + for sentence in sentence_variants: + word_indices, shift_reduce = data.encode_sentence(sentence, word2index) + self.assertEqual(np.int64, word_indices.dtype) + self.assertEqual((5, 1), word_indices.shape) + self.assertAllClose( + np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T, shift_reduce) + + def testSnliData(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + self._createFakeSnliData(fake_train_file) + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + self._createFakeGloveData(glove_file) + vocab = data.load_vocabulary(self._temp_data_dir) word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) @@ -230,7 +257,7 @@ class DataTest(tf.test.TestCase): self.assertEqual(1, train_data.num_batches(4)) generator = train_data.get_generator(2)() - for i in range(2): + for _ in range(2): label, prem, prem_trans, hypo, hypo_trans = next(generator) self.assertEqual(2, len(label)) self.assertEqual((4, 2), prem.shape) diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 84e25cf81a2223800c47994b26d000caddee6b01..081b0af14fcc983a3f85d2a50e2bb04d2f2493b3 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -26,6 +26,7 @@ import tempfile import time import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf # pylint: disable=g-bad-import-order @@ -35,6 +36,7 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util +from tensorflow.python.training import checkpoint_utils # pylint: enable=g-bad-import-order @@ -65,13 +67,30 @@ def _generate_synthetic_snli_data_batch(sequence_length, return labels, prem, prem_trans, hypo, hypo_trans -def _test_spinn_config(d_embed, d_out, logdir=None): +def _test_spinn_config(d_embed, d_out, logdir=None, inference_sentences=None): + """Generate a config tuple for testing. + + Args: + d_embed: Embedding dimensions. + d_out: Model output dimensions. + logdir: Optional logdir. + inference_sentences: A 2-tuple of strings representing the sentences (with + binary parsing result), e.g., + ("( ( The dog ) ( ( is running ) . ) )", "( ( The dog ) ( moves . ) )"). + + Returns: + A config tuple. + """ config_tuple = collections.namedtuple( "Config", ["d_hidden", "d_proj", "d_tracker", "predict", "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp", "d_out", "projection", "lr", "batch_size", "epochs", "force_cpu", "logdir", "log_every", "dev_every", "save_every", - "lr_decay_every", "lr_decay_by"]) + "lr_decay_every", "lr_decay_by", "inference_premise", + "inference_hypothesis"]) + + inference_premise = inference_sentences[0] if inference_sentences else None + inference_hypothesis = inference_sentences[1] if inference_sentences else None return config_tuple( d_hidden=d_embed, d_proj=d_embed * 2, @@ -85,14 +104,16 @@ def _test_spinn_config(d_embed, d_out, logdir=None): projection=True, lr=2e-2, batch_size=2, - epochs=10, + epochs=20, force_cpu=False, logdir=logdir, log_every=1, dev_every=2, save_every=2, lr_decay_every=1, - lr_decay_by=0.75) + lr_decay_by=0.75, + inference_premise=inference_premise, + inference_hypothesis=inference_hypothesis) class SpinnTest(test_util.TensorFlowTestCase): @@ -287,11 +308,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # Training on the batch should have led to a change in the loss value. self.assertNotEqual(loss1.numpy(), loss2.numpy()) - def testTrainSpinn(self): - """Test with fake toy SNLI data and GloVe vectors.""" - - # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. - snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + def _create_test_data(self, snli_1_0_dir): fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") os.makedirs(snli_1_0_dir) @@ -336,13 +353,52 @@ class SpinnTest(test_util.TensorFlowTestCase): else: f.write("\n") + return fake_train_file + + def testInferSpinnWorks(self): + """Test inference with the spinn model.""" + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + self._create_test_data(snli_1_0_dir) + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir"), + inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )")) + logits = spinn.train_or_infer_spinn( + embed, word2index, None, None, None, config) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((3,), logits.shape) + + def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + self._create_test_data(snli_1_0_dir) + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir"), + inference_sentences=("( foo ( bar . ) )", None)) + with self.assertRaises(ValueError): + spinn.train_or_infer_spinn(embed, word2index, None, None, None, config) + + def testTrainSpinn(self): + """Test with fake toy SNLI data and GloVe vectors.""" + + # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = self._create_test_data(snli_1_0_dir) + vocab = data.load_vocabulary(self._temp_data_dir) word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) train_data = data.SnliData(fake_train_file, word2index) dev_data = data.SnliData(fake_train_file, word2index) test_data = data.SnliData(fake_train_file, word2index) - print(embed) # 2. Create a fake config. config = _test_spinn_config( @@ -350,7 +406,8 @@ class SpinnTest(test_util.TensorFlowTestCase): logdir=os.path.join(self._temp_data_dir, "logdir")) # 3. Test training of a SPINN model. - spinn.train_spinn(embed, train_data, dev_data, test_data, config) + trainer = spinn.train_or_infer_spinn( + embed, word2index, train_data, dev_data, test_data, config) # 4. Load train loss values from the summary files and verify that they # decrease with training. @@ -362,6 +419,15 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual(config.epochs, len(train_losses)) self.assertLess(train_losses[-1], train_losses[0]) + # 5. Verify that checkpoints exist and contains all the expected variables. + self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) + ckpt_variable_names = [ + item[0] for item in checkpoint_utils.list_variables(config.logdir)] + self.assertIn("global_step", ckpt_variable_names) + for v in trainer.variables: + variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name + self.assertIn(variable_name, ckpt_variable_names) + class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index 7eea93ce1f5aefe82d73b49f57b636692818ba16..d97ff6b74cf033617154f7cbbd00cb6492a1d2f4 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -19,29 +19,34 @@ to models defined without using eager execution. ## Installation -Eager execution is **not** included in the latest release (version 1.4) of -TensorFlow. To use it, you will need to [build TensorFlow from -source](https://www.tensorflow.org/install/install_sources) or install the -nightly builds. +Eager execution is included in TensorFlow versions 1.5 and above. +Installation instructions at https://www.tensorflow.org/install/ -For example, the nightly builds can be installed using `pip`: +The contents of this guide are compatible with TensorFlow 1.5. +However, if you run into bugs that are fixed in source but not the +release, you may want to either [build from +source](https://www.tensorflow.org/install/install_sources) +or try a nightly build. The nightly builds are available as: -- `pip install tf-nightly` (for CPU-only TensorFlow) -- `pip install tf-nightly-gpu` (for GPU-enabled TensorFlow) +- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and -Or using `docker`, with [Jupyter Notebook](http://jupyter.org/) support: +- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images. + +For example, to run the latest nightly docker image: ```sh -# For CPU-only TensorFlow +# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker +docker pull tensorflow/tensorflow:nightly-gpu +docker run --runtime=nvidia -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu + +# If you do not have a GPU, use the CPU-only image docker pull tensorflow/tensorflow:nightly docker run -it -p 8888:8888 tensorflow/tensorflow:nightly - -# For GPU-enabled TensorFlow: -# (Requires https://github.com/NVIDIA/nvidia-docker) -nvidia-docker pull tensorflow/tensorflow:nightly-gpu -nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu ``` +And then visit http://localhost:8888 in your browser for a Jupyter notebook +environment. + ## Getting Started With TensorFlow installed, eager execution is enabled via a single call: diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index bf029ca5f9dddb152274da6a1cc96bea7981d8fd..ea8dbf2b46ea4bd0e33645ae3c590c4dd13f7a52 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -291,6 +291,9 @@ class Mean(Metric): Args: values: Tensor with the per-example value. weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. """ if weights is None: self.denom.assign_add( @@ -302,6 +305,9 @@ class Mean(Metric): self.denom.assign_add(math_ops.reduce_sum(weights)) values = math_ops.cast(values, self.dtype) * weights self.numer.assign_add(math_ops.reduce_sum(values)) + if weights is None: + return values + return values, weights def result(self): t = self.numer / self.denom @@ -329,7 +335,13 @@ class Accuracy(Mean): per element of the Tensor. predictions: Tensor with the predicted label for each example. weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. """ matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, dtypes.float64) super(Accuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 9cf34fd9b2dcf1b123cacc6863af817419eda007..a9ecaa3f8bced3043ea0eb0ac3aa8bfa65e9e1ff 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -180,6 +180,19 @@ class MetricsTest(test.TestCase): m2 = metrics.Mean() m2(2) + def testMetricsChain(self): + with context.graph_mode(), self.test_session(): + m1 = metrics.Mean() + m2 = metrics.Mean(name="m2") + update_m2 = m2(3.0) + update_m2_2 = m2(m1(1.0)) + m1.init_variables().run() + m2.init_variables().run() + update_m2.eval() + update_m2_2.eval() + self.assertAllEqual(m2.result().eval(), 2.0) + self.assertAllEqual(m1.result().eval(), 1.0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 81c77e41acf420fa84857ccb366aa2fbd6055f42..3329fc6c513265deff41a368f5688dd605209c14 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -539,7 +539,7 @@ class NetworkTest(test.TestCase): # No issue here since the name is unique within its scope. name_conflict3 = MyNetwork(name="name_conflict") net2 = MyNetwork() # name=outside_scope/my_network_2 to avoid the - # variable_scope my_network_1 below. + # variable_scope my_network_1 below. vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below with variable_scope.variable_scope("intervening_scope"): with variable_scope.variable_scope(captured_scope): diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index abc7e3690c76c4446bce6b945325f1ca15ef1c8b..1a7f7b85e688e80e3cf482f2754462888187d311 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -73,16 +73,6 @@ class SaverTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'v1'): saver.save(ckpt_prefix) - def testDifferentGraphError(self): - with ops.device(self._dev()): - with ops.Graph().as_default(): - v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') - with ops.Graph().as_default(): - saver = _saver.Saver([v1]) - ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') - with self.assertRaisesRegexp(ValueError, 'Graph'): - saver.save(ckpt_prefix) - def testSameObjectOK(self): with ops.device(self._dev()): v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 712d1cb94d2f565bf6216f6c07a45d3d855efe9c..d32bebf90c1e768d1efec26b3b78bf1a522a8f00 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -59,7 +59,6 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@in_eager_mode @@in_graph_mode -@@IsolateTest @@run_test_in_graph_and_eager_modes @@DEVICE_PLACEMENT_EXPLICIT @@ -101,7 +100,6 @@ from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run -from tensorflow.python.framework.test_util import IsolateTest from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 0dedb2fd7c0905801cd87c239ff2ee09eecb6080..b6659c2a1797feab261d756e78b45231dbea5a02 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -102,10 +102,6 @@ class TFETest(test_util.TensorFlowTestCase): # Expect at least one device. self.assertTrue(tfe.list_devices()) - def testNumGPUs(self): - devices = tfe.list_devices() - self.assertEqual(len(devices) - 1, tfe.num_gpus()) - def testAddCheckNumericsOpsRaisesError(self): with self.assertRaisesRegexp( RuntimeError, diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index cdbe05e4d2d7117c5acb12d679f359a9db17c9cc..6cdbed5b896577f5622b1bd0123c289c798bc0a5 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -163,7 +163,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", + "//tensorflow/python:check_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:lookup_ops", @@ -177,7 +177,6 @@ py_library( "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/estimator:util", "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:signature_constants", ], diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py index 5f4a3cc902c9cc07c0688ad41dab7391a641c133..ad1a8ef152b07ecbab33d9eb3184a2ae89def27d 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import linear from tensorflow.python.feature_column import feature_column as fc diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index fd0994490aac7b9a0ed628e0c3c624d0fefb1b81..238cf287b768eee28b20202084eb244c085c8b75 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.estimator import model_fn -from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import prediction_keys @@ -29,7 +28,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib @@ -45,6 +43,7 @@ def multi_class_head(n_classes, weight_column=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM, + loss_fn=None, name=None): """Creates a `_Head` for multi class classification. @@ -65,6 +64,12 @@ def multi_class_head(n_classes, labels have shape `[batch_size, 1]`, the loss is the weighted sum over `batch_size`. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with + shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to + the input labels before passing them to `loss_fn`. + Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use `binary_classification_head`). @@ -79,6 +84,7 @@ def multi_class_head(n_classes, `label_vocabulary` is not provided but labels are strings. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -94,12 +100,17 @@ def multi_class_head(n_classes, weight_column=weight_column, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) def binary_classification_head( - weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, name=None): + weight_column=None, + thresholds=None, + label_vocabulary=None, + loss_reduction=losses.Reduction.SUM, + loss_fn=None, + name=None): """Creates a `_Head` for single label binary classification. This head uses `sigmoid_cross_entropy_with_logits` loss. @@ -119,6 +130,12 @@ def binary_classification_head( labels have shape `[batch_size, 1]`, the loss is the weighted sum over `batch_size`. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with + shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to + the input labels before passing them to `loss_fn`. + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -136,6 +153,7 @@ def binary_classification_head( is not provided but labels are strings. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -151,12 +169,14 @@ def binary_classification_head( thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) def regression_head(weight_column=None, label_dimension=1, loss_reduction=losses.Reduction.SUM, + loss_fn=None, name=None): """Creates a `_Head` for regression using the `mean_squared_error` loss. @@ -175,6 +195,10 @@ def regression_head(weight_column=None, `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN, label_dimension]`. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[D0, D1, ... DN, label_dimension]`. + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -185,6 +209,7 @@ def regression_head(weight_column=None, `[batch_size, label_dimension]`). loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -198,6 +223,7 @@ def regression_head(weight_column=None, weight_column=weight_column, label_dimension=label_dimension, loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) @@ -287,7 +313,7 @@ def multi_label_head(n_classes, 'Length of label_vocabulary must be n_classes ({}). ' 'Given: {}'.format(n_classes, len(label_vocabulary))) if loss_fn: - _validate_loss_fn_args(loss_fn) + head_lib._validate_loss_fn_args(loss_fn) # pylint:disable=protected-access if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) @@ -371,9 +397,9 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access labels=processed_labels, logits=logits, expected_labels_dimension=self.logits_dimension) if self._loss_fn: - unweighted_loss = _call_loss_fn( + unweighted_loss = head_lib._call_loss_fn( # pylint:disable=protected-access loss_fn=self._loss_fn, labels=processed_labels, logits=logits, - features=features) + features=features, expected_loss_dim=1) else: unweighted_loss = losses.sigmoid_cross_entropy( multi_class_labels=processed_labels, logits=logits, @@ -555,52 +581,3 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access threshold=threshold, name=recall_key)) return metric_ops - - -def _validate_loss_fn_args(loss_fn): - """Validates loss_fn arguments. - - Required arguments: labels, logits. - Optional arguments: features. - - Args: - loss_fn: The loss function. - Raises: - ValueError: If the signature is unexpected. - """ - loss_fn_args = util.fn_args(loss_fn) - for required_arg in ['labels', 'logits']: - if required_arg not in loss_fn_args: - raise ValueError( - 'loss_fn must contain argument: {}. ' - 'Given arguments: {}'.format(required_arg, loss_fn_args)) - invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features'])) - if invalid_args: - raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args)) - - -def _call_loss_fn(loss_fn, labels, logits, features): - """Calls loss_fn and checks the returned shape. - - Args: - loss_fn: The loss function. - labels: Processed labels Tensor. - logits: Logits Tensor of shape [batch_size, logits_dimension]. - features: Features dict. - Returns: - Loss Tensor with shape [batch_size, 1]. - """ - loss_fn_args = util.fn_args(loss_fn) - kwargs = {} - if 'features' in loss_fn_args: - kwargs['features'] = features - unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs) - batch_size = array_ops.shape(logits)[0] - loss_shape = array_ops.shape(unweighted_loss) - check_shape_op = control_flow_ops.Assert( - math_ops.reduce_all(math_ops.equal(loss_shape, [batch_size, 1])), - data=[ - 'loss_fn must return Tensor of shape [batch_size, 1]. Given: ', - loss_shape]) - with ops.control_dependencies([check_shape_op]): - return array_ops.identity(unweighted_loss) diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 1adbd6f0fe32df4a513a2683d03fcefca07e2a42..1411635228457218578c0297d4d901e9c86ca91a 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -381,8 +381,8 @@ class MultiLabelHead(test.TestCase): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r'loss_fn must return Tensor of shape \[batch_size, 1\]\. ' - r'Given: \] \[2\]'): + r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] ' + r'\[logits_shape: \] \[2 2\] \[loss_shape: \] \[2\]'): actual_training_loss.eval() def test_eval_labels_none(self): @@ -446,7 +446,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, + keys.AUC_PR: 0.5972, } self._test_eval( head=head, @@ -478,7 +478,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, + keys.AUC_PR: 0.5972, } self._test_eval( head=head, @@ -509,7 +509,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, + keys.AUC_PR: 0.5972, } self._test_eval( head=head, @@ -543,7 +543,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, + keys.AUC_PR: 0.5972, } self._test_eval( head=head, @@ -573,7 +573,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, + keys.AUC_PR: 0.5972, keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4., keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3., keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3., @@ -621,7 +621,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.2000, - keys.AUC_PR: 0.7833, + keys.AUC_PR: 0.5833, } # Assert spec contains expected tensors. @@ -1095,7 +1095,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.4977, - keys.AUC_PR: 0.6645, + keys.AUC_PR: 0.4037, } self._test_eval( head=head, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 65ea89ba1b9236d0bf4d2de430fab168ef50bf97..e47a6788f3b5440c4906b9f0430c802cf73237e3 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -306,8 +306,8 @@ class MultiHeadTest(test.TestCase): # this assert tests that the algorithm remains consistent. keys.AUC + '/head1': 0.1667, keys.AUC + '/head2': 0.3333, - keys.AUC_PR + '/head1': 0.6667, - keys.AUC_PR + '/head2': 0.5000, + keys.AUC_PR + '/head1': 0.49999964, + keys.AUC_PR + '/head2': 0.33333313, } # Assert spec contains expected tensors. diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index caa9dd83233b6b850385335fde96431271d85c3a..e0fae2c99292385c6dd32cc6002cee2076a2bb20 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -110,7 +110,8 @@ def replicate_model_fn(model_fn, Certain algorithms were chosen for aggregating results of computations on multiple towers: - Losses from all towers are reduced according to `loss_reduction`. - - Gradients are reduced using sum for each trainable variable. + - Gradients from all towers are reduced according to `loss_reduction` + for each trainable variable. - `eval_metrics_ops` are reduced per metric using `reduce_mean`. - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are reduced using concatenation. @@ -195,7 +196,7 @@ def _replicate_model_fn_with_mode( if not devices: devices = _get_local_devices('GPU') or _get_local_devices('CPU') - is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] + is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0].upper() consolidation_device = devices[0] if is_a_single_gpu_case else '/CPU:0' ps_devices = [consolidation_device] @@ -457,6 +458,13 @@ def _get_local_devices(device_type): def _split_batch(features, labels, number_of_shards, device): """Split input features and labes into batches.""" + def ensure_divisible_by_shards(sequence): + batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0] + if batch_size % number_of_shards != 0: + raise ValueError( + 'Batch size {} needs to be divisible by the number of GPUs, which ' + 'is {}.'.format(batch_size, number_of_shards)) + def split_dictionary(dictionary): """Split a dictionary into shards.""" shards = [{} for _ in range(number_of_shards)] @@ -467,6 +475,7 @@ def _split_batch(features, labels, number_of_shards, device): sp_input=tensor, num_split=number_of_shards, axis=0)): shards[i][name] = shard else: + ensure_divisible_by_shards(tensor) for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): shards[i][name] = shard return shards @@ -476,6 +485,7 @@ def _split_batch(features, labels, number_of_shards, device): if isinstance(features, dict): feature_shards = split_dictionary(features) else: + ensure_divisible_by_shards(features) feature_shards = array_ops.split(features, number_of_shards) if labels is None: @@ -483,6 +493,7 @@ def _split_batch(features, labels, number_of_shards, device): elif isinstance(labels, dict): label_shards = split_dictionary(labels) else: + ensure_divisible_by_shards(labels) label_shards = array_ops.split(labels, number_of_shards) return feature_shards, label_shards @@ -780,7 +791,7 @@ def _extract_tensors(tensors_and_vars): tensor, _ = tensor_and_var if isinstance(tensor, ops_lib.IndexedSlices): tensors.append(tensor.values) - else: + elif tensor is not None: tensors.append(tensor) return tensors diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index 03d31226af613960a19ce116b19b30153b1fdcee..d46a18aacfcd911c56a9f22dc9581060c7b458a6 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -37,6 +37,7 @@ from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -239,6 +240,13 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): labels = np.array([[1.0], [2.0]]) with self.test_session() as session: + # Add another trainable variable that doesn't produce a gradient to + # verify that None gradients are supported. + _ = variable_scope.get_variable( + 'another_variable', + initializer=constant_op.constant(1, dtype=dtypes.float64), + dtype=dtypes.float64) + replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn( @@ -433,12 +441,51 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): 'probabilities': np.array([[0.1], [0.02]]) }, session.run(estimator_spec.predictions)) + def test_batch_size_that_is_not_divisible_by_the_number_of_gpus(self): + features = np.array([[1.0], [2.0], [3.0]]) + labels = np.array([[1.0], [2.0], [3.0]]) + + with self.assertRaisesRegexp( + ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0', '/gpu:1']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + def test_unsupported_loss_reduction(self): with self.assertRaisesRegexp(ValueError, '.+none.+reduction.+is.+specified.+'): _ = replicate_model_fn.replicate_model_fn(self.model_fn, losses.Reduction.NONE) + def test_places_on_gpu_with_upper_case_spelling(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session(): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/GPU:0']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', c.device) + + def test_places_on_gpu_with_lower_case_spelling(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session(): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', c.device) + class ReplicateAcrossASingleDeviceWithoutTowerOptimizer( test_util.TensorFlowTestCase): @@ -981,8 +1028,13 @@ class SplitBatchTest(test_util.TensorFlowTestCase): return list(map(evaluate_items, first_list)), list( map(evaluate_items, second_list)) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + def test_simple_half_split(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = [0.0, 1.0, 2.0, 3.0] labels = [10.0, 11.0, 12.0, 13.0] feature_shards, label_shards = replicate_model_fn._split_batch( @@ -995,7 +1047,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards) def test_to_each_their_own(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = [0.0, 1.0, 2.0, 3.0] labels = [10.0, 11.0, 12.0, 13.0] feature_shards, label_shards = replicate_model_fn._split_batch( @@ -1008,7 +1060,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards) def test_one_batch(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = [0.0, 1.0, 2.0, 3.0] labels = [10.0, 11.0, 12.0, 13.0] feature_shards, label_shards = replicate_model_fn._split_batch( @@ -1021,7 +1073,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards) def test_half_split_in_dictionary(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} labels = [10.0, 11.0, 12.0, 13.0] @@ -1035,6 +1087,58 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([10.0, 11.0], label_shards[0].eval()) self.assertAllEqual([12.0, 13.0], label_shards[1].eval()) + def test_sparse_tensor_can_be_split_unevenly(self): + with self.test_session(): + features = { + 'x': + sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2], [2, 2]], + values=[1.0, 2.0, 3.0], + dense_shape=[3, 4]) + } + labels = np.array([[1.0], [2.0]]) + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[2, 4]), + feature_shards[0]['x'].eval()) + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 2]], values=[3.], dense_shape=[1, 4]), + feature_shards[1]['x'].eval()) + self.assertAllEqual([[1.0]], label_shards[0].eval()) + self.assertAllEqual([[2.0]], label_shards[1].eval()) + + def test_sparse_tensor_can_be_split_unevenly_repeated_row(self): + with self.test_session(): + features = { + 'x': + sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=[1.0, 2.0, 3.0], + dense_shape=[3, 4]) + } + labels = np.array([[1.0], [2.0]]) + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [1, 1]], + values=[1., 2., 3.], + dense_shape=[2, 4]), feature_shards[0]['x'].eval()) + + second_batch = feature_shards[1]['x'].eval() + self.assertFalse(len(second_batch.indices)) + self.assertFalse(len(second_batch.values)) + self.assertAllEqual([1, 4], second_batch.dense_shape) + self.assertAllEqual([[1.0]], label_shards[0].eval()) + self.assertAllEqual([[2.0]], label_shards[1].eval()) + def test_one_batch_in_dictionary(self): with self.test_session() as session: # pylint: disable=unused-variable features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index fe86a20ab1f69a0eaf9d7486142451dac6337274..180f1b68f3b56113dfbbfc100bd04efc3bb8b31f 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -221,6 +221,7 @@ py_test( name = "kmeans_test", size = "medium", srcs = ["python/ops/kmeans_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["notsan"], # b/67512932 deps = [ diff --git a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc index 31d08bfb65ea49e1378ffba480771d38ce16abec..a8c5d0763c28ba2b54f217405f0da65533f26b91 100644 --- a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc +++ b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc @@ -57,11 +57,11 @@ typedef Eigen::Map< class MaskedMatmulOp : public OpKernel { public: - explicit MaskedMatmulOp(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->MatchSignature( - {DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL, DT_BOOL}, - {DT_FLOAT})); + explicit MaskedMatmulOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK( + context, + context->MatchSignature( + {DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL, DT_BOOL}, {DT_FLOAT})); } void Compute(OpKernelContext* context) override { @@ -110,12 +110,11 @@ class MaskedMatmulOp : public OpKernel { num_nonzero_elements, 2); Tensor* prod_values_tensor; - OP_REQUIRES_OK(context, - context->allocate_output( - 0, TensorShape({num_nonzero_elements}), - &prod_values_tensor)); - EigenMatFloatMap prod_values(prod_values_tensor->vec().data(), - 1, num_nonzero_elements); + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({num_nonzero_elements}), + &prod_values_tensor)); + EigenMatFloatMap prod_values(prod_values_tensor->vec().data(), 1, + num_nonzero_elements); auto get_a_index = [&indices_mat, &a_dim_0](int64 i) { int64 a_index = internal::SubtleMustCopy(indices_mat(i, 0)); @@ -182,8 +181,8 @@ class MaskedMatmulOp : public OpKernel { } }; // Shard the work. - worker_threads.workers->ParallelFor( - num_nonzero_elements, cost_per_unit, work); + worker_threads.workers->ParallelFor(num_nonzero_elements, cost_per_unit, + work); } }; REGISTER_KERNEL_BUILDER(Name("MaskedMatmul").Device(DEVICE_CPU), diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index 6d3acb2750743318aad83991bc1e89d64c329423..23137e0a973c0bdd2cdbd97159f7fd310178bf54 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -192,11 +192,11 @@ class KMeans(object): # Computes Euclidean distance. Note the first and third terms are # broadcast additions. squared_distance = ( - math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) - + math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) - 2 * math_ops.matmul(inp, clusters, transpose_b=True) + array_ops.transpose( math_ops.reduce_sum( - math_ops.square(clusters), 1, keep_dims=True))) + math_ops.square(clusters), 1, keepdims=True))) output.append(squared_distance) return output diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py index f72280c4ecf19e33278ffe74061f44bbb7b21709..b2dfe48b2dbe0ec0975f865bba95a7ceba0f590c 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm.py +++ b/tensorflow/contrib/factorization/python/ops/gmm.py @@ -24,17 +24,16 @@ import numpy as np from tensorflow.contrib import framework from tensorflow.contrib.factorization.python.ops import gmm_ops from tensorflow.contrib.framework.python.framework import checkpoint_utils -from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops as logging -from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops.control_flow_ops import with_dependencies from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util def _streaming_sum(scalar_tensor): @@ -70,8 +69,8 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook): class GMM(estimator.Estimator): """An estimator for GMM clustering.""" SCORES = 'scores' + LOG_LIKELIHOOD = 'loss' ASSIGNMENTS = 'assignments' - ALL_SCORES = 'all_scores' def __init__(self, num_clusters, @@ -113,10 +112,7 @@ class GMM(estimator.Estimator): yield result[GMM.ASSIGNMENTS] def score(self, input_fn=None, batch_size=None, steps=None): - """Predict total sum of distances to nearest clusters. - - Note that this function is different from the corresponding one in sklearn - which returns the negative of the sum of distances. + """Predict total log-likelihood. Args: input_fn: see predict. @@ -124,11 +120,11 @@ class GMM(estimator.Estimator): steps: see predict. Returns: - Total sum of distances to nearest clusters. + Total log-likelihood. """ results = self.evaluate(input_fn=input_fn, batch_size=batch_size, steps=steps) - return np.sum(results[GMM.SCORES]) + return np.log(np.sum(np.exp(results[GMM.SCORES]))) def weights(self): """Returns the cluster weights.""" @@ -158,9 +154,10 @@ class GMM(estimator.Estimator): def _model_fn(features, labels, mode, config): """Model function.""" assert labels is None, labels - (all_scores, + (loss, + scores, model_predictions, - losses, training_op, + training_op, init_op, is_initialized) = gmm_ops.gmm(self._parse_tensor_or_dict(features), self._training_initial_clusters, @@ -168,16 +165,15 @@ class GMM(estimator.Estimator): self._covariance_type, self._params) incr_step = state_ops.assign_add(training_util.get_global_step(), 1) - loss = math_ops.reduce_sum(losses) training_op = with_dependencies([training_op, incr_step], loss) training_hooks = [_InitializeClustersHook( init_op, is_initialized, config.is_chief)] predictions = { - GMM.ALL_SCORES: all_scores[0], GMM.ASSIGNMENTS: model_predictions[0][0], } eval_metric_ops = { - GMM.SCORES: _streaming_sum(loss), + GMM.SCORES: scores, + GMM.LOG_LIKELIHOOD: _streaming_sum(loss), } return model_fn_lib.ModelFnOps(mode=mode, predictions=predictions, eval_metric_ops=eval_metric_ops, diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index a61681c7f5a69a0fff1089404fc80b95c1c3106e..98d6434f4752b224201e38bed05ccd14428a758b 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -21,7 +21,6 @@ from __future__ import division from __future__ import print_function import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -36,7 +35,6 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.embedding_ops import embedding_lookup -from tensorflow.python.summary import summary # Machine epsilon. MEPS = np.finfo(float).eps @@ -253,14 +251,16 @@ class GmmAlgorithm(object): return ret def scores(self): - """Returns the distances to each class. + """Returns the per-sample likelihood fo the data. Returns: - A tuple with two Tensors. The first contains the distance to - each class. The second contains the distance to the assigned - class. + Log probabilities of each data point. """ - return (self._all_scores, self._scores) + return self._scores + + def log_likelihood_op(self): + """Returns the log-likelihood operation.""" + return self._log_likelihood_op def _define_graph(self, data): """Define graph for a single iteration. @@ -276,7 +276,8 @@ class GmmAlgorithm(object): self._define_expectation_operation(shard_id) self._define_partial_maximization_operation(shard_id, shard) self._define_maximization_operation(len(data)) - self._define_distance_to_clusters(data) + self._define_loglikelihood_operation() + self._define_score_samples() def _define_full_covariance_probs(self, shard_id, shard): """Defines the full covariance probabilties per example in a class. @@ -440,50 +441,20 @@ class GmmAlgorithm(object): state_ops.assign( self._covs, new_covs, validate_shape=False)) - def _define_distance_to_clusters(self, data): - """Defines the Mahalanobis distance to the assigned Gaussian.""" - # TODO(xavigonzalvo): reuse (input - mean) * cov^-1 * (input - - # mean) from log probability function. - self._all_scores = [] - for shard in data: - all_scores = [] - shard = array_ops.expand_dims(shard, 0) - for c in xrange(self._num_classes): - if self._covariance_type == FULL_COVARIANCE: - cov = self._covs[c, :, :] - elif self._covariance_type == DIAG_COVARIANCE: - cov = array_ops.diag(self._covs[c, :]) - inverse = linalg_ops.matrix_inverse(cov + self._min_var) - inv_cov = array_ops.tile( - array_ops.expand_dims(inverse, 0), - array_ops.stack([self._num_examples, 1, 1])) - diff = array_ops.transpose(shard - self._means[c, :, :], perm=[1, 0, 2]) - m_left = math_ops.matmul(diff, inv_cov) - all_scores.append( - math_ops.sqrt( - math_ops.matmul( - m_left, array_ops.transpose( - diff, perm=[0, 2, 1])))) - self._all_scores.append( - array_ops.reshape( - array_ops.concat(all_scores, 1), - array_ops.stack([self._num_examples, self._num_classes]))) - - # Distance to the associated class. - self._all_scores = array_ops.concat(self._all_scores, 0) - assignments = array_ops.concat(self.assignments(), 0) - rows = math_ops.to_int64(math_ops.range(0, self._num_examples)) - indices = array_ops.concat( - [array_ops.expand_dims(rows, 1), array_ops.expand_dims(assignments, 1)], - 1) - self._scores = array_ops.gather_nd(self._all_scores, indices) - def _define_loglikelihood_operation(self): """Defines the total log-likelihood of current iteration.""" - self._ll_op = [] + op = [] for prior_probs in self._prior_probs: - self._ll_op.append(math_ops.reduce_sum(math_ops.log(prior_probs))) - summary.scalar('ll', math_ops.reduce_sum(self._ll_op)) + op.append(math_ops.reduce_logsumexp(prior_probs)) + self._log_likelihood_op = math_ops.reduce_logsumexp(op) + + def _define_score_samples(self): + """Defines the likelihood of each data sample.""" + op = [] + for shard_id, prior_probs in enumerate(self._prior_probs): + op.append(prior_probs + math_ops.log(self._w[shard_id])) + self._scores = array_ops.squeeze( + math_ops.reduce_logsumexp(op, axis=2, keep_dims=True), axis=0) def gmm(inp, @@ -511,14 +482,9 @@ def gmm(inp, Returns: Note: tuple of lists returned to be consistent with skflow A tuple consisting of: - all_scores: A matrix (or list of matrices) of dimensions (num_input, - num_clusters) where the value is the distance of an input vector and a - cluster center. assignments: A vector (or list of vectors). Each element in the vector corresponds to an input row in 'inp' and specifies the cluster id corresponding to the input. - scores: Similar to assignments but specifies the distance to the - assigned cluster instead. training_op: an op that runs an iteration of training. init_op: an op that runs the initialization. """ @@ -532,6 +498,7 @@ def gmm(inp, gmm_tool = GmmAlgorithm(inp, num_clusters, initial_means, params, covariance_type, random_seed) assignments = gmm_tool.assignments() - all_scores, scores = gmm_tool.scores() - return ([all_scores], [assignments], [scores], gmm_tool.training_ops(), + scores = gmm_tool.scores() + loss = gmm_tool.log_likelihood_op() + return (loss, scores, [assignments], gmm_tool.training_ops(), gmm_tool.init_ops(), gmm_tool.is_initialized()) diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py index c50e82db8a230012ba13c1d7ad7e28c23bd27355..888c3c238c2654ea11ea3bf8270d6c3fcd951a03 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py @@ -122,17 +122,23 @@ class GmmOpsTest(test.TestCase): g.seed = 5 with self.test_session() as sess: data = constant_op.constant(self.data, dtype=dtypes.float32) - _, assignments, _, training_op, init_op, _ = gmm_ops.gmm( + loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm( data, 'random', num_classes, random_seed=self.seed) variables.global_variables_initializer().run() sess.run(init_op) + first_loss = sess.run(loss_op) for _ in xrange(self.iterations): sess.run(training_op) assignments = sess.run(assignments) + end_loss = sess.run(loss_op) + scores = sess.run(scores) + self.assertEqual((self.num_examples, 1), scores.shape) accuracy = np.mean( np.asarray(self.true_assignments) == np.squeeze(assignments)) logging.info('Accuracy: %f', accuracy) + logging.info('First loss: %f, end loss: %f', first_loss, end_loss) + self.assertGreater(end_loss, first_loss) self.assertGreater(accuracy, 0.98) def testParams(self): diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py index 7717b47daefce9ff65b1f1e84f671a463cf2e826..00a4734eb6d89cd02484f1c5161366377cc71208 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.factorization.python.ops import gmm as gmm_lib from tensorflow.contrib.learn.python.learn.estimators import kmeans @@ -30,12 +29,9 @@ from tensorflow.python.framework import random_seed as random_seed_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import random_ops -from tensorflow.python.platform import flags from tensorflow.python.platform import test from tensorflow.python.training import queue_runner -FLAGS = flags.FLAGS - class GMMTest(test.TestCase): @@ -64,9 +60,8 @@ class GMMTest(test.TestCase): self.batch_size = self.num_points self.true_centers = self.make_random_centers(self.num_centers, self.num_dims) - self.points, self.assignments, self.scores = self.make_random_points( + self.points, self.assignments = self.make_random_points( self.true_centers, self.num_points) - self.true_score = np.add.reduce(self.scores) # Use initial means from kmeans (just like scikit-learn does). clusterer = kmeans.KMeansClustering(num_clusters=self.num_centers) @@ -86,24 +81,7 @@ class GMMTest(test.TestCase): offsets = np.round( np.random.randn(num_points, num_dims).astype(np.float32) * 20) points = centers[assignments] + offsets - means = [ - np.mean( - points[assignments == center], axis=0) - for center in xrange(num_centers) - ] - covs = [ - np.cov(points[assignments == center].T) - for center in xrange(num_centers) - ] - scores = [] - for r in xrange(num_points): - scores.append( - np.sqrt( - np.dot( - np.dot(points[r, :] - means[assignments[r]], - np.linalg.inv(covs[assignments[r]])), points[r, :] - - means[assignments[r]]))) - return (points, assignments, scores) + return (points, assignments) def test_weights(self): """Tests the shape of the weights.""" @@ -136,8 +114,7 @@ class GMMTest(test.TestCase): gmm.fit(input_fn=self.input_fn(), steps=10) score2 = gmm.score(input_fn=self.input_fn(batch_size=self.num_points), steps=1) - self.assertGreater(score1, score2) - self.assertNear(self.true_score, score2, self.true_score * 0.15) + self.assertLess(score1, score2) def test_infer(self): gmm = gmm_lib.GMM(self.num_centers, @@ -149,8 +126,7 @@ class GMMTest(test.TestCase): # Make a small test set num_points = 40 - points, true_assignments, true_offsets = ( - self.make_random_points(clusters, num_points)) + points, true_assignments = self.make_random_points(clusters, num_points) assignments = [] for item in gmm.predict_assignments( @@ -159,11 +135,6 @@ class GMMTest(test.TestCase): assignments = np.ravel(assignments) self.assertAllEqual(true_assignments, assignments) - # Test score - score = gmm.score(input_fn=self.input_fn(points=points, - batch_size=num_points), steps=1) - self.assertNear(score, np.sum(true_offsets), 4.05) - def _compare_with_sklearn(self, cov_type): # sklearn version. iterations = 40 diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 4d0f9b24240ccbafe89ef912b4d3252cefb1f7f2..c861cfff544a78617aa1ace730b50c094cf16330 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -143,7 +143,7 @@ class _ModelFn(object): def model_fn(self, features, mode, config): """Model function for the estimator. - Note that this does not take a `1abels` arg. This works, but `input_fn` must + Note that this does not take a `labels` arg. This works, but `input_fn` must return either `features` or, equivalently, `(features, None)`. Args: diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..6fc053759c58d30c24657dd22e7d12be46fc7a7e --- /dev/null +++ b/tensorflow/contrib/feature_column/BUILD @@ -0,0 +1,37 @@ +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "feature_column_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":sequential_feature_column", + ], +) + +py_library( + name = "sequential_feature_column", + srcs = ["python/feature_column/sequential_feature_column.py"], + srcs_version = "PY2AND3", + deps = [], +) diff --git a/tensorflow/contrib/feature_column/__init__.py b/tensorflow/contrib/feature_column/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6da7b126931effae9cc97091a27070d7013450d4 --- /dev/null +++ b/tensorflow/contrib/feature_column/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental utilities for tf.feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.feature_column.python.feature_column.sequential_feature_column import * + +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py new file mode 100644 index 0000000000000000000000000000000000000000..690a44ff4368663306733300a1ea70397fb93e1e --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py @@ -0,0 +1,19 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental methods for tf.feature_column sequential input.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/ffmpeg/decode_video_op.cc b/tensorflow/contrib/ffmpeg/decode_video_op.cc index d44032968d559bec14722902a4d47d22c46ea4aa..6f8ad486d10a825a277749157d68fa671b9f8d3a 100644 --- a/tensorflow/contrib/ffmpeg/decode_video_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_video_op.cc @@ -102,16 +102,12 @@ REGISTER_OP("DecodeVideo") return Status::OK(); }) .Doc(R"doc( -Processes the contents of an audio file into a tensor using FFmpeg to decode +Processes the contents of an video file into a tensor using FFmpeg to decode the file. -One row of the tensor is created for each channel in the audio file. Each -channel contains audio samples starting at the beginning of the audio and -having `1/samples_per_second` time between them. If the `channel_count` is -different from the contents of the file, channels will be merged or created. - -contents: The binary audio file contents, as a string or rank-0 string - tensor. +contents: The binary contents of the video file to decode. This is a + scalar. +output: A rank-4 `Tensor` that has `[frames, height, width, 3]` RGB as output. )doc"); } // namespace ffmpeg diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index c85b1837ab5b0c1a3cea0525918f7717228d2fab..e61221a6b0d34373279a379f356c99c379488182 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -47,20 +47,19 @@ std::vector FfmpegAudioCommandLine(const string& input_filename, int32 channel_count, const string& stream) { std::vector command({ - "-nostats", // No additional progress display. - "-nostdin", // No interactive commands accepted. - "-f", input_format_id, // eg: "mp3" - "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, - "-loglevel", "error", // Print errors only. - "-hide_banner", // Skip printing build options, version, etc. - "-map_metadata", "-1", // Copy global metadata from input to output. - "-vn", // No video recording. - "-ac:a:0", StrCat(channel_count), "-ar:a:0", - StrCat(samples_per_second), - // Output set (in several ways) to signed 16-bit little-endian ints. - "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", - "-sn", // No subtitle recording. - "-y" // Overwrite output file. + "-nostats", // No additional progress display. + "-nostdin", // No interactive commands accepted. + "-f", input_format_id, // eg: "mp3" + "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, + "-loglevel", "error", // Print errors only. + "-hide_banner", // Skip printing build options, version, etc. + "-map_metadata", "-1", // Copy global metadata from input to output. + "-vn", // No video recording. + "-ac:a:0", StrCat(channel_count), "-ar:a:0", StrCat(samples_per_second), + // Output set (in several ways) to signed 16-bit little-endian ints. + "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", + "-sn", // No subtitle recording. + "-y" // Overwrite output file. }); if (!stream.empty()) { command.emplace_back("-map"); @@ -75,21 +74,13 @@ std::vector FfmpegVideoCommandLine(const string& input_filename, const string& output_filename) { return {"-nostats", // No additional progress display. "-nostdin", // No interactive commands accepted. - "-i", - input_filename, - "-f", - "image2pipe", - "-probesize", - StrCat(kDefaultProbeSize), - "-loglevel", + "-i", input_filename, "-f", "image2pipe", "-probesize", + StrCat(kDefaultProbeSize), "-loglevel", // Info is needed to get the information about stream, etc. // It is generated to a separate file, not stdout/stderr. "info", "-hide_banner", // Skip printing build options, version, etc. - "-vcodec", - "rawvideo", - "-pix_fmt", - "rgb24", + "-vcodec", "rawvideo", "-pix_fmt", "rgb24", "-y", // Overwrite output file. StrCat(output_filename)}; } diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc index 85b61b26163d87a10d4e316720b4f633e038bbec..05728b3d37570d06f2f8af67e3b0612d21d07601 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc @@ -32,10 +32,8 @@ namespace tensorflow { namespace ffmpeg { namespace { -const char kTestWavFilename[] = - "contrib/ffmpeg/testdata/mono_10khz.wav"; -const char kTestMp3Filename[] = - "contrib/ffmpeg/testdata/test_sound1.mp3"; +const char kTestWavFilename[] = "contrib/ffmpeg/testdata/mono_10khz.wav"; +const char kTestMp3Filename[] = "contrib/ffmpeg/testdata/test_sound1.mp3"; // Set to true via a command line flag iff the test is expected to have FFmpeg // installed. @@ -139,7 +137,7 @@ TEST(FfmpegLibTest, TestRoundTripWav) { } // namespace ffmpeg } // namespace tensorflow -int main(int argc, char **argv) { +int main(int argc, char** argv) { tensorflow::string usage = tensorflow::ffmpeg::ParseTestFlags(&argc, argv); testing::InitGoogleTest(&argc, argv); if (argc != 1) { diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc index 36fc71794b06e0f3cb86c40b325ce50e8999c667..d6c885a32424334bfc28c830e3701f219aa244ee 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -20,8 +20,6 @@ #include #include - -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 673c51784229bd88011f8b33fb851a2885566220..4746cfe0720cb20f530dc919fe062db17a1dfe84 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -53,6 +53,7 @@ See the @{$python/contrib.framework} guide. @@assign_from_values_fn @@create_global_step @@filter_variables +@@fuse_op @@get_global_step @@get_or_create_global_step @@get_local_variables @@ -84,7 +85,12 @@ See the @{$python/contrib.framework} guide. @@py_func @@sort +@@get_placeholders + @@CriticalSection + +@@BoundedTensorSpec +@@TensorSpec """ from __future__ import absolute_import @@ -98,6 +104,11 @@ from tensorflow.contrib.framework.python.ops import * from tensorflow.python.framework.ops import prepend_name_scope from tensorflow.python.framework.ops import strip_name_scope +from tensorflow.python.ops.control_flow_ops import smart_cond +from tensorflow.python.ops.control_flow_ops import smart_constant_value + +from tensorflow.python.framework.tensor_spec import BoundedTensorSpec +from tensorflow.python.framework.tensor_spec import TensorSpec from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc index 6677dca752f84fc1ba7548b7739df04b7aaf14f7..5bf6b67529579e71a615c27e035111a58d5c02e0 100644 --- a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc +++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/contrib/framework/kernels/zero_initializer_op.h" -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" namespace tensorflow { @@ -81,8 +81,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); #define REGISTER_GPU_KERNELS(T) REGISTER_KERNELS(GPU, T); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #undef REGISTER_KERNELS -} // namespace tensorflow +} // namespace tensorflow diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.h b/tensorflow/contrib/framework/kernels/zero_initializer_op.h index 14c9268efa869ffd48b01dd2add44990ef7a43f8..99389a5ab6aa73c2ab0e522dd0f9fbc7093c8f4a 100644 --- a/tensorflow/contrib/framework/kernels/zero_initializer_op.h +++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.h @@ -29,5 +29,5 @@ struct TensorSetZero { }; } // namespace functor -} // end namespace tensorflow -#endif // TENSORFLOW_CONTRIB_FRAMEWORK_KERNELS_ZERO_INITIALIZER_OP_H_ +} // end namespace tensorflow +#endif // TENSORFLOW_CONTRIB_FRAMEWORK_KERNELS_ZERO_INITIALIZER_OP_H_ diff --git a/tensorflow/contrib/framework/ops/variable_ops.cc b/tensorflow/contrib/framework/ops/variable_ops.cc index 1ee8e1498cf07559fe3db78ef832e2cdf26bea1c..706134ba9a51de6253ba7463b17ff662ea740ed0 100644 --- a/tensorflow/contrib/framework/ops/variable_ops.cc +++ b/tensorflow/contrib/framework/ops/variable_ops.cc @@ -26,8 +26,8 @@ REGISTER_OP("ZeroInitializer") .Attr("T: realnumbertype") .SetAllowsUninitializedInput() .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); - return Status::OK(); + c->set_output(0, c->input(0)); + return Status::OK(); }) .Doc(R"doc( Initialize 'ref' with all zeros. This op requires that the tensor is not diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index a18ff2320d99726bb355ff6179fc97a070c2fec7..49eec3a3f1a0f357ea3adfade51e71cb0f89942d 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -133,6 +133,18 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, def get_placeholders(graph): """Get placeholders of a graph. + For example: + + ```python + a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a') + a = tf.placeholder(dtype=tf.int32, shape=[3, 2], name='b') + + tf.contrib.framework.get_placeholders(tf.get_default_graph()) + # Returns: + # [, + # ] + ``` + Args: graph: A tf.Graph. Returns: diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py index 2375ee4f550616ff60d20b87b5773704d8fbbe1e..476528b0dd3df05239d5dc402b466e06dd789985 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -108,4 +109,3 @@ def _AddNGrad(op, grad): """Same as gradient for AddN. Copies the gradient to all inputs.""" # Not broadcasting. return [grad] * len(op.inputs) - diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py index 8f44698da851b48abf831e957c80fa1643a58bda..35974b9e21d2d7423777a95a99f51c9cb4b453b2 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py @@ -27,16 +27,11 @@ import numpy as np from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 from tensorflow.python.eager import backprop -from tensorflow.python.eager import context as eager_context -from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.ops import gradients -from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py index b5e9f8df79262635bf579a6bf2260bc40c140c6f..45962098e93acfac414396ddbeaa847701ff2b4b 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py @@ -22,7 +22,6 @@ import numpy as np from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -31,7 +30,6 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import googletest - class AccumulateNV2Test(test_util.TensorFlowTestCase): """Tests of the new, differentiable version of accumulate_n""" @@ -62,8 +60,9 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): accum_n = av2.accumulate_n_v2(input_vars) sess.run(variables.global_variables_initializer()) accum_n_grad = gradients.gradients(accum_n, input_vars) - self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 - [g.eval() for g in accum_n_grad]) + self.assertAllEqual( + np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + [g.eval() for g in accum_n_grad]) # The tests below used to be in a separate class under cwise_ops_test.py, # which did not run in the default test target. @@ -75,8 +74,8 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20) ] random_tensors = [ - ops.convert_to_tensor( - x, dtype=dtypes_lib.float32) for x in random_arrays + ops.convert_to_tensor(x, dtype=dtypes_lib.float32) + for x in random_arrays ] tf_val = av2.accumulate_n_v2(random_tensors) np_val = random_arrays[0] @@ -95,21 +94,21 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): a = variables.Variable(0.2) b = variables.Variable(0.1) - tf_val = av2.accumulate_n_v2([a,b], shape=[2,2]) # Should be shape=[] + tf_val = av2.accumulate_n_v2([a, b], shape=[2, 2]) # Should be shape=[] def testIncompatibleShapes(self): with self.test_session(): with self.assertRaises(ValueError): - a = variables.Variable(np.array([0.1,0.2])) - b = variables.Variable(np.array([[0.3],[0.4]])) - tf_val = av2.accumulate_n_v2([a,b]) + a = variables.Variable(np.array([0.1, 0.2])) + b = variables.Variable(np.array([[0.3], [0.4]])) + tf_val = av2.accumulate_n_v2([a, b]) def testWrongType(self): with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) b = variables.Variable(0.1, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) + tf_val = av2.accumulate_n_v2([a, b], tensor_dtype=np.int32) def testWrongTypeOneInput(self): # Scenario that used to trigger a bug, even when testWrongType() worked diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 2bce00fde2459878a12027bb4d98bd3818bc92a2..409657fe1da0e5540cd2ad6070d86737c039e91f 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -53,7 +53,8 @@ net = layers.conv2d(net, 256, [5, 5], scope='conv2') ``` - Example of how to use tf.contrib.framework.add_arg_scope to enable your function to be called within an arg_scope later: + Example of how to use tf.contrib.framework.add_arg_scope to enable your + function to be called within an arg_scope later: @tf.contrib.framework.add_arg_scope def conv2d(*args, **kwargs) @@ -65,11 +66,10 @@ from __future__ import print_function from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator -__all__ = ['arg_scope', - 'add_arg_scope', - 'current_arg_scope', - 'has_arg_scope', - 'arg_scoped_arguments'] +__all__ = [ + 'arg_scope', 'add_arg_scope', 'current_arg_scope', 'has_arg_scope', + 'arg_scoped_arguments' +] _ARGSTACK = [{}] @@ -172,6 +172,7 @@ def add_arg_scope(func): Returns: A tuple with the decorated function func_with_args(). """ + def func_with_args(*args, **kwargs): current_scope = current_arg_scope() current_args = kwargs @@ -180,6 +181,7 @@ def add_arg_scope(func): current_args = current_scope[key_func].copy() current_args.update(kwargs) return func(*args, **current_args) + _add_op(func) setattr(func_with_args, '_key_op', _key_op(func)) return tf_decorator.make_decorator(func, func_with_args) diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 3f1ece4510578b5ac39849c577fffbb2a3be45a7..0754c3e0e30a340910a43a3ce86f6ca10afe848e 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -25,6 +25,7 @@ import re from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope from tensorflow.contrib.framework.python.ops import gen_variable_ops from tensorflow.contrib.util import loader +from tensorflow.core.protobuf import saver_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes @@ -32,9 +33,8 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import gen_state_ops -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import training_util from tensorflow.python.util.deprecation import deprecated @@ -685,7 +685,8 @@ def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False, 'Variable %s missing in checkpoint %s', var, model_path) var_list = available_vars if var_list: - saver = tf_saver.Saver(var_list, reshape=reshape_variables) + saver = tf_saver.Saver(var_list, reshape=reshape_variables, + write_version=saver_pb2.SaverDef.V1) def callback(session): saver.restore(session, model_path) return callback diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 6a56237f67c844a3daa546eb02d64c9e2658f639..bafd1d59418f0ba47ebbdaabbf06f8e5471fc1a1 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -25,13 +25,6 @@ limitations under the License. namespace tensorflow { -namespace { -// Return the string containing the list of valid activation modes, that can be -// used as an Attr() in REGISTER_OP. -string GetAllActivationModeAttrString() { return "activation_mode: {'Relu'}"; } - -} // namespace - // -------------------------------------------------------------------------- // TODO(pauldonnelly): Add support for double inputs and scales to this Op, diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py index a65d4bc50ff796977e8ea7f652b7cbe3fe37f673..96cdd8b1ca4d56d12d38ea961ae73f3a3aa28968 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py @@ -116,7 +116,7 @@ def build_fused_conv_bias_relu_graph(device, input_shape, filter_shape, strides, for _ in range(1, num_iters): with ops.control_dependencies([fused_out]): # pylint: disable=g-line-too-long - fused_out = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + fused_out = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( # pylint: disable=line-too-long inp, filt, bias, @@ -166,10 +166,10 @@ class FusedConv2DBiasActivationBenchmark(test.Benchmark): duration = (time.time() - start_time) / num_iters print("%s inputshape:%s filtershape:%s strides:%s padding:%s " - "%d iters: %.8f sec" % - (device, str(input_shape).replace(" ", ""), - str(filter_shape).replace(" ", ""), - str(strides).replace(" ", ""), padding, num_iters, duration)) + "%d iters: %.8f sec" % (device, str(input_shape).replace(" ", ""), + str(filter_shape).replace(" ", ""), + str(strides).replace(" ", ""), padding, + num_iters, duration)) name_template = ( "conv2d_{device}_input_shape_{inputshape}_filter_shape_{filtershape}_" "strides_{strides}_padding_{padding}") diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 5db34f0f8db93620b8b4a6b71f63b66ac718ee30..0eb0e3cbe20f5804db5476c08167d4e1c9080cfa 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -55,6 +55,7 @@ py_test( name = "train_test", srcs = ["python/train_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":features", ":namedtuples", diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 0d51c282a8977871185fb4200082feb7868cdbae..082c42eba180917e732bb7890129dfa94bf00fec 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -59,7 +59,11 @@ _summary_type_map = { class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). - This Estimator is backed by TFGAN. + This Estimator is backed by TFGAN. The network functions follow the TFGAN API + except for one exception: if either `generator_fn` or `discriminator_fn` have + an argument called `mode`, then the tf.Estimator mode is passed in for that + argument. This helps with operations like batch normalization, which have + different train and evaluation behavior. Example: @@ -233,9 +237,11 @@ def _gan_model_fn( def _make_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries, mode): """Make a `GANModel`, and optionally pass in `mode`.""" - # If `generator_fn` has an argument `mode`, pass mode to it. + # If network functions have an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=mode) + if 'mode' in inspect.getargspec(discriminator_fn).args: + discriminator_fn = functools.partial(discriminator_fn, mode=mode) gan_model = tfgan_train.gan_model( generator_fn, discriminator_fn, diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index e752f0bcccda418b79d4fdabb27807394cbbb425..387a62bd741bd42c03dc1bf70592060c29ccd7a8 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -54,7 +54,8 @@ def generator_fn(noise_dict, mode): return layers.fully_connected(noise, noise.shape[1].value) -def discriminator_fn(data, _): +def discriminator_fn(data, unused_conditioning, mode): + del unused_conditioning, mode return layers.fully_connected(data, 1) @@ -99,7 +100,6 @@ def mock_head(testcase, expected_generator_inputs, expected_real_data, else: testcase.assertEqual(discriminator_scope_name, gan_model.discriminator_scope.name) - testcase.assertEqual(_or_none(discriminator_fn), gan_model.discriminator_fn) with ops.control_dependencies(assertions): if mode == model_fn_lib.ModeKeys.TRAIN: diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 986a5ff6dcbeb2ff996f49137adc6d34e14c979f..fdfabd07c13f689d075ecbb8786d725fa8a62d01 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -28,6 +28,7 @@ from __future__ import division from __future__ import print_function import functools +import os import sys import tarfile @@ -189,20 +190,34 @@ def get_graph_def_from_resource(filename): return graph_pb2.GraphDef.FromString(resource_loader.load_resource(filename)) -def get_graph_def_from_url_tarball(url, filename): - """Get a GraphDef proto from a tarball on the web.""" - def _progress(count, block_size, total_size): - sys.stdout.write('\r>> Downloading %s %.1f%%' % ( - url, float(count * block_size) / float(total_size) * 100.0)) - sys.stdout.flush() - tar_filename, _ = urllib.request.urlretrieve(url, reporthook=_progress) +def get_graph_def_from_url_tarball(url, filename, tar_filename=None): + """Get a GraphDef proto from a tarball on the web. + + Args: + url: Web address of tarball + filename: Filename of graph definition within tarball + tar_filename: Temporary download filename (None = always download) + + Returns: + A GraphDef loaded from a file in the downloaded tarball. + """ + if not (tar_filename and os.path.exists(tar_filename)): + + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % + (url, + float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + + tar_filename, _ = urllib.request.urlretrieve(url, tar_filename, _progress) with tarfile.open(tar_filename, 'r:gz') as tar: proto_str = tar.extractfile(filename).read() return graph_pb2.GraphDef.FromString(proto_str) def _default_graph_def_fn(): - return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH) + return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH, + os.path.basename(INCEPTION_URL)) def run_inception(images, diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 1e18c699ba93b5f524341c65d0a2db84556b65a2..61dc8646ddc10605561ae6b19e90f4739c346608 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -181,7 +181,8 @@ class ClassifierMetricsTest(test.TestCase): batch_size = 3 img = array_ops.ones([batch_size, 299, 299, 3]) pool = _run_with_mock( - classifier_metrics.run_inception, img, + classifier_metrics.run_inception, + img, output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) self.assertTrue(isinstance(pool, ops.Tensor)) @@ -195,9 +196,12 @@ class ClassifierMetricsTest(test.TestCase): batch_size = 3 img = array_ops.ones([batch_size, 299, 299, 3]) logits, pool = _run_with_mock( - classifier_metrics.run_inception, img, - output_tensor=[classifier_metrics.INCEPTION_OUTPUT, - classifier_metrics.INCEPTION_FINAL_POOL]) + classifier_metrics.run_inception, + img, + output_tensor=[ + classifier_metrics.INCEPTION_OUTPUT, + classifier_metrics.INCEPTION_FINAL_POOL + ]) self.assertTrue(isinstance(logits, ops.Tensor)) self.assertTrue(isinstance(pool, ops.Tensor)) @@ -209,8 +213,10 @@ class ClassifierMetricsTest(test.TestCase): def test_inception_score_graph(self): """Test `inception_score` graph construction.""" - score = _run_with_mock(classifier_metrics.inception_score, - array_ops.zeros([6, 299, 299, 3]), num_batches=3) + score = _run_with_mock( + classifier_metrics.inception_score, + array_ops.zeros([6, 299, 299, 3]), + num_batches=3) self.assertTrue(isinstance(score, ops.Tensor)) score.shape.assert_has_rank(0) @@ -248,12 +254,14 @@ class ClassifierMetricsTest(test.TestCase): array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q) with self.assertRaisesRegexp(ValueError, 'must be floating type'): - classifier_metrics._kl_divergence( - p, array_ops.zeros([8, 10], dtype=dtypes.int32), q) + classifier_metrics._kl_divergence(p, + array_ops.zeros( + [8, 10], dtype=dtypes.int32), q) with self.assertRaisesRegexp(ValueError, 'must be floating type'): - classifier_metrics._kl_divergence( - p, p_logits, array_ops.zeros([10], dtype=dtypes.int32)) + classifier_metrics._kl_divergence(p, p_logits, + array_ops.zeros( + [10], dtype=dtypes.int32)) with self.assertRaisesRegexp(ValueError, 'must have rank 2'): classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q) @@ -266,8 +274,9 @@ class ClassifierMetricsTest(test.TestCase): def test_inception_score_value(self): """Test that `inception_score` gives the correct value.""" - logits = np.array([np.array([1, 2] * 500 + [4]), - np.array([4, 5] * 500 + [6])]) + logits = np.array( + [np.array([1, 2] * 500 + [4]), + np.array([4, 5] * 500 + [6])]) unused_image = array_ops.zeros([2, 299, 299, 3]) incscore = _run_with_mock(classifier_metrics.inception_score, unused_image) @@ -285,9 +294,11 @@ class ClassifierMetricsTest(test.TestCase): test_pool_real_a = np.float32(np.random.randn(512, 256)) test_pool_gen_a = np.float32(np.random.randn(512, 256)) - fid_op = _run_with_mock(classifier_metrics.frechet_classifier_distance, - test_pool_real_a, test_pool_gen_a, - classifier_fn=lambda x: x) + fid_op = _run_with_mock( + classifier_metrics.frechet_classifier_distance, + test_pool_real_a, + test_pool_gen_a, + classifier_fn=lambda x: x) with self.test_session() as sess: actual_fid = sess.run(fid_op) @@ -296,6 +307,33 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose(expected_fid, actual_fid, 0.0001) + def test_frechet_classifier_distance_covariance(self): + """Test that `frechet_classifier_distance` takes covariance into account.""" + np.random.seed(0) + + # Make num_examples > num_features to ensure scipy's sqrtm function + # doesn't return a complex matrix. + test_pool_reals, test_pool_gens = [], [] + for i in range(1, 11, 2): + test_pool_reals.append(np.float32(np.random.randn(2048, 256) * i)) + test_pool_gens.append(np.float32(np.random.randn(2048, 256) * i)) + + fid_ops = [] + for i in range(len(test_pool_reals)): + fid_ops.append(_run_with_mock( + classifier_metrics.frechet_classifier_distance, + test_pool_reals[i], + test_pool_gens[i], + classifier_fn=lambda x: x)) + + fids = [] + with self.test_session() as sess: + for fid_op in fid_ops: + fids.append(sess.run(fid_op)) + + # Check that the FIDs increase monotonically. + self.assertTrue(all(fid_a < fid_b for fid_a, fid_b in zip(fids, fids[1:]))) + def test_trace_sqrt_product_value(self): """Test that `trace_sqrt_product` gives the correct value.""" np.random.seed(0) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py index b960af28eaa969079b72c7aabcde2ad6cd1f5c68..871f1ad54e2559f5df28efa78f99997a866f7087 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py @@ -84,11 +84,11 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose( np.array([0.014, 0.014], 'f'), np.array([x[0] for x in wscores], 'f'), - rtol=0.1) + rtol=0.15) self.assertAllClose( np.array([0.014, 0.020], 'f'), np.array([x[1] for x in wscores], 'f'), - rtol=0.1) + rtol=0.15) def test_sliced_wasserstein_distance_svd(self): """Test the distance.""" diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 74811ff4096eb5215148f0565bf094b83408014c..0d1afad72da8a8e087239868e25ddebe23490d1e 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -39,12 +39,13 @@ def _assert_is_image(data): data.shape[1:].assert_is_fully_defined() -def add_gan_model_image_summaries(gan_model, grid_size=4): +def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): """Adds image summaries for real and fake images. Args: gan_model: A GANModel tuple. grid_size: The size of an image grid. + model_summaries: Also add summaries of the model. Raises: ValueError: If real and generated data aren't images. @@ -83,7 +84,9 @@ def add_gan_model_image_summaries(gan_model, grid_size=4): image_shape=generated_image_shape, num_channels=generated_channels), max_outputs=1) - add_gan_model_summaries(gan_model) + + if model_summaries: + add_gan_model_summaries(gan_model) def add_image_comparison_summaries(gan_model, num_comparisons=2, diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index a02d8772e130a2a927735e56c4272aba1f1a6996..45eb108586bed07434ac29595164745eac6054c1 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -72,8 +72,10 @@ def get_cyclegan_model(): class SummariesTest(test.TestCase): def _test_add_gan_model_image_summaries_impl(self, get_model_fn, - expected_num_summary_ops): - summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2) + expected_num_summary_ops, + model_summaries): + summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2, + model_summaries=model_summaries) self.assertEquals(expected_num_summary_ops, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) @@ -82,10 +84,13 @@ class SummariesTest(test.TestCase): summary.merge_all().eval() def test_add_gan_model_image_summaries(self): - self._test_add_gan_model_image_summaries_impl(get_gan_model, 5) + self._test_add_gan_model_image_summaries_impl(get_gan_model, 5, True) + + def test_add_gan_model_image_summaries_no_model(self): + self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False) def test_add_gan_model_image_summaries_for_cyclegan(self): - self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10) + self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, True) def _test_add_gan_model_summaries_impl(self, get_model_fn, expected_num_summary_ops): diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 23a3b60cc0055917bfc5243b0ebdbaea7b61edb9..39588b7219ebac1cc4855532be3fcc38e6381134 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -305,6 +305,7 @@ def wasserstein_gradient_penalty( discriminator_fn, discriminator_scope, epsilon=1e-10, + target=1.0, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -324,6 +325,8 @@ def wasserstein_gradient_penalty( discriminator_scope: If not `None`, reuse discriminators from this scope. epsilon: A small positive number added for numerical stability when computing the gradient norm. + target: Optional Python number or `Tensor` indicating the target value of + gradient norm. Defaults to 1.0. weights: Optional `Tensor` whose rank is either 0, or the same rank as `real_data` and `generated_data`, and must be broadcastable to them (i.e., all dimensions must be either `1`, or the same as the @@ -374,7 +377,7 @@ def wasserstein_gradient_penalty( # For numerical stability, add epsilon to the sum before taking the square # root. Note tf.norm does not add epsilon. slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = math_ops.square(slopes - 1.0) + penalties = math_ops.square(slopes / target - 1.0) penalty = losses.compute_weighted_loss( penalties, weights, scope=scope, loss_collection=loss_collection, reduction=reduction) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 7d2a7a254f6656198e47325dbb351618d85d147c..dbaa624ae9d6a5a5949db692e52c0c1deb18b8df 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -481,6 +481,29 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): }) self.assertAlmostEqual(self._expected_loss, loss, 5) + def test_loss_with_gradient_norm_target(self): + """Test loss value with non default gradient norm target.""" + generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + + loss = tfgan_losses.wasserstein_gradient_penalty( + generated_data, + real_data, + self._kwargs['generator_inputs'], + self._kwargs['discriminator_fn'], + self._kwargs['discriminator_scope'], + target=2.0) + + with self.test_session() as sess: + variables.global_variables_initializer().run() + loss = sess.run( + loss, + feed_dict={ + generated_data: self._generated_data_np, + real_data: self._real_data_np, + }) + self.assertAlmostEqual(1.0, loss, 5) + def test_reuses_scope(self): """Test that gradient penalty reuses discriminator scope.""" num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) @@ -620,7 +643,7 @@ class CombineAdversarialLossTest(test.TestCase): with self.test_session(use_gpu=True) as sess: for _ in range(10): # spot check closeness on more than one sample. gnorm_np, precond_gnorm_np = sess.run([gnorm, precond_gnorm]) - self.assertNear(gnorm_np, precond_gnorm_np, 1e-5) + self.assertNear(gnorm_np, precond_gnorm_np, 1e-4) class CycleConsistencyLossTest(test.TestCase): diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 5d0ac93aec7869bb1d9b8a174ba50d4bec2c2826..776eb11ecb1624544d24611d8fe6ca19768b8313 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -460,6 +460,7 @@ def gan_loss( # Auxiliary losses. gradient_penalty_weight=None, gradient_penalty_epsilon=1e-10, + gradient_penalty_target=1.0, mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, @@ -481,6 +482,9 @@ def gan_loss( small positive value used by the gradient penalty function for numerical stability. Note some applications will need to increase this value to avoid NaNs. + gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python + number or `Tensor` indicating the target value of gradient norm. See the + CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0. mutual_information_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the mutual information penalty. See https://arxiv.org/abs/1606.03657 for more @@ -539,7 +543,10 @@ def gan_loss( # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): gp_loss = tfgan_losses.wasserstein_gradient_penalty( - model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries) + model, + epsilon=gradient_penalty_epsilon, + target=gradient_penalty_target, + add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): info_loss = tfgan_losses.mutual_information_penalty( diff --git a/tensorflow/contrib/gdr/README.md b/tensorflow/contrib/gdr/README.md index 34ce60b360822888aa6223c89362ae1b0d9d991f..8242d93f129904828a11b61d48f2df8fb0f88bc3 100644 --- a/tensorflow/contrib/gdr/README.md +++ b/tensorflow/contrib/gdr/README.md @@ -119,4 +119,4 @@ In the original design (as in the reference), tensor buffers are only registered Reference === -Bairen Yi, Jiacheng Xia, Li Chen, and Kai Chen. 2017. Towards Zero Copy Dataflows using RDMA. In Proceedings of SIGCOMM Posters and Demos'17, Los Angeles, CA, USA, August 22-24, 2017, 3 pages. https://doi.org/10.1145/3123878.3123907 +Bairen Yi, Jiacheng Xia, Li Chen, and Kai Chen. 2017. Towards Zero Copy Dataflows using RDMA. In Proceedings of SIGCOMM Posters and Demos'17, Los Angeles, CA, USA, August 22-24, 2017, 3 pages. https://doi.org/10.1145/3123878.3131975 diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 5c7ac744289ab7729b4cc43ab9bedc9342284e65..81e70ae30a4c72dbcedd1aabfe758ecca4c8b366 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -86,8 +86,9 @@ int TryToReadNumaNode(ibv_device* device) { if (strings::safe_strto32(content, &value)) { if (value < 0) { LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" - << value << "), but there must be at least one NUMA node" - ", so returning NUMA node zero"; + << value + << "), but there must be at least one NUMA node" + ", so returning NUMA node zero"; return 0; } LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; @@ -290,8 +291,8 @@ Status GdrMemoryManager::Init() { // Host memory allocators for (Allocator* allocator : allocators) { auto* visitable_allocator = dynamic_cast(allocator); - CHECK(visitable_allocator) << "is not visitable for instrumentation" - << allocator->Name(); + CHECK(visitable_allocator) + << "is not visitable for instrumentation" << allocator->Name(); // Make sure we don't instrument the same allocator twice if (instrumented_.find(allocator) == std::end(instrumented_)) { visitable_allocator->AddAllocVisitor(alloc_visitor); @@ -635,8 +636,8 @@ void GdrMemoryManager::TensorFromTransportOptions( } else { checksum = GPUUtil::Checksum(*tensor); } - CHECK(checksum == remote_mr.checksum()) << "Checksum mismatch: " << checksum - << "!=" << remote_mr.checksum(); + CHECK(checksum == remote_mr.checksum()) + << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); #endif } done(Status::OK()); diff --git a/tensorflow/contrib/hvx/README.md b/tensorflow/contrib/hvx/README.md index 5a6f2f3086d708e5264b0483c211902ac8dce5f6..163993a3f6bb1bedcdffb32944a98c7cc846878e 100644 --- a/tensorflow/contrib/hvx/README.md +++ b/tensorflow/contrib/hvx/README.md @@ -1,60 +1,67 @@ # TensorFlow Runtime with HVX Acceleration -## Description +This README explain how to build and use the TensorFlow runtime with HVX Acceleration. HVX is an extension of Hexagon, a DSP provided by Qualcomm, which can compute vector calculations faster using less energy than ARM processors. -This README explain how to build and use the TensorFlow Runtime with HVX Acceleration. HVX is an extension of Hexagon which is a DSP provided by qualcomm which can compute vector calculations faster using lower energy than ARM processors. +## Dependencies + +* [Android SDK](https://developer.android.com/studio/index.html). +* [Android NDK](https://developer.android.com/ndk/index.html). Save the path in `${NDK_ROOT}`. +* A rooted Qualcomm-based Android device connected to the computer (preferably, a [Snapdragon Development Board](https://developer.qualcomm.com/hardware/additional-snapdragon), but it could be a rooted phone with a Qualcomm SoC, albeit this guide may not work with it). The device needs to be rooted for development and testing purposes, and shouldn't be needed in production. See [Behold, The Snapdragon MDP](https://developer.qualcomm.com/blog/behold-snapdragon-mdp) for more information. +* [Hexagon SDK v3.0](https://developer.qualcomm.com/software/hexagon-dsp-sdk/tools). Save the path in `${QUALCOMM_SDK}`. +* The current directory should be TensorFlow source code (`git clone https://github.com/tensorflow/tensorflow.git && cd tensorflow`), and saved into `${TF_ROOT_DIR}`. + +You may also need to add a test signature in the device to run HVX-based binaries. Follow the instructions in `${QUALCOMM_SDK}/docs/Tools_Signing.html`, using Python 2. + +Note that if the device is not rooted, you may not be able to get the serial number, push the test signature and/or run binary files that call HVX libraries. ## Quick Start Guide -We provides several tools to build and run inference with this runtime quickly. +We provide several tools to build and run inference with this runtime quickly. -#### All-in-one script to run inception model with prebuild hexagon library -If you don’t need to build your own implementation of hexagon HVX, we provide a shortcut to execute graphs by using pre-compiled binaries. +### Run inception model with a prebuilt Hexagon library +If you don’t need to build your own implementation of Hexagon HVX, we provide a shortcut to execute graphs by using pre-compiled binaries. + +```shell +./tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh -p ``` -git clone https://github.com/tensorflow/tensorflow.git -cd tensorflow -NDK_ROOT="/path/to/ndk" ./tensorflow/contrib/makefile/build_all_android.sh -X -``` -(-X downloads dependencies to hexagon HVX and graphs, and copy all dependencies to android and execute a test) -#### All-in-one script to run inception model by building entire libraries from source code - If you want to build your own implementation of hexagon HVX, we provide a sample all-in-one script to execute graphs which downloads source and build everything for hexagon. +The `-p` option makes the script download dependencies (i.e., Hexagon HVX binaries and graphs models), copy them to the Android device and execute a test. -``` -git clone https://github.com/tensorflow/tensorflow.git -cd tensorflow -QUALCOMM_SDK="/path/to/qualcomm/sdk" NDK_ROOT="/path/to/ndk" ./tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh +### Run inception model by building all from the source code + +If you want to build your own implementation of Hexagon HVX, we provide a sample all-in-one script to execute graphs which downloads the source and builds everything that's necessary. + +```shell +./tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh ``` ## Building libraries If you've finished walking through the quick start guide, you may want to try building each binary manually. -#### Build libhexagon_nn_skel.so -Download hexagon nn library from codeaurora.org and build it. +### Build libhexagon\_nn\_skel.so -``` +Download Hexagon NN library from codeaurora.org and build it. + +```shell git clone https://source.codeaurora.org/quic/hexagon_nn/nnlib cd nnlib ``` -(Just follow instructions in README.HOW_TO_BUILD. You can find libhexagon_nn_skel.so in hexagon_Release_dynamic_toolv72_v60/ship) -Then copy the generated binary to GEN_LIBS_DIR +Just follow the instructions in `README.HOW_TO_BUILD`. You can find the file `libhexagon_nn_skel.so` in `hexagon_Release_dynamic_toolv72_v60/ship`. +Then copy the generated binary to `${GEN_LIBS_DIR}`. -``` +```shell GEN_LIBS_DIR="/path/to/a/dir/to/store/hexagon/libraries" cp -v "hexagon_Release_dynamic_toolv72_v60/ship/libhexagon_nn_skel.so" "${GEN_LIBS_DIR}" ``` -#### Build libhexagon_controller.so +### Build libhexagon\_controller.so + Download tensorflow and build hexagon controller. -``` -git clone https://github.com/tensorflow/tensorflow.git -cd tensorflow -TF_ROOT_DIR="$(pwd)" -QUALCOMM_SDK="/path/to/qualcomm/sdk" +```shell GENERATED_NNLIB_DIRECTORY="/path/to/nnlib" GENERATED_HEXAGON_CONTROLLER_DIRECTORY="${QUALCOMM_SDK}/examples/common/generated_hexagon_controller" rm -rf "${GENERATED_HEXAGON_CONTROLLER_DIRECTORY}" @@ -70,12 +77,12 @@ make tree VERBOSE=1 V=android_Release cp -v "${GENERATED_HEXAGON_CONTROLLER_DIRECTORY}/android_Release/ship/libhexagon_controller.so" "${GEN_LIBS_DIR}" ``` -#### Build tensorflow linking hexagon library -Build tensorflow with the build_all_android.sh with specifying -x option. +### Build TensorFlow linking Hexagon library -``` +Build TensorFlow with `build_all_android.sh` specifying the `-x` option. + +```shell BUILD_ALL_ANDROID_PATH="${TF_ROOT_DIR}/tensorflow/contrib/makefile/build_all_android.sh" -NDK_ROOT="/path/to/ndk/root" CC_PREFIX=${CC_PREFIX} NDK_ROOT=${NDK_ROOT} "${BUILD_ALL_ANDROID_PATH}" \ -x "${GEN_LIBS_DIR}" \ @@ -83,11 +90,11 @@ CC_PREFIX=${CC_PREFIX} NDK_ROOT=${NDK_ROOT} "${BUILD_ALL_ANDROID_PATH}" \ -t hexagon_graph_execution ``` -#### Push binaries to your Android device +### Push binaries to your Android device Before running tests on your Android device, you need to push several binaries to it. -``` +```shell adb push "${GEN_LIBS_DIR}/libhexagon_controller.so" "/data/local/tmp" adb push "${GEN_LIBS_DIR}/libhexagon_nn_skel.so" "/vendor/lib/rfsa/adsp" adb push -p \ @@ -100,40 +107,54 @@ adb shell chmod "${ANDROID_EXEC_FILE_MODE}" \ adb wait-for-device ``` -#### Run tests on the device +### Run tests on the device Finally, you can run the inference tests on your device. -``` +```shell adb shell 'LD_LIBRARY_PATH=/data/local/tmp:$LD_LIBRARY_PATH' \ "/data/local/tmp/hexagon_graph_execution" ``` -#### Troubleshooting -If you're using the Open-Q 820 Snapdragon development kit, you may run into an issue with running the executable due to a missing testsig library. From the Hexagon SDK documentation: *Dynamic shared objects are required to be digitally signed and then authenticated at runtime before they are allowed to be loaded and executed.* Generating a testsig library is necessary to run the unsigned sample library built from this project. +### Troubleshooting + +#### Testsig issue + +If you're using the Open-Q 820 Snapdragon Development Kit, you may run into an issue with running the executable due to a missing `testsig` library. From the Hexagon SDK documentation: *Dynamic shared objects are required to be digitally signed and then authenticated at runtime before they are allowed to be loaded and executed.* Generating a testsig library is necessary to run the unsigned sample library built from this project. -If the lack of a testsig library is your problem, you will see errors of the type: +If the lack of a `testsig` library is your problem, you will see errors of the type: `vendor/qcom/proprietary/adsprpc/src/fastrpc_apps_user.c:169::error: -1: 0 == (nErr = remotectl_open(name, (int*)ph, dlerrstr, sizeof(dlerrstr), &dlerr))` -appearing in adb logcat. - -There are several ways to create the testsig library, the only prerequisite is Python and the correct version of the Hexagon-SDK. The following steps is one way to create this library: -1. Run adb as root: `adb root` -2. Run the command `adb shell cat /sys/devices/soc0/serial_number` -3. Convert the decimal number you get as output to hex -4. Run the python script: `python ${QUALCOMM_SDK}/tools/elfsigner/elfsigner.py -t $(SERIAL_NUMBER_HEX_VALUE)` -5. The output of the python script is a shared library stored in ${QUALCOMM_SDK}/tools/elfsigner/output/testsig-$(SERIAL_NUMBER_HEX_VALUE).so -6. Push the shared library to your device: +appearing in `adb logcat` or ["Expected: (version) >= (1), actual: 0 vs 1" while running a binary from adb](https://github.com/tensorflow/tensorflow/issues/11210). + +You need to add a test signature, as described at the beginning of this README. After rebooting your device, you should be able to run the sample application. + +#### Qualcomm SDK Linux installation fails with "Malformed \uxxxx encoding" + +The installation file is based on LaunchAnywhere, which fails in Linux if the `PS1` env variable contains non-common Unicode chars: + ``` -adb root -adb wait-for-device -adb remount -adb wait-for-device -adb shell mkdir /system/lib/rfsa -adb shell mkdir /system/lib/rfsa/adsp -adb push ${QUALCOMM_SDK}/tools/elfsigner/output/testsig-$(SERIAL_NUMBER_HEX_VALUE).so /system/lib/rfsa/adsp/ +Preparing to install... +Extracting the JRE from the installer archive... +Unpacking the JRE... +Extracting the installation resources from the installer archive... +Configuring the installer for this system's environment... + +Launching installer... + +An internal LaunchAnywhere application error has occurred and this application cannot proceed. (LAX) + +Stack Trace: +java.lang.IllegalArgumentException: Malformed \uxxxx encoding. + at java.util.Properties.loadConvert(Properties.java:574) + at java.util.Properties.load0(Properties.java:391) + at java.util.Properties.load(Properties.java:317) + at com.zerog.common.java.util.PropertiesUtil.loadProperties(Unknown Source) + at com.zerog.lax.LAX.(Unknown Source) + at com.zerog.lax.LAX.main(Unknown Source) ``` -After rebooting your device, you should be able to run the sample application. +It can be solved by temporarily assigning the `PS1` environment variable to something simple, such as '$'. + +## Maintainers -Maintainers: -- Satoshi Kataoka (satok@google.com, github.com/satok16) +* Satoshi Kataoka (satok@google.com, github.com/satok16) diff --git a/tensorflow/contrib/image/kernels/bipartite_match_op.cc b/tensorflow/contrib/image/kernels/bipartite_match_op.cc index 7d207c388b159c4ad0f25032811e97b153fd50d6..726adb07775e3243fdc96a7f1a00dbb0304d3dd9 100644 --- a/tensorflow/contrib/image/kernels/bipartite_match_op.cc +++ b/tensorflow/contrib/image/kernels/bipartite_match_op.cc @@ -85,7 +85,7 @@ class BipartiteMatchOp : public OpKernel { context->allocate_output(1, TensorShape({num_input_columns}), &column_to_row_match_indices)); - typename TTypes::ConstTensor distance_mat = + TTypes::ConstTensor distance_mat = input_distance_mat.shaped( {num_input_rows, num_input_columns}); diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index 6adf837ca0ab506bd18f5e2e1fc1847e31d782bf..c2e32da133b32c8fe169302668031af8bace2c22 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -43,9 +43,9 @@ template struct FillProjectiveTransform; typedef Eigen::ThreadPoolDevice CPUDevice; using functor::FillProjectiveTransform; +using generator::Interpolation; using generator::INTERPOLATION_BILINEAR; using generator::INTERPOLATION_NEAREST; -using generator::Interpolation; using generator::ProjectiveGenerator; template @@ -72,11 +72,12 @@ class ImageProjectiveTransform : public OpKernel { const Tensor& transform_t = ctx->input(1); OP_REQUIRES(ctx, images_t.shape().dims() == 4, errors::InvalidArgument("Input images must have rank 4")); - OP_REQUIRES(ctx, (TensorShapeUtils::IsMatrix(transform_t.shape()) && - (transform_t.dim_size(0) == images_t.dim_size(0) || - transform_t.dim_size(0) == 1) && - transform_t.dim_size(1) == - ProjectiveGenerator::kNumParameters), + OP_REQUIRES(ctx, + (TensorShapeUtils::IsMatrix(transform_t.shape()) && + (transform_t.dim_size(0) == images_t.dim_size(0) || + transform_t.dim_size(0) == 1) && + transform_t.dim_size(1) == + ProjectiveGenerator::kNumParameters), errors::InvalidArgument( "Input transform should be num_images x 8 or 1 x 8")); auto images = images_t.tensor(); diff --git a/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc index 9f0bf37aed3fc9aeefb7602ef3fda4cfd76f1917..8f9a5c28039b74a874028826ca8a6d5a36ab7cf4 100755 --- a/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc @@ -143,8 +143,8 @@ class SingleImageRandomDotStereogramsOp : public OpKernel { } data_box_left = deltaX_border_image / 2; // Center DATA in X dimension - data_box_width = data_Xwindow; // width of scan line - data_box_height = data_Ywindow; // hight of image + data_box_width = data_Xwindow; // width of scan line + data_box_height = data_Ywindow; // hight of image const T* inputZ = input_tensor.flat().data(); // Flatten input Z buffer diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc index 1f41f243f2ebc0d1e884728defa160bf6d6c34ce..8139d4272d6950815bd39a64e86e0f7422e6f799 100755 --- a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc @@ -58,7 +58,9 @@ REGISTER_OP("SingleImageRandomDotStereograms") int colors; TF_RETURN_IF_ERROR(c->GetAttr("number_colors", &colors)); - c->set_output(0, c->MakeShape({y_dim, x_dim, colors > 256? c->MakeDim(3) : c->MakeDim(1)})); + c->set_output( + 0, c->MakeShape( + {y_dim, x_dim, colors > 256 ? c->MakeDim(3) : c->MakeDim(1)})); return Status::OK(); }) .Doc(R"doc( diff --git a/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py index bf0c97245fc5c70469350ec66023f4d1474930e2..3f4029e558d92a2b6539456bf9cf49ec2d21c9f3 100644 --- a/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py @@ -18,13 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from six.moves import xrange # pylint: disable=redefined-builtin - from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms \ import single_image_random_dot_stereograms -from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 63377ae50310db51a3111c5a6e00df7d75dccc0b..c139ae89d8d682d6b87813c3a21703ffa762f28e 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -40,7 +40,7 @@ ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) def rotate(images, angles, interpolation="NEAREST", name=None): - """Rotate image(s) by the passed angle(s) in radians. + """Rotate image(s) counterclockwise by the passed angle(s) in radians. Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) @@ -290,31 +290,76 @@ def compose_transforms(*transforms): """ assert transforms, "transforms cannot be empty" with ops.name_scope("compose_transforms"): - composed = _flat_transforms_to_matrices(transforms[0]) + composed = flat_transforms_to_matrices(transforms[0]) for tr in transforms[1:]: # Multiply batches of matrices. - composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr)) - return _transform_matrices_to_flat(composed) + composed = math_ops.matmul(composed, flat_transforms_to_matrices(tr)) + return matrices_to_flat_transforms(composed) -def _flat_transforms_to_matrices(transforms): - # Make the transform(s) 2D in case the input is a single transform. - transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8])) - num_transforms = array_ops.shape(transforms)[0] - # Add a column of ones for the implicit last entry in the matrix. - return array_ops.reshape( - array_ops.concat( - [transforms, array_ops.ones([num_transforms, 1])], axis=1), - constant_op.constant([-1, 3, 3])) +def flat_transforms_to_matrices(transforms): + """Converts `tf.contrib.image` projective transforms to affine matrices. + Note that the output matrices map output coordinates to input coordinates. For + the forward transformation matrix, call `tf.linalg.inv` on the result. -def _transform_matrices_to_flat(transform_matrices): - # Flatten each matrix. - transforms = array_ops.reshape(transform_matrices, - constant_op.constant([-1, 9])) - # Divide each matrix by the last entry (normally 1). - transforms /= transforms[:, 8:9] - return transforms[:, :8] + Args: + transforms: Vector of length 8, or batches of transforms with shape + `(N, 8)`. + + Returns: + 3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the + *output coordinates* (in homogeneous coordinates) of each transform to the + corresponding *input coordinates*. + + Raises: + ValueError: If `transforms` have an invalid shape. + """ + with ops.name_scope("flat_transforms_to_matrices"): + transforms = ops.convert_to_tensor(transforms, name="transforms") + if transforms.shape.ndims not in (1, 2): + raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms) + # Make the transform(s) 2D in case the input is a single transform. + transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8])) + num_transforms = array_ops.shape(transforms)[0] + # Add a column of ones for the implicit last entry in the matrix. + return array_ops.reshape( + array_ops.concat( + [transforms, array_ops.ones([num_transforms, 1])], axis=1), + constant_op.constant([-1, 3, 3])) + + +def matrices_to_flat_transforms(transform_matrices): + """Converts affine matrices to `tf.contrib.image` projective transforms. + + Note that we expect matrices that map output coordinates to input coordinates. + To convert forward transformation matrices, call `tf.linalg.inv` on the + matrices and use the result here. + + Args: + transform_matrices: One or more affine transformation matrices, for the + reverse transformation in homogeneous coordinates. Shape `(3, 3)` or + `(N, 3, 3)`. + + Returns: + 2D tensor of flat transforms with shape `(N, 8)`, which may be passed into + `tf.contrib.image.transform`. + + Raises: + ValueError: If `transform_matrices` have an invalid shape. + """ + with ops.name_scope("matrices_to_flat_transforms"): + transform_matrices = ops.convert_to_tensor( + transform_matrices, name="transform_matrices") + if transform_matrices.shape.ndims not in (2, 3): + raise ValueError( + "Matrices should be 2D or 3D, got: %s" % transform_matrices) + # Flatten each matrix. + transforms = array_ops.reshape(transform_matrices, + constant_op.constant([-1, 9])) + # Divide each matrix by the last entry (normally 1). + transforms /= transforms[:, 8:9] + return transforms[:, :8] @ops.RegisterGradient("ImageProjectiveTransform") @@ -346,9 +391,9 @@ def _image_projective_transform_grad(op, grad): raise TypeError("Transforms should have rank 1 or 2.") # Invert transformations - transforms = _flat_transforms_to_matrices(transforms=transforms) + transforms = flat_transforms_to_matrices(transforms=transforms) inverse = linalg_ops.matrix_inverse(transforms) - transforms = _transform_matrices_to_flat(inverse) + transforms = matrices_to_flat_transforms(inverse) output = gen_image_ops.image_projective_transform( grad, transforms, interpolation=interpolation) if len(image_or_images.get_shape()) == 2: diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index bb766e59d2cee648042cc08be466796d9233ad66..d4a6a5bcbb52511d4093587814100b2a0e8b2420 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -26,18 +26,20 @@ _sirds_ops = loader.load_op_library( resource_loader.get_path_to_datafile( "_single_image_random_dot_stereograms.so")) -def single_image_random_dot_stereograms( - depth_values, - hidden_surface_removal=None, - convergence_dots_size=None, - dots_per_inch=None, - eye_separation=None, mu=None, - normalize=None, normalize_max=None, - normalize_min=None, - border_level=None, - number_colors=None, - output_image_shape=None, - output_data_window=None): + +def single_image_random_dot_stereograms(depth_values, + hidden_surface_removal=None, + convergence_dots_size=None, + dots_per_inch=None, + eye_separation=None, + mu=None, + normalize=None, + normalize_max=None, + normalize_min=None, + border_level=None, + number_colors=None, + output_image_shape=None, + output_data_window=None): """Output a RandomDotStereogram Tensor for export via encode_PNG/JPG OP. Given the 2-D tensor 'depth_values' with encoded Z values, this operation @@ -45,7 +47,8 @@ def single_image_random_dot_stereograms( for the encode_PNG/JPG ops. Be careful with image compression as this may corrupt the encode 3-D data witin the image. - Based upon [this paper](http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper). + Based upon [this + paper](http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper). This outputs a SIRDS image as picture_out.png: @@ -113,7 +116,8 @@ def single_image_random_dot_stereograms( hidden_surface_removal=hidden_surface_removal, convergence_dots_size=convergence_dots_size, dots_per_inch=dots_per_inch, - eye_separation=eye_separation, mu=mu, + eye_separation=eye_separation, + mu=mu, normalize=normalize, normalize_max=normalize_max, normalize_min=normalize_min, @@ -123,4 +127,5 @@ def single_image_random_dot_stereograms( output_data_window=output_data_window) return result + ops.NotDifferentiable("SingleImageRandomDotStereograms") diff --git a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc index ca288c1f737d25faac678f5c199d5c1e49f721cb..886f6798150c57d8066546b0919481d3878882fc 100644 --- a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc +++ b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc @@ -34,9 +34,8 @@ class ObtainNextOp : public OpKernel { // Allocate output. Tensor* output_tensor = nullptr; - OP_REQUIRES_OK( - ctx, - ctx->allocate_output("out_element", TensorShape({}), &output_tensor)); + OP_REQUIRES_OK(ctx, ctx->allocate_output("out_element", TensorShape({}), + &output_tensor)); // Obtain mutex for the "counter" tensor. mutex* mu; diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..efb403462a6e5df5b69ac0735ffc03f40d4a252c --- /dev/null +++ b/tensorflow/contrib/kafka/BUILD @@ -0,0 +1,105 @@ +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +tf_kernel_library( + name = "kafka_kernels", + srcs = ["kernels/kafka_dataset_ops.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:bounds_check_lib", + "//tensorflow/core/kernels:dataset", + "//third_party/eigen3", + "@kafka", + ], +) + +tf_gen_op_libs( + op_lib_names = ["kafka_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_kafka_ops", + out = "python/ops/gen_kafka_ops.py", + require_shape_functions = True, + deps = [":kafka_ops_op_lib"], +) + +py_library( + name = "kafka", + srcs = [ + "__init__.py", + "python/ops/kafka_dataset_ops.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":gen_kafka_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", + ], +) + +# The Kafka server has to be setup before running the test. +# The Kafka server is setup through Docker so the Docker engine +# has to be installed. +# +# Once the Docker engine is ready: +# To setup the Kafka server: +# $ bash tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh start kafka +# +# After the test is complete: +# To team down the Kafka server: +# $ bash tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh stop kafka +tf_py_test( + name = "kafka_test", + srcs = ["python/kernel_tests/kafka_test.py"], + additional_deps = [ + ":kafka", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "notap", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/ndlstm/python/__init__.py b/tensorflow/contrib/kafka/__init__.py similarity index 68% rename from tensorflow/contrib/ndlstm/python/__init__.py rename to tensorflow/contrib/kafka/__init__.py index 1aa51a6ec40c042ca3c26c6b08e5bdb8a42a12bd..4d755c40568dfa2f7f6f617cf3180268837a5ca0 100644 --- a/tensorflow/contrib/ndlstm/python/__init__.py +++ b/tensorflow/contrib/kafka/__init__.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,14 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Init file, giving convenient access to all ndlstm ops.""" +"""Kafka Dataset. + +@@KafkaDataset +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=wildcard-import,g-importing-member -from tensorflow.contrib.ndlstm.python.lstm1d import * -from tensorflow.contrib.ndlstm.python.lstm2d import * -from tensorflow.contrib.ndlstm.python.misc import * -# pylint: enable=wildcard-import +from tensorflow.contrib.kafka.python.ops.kafka_dataset_ops import KafkaDataset + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "KafkaDataset", +] + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..88ef5f357113372b0a2d0cb13382ac980a61252d --- /dev/null +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -0,0 +1,321 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/dataset.h" + +#include "tensorflow/core/framework/tensor.h" + +#include "src-cpp/rdkafkacpp.h" + +namespace tensorflow { + +class KafkaDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* topics_tensor; + OP_REQUIRES_OK(ctx, ctx->input("topics", &topics_tensor)); + OP_REQUIRES( + ctx, topics_tensor->dims() <= 1, + errors::InvalidArgument("`topics` must be a scalar or a vector.")); + + std::vector topics; + topics.reserve(topics_tensor->NumElements()); + for (int i = 0; i < topics_tensor->NumElements(); ++i) { + topics.push_back(topics_tensor->flat()(i)); + } + + std::string servers = ""; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "servers", &servers)); + std::string group = ""; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "group", &group)); + bool eof = false; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "eof", &eof)); + int64 timeout = -1; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "timeout", &timeout)); + OP_REQUIRES(ctx, (timeout > 0), + errors::InvalidArgument( + "Timeout value should be large than 0, got ", timeout)); + *output = new Dataset(ctx, std::move(topics), servers, group, eof, timeout); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, std::vector topics, + const string& servers, const string& group, const bool eof, + const int64 timeout) + : GraphDatasetBase(ctx), + topics_(std::move(topics)), + servers_(servers), + group_(group), + eof_(eof), + timeout_(timeout) {} + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Kafka")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() override { return "KafkaDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + Node* topics = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(topics_, &topics)); + Node* servers = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(servers_, &servers)); + Node* group = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(group_, &group)); + Node* eof = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(eof_, &eof)); + Node* timeout = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(timeout_, &timeout)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {topics, servers, group, eof, timeout}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + // We are currently processing a topic, so try to read the next line. + if (consumer_.get()) { + while (true) { + if (limit_ >= 0 && + (topic_partition_->offset() >= limit_ || offset_ >= limit_)) { + // EOF current topic + break; + } + std::unique_ptr message( + consumer_->consume(dataset()->timeout_)); + if (message->err() == RdKafka::ERR_NO_ERROR) { + // Produce the line as output. + Tensor line_tensor(cpu_allocator(), DT_STRING, {}); + line_tensor.scalar()() = + std::string(static_cast(message->payload()), + message->len()); + out_tensors->emplace_back(std::move(line_tensor)); + *end_of_sequence = false; + // Sync offset + offset_ = message->offset(); + return Status::OK(); + } + + if (message->err() == RdKafka::ERR__PARTITION_EOF && + dataset()->eof_) { + // EOF current topic + break; + } + if (message->err() != RdKafka::ERR__TIMED_OUT) { + return errors::Internal("Failed to consume:", + message->errstr()); + } + message.reset(nullptr); + consumer_->poll(0); + } + + // We have reached the end of the current topic, so maybe + // move on to next topic. + ResetStreamsLocked(); + ++current_topic_index_; + } + + // Iteration ends when there are no more topic to process. + if (current_topic_index_ == dataset()->topics_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_topic_index"), + current_topic_index_)); + + // `consumer_` is empty if + // 1. GetNext has not been called even once. + // 2. All topics have been read and iterator has been exhausted. + if (consumer_.get()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("current_pos"), offset_)); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + ResetStreamsLocked(); + int64 current_topic_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_topic_index"), + ¤t_topic_index)); + current_topic_index_ = size_t(current_topic_index); + // The key "current_pos" is written only if the iterator was saved + // with an open topic. + if (reader->Contains(full_name("current_pos"))) { + int64 current_pos; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("current_pos"), ¤t_pos)); + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + topic_partition_->set_offset(current_pos); + if (topic_partition_->offset() != current_pos) { + return errors::Internal("Failed to restore to offset ", + current_pos); + } + offset_ = current_pos; + } + return Status::OK(); + } + + private: + // Sets up Kafka streams to read from the topic at + // `current_topic_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_topic_index_ >= dataset()->topics_.size()) { + return errors::InvalidArgument( + "current_topic_index_:", current_topic_index_, + " >= topics_.size():", dataset()->topics_.size()); + } + + // Actually move on to next topic. + string entry = dataset()->topics_[current_topic_index_]; + + std::vector parts = str_util::Split(entry, ":"); + if (parts.size() < 1) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + string topic = parts[0]; + int32 partition = 0; + if (parts.size() > 1) { + if (!strings::safe_strto32(parts[1], &partition)) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + } + int64 offset = 0; + if (parts.size() > 2) { + if (!strings::safe_strto64(parts[2], &offset)) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + } + + topic_partition_.reset( + RdKafka::TopicPartition::create(topic, partition, offset)); + + offset_ = topic_partition_->offset(); + limit_ = -1; + if (parts.size() > 3) { + if (!strings::safe_strto64(parts[3], &limit_)) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + } + + std::unique_ptr conf( + RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL)); + std::unique_ptr topic_conf( + RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC)); + + std::string errstr; + + RdKafka::Conf::ConfResult result = + conf->set("default_topic_conf", topic_conf.get(), errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("Failed to set default_topic_conf:", errstr); + } + + result = conf->set("bootstrap.servers", dataset()->servers_, errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("Failed to set bootstrap.servers ", + dataset()->servers_, ":", errstr); + } + result = conf->set("group.id", dataset()->group_, errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("Failed to set group.id ", dataset()->group_, + ":", errstr); + } + + consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr)); + if (!consumer_.get()) { + return errors::Internal("Failed to create consumer:", errstr); + } + + std::vector partitions; + partitions.emplace_back(topic_partition_.get()); + RdKafka::ErrorCode err = consumer_->assign(partitions); + if (err != RdKafka::ERR_NO_ERROR) { + return errors::Internal( + "Failed to assign partition [", topic_partition_->topic(), ", ", + topic_partition_->partition(), ", ", topic_partition_->offset(), + "]:", RdKafka::err2str(err)); + } + + return Status::OK(); + } + + // Resets all Kafka streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + consumer_->unassign(); + consumer_->close(); + consumer_.reset(nullptr); + } + + mutex mu_; + size_t current_topic_index_ GUARDED_BY(mu_) = 0; + int64 offset_ GUARDED_BY(mu_) = 0; + int64 limit_ GUARDED_BY(mu_) = -1; + std::unique_ptr topic_partition_ GUARDED_BY(mu_); + std::unique_ptr consumer_ GUARDED_BY(mu_); + }; + + const std::vector topics_; + const std::string servers_; + const std::string group_; + const bool eof_; + const int64 timeout_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("KafkaDataset").Device(DEVICE_CPU), + KafkaDatasetOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/kafka_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cdf16103bab2b22d51c144d21a589e1e39f2f0b --- /dev/null +++ b/tensorflow/contrib/kafka/ops/kafka_ops.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("KafkaDataset") + .Input("topics: string") + .Input("servers: string") + .Input("group: string") + .Input("eof: bool") + .Input("timeout: int64") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that emits the messages of one or more Kafka topics. + +topics: A `tf.string` tensor containing one or more subscriptions, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. +servers: A list of bootstrap servers. +group: The consumer group id. +eof: If True, the kafka reader will stop on EOF. +timeout: The timeout value for the Kafka Consumer to wait + (in millisecond). +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py new file mode 100644 index 0000000000000000000000000000000000000000..621911876fc502ece76b08eb6c28697b3c12c863 --- /dev/null +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py @@ -0,0 +1,115 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for KafkaDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class KafkaDatasetTest(test.TestCase): + + def setUp(self): + # The Kafka server has to be setup before the test + # and tear down after the test manually. + # The docker engine has to be installed. + # + # To setup the Kafka server: + # $ bash kafka_test.sh start kafka + # + # To team down the Kafka server: + # $ bash kafka_test.sh stop kafka + pass + + def testKafkaDataset(self): + topics = array_ops.placeholder(dtypes.string, shape=[None]) + num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_dataset_ops.KafkaDataset( + topics, group="test", eof=True).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.test_session() as sess: + # Basic test: read from topic 0. + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual("D" + str(i), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from topic 1. + sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) + for i in range(5): + self.assertEqual("D" + str(i + 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from both topics. + sess.run( + init_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 1 + }) + for j in range(2): + for i in range(5): + self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test repeated iteration through both files. + sess.run( + init_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10 + }) + for _ in range(10): + for j in range(2): + for i in range(5): + self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test batched and repeated iteration through both files. + sess.run( + init_batch_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10, + batch_size: 5 + }) + for _ in range(10): + self.assertAllEqual(["D" + str(i) for i in range(5)], + sess.run(get_next)) + self.assertAllEqual(["D" + str(i + 5) for i in range(5)], + sess.run(get_next)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..adf027b8e714124cde2b4618546e20c6b7162e1f --- /dev/null +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e +set -o pipefail + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 start|stop " >&2 + exit 1 +fi + +container=$2 +if [ "$1" == "start" ]; then + docker run -d --rm --net=host --name=$container spotify/kafka + echo Wait 5 secs until kafka is up and running + sleep 5 + echo Create test topic + docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-topics.sh --create --zookeeper localhost:2181 --replication-factor 1 --partitions 1 --topic test' + echo Create test message + docker exec $container bash -c 'echo -e "D0\nD1\nD2\nD3\nD4\nD5\nD6\nD7\nD8\nD9" > /test' + echo Produce test message + docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-console-producer.sh --topic test --broker-list 127.0.0.1:9092 < /test' + + echo Container $container started successfully +elif [ "$1" == "stop" ]; then + docker rm -f $container + + echo Container $container stopped successfully +else + echo "Usage: $0 start|stop " >&2 + exit 1 +fi + + + diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8e51d27a342359881de072c3979a2b5a7fc034ea --- /dev/null +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -0,0 +1,74 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kafka Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kafka.python.ops import gen_kafka_ops +from tensorflow.python.data.ops.readers import Dataset +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class KafkaDataset(Dataset): + """A Kafka Dataset that consumes the message. + """ + + def __init__(self, + topics, + servers="localhost", + group="", + eof=False, + timeout=1000): + """Create a KafkaReader. + + Args: + topics: A `tf.string` tensor containing one or more subscriptions, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. + servers: A list of bootstrap servers. + group: The consumer group id. + eof: If True, the kafka reader will stop on EOF. + timeout: The timeout value for the Kafka Consumer to wait + (in millisecond). + """ + super(KafkaDataset, self).__init__() + self._topics = ops.convert_to_tensor( + topics, dtype=dtypes.string, name="topics") + self._servers = ops.convert_to_tensor( + servers, dtype=dtypes.string, name="servers") + self._group = ops.convert_to_tensor( + group, dtype=dtypes.string, name="group") + self._eof = ops.convert_to_tensor(eof, dtype=dtypes.bool, name="eof") + self._timeout = ops.convert_to_tensor( + timeout, dtype=dtypes.int64, name="timeout") + + def _as_variant_tensor(self): + return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group, + self._eof, self._timeout) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.string diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py index d38d8041ce1216dfb5af6e93984b35e71008610a..72507539f813d14064bc58f03b6db4781abc9438 100644 --- a/tensorflow/contrib/kernel_methods/python/losses_test.py +++ b/tensorflow/contrib/kernel_methods/python/losses_test.py @@ -119,19 +119,20 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testUnknownShape(self): """Result keeps same with `testZeroLossInt32Labels`""" - logits_np = np.array([[1.2, -1.4, -1.0], - [1.4, 1.8, 4.0], - [0.5, 1.8, -1.0]]) + logits_np = np.array([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], [0.5, 1.8, -1.0]]) labels_np = np.array([0, 2, 1], dtype=np.int32) - logits_shapes = [[3, 3], # batch_size, num_classes - [None, 3], - [3, None], - [None, None]] + logits_shapes = [ + [3, 3], # batch_size, num_classes + [None, 3], + [3, None], + [None, None] + ] for batch_size, num_classes in logits_shapes: with self.test_session(): - logits = array_ops.placeholder(dtypes.float32, shape=(batch_size, num_classes)) + logits = array_ops.placeholder( + dtypes.float32, shape=(batch_size, num_classes)) labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,)) loss = losses.sparse_multiclass_hinge_loss(labels, logits) result = loss.eval(feed_dict={logits: logits_np, labels: labels_np}) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py index 0f0dbb53f45dfefe69aaa9e25caf6ba0a3cf449e..87eed03888c894a04c0521d1ce5ee8975b60776b 100644 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -317,7 +317,10 @@ def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=hooks) + run_config = tf.estimator.RunConfig( + model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100) + # Train until input_fn() is empty with Estimator. This is a prerequisite for # TPU compatibility. - estimator = tf.estimator.Estimator(model_fn=model_fn) + estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) estimator.train(input_fn=input_fn) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 82accd57f0c37d140238f1884fce956654d14227..fb4b3a241c1e9fd82e7bf630fd57295917048fbd 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops @@ -236,10 +237,10 @@ class NaiveDiagonalFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) -class FullyConnectedDiagonalFB(test.TestCase): +class FullyConnectedDiagonalFBTest(test.TestCase): def setUp(self): - super(FullyConnectedDiagonalFB, self).setUp() + super(FullyConnectedDiagonalFBTest, self).setUp() self.batch_size = 4 self.input_size = 6 @@ -375,6 +376,65 @@ class FullyConnectedDiagonalFB(test.TestCase): return multiply_result, multiply_inverse_result +class EmbeddingKFACFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_minibatch(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(([grads],), damping) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_minibatch(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(([grads],), damping) + + # Create a sparse update. + indices = array_ops.constant([1, 3, 4]) + values = array_ops.constant([[1.], [1.], [1.]]) + sparse_vector = ops.IndexedSlices( + values, indices, dense_shape=[vocab_size, 1]) + dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) + + # Compare Fisher-vector product against explicit result. + result = block.multiply_inverse(sparse_vector) + expected_result = linalg_ops.matrix_solve(block.full_fisher_block(), + dense_vector) + + sess.run(tf_variables.global_variables_initializer()) + self.assertAlmostEqual( + sess.run(expected_result[1]), sess.run(result.values[0])) + self.assertAlmostEqual( + sess.run(expected_result[3]), sess.run(result.values[1])) + self.assertAlmostEqual( + sess.run(expected_result[4]), sess.run(result.values[2])) + + class FullyConnectedKFACBasicFBTest(test.TestCase): def testFullyConnectedKFACBasicFBInit(self): diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 753378d9f4a0d8762bafbee2ec27d6c71783dda1..66e18974abfadaad5d7a20b40d0b1352bfda67ee 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -89,6 +89,21 @@ class FisherFactorTestingDummy(ff.FisherFactor): def make_inverse_update_ops(self): return [] + def get_cov(self): + return NotImplementedError + + def left_multiply(self, x, damping): + return NotImplementedError + + def right_multiply(self, x, damping): + return NotImplementedError + + def left_multiply_inverse(self, x, damping): + return NotImplementedError + + def right_multiply_inverse(self, x, damping): + return NotImplementedError + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -379,7 +394,7 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) - self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list()) def testNaiveDiagonalFactorInitFloat64(self): with tf_ops.Graph().as_default(): @@ -387,7 +402,7 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) - cov = factor.get_cov() + cov = factor.get_cov_var() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 1], cov.get_shape().as_list()) @@ -402,6 +417,29 @@ class NaiveDiagonalFactorTest(test.TestCase): self.assertAllClose([[0.75], [1.5]], new_cov) +class EmbeddingInputKroneckerFactorTest(test.TestCase): + + def testInitialization(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + cov = factor.get_cov_var() + self.assertEqual(cov.shape.as_list(), [vocab_size]) + + def testCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + cov_update_op = factor.make_covariance_update_op(0.0) + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(cov_update_op) + self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) + + class FullyConnectedKroneckerFactorTest(test.TestCase): def _testFullyConnectedKroneckerFactorInit(self, diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 9436caf9618bc3d3c0dd7b3842420016b119464f..cf38d28b43836dced8babe2ffa7853b1c4b1b369 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -92,10 +92,22 @@ def compute_pi_tracenorm(left_cov, right_cov): Returns: The computed scalar constant pi for these Kronecker Factors (as a Tensor). """ + + def _trace(cov): + if len(cov.shape) == 1: + # Diagonal matrix. + return math_ops.reduce_sum(cov) + elif len(cov.shape) == 2: + # Full matrix. + return math_ops.trace(cov) + else: + raise ValueError( + "What's the trace of a Tensor of rank %d?" % len(cov.shape)) + # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. - left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] - right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] + left_norm = _trace(left_cov) * right_cov.shape.as_list()[0] + right_norm = _trace(right_cov) * left_cov.shape.as_list()[0] return math_ops.sqrt(left_norm / right_norm) @@ -201,15 +213,15 @@ class FullFB(FisherBlock): self._factor.register_damped_inverse(damping) def multiply_inverse(self, vector): - inverse = self._factor.get_damped_inverse(self._damping) - out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) + vector_flat = utils.tensors_to_column(vector) + out_flat = self._factor.left_multiply_inverse( + vector_flat, self._damping) return utils.column_to_tensors(vector, out_flat) def multiply(self, vector): vector_flat = utils.tensors_to_column(vector) - out_flat = ( - math_ops.matmul(self._factor.get_cov(), vector_flat) + - self._damping * vector_flat) + out_flat = self._factor.left_multiply( + vector_flat, self._damping) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): @@ -265,16 +277,20 @@ class NaiveDiagonalFB(FisherBlock): def multiply_inverse(self, vector): vector_flat = utils.tensors_to_column(vector) - out_flat = vector_flat / (self._factor.get_cov() + self._damping) + print("vector_flat: %s" % vector_flat) + out_flat = self._factor.left_multiply_inverse( + vector_flat, self._damping) + print("out_flat: %s" % out_flat) return utils.column_to_tensors(vector, out_flat) def multiply(self, vector): vector_flat = utils.tensors_to_column(vector) - out_flat = vector_flat * (self._factor.get_cov() + self._damping) + out_flat = self._factor.left_multiply( + vector_flat, self._damping) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): - return array_ops.diag(array_ops.reshape(self._factor.get_cov(), (-1,))) + return self._factor.get_cov() def tensors_to_compute_grads(self): return self._params @@ -356,8 +372,9 @@ class FullyConnectedDiagonalFB(FisherBlock): Tensor of the same shape, corresponding to the inverse Fisher-vector product. """ - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) + reshaped_vec = utils.layer_params_to_mat2d(vector) + reshaped_out = self._factor.left_multiply_inverse( + reshaped_vec, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply(self, vector): @@ -372,8 +389,9 @@ class FullyConnectedDiagonalFB(FisherBlock): Returns: Tensor of the same shape, corresponding to the Fisher-vector product. """ - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) + reshaped_vec = utils.layer_params_to_mat2d(vector) + reshaped_out = self._factor.left_multiply( + reshaped_vec, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def tensors_to_compute_grads(self): @@ -457,7 +475,9 @@ class ConvDiagonalFB(FisherBlock): self._num_locations = ( inputs_shape[1] * inputs_shape[2] // (self._strides[1] * self._strides[2])) - self._damping = normalize_damping(damping, self._num_locations) + + self._damping = (self._num_locations + * normalize_damping(damping, self._num_locations)) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvDiagonalFactor, @@ -466,12 +486,14 @@ class ConvDiagonalFB(FisherBlock): def multiply_inverse(self, vector): reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) + reshaped_out = self._factor.left_multiply_inverse( + reshaped_vect, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply(self, vector): reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) + reshaped_out = self._factor.left_multiply( + reshaped_vect, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def tensors_to_compute_grads(self): @@ -531,28 +553,24 @@ class KroneckerProductFB(FisherBlock): return 1.0 def multiply_inverse(self, vector): - left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping) - right_factor_inv = self._output_factor.get_damped_inverse( - self._output_damping) reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = math_ops.matmul(left_factor_inv, - math_ops.matmul(reshaped_vector, - right_factor_inv)) + reshaped_out = self._output_factor.right_multiply_inverse( + reshaped_vector, + self._output_damping) + reshaped_out = self._input_factor.left_multiply_inverse( + reshaped_out, self._input_damping) if self._renorm_coeff != 1.0: reshaped_out /= math_ops.cast( self._renorm_coeff, dtype=reshaped_out.dtype) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply(self, vector): - left_factor = self._input_factor.get_cov() - right_factor = self._output_factor.get_cov() reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = ( - math_ops.matmul(reshaped_vector, right_factor) + - self._output_damping * reshaped_vector) - reshaped_out = ( - math_ops.matmul(left_factor, reshaped_out) + - self._input_damping * reshaped_out) + reshaped_out = self._output_factor.right_multiply( + reshaped_vector, + self._output_damping) + reshaped_out = self._input_factor.left_multiply( + reshaped_out, self._input_damping) if self._renorm_coeff != 1.0: reshaped_out *= math_ops.cast( self._renorm_coeff, dtype=reshaped_out.dtype) @@ -572,6 +590,74 @@ class KroneckerProductFB(FisherBlock): right_factor) +class EmbeddingKFACFB(KroneckerProductFB): + """K-FAC FisherBlock for embedding layers. + + This FisherBlock is similar to EmbeddingKFACFB, except that its + input factor is approximated by a diagonal matrix. In the case that each + example references exactly one embedding, this approximation is exact. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size): + """Creates a EmbeddingKFACFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + """ + self._inputs = [] + self._outputs = [] + self._vocab_size = vocab_size + + super(EmbeddingKFACFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + # TODO(b/68033310): Validate which of, + # (1) summing on a single device (as below), or + # (2) on each device in isolation and aggregating + # is faster. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.EmbeddingInputKroneckerFactor, # + ((inputs,), self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.FullyConnectedKroneckerFactor, # + (grads_list,)) + self._register_damped_input_and_output_inverses(damping) + + def tensors_to_compute_grads(self): + return self._outputs + + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, input_size]. Inputs to the + matrix-multiply. + outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_minibatches(self): + return len(self._inputs) + + class FullyConnectedKFACBasicFB(KroneckerProductFB): """K-FAC FisherBlock for fully-connected (dense) layers. diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py index ac396309206fe09af65c2b70840a513fb25b579b..c04cf727fa958160d61c7a3638ec65f6c93c2f24 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -29,6 +29,7 @@ _allowed_symbols = [ 'NaiveDiagonalFB', 'FullyConnectedDiagonalFB', 'KroneckerProductFB', + 'EmbeddingKFACFB', 'FullyConnectedKFACBasicFB', 'ConvKFCBasicFB', 'ConvDiagonalFB', @@ -36,7 +37,9 @@ _allowed_symbols = [ 'compute_pi_tracenorm', 'compute_pi_adjusted_damping', 'num_conv_locations', - 'normalize_damping' + 'normalize_damping', + 'LEFT_MULTIPLY', + 'RIGHT_MULTIPLY', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index f59168cbc05fffd104ff5a44308eefd206beb9db..603d8b8b210279ee6d8f1de0ce10869fde23f4d9 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -25,6 +25,7 @@ import numpy as np import six from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -210,12 +211,21 @@ def scalar_or_tensor_to_string(val): class FisherFactor(object): """Base class for objects modeling factors of approximate Fisher blocks. - Note that for blocks that aren't based on approximations, a 'factor' can - be the entire block itself, as is the case for the diagonal and full - representations. + A FisherFactor represents part of an approximate Fisher Information matrix. + For example, one approximation to the Fisher uses the Kronecker product of two + FisherFactors A and B, F = kron(A, B). FisherFactors are composed with + FisherBlocks to construct a block-diagonal approximation to the full Fisher. - Subclasses must implement the _compute_new_cov method, and the _var_scope - and _cov_shape properties. + FisherFactors are backed by a single, non-trainable variable that is updated + by running FisherFactor.make_covariance_update_op(). The shape and type of + this variable is implementation specific. + + Note that for blocks that aren't based on approximations, a 'factor' can + be the entire block itself, as is the case for the diagonal and full + representations. + + Subclasses must implement the _compute_new_cov() method, and the _var_scope + and _cov_shape properties. """ def __init__(self): @@ -223,16 +233,21 @@ class FisherFactor(object): @abc.abstractproperty def _var_scope(self): + """Variable scope for this FisherFactor instance. + + Returns: + string that unique identifies this FisherFactor instance. + """ pass @abc.abstractproperty def _cov_shape(self): - """The shape of the cov matrix.""" + """The shape of the variable backing this FisherFactor.""" pass @abc.abstractproperty def _num_sources(self): - """The number of things to sum over when computing cov. + """The number of things to sum over when updating covariance variable. The default make_covariance_update_op function will call _compute_new_cov with indices ranging from 0 to _num_sources-1. The typical situation is @@ -244,10 +259,12 @@ class FisherFactor(object): @abc.abstractproperty def _dtype(self): + """dtype for variable backing this factor.""" pass @property def _cov_initializer(self): + """Function for initializing covariance variable.""" return covariance_initializer def instantiate_covariance(self): @@ -262,6 +279,15 @@ class FisherFactor(object): @abc.abstractmethod def _compute_new_cov(self, idx=0): + """Computes minibatch-estimated covariance for a single source. + + Args: + idx: int in [0, self._num_sources). Which source to use when estimating + covariance. + + Returns: + Tensor of same shape as self.get_cov_var(). + """ pass def make_covariance_update_op(self, ema_decay): @@ -294,14 +320,101 @@ class FisherFactor(object): """Create and return update ops corresponding to registered computations.""" pass + @abc.abstractmethod def get_cov(self): + """Get full covariance matrix. + + Returns: + Tensor of shape [n, n]. Represents all parameter-parameter correlations + captured by this FisherFactor. + """ + pass + + def get_cov_var(self): + """Get variable backing this FisherFactor. + + May or may not be the same as self.get_cov() + + Returns: + Variable of shape self._cov_shape. + """ return self._cov + @abc.abstractmethod + def left_multiply(self, x, damping): + """Multiplies 'x' by the damped covariance of this factor. + + Let C be the covariance matrix this factor represents, and + D = C + damping * I be its damped variant. This method calculates + matmul(D, vec(x)). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + + @abc.abstractmethod + def right_multiply(self, x, damping): + """Multiplies 'x' by the damped covariance of this factor. + + Let C be the covariance matrix this factor represents, and + D = C + damping * I be its damped variant. This method calculates + matmul(vec(x), D). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + + @abc.abstractmethod + def left_multiply_inverse(self, x, damping): + """Multiplies 'x' by damped inverse of this factor. + + Let C be the covariance matrix this factor represents and + E = inv(C + damping * I) be its damped inverse. This method calculates + matmul(E, vec(x)). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + + @abc.abstractmethod + def right_multiply_inverse(self, x, damping): + """Multiplies 'x' by damped inverse of this factor. + + Let C be the covariance matrix this factor represents and + E = inv(C + damping * I) be its damped inverse. This method calculates + matmul(vec(x), E). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + class InverseProvidingFactor(FisherFactor): - """Base class for FisherFactors that maintain inverses, powers, etc of _cov. + """Base class for FisherFactors that maintain inverses explicitly. - Assumes that the _cov property is a square PSD matrix. + This class explicitly calculates and stores inverses of covariance matrices + provided by the underlying FisherFactor implementation. It is assumed that + vectors can be represented as 2-D matrices. Subclasses must implement the _compute_new_cov method, and the _var_scope and _cov_shape properties. @@ -436,6 +549,61 @@ class InverseProvidingFactor(FisherFactor): def reset_eigendecomp(self): self._eigendecomp = None + def get_cov(self): + # Variable contains full covariance matrix. + return self.get_cov_var() + + def left_multiply(self, x, damping): + n = self.get_cov().shape[0] + damped_cov = self.get_cov() + damping * array_ops.eye(n) + + if isinstance(x, tf_ops.IndexedSlices): + raise NotImplementedError( + "Left-multiply not yet supported for IndexedSlices.") + + if len(x.shape) != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(damped_cov, x) + + def right_multiply(self, x, damping): + n = self.get_cov().shape[0] + damped_cov = self.get_cov() + damping * array_ops.eye(n) + + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_sparse_dense(x, damped_cov) + + if len(x.shape) != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(x, damped_cov) + + def left_multiply_inverse(self, x, damping): + if isinstance(x, tf_ops.IndexedSlices): + raise ValueError("Left-multiply not yet supported for IndexedSlices.") + + if x.shape.ndims != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(self.get_damped_inverse(damping), x) + + def right_multiply_inverse(self, x, damping): + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_sparse_dense(x, self.get_damped_inverse(damping)) + + if x.shape.ndims != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(x, self.get_damped_inverse(damping)) + class FullFactor(InverseProvidingFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. @@ -481,7 +649,11 @@ class FullFactor(InverseProvidingFactor): class DiagonalFactor(FisherFactor): - """A base class for FisherFactors that use diagonal approximations.""" + """A base class for FisherFactors that use diagonal approximations. + + A DiagonalFactor's covariance variable can be of any shape, but must contain + exactly one entry per parameter. + """ def __init__(self): super(DiagonalFactor, self).__init__() @@ -493,6 +665,45 @@ class DiagonalFactor(FisherFactor): def make_inverse_update_ops(self): return [] + def get_cov(self): + # self.get_cov() could be any shape, but it must have one entry per + # parameter. Flatten it into a vector. + cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1]) + return array_ops.diag(cov_diag_vec) + + def left_multiply(self, x, damping): + damped_cov = self.get_cov_var() + damping + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_diag_sparse(array_ops.reshape(damped_cov, [-1]), x) + + if x.shape != damped_cov.shape: + raise ValueError("x (%s) and cov (%s) must have same shape." % + (x, damped_cov)) + + return damped_cov * x + + def right_multiply(self, x, damping): + raise NotImplementedError("Only left-multiply is currently supported.") + + def left_multiply_inverse(self, x, damping): + inverse = 1. / (self.get_cov_var() + damping) + + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_diag_sparse(array_ops.reshape(inverse, [-1]), x) + + if x.shape != inverse.shape: + raise ValueError("x (%s) and cov (%s) must have same shape." % + (x, inverse)) + + return inverse * x + + def right_multiply_inverse(self, x, damping): + raise NotImplementedError("Only left-multiply is currently supported.") + + def register_damped_inverse(self, damping): + # DiagonalFactors don't keep explicit inverses. + pass + class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. @@ -504,6 +715,14 @@ class NaiveDiagonalFactor(DiagonalFactor): def __init__(self, params_grads, batch_size): + """Initializes NaiveDiagonalFactor instance. + + Args: + params_grads: Sequence of Tensors, each with same shape as parameters this + FisherFactor corresponds to. For example, the gradient of the loss with + respect to parameters. + batch_size: int or 0-D Tensor. Size + """ self._params_grads = tuple(utils.ensure_sequence(params_grad) for params_grad in params_grads) self._batch_size = batch_size @@ -518,7 +737,7 @@ class NaiveDiagonalFactor(DiagonalFactor): def _cov_shape(self): size = sum(param_grad.shape.num_elements() for param_grad in self._params_grads[0]) - return (size, 1) + return [size, 1] @property def _num_sources(self): @@ -535,6 +754,84 @@ class NaiveDiagonalFactor(DiagonalFactor): self._batch_size, params_grads_flat.dtype)) +class EmbeddingInputKroneckerFactor(DiagonalFactor): + r"""FisherFactor for input to an embedding layer. + + Given input_ids = [batch_size, input_size] representing indices into an + [vocab_size, embedding_size] embedding matrix, approximate input covariance by + a diagonal matrix, + + Cov(input_ids, input_ids) = + (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2). + + where n_hot() constructs an n-hot binary vector and diag() constructs a + diagonal matrix of size [vocab_size, vocab_size]. + """ + + def __init__(self, input_ids, vocab_size, dtype=None): + """Instantiate EmbeddingInputKroneckerFactor. + + Args: + input_ids: Tuple of Tensors of shape [batch_size, input_size] and dtype + int32. Indices into embedding matrix. + vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. + dtype: dtype for covariance statistics. Must be a floating point type. + Defaults to float32. + """ + self._input_ids = input_ids + self._vocab_size = vocab_size + self._cov_dtype = dtype or dtypes.float32 + + super(EmbeddingInputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_diag_embedding/" + scope_string_from_params(self._input_ids) + + @property + def _cov_shape(self): + return [self._vocab_size] + + @property + def _num_sources(self): + return len(self._input_ids) + + @property + def _dtype(self): + return self._cov_dtype + + def _compute_new_cov(self, idx=0): + with maybe_colocate_with(self._input_ids): + input_ids = self._input_ids[idx] + if len(input_ids.shape) > 2: + raise ValueError( + "Input to embeddings must have rank <= 2. Found rank %d." % len( + input_ids.shape)) + + batch_size = array_ops.shape(input_ids)[0] + + # Transform indices into one-hot vectors. + # + # TODO(b/72714822): There must be a faster way to construct the diagonal + # covariance matrix! This operation is O(batch_size * vocab_size), where + # it should be O(batch_size * input_size). + flat_input_ids = array_ops.reshape(input_ids, [-1]) + one_hots = array_ops.one_hot(flat_input_ids, + self._vocab_size) # [?, vocab_size] + + # Take average across examples. Note that, because all entries have + # magnitude zero or one, there's no need to square the entries. + # + # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation + # within an example such as average. + # + # TODO(b/72714822): Support for partitioned embeddings. + new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + + return new_cov + + class FullyConnectedDiagonalFactor(DiagonalFactor): r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. @@ -574,8 +871,9 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): @property def _cov_shape(self): - return [self._inputs.shape[1] + self._has_bias, - self._outputs_grads[0].shape[1]] + input_size = self._inputs.shape[1] + self._has_bias + output_size = self._outputs_grads[0].shape[1] + return [input_size, output_size] @property def _num_sources(self): diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py index ad93919149c287b1932dd2b6bd772c0dab26192d..2d8e378a932c16d48360bc4b15ff4f3239c0ed1f 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -24,26 +24,15 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ - "inverse_initializer", - "covariance_initializer", - "diagonal_covariance_initializer", - "scope_string_from_params", - "scope_string_from_name", - "scalar_or_tensor_to_string", - "FisherFactor", - "InverseProvidingFactor", - "FullFactor", - "DiagonalFactor", - "NaiveDiagonalFactor", - "FullyConnectedDiagonalFactor", - "FullyConnectedKroneckerFactor", - "ConvInputKroneckerFactor", - "ConvOutputKroneckerFactor", - "ConvDiagonalFactor", - "set_global_constants", - "maybe_colocate_with", - "compute_cov", - "append_homog" + "inverse_initializer", "covariance_initializer", + "diagonal_covariance_initializer", "scope_string_from_params", + "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor", + "InverseProvidingFactor", "FullFactor", "DiagonalFactor", + "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor", + "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor", + "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", + "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with", + "compute_cov", "append_homog" ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 8d450f04f379701e46a18b2e34bbbd6fcfcce2bb..ce9005b9ce99a4efa5f2821c56e199dd2086482e 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -143,6 +143,7 @@ class LayerCollection(object): self._loss_dict = {} # {str: LossFunction} self._subgraph = None self._default_generic_approximation = APPROX_FULL_NAME + self._default_embedding_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_multi_approximation = ( @@ -178,6 +179,17 @@ class LayerCollection(object): """ return self._linked_parameters + @property + def default_embedding_approximation(self): + return self._default_embedding_approximation + + def set_default_embedding_approximation(self, value): + if value != APPROX_KRONECKER_NAME: + raise ValueError( + "{} is not a valid approximation for embedding variables.".format( + value)) + self._default_embedding_approximation = value + @property def default_generic_approximation(self): return self._default_generic_approximation @@ -417,6 +429,46 @@ class LayerCollection(object): else: return None + def register_embedding(self, + params, + inputs, + outputs, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers a fully connnected layer. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices + into embedding matrix. + outputs: Tensor of shape [batch_size, output_size]. Outputs + produced by layer. + approx: str. Must be "kron". + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_embedding_approximation + + if approx != APPROX_KRONECKER_NAME: + raise ValueError("Bad value {} for approx.".format(approx)) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + + vocab_size = int(params.shape[0]) + block = self.register_block( + params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse) + block.register_additional_minibatch(inputs, outputs) + def register_fully_connected(self, params, inputs, diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py index 831870fca451c585cb1a1dc6b24aad757e2bbaa8..b6d9d37a31a949b154b79e6f3677289a0d167373 100644 --- a/tensorflow/contrib/kfac/python/ops/op_queue.py +++ b/tensorflow/contrib/kfac/python/ops/op_queue.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import ops as tf_ops diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index e89508fa46b6e2ce278e5373e6c9d17203ad1ef2..f5bd97cb4e7d547394050e944f75b43a40887f34 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -144,7 +144,9 @@ def layer_params_to_mat2d(vector): [-1, w_part.shape.as_list()[-1]]) return array_ops.concat( (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) - else: + elif isinstance(vector, ops.IndexedSlices): + return vector + else: # Tensor or Tensor-like. return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) @@ -163,6 +165,11 @@ def mat2d_to_layer_params(vector_template, mat2d): if isinstance(vector_template, (tuple, list)): w_part, b_part = mat2d[:-1], mat2d[-1] return array_ops.reshape(w_part, vector_template[0].shape), b_part + elif isinstance(vector_template, ops.IndexedSlices): + if not isinstance(mat2d, ops.IndexedSlices): + raise TypeError( + "If vector_template is an IndexedSlices, so should mat2d.") + return mat2d else: return array_ops.reshape(mat2d, vector_template.shape) @@ -420,5 +427,57 @@ def batch_execute(global_step, thunks, batch_size, name=None): return result +def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is sparse, B is dense. + + Args: + A: tf.IndexedSlices with dense shape [m, n]. + B: tf.Tensor with shape [n, k]. + name: str. Name of op. + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A doesn't represent a matrix. + ValueError: If B is not rank-2. + """ + with ops.name_scope(name, "matmul_sparse_dense", [A, B]): + if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2: + raise ValueError("A must represent a matrix. Found: %s." % A) + if B.shape.ndims != 2: + raise ValueError("B must be a matrix.") + new_values = math_ops.matmul(A.values, B) + return ops.IndexedSlices( + new_values, + A.indices, + dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]])) + + +def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is a diagonal matrix, B is sparse. + + Args: + A_diag: diagonal entries of matrix A of shape [m, m]. + B: tf.IndexedSlices. Represents matrix of shape [m, n]. + name: str. Name of op. + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A_diag is not rank-1. + ValueError: If B doesn't represent a matrix. + """ + with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]): + A_diag = ops.convert_to_tensor(A_diag) + if A_diag.shape.ndims != 1: + raise ValueError("A_diag must be a rank-1 Tensor.") + if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2: + raise ValueError("B must represent a matrix. Found: %s." % B) + a = array_ops.gather(A_diag, B.indices) + a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) + return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index cc48e3c69f24c2abd343e2e120d3589cd323fcdc..8e424a794691484fdea7d8481677aa641c433d4c 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -24,6 +24,7 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + "set_global_constants", "SequenceDict", "tensors_to_column", "column_to_tensors", @@ -39,6 +40,8 @@ _allowed_symbols = [ "fwd_gradients", "ensure_sequence", "batch_execute", + "matmul_sparse_dense", + "matmul_diag_sparse", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py index abc18aa123bb4d40b54d22ec03257c5350118d13..0c6bba758b429a8c4112bc6abb2fae542b5dfc14 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py @@ -361,6 +361,10 @@ class LabeledTensor(object): def dtype(self): return self._tensor.dtype + @property + def shape(self): + return self._tensor.shape + @property def name(self): return self._tensor.name diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py index e70b4923749d89aba1bd0187857d762305daeb07..e378db56afb1d4f9463d2c9b0f1fa4c0feea8fb0 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py @@ -244,6 +244,9 @@ class LabeledTensorTest(test_util.Base): def test_dtype(self): self.assertEqual(self.lt.dtype, self.lt.tensor.dtype) + def test_shape(self): + self.assertEqual(self.lt.shape, self.lt.tensor.shape) + def test_get_shape(self): self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape()) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index c957b41a49b292225e547ce17b0c5a247810325a..3ba1026383ef146adb32197ae41b5c251155bf46 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -951,7 +951,7 @@ def define_reduce_op(op_name, reduce_fn): intermediate_axes.append(axis) reduce_op = reduce_fn( - labeled_tensor.tensor, reduction_dimensions, keep_dims=True) + labeled_tensor.tensor, reduction_dimensions, keepdims=True) reduce_lt = core.LabeledTensor(reduce_op, intermediate_axes) return squeeze(reduce_lt, axes_to_squeeze, name=scope) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index ef419862b49f4d03d9b711c49155d4ae1252d5bc..337c9e06b870b2cca53fcdbf3d94225660e193c4 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -35,6 +35,7 @@ See the @{$python/contrib.layers} guide. @@fully_connected @@GDN @@gdn +@@images_to_sequence @@layer_norm @@linear @@max_pool2d @@ -50,6 +51,7 @@ See the @{$python/contrib.layers} guide. @@scale_gradient @@separable_conv2d @@separable_convolution2d +@@sequence_to_images @@softmax @@spatial_softmax @@stack diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index 932c5ab99249feda1e3a7f2d707ce4237fe7177f..01893d60615a9b4ded2afc88c6de0168d4be0921 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -423,8 +423,9 @@ class SparseFeatureCrossOp : public OpKernel { "Input values should be a std::vector but received shape ", values_list_in[i].shape().DebugString(), " at position ", i)); OP_REQUIRES( - context, indices_list_in[i].shape().dim_size(0) == - values_list_in[i].shape().dim_size(0), + context, + indices_list_in[i].shape().dim_size(0) == + values_list_in[i].shape().dim_size(0), errors::InvalidArgument( "Expected size of values to be ", indices_list_in[i].shape().dim_size(0), " got ", diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index b7d34d6435789e54403926a342481971e854b449..9ccb589d698ad83c9654f5523ccdcb35b031b3da 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -154,6 +154,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation +from tensorflow.python.util import nest # Imports the core `InputLayer` symbol in contrib during development. @@ -554,28 +555,70 @@ def sparse_column_with_integerized_feature(column_name, class _SparseColumnHashed(_SparseColumn): """See `sparse_column_with_hash_bucket`.""" + def __new__(cls, + column_name, + is_integerized=False, + bucket_size=None, + lookup_config=None, + combiner="sum", + dtype=dtypes.string, + hash_keys=None): + if hash_keys is not None: + if not isinstance(hash_keys, list) or not hash_keys: + raise ValueError("hash_keys must be a non-empty list.") + if (any([not isinstance(key_pair, list) for key_pair in hash_keys]) or + any([len(key_pair) != 2 for key_pair in hash_keys]) or + any([not isinstance(key, int) for key in nest.flatten(hash_keys)])): + raise ValueError( + "Each element of hash_keys must be a pair of integers.") + obj = super(_SparseColumnHashed, cls).__new__( + cls, + column_name, + is_integerized=is_integerized, + bucket_size=bucket_size, + lookup_config=lookup_config, + combiner=combiner, + dtype=dtype) + obj.hash_keys = hash_keys + return obj + def _do_transform(self, input_tensor): if self.dtype.is_integer: sparse_values = string_ops.as_string(input_tensor.values) else: sparse_values = input_tensor.values - sparse_id_values = string_ops.string_to_hash_bucket_fast( - sparse_values, self.bucket_size, name="lookup") - return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values, - input_tensor.dense_shape) + if self.hash_keys: + result = [] + for key in self.hash_keys: + sparse_id_values = string_ops.string_to_hash_bucket_strong( + sparse_values, self.bucket_size, key) + result.append( + sparse_tensor_py.SparseTensor(input_tensor.indices, + sparse_id_values, + input_tensor.dense_shape)) + return sparse_ops.sparse_concat(axis=1, sp_inputs=result, name="lookup") + else: + sparse_id_values = string_ops.string_to_hash_bucket_fast( + sparse_values, self.bucket_size, name="lookup") + return sparse_tensor_py.SparseTensor( + input_tensor.indices, sparse_id_values, input_tensor.dense_shape) def sparse_column_with_hash_bucket(column_name, hash_bucket_size, combiner="sum", - dtype=dtypes.string): + dtype=dtypes.string, + hash_keys=None): """Creates a _SparseColumn with hashed bucket configuration. Use this when your sparse features are in string or integer format, but you don't have a vocab file that maps each value to an integer ID. output_id = Hash(input_feature_string) % bucket_size + When hash_keys is set, multiple integer IDs would be created with each key + pair in the `hash_keys`. This is useful to reduce the collision of hashed ids. + Args: column_name: A string defining sparse column name. hash_bucket_size: An int that is > 1. The number of buckets. @@ -588,6 +631,9 @@ def sparse_column_with_hash_bucket(column_name, * "sqrtn": do l2 normalization on features in the column For more information: `tf.embedding_lookup_sparse`. dtype: The type of features. Only string and integer types are supported. + hash_keys: The hash keys to use. It is a list of lists of two uint64s. If + None, simple and fast hashing algorithm is used. Otherwise, multiple + strong hash ids would be produced with each two unit64s in this argument. Returns: A _SparseColumn with hashed bucket configuration @@ -600,7 +646,8 @@ def sparse_column_with_hash_bucket(column_name, column_name, bucket_size=hash_bucket_size, combiner=combiner, - dtype=dtype) + dtype=dtype, + hash_keys=hash_keys) class _SparseColumnKeys(_SparseColumn): diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index 2eaea231776bd2f5fb8bb4bd422074beacd61720..1de9ab705655db9863d9c7d2630f24283c83d44d 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -221,8 +221,8 @@ class FeatureColumnTest(test.TestCase): weighted_sparse_col = fc.weighted_sparse_column(ids, "weights") self.assertEqual(weighted_sparse_col.name, "ids_weighted_by_weights") - b = fc.shared_embedding_columns([sparse_col, weighted_sparse_col], - dimension=4, combiner="mean") + b = fc.shared_embedding_columns( + [sparse_col, weighted_sparse_col], dimension=4, combiner="mean") self.assertEqual(len(b), 2) self.assertEqual(b[0].shared_embedding_name, "a1_ids_weighted_by_weights_shared_embedding") @@ -230,8 +230,8 @@ class FeatureColumnTest(test.TestCase): "a1_ids_weighted_by_weights_shared_embedding") # Tries reversing order to check compatibility condition. - b = fc.shared_embedding_columns([weighted_sparse_col, sparse_col], - dimension=4, combiner="mean") + b = fc.shared_embedding_columns( + [weighted_sparse_col, sparse_col], dimension=4, combiner="mean") self.assertEqual(len(b), 2) self.assertEqual(b[0].shared_embedding_name, "a1_ids_weighted_by_weights_shared_embedding") @@ -240,18 +240,17 @@ class FeatureColumnTest(test.TestCase): # Tries adding two weighted columns to check compatibility between them. weighted_sparse_col_2 = fc.weighted_sparse_column(ids, "weights_2") - b = fc.shared_embedding_columns([weighted_sparse_col, - weighted_sparse_col_2], - dimension=4, combiner="mean") + b = fc.shared_embedding_columns( + [weighted_sparse_col, weighted_sparse_col_2], + dimension=4, + combiner="mean") self.assertEqual(len(b), 2) self.assertEqual( b[0].shared_embedding_name, - "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding" - ) + "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding") self.assertEqual( b[1].shared_embedding_name, - "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding" - ) + "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding") def testSharedEmbeddingColumnDeterminism(self): # Tests determinism in auto-generated shared_embedding_name. @@ -286,10 +285,10 @@ class FeatureColumnTest(test.TestCase): columns = fc.shared_embedding_columns( [a1, a2], dimension=4, combiner="mean") columns_copy = copy.deepcopy(columns) - self.assertEqual( - columns_copy[0].shared_embedding_name, "a1_a2_shared_embedding") - self.assertEqual( - columns_copy[1].shared_embedding_name, "a1_a2_shared_embedding") + self.assertEqual(columns_copy[0].shared_embedding_name, + "a1_a2_shared_embedding") + self.assertEqual(columns_copy[1].shared_embedding_name, + "a1_a2_shared_embedding") def testOneHotColumn(self): a = fc.sparse_column_with_keys("a", ["a", "b", "c", "d"]) @@ -330,17 +329,66 @@ class FeatureColumnTest(test.TestCase): self.assertEqual(one_hot.sparse_id_column.name, "ids_weighted_by_weights") self.assertEqual(one_hot.length, 3) + def testOneHotColumnWithSparseColumnWithHashKeys(self): + input_values = ["marlo", "unknown", "omar"] + inputs = constant_op.constant(input_values) + hash_keys = [[10, 20], [20, 30]] + hash_column = fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=10, hash_keys=hash_keys) + columns_to_tensors = {} + columns_to_tensors["ids"] = inputs + hash_column.insert_transformed_feature(columns_to_tensors) + self.assertEqual(len(columns_to_tensors), 2) + self.assertTrue(hash_column in columns_to_tensors) + + one_hot_column = fc.one_hot_column(hash_column) + one_hot_output = one_hot_column._to_dnn_input_layer( + columns_to_tensors[hash_column]) + + expected = np.array([[0., 1., 0., 0., 0., 0., 0., 1., 0., + 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 1.], + [1., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]) + with self.test_session() as sess: + one_hot_value = sess.run(one_hot_output) + self.assertTrue(np.array_equal(one_hot_value, expected)) + + def testSparseColumnWithHashKeysWithUnexpectedHashKeys(self): + with self.assertRaisesRegexp(ValueError, + "hash_keys must be a non-empty list."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=[]) + + with self.assertRaisesRegexp(ValueError, + "hash_keys must be a non-empty list."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=1) + + with self.assertRaisesRegexp( + ValueError, "Each element of hash_keys must be a pair of integers."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=[1, 2]) + + with self.assertRaisesRegexp( + ValueError, "Each element of hash_keys must be a pair of integers."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=["key"]) + + with self.assertRaisesRegexp( + ValueError, "Each element of hash_keys must be a pair of integers."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=[[1, 2.0]]) + def testMissingValueInOneHotColumnForWeightedSparseColumn(self): # Github issue 12583 ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"]) weighted_ids = fc.weighted_sparse_column(ids, "weights") one_hot = fc.one_hot_column(weighted_ids) features = { - 'ids': constant_op.constant([['marlo', 'unknown', 'omar']]), - 'weights': constant_op.constant([[2., 4., 6.]]) + "ids": constant_op.constant([["marlo", "unknown", "omar"]]), + "weights": constant_op.constant([[2., 4., 6.]]) } one_hot_tensor = feature_column_ops.input_from_feature_columns( - features, [one_hot]) + features, [one_hot]) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) @@ -349,11 +397,9 @@ class FeatureColumnTest(test.TestCase): def testMissingValueInOneHotColumnForSparseColumnWithKeys(self): ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"]) one_hot = fc.one_hot_column(ids) - features = { - 'ids': constant_op.constant([['marlo', 'unknown', 'omar']]) - } + features = {"ids": constant_op.constant([["marlo", "unknown", "omar"]])} one_hot_tensor = feature_column_ops.input_from_feature_columns( - features, [one_hot]) + features, [one_hot]) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) @@ -379,8 +425,7 @@ class FeatureColumnTest(test.TestCase): self.assertEqual(d4.default_value, None) self.assertEqual(d4.is_sparse, True) # Default value is a list but dimension is None. - with self.assertRaisesRegexp(ValueError, - "Only scalar default value.*"): + with self.assertRaisesRegexp(ValueError, "Only scalar default value.*"): fc._real_valued_var_len_column("g5", default_value=[2., 3.]) def testRealValuedVarLenColumnDtypes(self): @@ -390,18 +435,19 @@ class FeatureColumnTest(test.TestCase): "rvc": parsing_ops.VarLenFeature(dtype=dtypes.float32) }, rvc.config) - rvc = fc._real_valued_var_len_column("rvc", default_value=0, - is_sparse=False) - self.assertDictEqual( - { - "rvc": parsing_ops.FixedLenSequenceFeature(shape=[], - dtype=dtypes.float32, - allow_missing=True, - default_value=0.0) - }, rvc.config) - - rvc = fc._real_valued_var_len_column("rvc", dtype=dtypes.int32, - default_value=0, is_sparse=True) + rvc = fc._real_valued_var_len_column( + "rvc", default_value=0, is_sparse=False) + self.assertDictEqual({ + "rvc": + parsing_ops.FixedLenSequenceFeature( + shape=[], + dtype=dtypes.float32, + allow_missing=True, + default_value=0.0) + }, rvc.config) + + rvc = fc._real_valued_var_len_column( + "rvc", dtype=dtypes.int32, default_value=0, is_sparse=True) self.assertDictEqual( { "rvc": parsing_ops.VarLenFeature(dtype=dtypes.int32) @@ -409,8 +455,8 @@ class FeatureColumnTest(test.TestCase): with self.assertRaisesRegexp(TypeError, "dtype must be convertible to float"): - fc._real_valued_var_len_column("rvc", dtype=dtypes.string, - default_value="", is_sparse=True) + fc._real_valued_var_len_column( + "rvc", dtype=dtypes.string, default_value="", is_sparse=True) def testRealValuedColumn(self): a = fc.real_valued_column("aaa") @@ -504,13 +550,13 @@ class FeatureColumnTest(test.TestCase): for output_rank in range(1, 3 + len(dimensions)): with variable_scope.variable_scope("output_rank_{}".format(output_rank)): real_valued_output = real_valued_column._to_dnn_input_layer( - constant_op.constant( - real_valued_input, dtype=dtypes.float32), + constant_op.constant(real_valued_input, dtype=dtypes.float32), output_rank=output_rank) with self.test_session() as sess: real_valued_eval = sess.run(real_valued_output) - expected_shape = (input_shape[:output_rank - 1] + - [np.prod(input_shape[output_rank - 1:])]) + expected_shape = ( + input_shape[:output_rank - 1] + + [np.prod(input_shape[output_rank - 1:])]) self.assertEquals(expected_shape, list(real_valued_eval.shape)) def testRealValuedColumnDensification(self): @@ -520,8 +566,7 @@ class FeatureColumnTest(test.TestCase): "sparse_real_valued1", is_sparse=True) sparse_tensor = sparse_tensor_lib.SparseTensor( values=[2.0, 5.0], indices=[[0, 0], [2, 0]], dense_shape=[3, 1]) - with self.assertRaisesRegexp( - ValueError, "Set is_sparse to False"): + with self.assertRaisesRegexp(ValueError, "Set is_sparse to False"): real_valued_column._to_dnn_input_layer(sparse_tensor) def testRealValuedColumnDeepCopy(self): @@ -549,9 +594,8 @@ class FeatureColumnTest(test.TestCase): def testBucketizedColumnRequiresRealValuedColumnDimension(self): with self.assertRaisesRegexp( TypeError, "source_column must be an instance of _RealValuedColumn.*"): - fc.bucketized_column(fc._real_valued_var_len_column("bbb", - is_sparse=True), - [0]) + fc.bucketized_column( + fc._real_valued_var_len_column("bbb", is_sparse=True), [0]) def testBucketizedColumnRequiresSortedBuckets(self): with self.assertRaisesRegexp(ValueError, @@ -654,20 +698,14 @@ class FeatureColumnTest(test.TestCase): def testRealValuedColumnDtypes(self): rvc = fc.real_valued_column("rvc") - self.assertDictEqual( - { - "rvc": parsing_ops.FixedLenFeature( - [1], dtype=dtypes.float32) - }, - rvc.config) + self.assertDictEqual({ + "rvc": parsing_ops.FixedLenFeature([1], dtype=dtypes.float32) + }, rvc.config) rvc = fc.real_valued_column("rvc", dtype=dtypes.int32) - self.assertDictEqual( - { - "rvc": parsing_ops.FixedLenFeature( - [1], dtype=dtypes.int32) - }, - rvc.config) + self.assertDictEqual({ + "rvc": parsing_ops.FixedLenFeature([1], dtype=dtypes.int32) + }, rvc.config) with self.assertRaisesRegexp(ValueError, "dtype must be convertible to float"): @@ -702,8 +740,9 @@ class FeatureColumnTest(test.TestCase): batch_size = 4 dense_scalar_input = [1, 2, 3, 4] sparse_column = fc.sparse_column_with_integerized_feature("values", 10) - features = {"values": - constant_op.constant(dense_scalar_input, dtype=dtypes.int64)} + features = { + "values": constant_op.constant(dense_scalar_input, dtype=dtypes.int64) + } sparse_column.insert_transformed_feature(features) sparse_output = features[sparse_column] expected_shape = [batch_size, 1] @@ -731,8 +770,7 @@ class FeatureColumnTest(test.TestCase): def testSparseColumnKeysDeepCopy(self): """Tests deepcopy of sparse_column_with_keys.""" - column = fc.sparse_column_with_keys( - "a", keys=["key0", "key1", "key2"]) + column = fc.sparse_column_with_keys("a", keys=["key0", "key1", "key2"]) self.assertEqual("a", column.name) column_copy = copy.deepcopy(column) self.assertEqual("a", column_copy.name) @@ -785,8 +823,9 @@ class FeatureColumnTest(test.TestCase): a = fc.sparse_column_with_hash_bucket("cross_aaa", hash_bucket_size=100) b = fc.sparse_column_with_hash_bucket("cross_bbb", hash_bucket_size=100) cross_col = fc.crossed_column(set([a, b]), hash_bucket_size=10000) - one_hot_col = fc.one_hot_column(fc.sparse_column_with_hash_bucket( - "sparse_column_for_one_hot", hash_bucket_size=100)) + one_hot_col = fc.one_hot_column( + fc.sparse_column_with_hash_bucket( + "sparse_column_for_one_hot", hash_bucket_size=100)) scattered_embedding_col = fc.scattered_embedding_column( "scattered_embedding_column", size=100, dimension=10, hash_key=1) feature_columns = set([ @@ -809,17 +848,13 @@ class FeatureColumnTest(test.TestCase): "str_id_weights_column": parsing_ops.VarLenFeature(dtypes.float32), "real_valued_column1": - parsing_ops.FixedLenFeature( - [1], dtype=dtypes.float32), + parsing_ops.FixedLenFeature([1], dtype=dtypes.float32), "real_valued_column2": - parsing_ops.FixedLenFeature( - [5], dtype=dtypes.float32), + parsing_ops.FixedLenFeature([5], dtype=dtypes.float32), "real_valued_column_for_bucketization1": - parsing_ops.FixedLenFeature( - [1], dtype=dtypes.float32), + parsing_ops.FixedLenFeature([1], dtype=dtypes.float32), "real_valued_column_for_bucketization2": - parsing_ops.FixedLenFeature( - [4], dtype=dtypes.float32), + parsing_ops.FixedLenFeature([4], dtype=dtypes.float32), "cross_aaa": parsing_ops.VarLenFeature(dtypes.string), "cross_bbb": @@ -849,11 +884,14 @@ class FeatureColumnTest(test.TestCase): real_valued_col0 = fc._real_valued_var_len_column( "real_valued_column0", is_sparse=True) real_valued_col1 = fc._real_valued_var_len_column( - "real_valued_column1", dtype=dtypes.int64, default_value=0, + "real_valued_column1", + dtype=dtypes.int64, + default_value=0, is_sparse=False) feature_columns = set([real_valued_col0, real_valued_col1]) expected_config = { - "real_valued_column0": parsing_ops.VarLenFeature(dtype=dtypes.float32), + "real_valued_column0": + parsing_ops.VarLenFeature(dtype=dtypes.float32), "real_valued_column1": parsing_ops.FixedLenSequenceFeature( [], dtype=dtypes.int64, allow_missing=True, default_value=0), @@ -874,7 +912,9 @@ class FeatureColumnTest(test.TestCase): real_valued_col5 = fc._real_valued_var_len_column( "real_valued_column5", default_value=2, is_sparse=True) real_valued_col6 = fc._real_valued_var_len_column( - "real_valued_column6", dtype=dtypes.int64, default_value=1, + "real_valued_column6", + dtype=dtypes.int64, + default_value=1, is_sparse=False) feature_columns = [ real_valued_col1, real_valued_col2, real_valued_col3, real_valued_col4, @@ -902,8 +942,7 @@ class FeatureColumnTest(test.TestCase): parsing_ops.VarLenFeature(dtype=dtypes.float32), "real_valued_column6": parsing_ops.FixedLenSequenceFeature( - [], dtype=dtypes.int64, allow_missing=True, - default_value=1) + [], dtype=dtypes.int64, allow_missing=True, default_value=1) }, config) @@ -1104,8 +1143,8 @@ class FeatureColumnTest(test.TestCase): # This will initialize the crossed column weights from provided checkpoint # and return a [4, 1] tensor which is same as weights variable. Since we # won't modify weights, this should be same as 'saved_col_weights'. - _, col_weights, _ = (feature_column_ops.weighted_sum_from_feature_columns( - { + _, col_weights, _ = ( + feature_column_ops.weighted_sum_from_feature_columns({ sparse_col_1.name: input_tensor, sparse_col_2.name: input_tensor }, [crossed_col_initialized], 1)) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index c8e3307ee8b5ded30dc864c4e69452f58685b8f0..e27b36908eba7cc1b992079b1abb5c5367340de1 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -60,12 +60,12 @@ __all__ = [ 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', - 'dropout', 'elu', 'flatten', - 'fully_connected', 'GDN', 'gdn', 'layer_norm', 'linear', 'pool', - 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', - 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'softmax', - 'spatial_softmax', 'stack', 'unit_norm', 'legacy_fully_connected', - 'legacy_linear', 'legacy_relu', 'maxout' + 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', + 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', + 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', + 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', + 'sequence_to_images', 'softmax', 'spatial_softmax', 'stack', 'unit_norm', + 'legacy_fully_connected', 'legacy_linear', 'legacy_relu', 'maxout' ] DATA_FORMAT_NCHW = 'NCHW' @@ -518,8 +518,8 @@ def batch_norm(inputs, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) - fused: if `True`, use a faster, fused implementation if possible. - If `None`, use the system recommended implementation. + fused: if `None` or `True`, use a faster, fused implementation if possible. + If `False`, use the system recommended implementation. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. @@ -779,7 +779,7 @@ def batch_norm(inputs, else: if data_format == DATA_FORMAT_NCHW: mean, variance = nn.weighted_moments( - inputs, moments_axes, batch_weights, keep_dims=True) + inputs, moments_axes, batch_weights, keepdims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: @@ -1415,10 +1415,11 @@ def dense_to_sparse(tensor, eos_token=0, outputs_collections=None, scope=None): outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. """ - with variable_scope.variable_scope( - scope, 'dense_to_sparse', [tensor]) as sc: + with variable_scope.variable_scope(scope, 'dense_to_sparse', [tensor]) as sc: tensor = ops.convert_to_tensor(tensor) - indices = array_ops.where(math_ops.not_equal(tensor, constant_op.constant(eos_token, tensor.dtype))) + indices = array_ops.where( + math_ops.not_equal(tensor, constant_op.constant(eos_token, + tensor.dtype))) values = array_ops.gather_nd(tensor, indices) shape = array_ops.shape(tensor, out_type=dtypes.int64) outputs = sparse_tensor.SparseTensor(indices, values, shape) @@ -2185,6 +2186,36 @@ def layer_norm(inputs, return utils.collect_named_outputs(outputs_collections, sc.name, outputs) +@add_arg_scope +def images_to_sequence(inputs, + data_format=DATA_FORMAT_NHWC, + outputs_collections=None, + scope=None): + """Convert a batch of images into a batch of sequences. + Args: + inputs: a (num_images, height, width, depth) tensor + data_format: A string. `NHWC` (default) and `NCHW` are supported. + outputs_collections: The collections to which the outputs are added. + scope: Optional scope for name_scope. + Returns: + (width, num_images*height, depth) sequence tensor + """ + if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): + raise ValueError('data_format has to be either NCHW or NHWC.') + with ops.name_scope(scope, 'ImagesToSequence', [inputs]) as sc: + inputs = ops.convert_to_tensor(inputs) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + if df == 'channels_first': + inputs = array_ops.transpose(inputs, [0, 2, 3, 1]) + _, _, width, depth = inputs.get_shape().as_list() + s = array_ops.shape(inputs) + batch_size, height = s[0], s[1] + transposed = array_ops.transpose(inputs, [2, 0, 1, 3]) + outputs = array_ops.reshape(transposed, [width, batch_size * height, depth]) + return utils.collect_named_outputs(outputs_collections, sc, outputs) + + @add_arg_scope def max_pool2d(inputs, kernel_size, @@ -2664,6 +2695,39 @@ def separable_convolution2d( return utils.collect_named_outputs(outputs_collections, sc.name, outputs) +@add_arg_scope +def sequence_to_images(inputs, + height, + output_data_format='channels_last', + outputs_collections=None, + scope=None): + """Convert a batch of sequences into a batch of images. + Args: + inputs: (num_steps, num_batches, depth) sequence tensor + height: the height of the images + output_data_format: Format of output tensor. + Currently supports `'channels_first'` and `'channels_last'`. + outputs_collections: The collections to which the outputs are added. + scope: Optional scope for name_scope. + Returns: + A tensor representing the output of the operation. + """ + with ops.name_scope(scope, 'SequenceToImages', [inputs]) as sc: + inputs = ops.convert_to_tensor(inputs) + width, num_batches, depth = inputs.get_shape().as_list() + if num_batches is None: + num_batches = -1 + else: + num_batches = num_batches // height + reshaped = array_ops.reshape(inputs, + [width, num_batches, height, depth]) + if output_data_format == 'channels_first': + outputs = array_ops.transpose(reshaped, [1, 3, 2, 0]) + else: + outputs = array_ops.transpose(reshaped, [1, 2, 0, 3]) + return utils.collect_named_outputs(outputs_collections, sc, outputs) + + @add_arg_scope def softmax(logits, scope=None): """Performs softmax on Nth dimension of N-dimensional logit tensor. @@ -2774,9 +2838,9 @@ def spatial_softmax(features, softmax_attention = nn.softmax(features / temperature) expected_x = math_ops.reduce_sum( - pos_x * softmax_attention, [1], keep_dims=True) + pos_x * softmax_attention, [1], keepdims=True) expected_y = math_ops.reduce_sum( - pos_y * softmax_attention, [1], keep_dims=True) + pos_y * softmax_attention, [1], keepdims=True) expected_xy = array_ops.concat([expected_x, expected_y], 1) feature_keypoints = array_ops.reshape(expected_xy, [-1, num_channels.value * 2]) @@ -2909,7 +2973,7 @@ def poincare_normalize(x, axis=1, epsilon=1e-5, name=None): """ with ops.name_scope(name, 'poincare_normalize', [x]) as name: x = ops.convert_to_tensor(x, name='x') - square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keep_dims=True) + square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) x_inv_norm = math_ops.rsqrt(square_sum) x_inv_norm = math_ops.minimum((1. - epsilon) * x_inv_norm, 1.) return math_ops.multiply(x, x_inv_norm, name=name) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index c5790c76221848524a106f1a218922f4e7a0b7e6..0f062adbab3ca9acfb89543b69c7c957bbdf5dd8 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -127,8 +127,8 @@ class AvgPool3DTest(test.TestCase): def testInvalidDataFormat(self): depth, height, width = 3, 6, 9 images = np.random.uniform(size=(5, depth, height, width, 3)) - with self.assertRaisesRegexp(ValueError, - 'data_format has to be either NCDHW or NDHWC.'): + with self.assertRaisesRegexp( + ValueError, 'data_format has to be either NCDHW or NDHWC.'): _layers.avg_pool3d(images, [3, 3, 3], data_format='CDHWN') def testCreateAvgPool(self): @@ -148,7 +148,8 @@ class AvgPool3DTest(test.TestCase): def testCollectOutputs(self): depth, height, width = 3, 6, 9 images = random_ops.random_uniform((5, depth, height, width, 3), seed=1) - output = _layers.avg_pool3d(images, [3, 3, 3], outputs_collections='outputs') + output = _layers.avg_pool3d( + images, [3, 3, 3], outputs_collections='outputs') output_collected = ops.get_collection('outputs')[0] self.assertEqual(output_collected.aliases, ['AvgPool3D']) self.assertEqual(output_collected, output) @@ -183,7 +184,8 @@ class AvgPool3DTest(test.TestCase): depth, height, width = 3, 6, 9 images = random_ops.random_uniform((5, depth, height, width, 3), seed=1) output = _layers.avg_pool3d(images, [3, 3, 3], stride=1, padding='SAME') - self.assertListEqual(output.get_shape().as_list(), [5, depth, height, width, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, depth, height, width, 3]) def testGlobalAvgPool(self): depth, height, width = 3, 6, 9 @@ -515,7 +517,9 @@ class ConvolutionTest(test.TestCase): with arg_scope( [layers_lib.convolution2d], normalizer_fn=_layers.batch_norm, - normalizer_params={'decay': 0.9}): + normalizer_params={ + 'decay': 0.9 + }): net = layers_lib.convolution2d(images, 32, [3, 3]) net = layers_lib.convolution2d(net, 32, [3, 3]) self.assertEqual(len(variables.get_variables()), 8) @@ -529,7 +533,9 @@ class ConvolutionTest(test.TestCase): with arg_scope( [layers_lib.convolution2d], normalizer_fn=_layers.batch_norm, - normalizer_params={'decay': 0.9}): + normalizer_params={ + 'decay': 0.9 + }): net = layers_lib.convolution2d(images, 32, [3, 3], scope='Conv') net = layers_lib.convolution2d( net, 32, [3, 3], scope='Conv', reuse=True) @@ -702,7 +708,7 @@ class Convolution2dTransposeTests(test.TestCase): _layers.convolution2d_transpose(images, 32, 3, data_format='CHWN') def testOutputSizeWithStrideOneSamePaddingNCHW(self): - # `NCHW` data fomat is only supported for `GPU` device. + # `NCHW` data format is only supported for `GPU` device. if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 32 @@ -1031,7 +1037,8 @@ class Convolution2dTransposeTests(test.TestCase): for _ in range(10): num_filters = 1 input_size = [ - 1, np.random.randint(1, max_image_size), + 1, + np.random.randint(1, max_image_size), np.random.randint(1, max_image_size), 1 ] filter_size = [ @@ -1185,8 +1192,10 @@ class ConvolutionInPlaneTest(test.TestCase): with self.test_session() as sess: sess.run(init_op) - result = sess.run(horz_gradients, - feed_dict={image: np.ones((1, 10, 10, 1))}) + result = sess.run( + horz_gradients, feed_dict={ + image: np.ones((1, 10, 10, 1)) + }) expected = np.zeros((1, 10, 9, 1)) self.assertAllEqual(result, expected) @@ -1299,11 +1308,13 @@ class DenseToSparseTest(test.TestCase): expected_constant = np.reshape(np.arange(24, dtype=np.int64), (3, 4, 2)) tensor = constant_op.constant(expected_constant) sparse = _layers.dense_to_sparse(tensor) - dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape, sparse.values) + dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape, + sparse.values) with self.test_session() as sess: constant = sess.run(dense) self.assertAllEqual(expected_constant, constant) + class DropoutTest(test.TestCase): def testCreateDropout(self): @@ -1418,8 +1429,7 @@ class FlattenTest(test.TestCase): with ops.Graph().as_default() as g, self.test_session(g): inputs = array_ops.placeholder(dtype=dtypes.float32) inputs.set_shape(tensor_shape.TensorShape((5,))) - with self.assertRaisesRegexp(ValueError, - 'incompatible with the layer'): + with self.assertRaisesRegexp(ValueError, 'incompatible with the layer'): _layers.flatten(inputs) def testUnknownLastDim(self): @@ -1729,7 +1739,9 @@ class FCTest(test.TestCase): with arg_scope( [_layers.fully_connected], normalizer_fn=_layers.batch_norm, - normalizer_params={'decay': 0.9}): + normalizer_params={ + 'decay': 0.9 + }): net = _layers.fully_connected(images, 27) net = _layers.fully_connected(net, 27) self.assertEqual(len(variables.get_variables()), 8) @@ -1745,7 +1757,9 @@ class FCTest(test.TestCase): with arg_scope( [_layers.fully_connected], normalizer_fn=_layers.batch_norm, - normalizer_params={'decay': 0.9}): + normalizer_params={ + 'decay': 0.9 + }): net = _layers.fully_connected(images, 27, scope='fc1') net = _layers.fully_connected(net, 27, scope='fc1', reuse=True) self.assertEqual(len(variables.get_variables()), 4) @@ -1762,8 +1776,8 @@ class BatchNormTest(test.TestCase): def testBatchNormCenterFalse(self): a = array_ops.placeholder(dtype=dtypes.float32, shape=(10, 10, 10, 10)) # Test that center=False builds a valid graph. - _layers.batch_norm(a, center=False, data_format='NCHW', - zero_debias_moving_mean=True) + _layers.batch_norm( + a, center=False, data_format='NCHW', zero_debias_moving_mean=True) def testUnknownShape(self): with ops.Graph().as_default() as g, self.test_session(g): @@ -1800,8 +1814,8 @@ class BatchNormTest(test.TestCase): images = np.random.uniform(size=(5, height, width, 3)).astype( dtype.as_numpy_dtype) output = _layers.batch_norm(images, fused=fused) - expected_name = ('BatchNorm/FusedBatchNorm' if fused else - 'BatchNorm/batchnorm') + expected_name = ('BatchNorm/FusedBatchNorm' + if fused else 'BatchNorm/batchnorm') self.assertTrue(output.op.name.startswith(expected_name)) self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3]) self.assertEqual( @@ -2020,8 +2034,8 @@ class BatchNormTest(test.TestCase): expected_var = np.var(image_values, axis=axis) if fused: # Add Bessel's correction - expected_var, _ = self._addBesselsCorrection(batch_size * height * - width, expected_var) + expected_var, _ = self._addBesselsCorrection( + batch_size * height * width, expected_var) images = constant_op.constant( image_values, shape=image_shape, dtype=dtypes.float32) output = _layers.batch_norm( @@ -2182,7 +2196,7 @@ class BatchNormTest(test.TestCase): # After initialization moving_mean == 0 and moving_variance == 1. self.assertAllClose(mean, [0] * 3) self.assertAllClose(variance, [1] * 3) - # Simulate assigment from saver restore. + # Simulate assignment from saver restore. init_assigns = [ state_ops.assign(moving_mean, expected_mean), state_ops.assign(moving_variance, expected_var) @@ -2540,8 +2554,8 @@ class BatchNormTest(test.TestCase): expected_var = np.var(image_values, axis=axis) if fused: # Add Bessel's correction - expected_var, _ = self._addBesselsCorrection(batch_size * height * - width, expected_var) + expected_var, _ = self._addBesselsCorrection( + batch_size * height * width, expected_var) images = constant_op.constant( image_values, shape=image_shape, dtype=dtypes.float32) output = _layers.batch_norm( @@ -2571,8 +2585,9 @@ class BatchNormTest(test.TestCase): np_output, new_images_gradients = sess.run([output, images_gradients]) # The outputs should be close to 0.0 mean and 1.0 variance self.assertAllClose( - np.mean( - np_output, axis=axis), [0] * channels, rtol=0.001, atol=0.001) + np.mean(np_output, axis=axis), [0] * channels, + rtol=0.001, + atol=0.001) self.assertAllClose( np.var(np_output, axis=axis), [1] * channels, rtol=0.01, atol=0.01) # The gradients should change slowly while updating moving_mean. @@ -2600,14 +2615,14 @@ class BatchNormTest(test.TestCase): channels = 3 with self.test_session() as sess: images = (np.ones((5, height, width, channels)) * 9.0).astype('f') - beta = init_ops.constant_initializer((np.ones(channels) * 5.0).astype( - 'f')) - gamma = init_ops.constant_initializer((np.ones(channels) * 2.0).astype( - 'f')) - mean = init_ops.constant_initializer((np.ones(channels) * 5.0).astype( - 'f')) - variance = init_ops.constant_initializer((np.ones(channels) * 4.0).astype( - 'f')) + beta = init_ops.constant_initializer( + (np.ones(channels) * 5.0).astype('f')) + gamma = init_ops.constant_initializer( + (np.ones(channels) * 2.0).astype('f')) + mean = init_ops.constant_initializer( + (np.ones(channels) * 5.0).astype('f')) + variance = init_ops.constant_initializer( + (np.ones(channels) * 4.0).astype('f')) output = _layers.batch_norm( images, is_training=False, @@ -2628,21 +2643,18 @@ class BatchNormTest(test.TestCase): with self.test_session(use_gpu=True) as sess: images = np.arange(np.product(shape), dtype=np.float32).reshape(shape) beta = init_ops.constant_initializer( - np.arange( - 2, channels + 2, dtype=np.float32)) + np.arange(2, channels + 2, dtype=np.float32)) gamma = init_ops.constant_initializer( - np.arange( - 10, channels + 10, dtype=np.float32) * 2.0) + np.arange(10, channels + 10, dtype=np.float32) * 2.0) mean = init_ops.constant_initializer( - np.arange( - 3, channels + 3, dtype=np.float32) * 5.0) + np.arange(3, channels + 3, dtype=np.float32) * 5.0) variance = init_ops.constant_initializer( - np.arange( - 1, channels + 1, dtype=np.float32) * 4.0) + np.arange(1, channels + 1, dtype=np.float32) * 4.0) if data_format == 'NCHW': # Reshape inputs from NHWC to NCHW format. images = array_ops.transpose( - images, [0, len(shape) - 1] + list(range(1, len(shape) - 1))) + images, [0, len(shape) - 1] + list(range(1, + len(shape) - 1))) output = _layers.batch_norm( images, is_training=is_training, @@ -2745,16 +2757,16 @@ class BatchNormTest(test.TestCase): # Tests that the adjustment is appropriately passed to and used by the core # BN layer. all_adjustments = [] + def _create_adjustment(shape): adjustments = [array_ops.ones(shape[-1:]), array_ops.zeros(shape[-1:])] all_adjustments.extend(adjustments) return adjustments + depth = 8 images = array_ops.zeros([10, 5, 5, depth]) output = _layers.batch_norm( - images, - is_training=True, - adjustment=_create_adjustment) + images, is_training=True, adjustment=_create_adjustment) self.assertListEqual(output.shape.as_list(), images.shape.as_list()) self.assertEqual(len(all_adjustments), 2) self.assertListEqual(all_adjustments[0].shape.as_list(), [depth]) @@ -2819,7 +2831,10 @@ class LayerNormTest(test.TestCase): # output_train and output_eval should be the same. self.assertAllClose(sess.run([output_train]), sess.run([output_eval])) - def doOutputTest(self, input_shape, tol=1e-5, begin_norm_axis=1, + def doOutputTest(self, + input_shape, + tol=1e-5, + begin_norm_axis=1, dtype=dtypes.float64): expected_mean = np.zeros(input_shape[:begin_norm_axis]) expected_var = np.ones(input_shape[:begin_norm_axis]) @@ -2850,13 +2865,10 @@ class LayerNormTest(test.TestCase): # Layer-norm implemented in numpy eps = 1e-12 expected_out = ( - (gamma * ( - input_values - - np.mean(input_values, axis=moments_axis, keepdims=True)) - / np.sqrt( - eps - + np.var(input_values, axis=moments_axis, keepdims=True))) - + beta) + (gamma * (input_values - np.mean( + input_values, axis=moments_axis, keepdims=True)) / + np.sqrt(eps + np.var( + input_values, axis=moments_axis, keepdims=True))) + beta) self.assertAllClose(expected_mean, mean, atol=tol, rtol=tol) self.assertAllClose(expected_var, var, atol=tol) # The full computation gets a bigger tolerance @@ -2874,10 +2886,10 @@ class LayerNormTest(test.TestCase): def testOutput4DInputNormOnInnermostAxis(self): # Equivalent tests - self.doOutputTest((100, 10, 10, 3), begin_norm_axis=3, tol=1e-4, - dtype=dtypes.float64) - self.doOutputTest((100, 10, 10, 3), begin_norm_axis=-1, tol=1e-4, - dtype=dtypes.float64) + self.doOutputTest( + (100, 10, 10, 3), begin_norm_axis=3, tol=1e-4, dtype=dtypes.float64) + self.doOutputTest( + (100, 10, 10, 3), begin_norm_axis=-1, tol=1e-4, dtype=dtypes.float64) def testOutputSmallInput(self): self.doOutputTest((10, 10, 10, 30)) @@ -2914,7 +2926,7 @@ class GDNTest(test.TestCase): x = np.random.uniform(size=(1, 2, 3, 4)[:ndim]) y = self._runGDN(x, x.shape, False, 'channels_last') self.assertEqual(x.shape, y.shape) - self.assertAllClose(y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6) + self.assertAllClose(y, x / np.sqrt(1 + .1 * (x**2)), rtol=0, atol=1e-6) def testChannelsFirst(self): # `bias_add` doesn't support NCHW on CPU. @@ -2923,8 +2935,7 @@ class GDNTest(test.TestCase): x = np.random.uniform(size=(4, 3, 2, 1)[:ndim]) y = self._runGDN(x, x.shape, False, 'channels_first') self.assertEqual(x.shape, y.shape) - self.assertAllClose( - y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6) + self.assertAllClose(y, x / np.sqrt(1 + .1 * (x**2)), rtol=0, atol=1e-6) def testWrongDims(self): for ndim in [1, 2, 6]: @@ -2936,7 +2947,29 @@ class GDNTest(test.TestCase): x = np.random.uniform(size=(1, 2, 3, 4)) y = self._runGDN(x, x.shape, True, 'channels_last') self.assertEqual(x.shape, y.shape) - self.assertAllClose(y, x * np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6) + self.assertAllClose(y, x * np.sqrt(1 + .1 * (x**2)), rtol=0, atol=1e-6) + + +class ImagesToSequenceTest(test.TestCase): + + def testInvalidDataFormat(self): + height, width = 7, 11 + images = np.random.uniform(size=(5, height, width, 2)) + with self.assertRaisesRegexp(ValueError, + 'data_format has to be either NCHW or NHWC.'): + _layers.images_to_sequence(images, data_format='CHWN') + + def testImagesToSequenceDims(self): + height, width = 7, 11 + images = np.random.uniform(size=(2, height, width, 5)).astype(np.float32) + output = _layers.images_to_sequence(images) + self.assertListEqual(output.get_shape().as_list(), [11, 14, 5]) + + def testImagesToSequenceNCHW(self): + height, width = 7, 11 + images = np.random.uniform(size=(2, 5, height, width)).astype(np.float32) + output = _layers.images_to_sequence(images, data_format='NCHW') + self.assertListEqual(output.get_shape().as_list(), [11, 14, 5]) class MaxPool2DTest(test.TestCase): @@ -3013,20 +3046,22 @@ class MaxPool3DTest(test.TestCase): def testInvalidDataFormat(self): depth, height, width = 3, 6, 9 images = np.random.uniform(size=(5, depth, height, width, 3)) - with self.assertRaisesRegexp(ValueError, - 'data_format has to be either NCDHW or NDHWC.'): + with self.assertRaisesRegexp( + ValueError, 'data_format has to be either NCDHW or NDHWC.'): _layers.max_pool3d(images, [3, 3, 3], data_format='CDHWN') def testCreateMaxPool(self): depth, height, width = 3, 6, 9 - images = np.random.uniform(size=(5, depth, height, width, 3)).astype(np.float32) + images = np.random.uniform(size=(5, depth, height, width, 3)).astype( + np.float32) output = _layers.max_pool3d(images, [3, 3, 3]) self.assertEqual(output.op.name, 'MaxPool3D/MaxPool3D') self.assertListEqual(output.get_shape().as_list(), [5, 1, 2, 4, 3]) def testCreateMaxPoolNCDHW(self): depth, height, width = 3, 6, 9 - images = np.random.uniform(size=(5, 3, depth, height, width)).astype(np.float32) + images = np.random.uniform(size=(5, 3, depth, height, width)).astype( + np.float32) output = _layers.max_pool3d(images, [3, 3, 3], data_format='NCDHW') self.assertEquals(output.op.name, 'MaxPool3D/transpose_1') self.assertListEqual(output.get_shape().as_list(), [5, 3, 1, 2, 4]) @@ -3034,7 +3069,8 @@ class MaxPool3DTest(test.TestCase): def testCollectOutputs(self): depth, height, width = 3, 6, 9 images = random_ops.random_uniform((5, depth, height, width, 3), seed=1) - output = _layers.max_pool3d(images, [3, 3, 3], outputs_collections='outputs') + output = _layers.max_pool3d( + images, [3, 3, 3], outputs_collections='outputs') output_collected = ops.get_collection('outputs')[0] self.assertEqual(output_collected.aliases, ['MaxPool3D']) self.assertEqual(output_collected, output) @@ -3069,7 +3105,8 @@ class MaxPool3DTest(test.TestCase): depth, height, width = 3, 6, 9 images = random_ops.random_uniform((5, depth, height, width, 3), seed=1) output = _layers.max_pool3d(images, [3, 3, 3], stride=1, padding='SAME') - self.assertListEqual(output.get_shape().as_list(), [5, depth, height, width, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, depth, height, width, 3]) def testGlobalMaxPool(self): depth, height, width = 3, 6, 9 @@ -3403,6 +3440,33 @@ class ScaleGradientTests(test.TestCase): np.testing.assert_array_equal([3 * 2], g_x.eval()) +class SequenceToImagesTest(test.TestCase): + + def testImagesToSequenceDims(self): + num_batches = 14 + num_time_steps = 11 + num_channels = 5 + desired_height = 7 + sequence = np.random.uniform(size=(num_time_steps, + num_batches, + num_channels)).astype(np.float32) + output = _layers.sequence_to_images(sequence, desired_height) + self.assertListEqual(output.get_shape().as_list(), [2, 7, 11, 5]) + + def testImagesToSequenceNCHW(self): + num_batches = 14 + num_time_steps = 11 + num_channels = 5 + desired_height = 7 + sequence = np.random.uniform(size=(num_time_steps, + num_batches, + num_channels)).astype(np.float32) + output = _layers.sequence_to_images(sequence, + desired_height, + output_data_format='channels_first') + self.assertListEqual(output.get_shape().as_list(), [2, 5, 7, 11]) + + class SoftmaxTests(test.TestCase): def setUp(self): @@ -3481,8 +3545,7 @@ class SpatialSoftmaxTests(test.TestCase): sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} keypoints = sess.run(spatial_softmax, feed_dict) - self.assertAllEqual(keypoints.shape, - (batch_shape[0], batch_shape[3] * 2)) + self.assertAllEqual(keypoints.shape, (batch_shape[0], batch_shape[3] * 2)) def testSpatialSoftmaxShapeNCHW(self): batch_shape = (2, 2, 35, 35) @@ -3493,8 +3556,7 @@ class SpatialSoftmaxTests(test.TestCase): sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} keypoints = sess.run(spatial_softmax, feed_dict) - self.assertAllEqual(keypoints.shape, - (batch_shape[0], batch_shape[1] * 2)) + self.assertAllEqual(keypoints.shape, (batch_shape[0], batch_shape[1] * 2)) def testTwoMaxActivationsSameChannel(self): batch_size, height, width, nchannels = (2, 35, 35, 1) @@ -3513,8 +3575,8 @@ class SpatialSoftmaxTests(test.TestCase): x_loc = [avg_x] y_loc = [avg_y] - np_keypoints = self._SpatialSoftmax( - x_loc, y_loc, height, width, batch_size, nchannels) + np_keypoints = self._SpatialSoftmax(x_loc, y_loc, height, width, batch_size, + nchannels) # Make sure expected location keypoints matches actual location keypoints. with self.test_session() as sess: @@ -3532,13 +3594,13 @@ class SpatialSoftmaxTests(test.TestCase): spatial_softmax = _layers.spatial_softmax(features) np_features = np.zeros(batch_shape, dtype=np.float32) - edges = [(0, 0), (0, width-1), (height-1, 0), (height-1, width-1)] + edges = [(0, 0), (0, width - 1), (height - 1, 0), (height - 1, width - 1)] x_loc, y_loc = zip(*edges) for c in range(nchannels): np_features[:, x_loc[c], y_loc[c], c] = 100. - np_keypoints = self._SpatialSoftmax( - x_loc, y_loc, height, width, batch_size, nchannels) + np_keypoints = self._SpatialSoftmax(x_loc, y_loc, height, width, batch_size, + nchannels) # Make sure expected location keypoints matches actual location keypoints. with self.test_session() as sess: @@ -3567,10 +3629,10 @@ class SpatialSoftmaxTests(test.TestCase): np_features1[:, x_loc[c], y_loc[c], c] = 100. np_features2[:, x_loc[c], y_loc[c], c] = 100. - np_keypoints1 = self._SpatialSoftmax( - x_loc, y_loc, height1, width1, batch_size, nchannels) - np_keypoints2 = self._SpatialSoftmax( - x_loc, y_loc, height2, width2, batch_size, nchannels) + np_keypoints1 = self._SpatialSoftmax(x_loc, y_loc, height1, width1, + batch_size, nchannels) + np_keypoints2 = self._SpatialSoftmax(x_loc, y_loc, height2, width2, + batch_size, nchannels) # Make sure expected location keypoints matches actual location keypoints. with self.test_session() as sess: @@ -3596,8 +3658,8 @@ class SpatialSoftmaxTests(test.TestCase): for c in range(nchannels): np_features[:, x_loc[c], y_loc[c], c] = 100. - np_keypoints = self._SpatialSoftmax( - x_loc, y_loc, height, width, batch_size, nchannels) + np_keypoints = self._SpatialSoftmax(x_loc, y_loc, height, width, batch_size, + nchannels) # Make sure expected location keypoints matches actual location keypoints. with self.test_session() as sess: @@ -3619,8 +3681,8 @@ class SpatialSoftmaxTests(test.TestCase): for c in range(nchannels): np_features[:, c, x_loc[c], y_loc[c]] = 100. - np_keypoints = self._SpatialSoftmax( - x_loc, y_loc, height, width, batch_size, nchannels) + np_keypoints = self._SpatialSoftmax(x_loc, y_loc, height, width, batch_size, + nchannels) # Make sure expected location keypoints matches actual location keypoints. with self.test_session() as sess: @@ -3715,8 +3777,7 @@ class UnitNormTests(test.TestCase): image = random_ops.random_uniform((height, width, 3)) output = _layers.unit_norm(image, dim=dim, epsilon=1e-6) norms = math_ops.sqrt( - math_ops.reduce_sum( - math_ops.square(output), reduction_indices=dim)) + math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) shape = [height, width, 3] del shape[dim] @@ -3752,8 +3813,7 @@ class UnitNormTests(test.TestCase): image = array_ops.placeholder(dtypes.float32, (None, None, 3)) output = _layers.unit_norm(image, dim=dim, epsilon=1e-6) norms = math_ops.sqrt( - math_ops.reduce_sum( - math_ops.square(output), reduction_indices=dim)) + math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) with self.test_session(): actual = norms.eval({image: placeholder_value}) @@ -3817,8 +3877,8 @@ class PoincareNormalizeTest(test.TestCase): with self.test_session(): x_tf = constant_op.constant(x_np, name='x') y_tf = _layers.poincare_normalize(x_tf, dim) - err = gradient_checker.compute_gradient_error(x_tf, x_shape, - y_tf, x_shape) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) print('PoinCareNormalize gradient err = %g ' % err) self.assertLess(err, 1e-4) @@ -3830,14 +3890,9 @@ class LegacyFullyConnectedTest(test.TestCase): test.TestCase.setUp(self) random_seed.set_random_seed(1234) self.input = constant_op.constant([[1., 2., 3.], [-4., 15., -6.]]) - self.input_3_dim_arr = [[[1., 1.1, 1.2], - [2., 2.1, 2.2], - [3., 3.1, 3.2], - [4., 4.1, 4.2]], - [[5., 5.1, 5.2], - [6., 6.1, 6.2], - [7., 7.1, 7.2], - [8., 8.1, 8.2]]] + self.input_3_dim_arr = [[[1., 1.1, 1.2], [2., 2.1, 2.2], [3., 3.1, 3.2], + [4., 4.1, 4.2]], [[5., 5.1, 5.2], [6., 6.1, 6.2], + [7., 7.1, 7.2], [8., 8.1, 8.2]]] self.input_3_dim = constant_op.constant(self.input_3_dim_arr) assert not ops.get_collection(ops.GraphKeys.SUMMARIES) @@ -3932,15 +3987,10 @@ class LegacyFullyConnectedTest(test.TestCase): self._custom_initializers(self.input, 2, [[13.0, 13.0], [11.0, 11.0]]) def test_custom_initializers_multi_dim(self): - self._custom_initializers(self.input_3_dim, 2, - [[[7.6, 7.6], - [13.6, 13.6], - [19.6, 19.6], - [25.6, 25.6]], - [[31.6, 31.6], - [37.6, 37.6], - [43.6, 43.6], - [49.6, 49.6]]]) + self._custom_initializers( + self.input_3_dim, 2, + [[[7.6, 7.6], [13.6, 13.6], [19.6, 19.6], [25.6, 25.6]], + [[31.6, 31.6], [37.6, 37.6], [43.6, 43.6], [49.6, 49.6]]]) def test_custom_collections(self): layers_lib.legacy_relu( @@ -4050,12 +4100,16 @@ class LegacyFullyConnectedTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() # we can feed in input with first dimension 2 - shape_value = sess.run(array_ops.shape(y), - feed_dict={x: self.input_3_dim_arr}) + shape_value = sess.run( + array_ops.shape(y), feed_dict={ + x: self.input_3_dim_arr + }) self.assertAllClose(shape_value, [2, 4, 1]) # we can feed in input with first dimension 1 - shape_value = sess.run(array_ops.shape(y), - feed_dict={x: [self.input_3_dim_arr[0]]}) + shape_value = sess.run( + array_ops.shape(y), feed_dict={ + x: [self.input_3_dim_arr[0]] + }) self.assertAllClose(shape_value, [1, 4, 1]) # we cannot feed in input with inconsistent dimensions with self.assertRaises(ValueError): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 3c782b54a8559a6aac19d12ea11a9c76bffdb9c3..abf6e393bb0fbbce4e43f6d209e9b30517df36c3 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -388,6 +388,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lookup_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/learn/python/learn/datasets/__init__.py b/tensorflow/contrib/learn/python/learn/datasets/__init__.py index a3521b4109ab40d8478f20afc317cf5154da2b43..7240b0de149051afa045a8113f9e9b212840c311 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/__init__.py +++ b/tensorflow/contrib/learn/python/learn/datasets/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Dataset utilities and synthetic/reference datasets.""" from __future__ import absolute_import @@ -46,11 +45,12 @@ DATASETS = { # List of all synthetic datasets SYNTHETIC = { - # All of these will return ['data', 'target'] -> base.Dataset - 'circles': synthetic.circles, - 'spirals': synthetic.spirals + # All of these will return ['data', 'target'] -> base.Dataset + 'circles': synthetic.circles, + 'spirals': synthetic.spirals } + def load_dataset(name, size='small', test_with_fake_data=False): """Loads dataset by name. @@ -83,23 +83,28 @@ def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs): seed: int or None, seed for noise Returns: - Shuffled features and labels for given synthetic dataset of type `base.Dataset` + Shuffled features and labels for given synthetic dataset of type + `base.Dataset` Raises: ValueError: Raised if `name` not found Note: - - This is a generic synthetic data generator - individual generators might have more parameters! + - This is a generic synthetic data generator - individual generators might + have more parameters! See documentation for individual parameters - - Note that the `noise` parameter uses `numpy.random.normal` and depends on `numpy`'s seed + - Note that the `noise` parameter uses `numpy.random.normal` and depends on + `numpy`'s seed TODO: - Support multiclass datasets - - Need shuffling routine. Currently synthetic datasets are reshuffled to avoid train/test correlation, + - Need shuffling routine. Currently synthetic datasets are reshuffled to + avoid train/test correlation, but that hurts reprodusability """ # seed = kwargs.pop('seed', None) if name not in SYNTHETIC: raise ValueError('Synthetic dataset not found or not implemeted: %s' % name) else: - return SYNTHETIC[name](n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs) + return SYNTHETIC[name]( + n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs) diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py index 71978d439449e29c7cb907b18bab5d6659a972b6..ca720ae5ed26e74da12bd6c5a37231b41442f76f 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/base.py +++ b/tensorflow/contrib/learn/python/learn/datasets/base.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Base utilities for loading datasets.""" from __future__ import absolute_import @@ -24,13 +23,11 @@ import csv import os from os import path import random -import tempfile import time import numpy as np from six.moves import urllib -from tensorflow.contrib.framework import deprecated from tensorflow.python.platform import gfile Dataset = collections.namedtuple('Dataset', ['data', 'target']) @@ -100,9 +97,7 @@ def load_iris(data_path=None): module_path = path.dirname(__file__) data_path = path.join(module_path, 'data', 'iris.csv') return load_csv_with_header( - data_path, - target_dtype=np.int, - features_dtype=np.float) + data_path, target_dtype=np.int, features_dtype=np.float) def load_boston(data_path=None): @@ -118,16 +113,10 @@ def load_boston(data_path=None): module_path = path.dirname(__file__) data_path = path.join(module_path, 'data', 'boston_house_prices.csv') return load_csv_with_header( - data_path, - target_dtype=np.float, - features_dtype=np.float) + data_path, target_dtype=np.float, features_dtype=np.float) -def retry(initial_delay, - max_delay, - factor=2.0, - jitter=0.25, - is_retriable=None): +def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): """Simple decorator for wrapping retriable functions. Args: @@ -152,7 +141,7 @@ def retry(initial_delay, def delays(): delay = initial_delay while delay <= max_delay: - yield delay * random.uniform(1 - jitter, 1 + jitter) + yield delay * random.uniform(1 - jitter, 1 + jitter) delay *= factor def wrap(fn): @@ -172,7 +161,9 @@ def retry(initial_delay, else: raise return fn(*args, **kwargs) + return wrapped_fn + return wrap diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py index 1f3295747e141760445b021bf4f59cc47b88b8b2..37f9175015a239f763c7721cf36ab8063c0a3e32 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py +++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Functions for downloading and reading MNIST data.""" from __future__ import absolute_import @@ -123,8 +122,8 @@ class DataSet(object): numpy.random.seed(seed1 if seed is None else seed2) dtype = dtypes.as_dtype(dtype).base_dtype if dtype not in (dtypes.uint8, dtypes.float32): - raise TypeError('Invalid image dtype %r, expected uint8 or float32' % - dtype) + raise TypeError( + 'Invalid image dtype %r, expected uint8 or float32' % dtype) if fake_data: self._num_examples = 10000 self.one_hot = one_hot @@ -202,7 +201,9 @@ class DataSet(object): end = self._index_in_epoch images_new_part = self._images[start:end] labels_new_part = self._labels[start:end] - return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0) + return numpy.concatenate( + (images_rest_part, images_new_part), axis=0), numpy.concatenate( + (labels_rest_part, labels_new_part), axis=0) else: self._index_in_epoch += batch_size end = self._index_in_epoch @@ -257,16 +258,14 @@ def read_data_sets(train_dir, test_labels = extract_labels(f, one_hot=one_hot) if not 0 <= validation_size <= len(train_images): - raise ValueError( - 'Validation size should be between 0 and {}. Received: {}.' - .format(len(train_images), validation_size)) + raise ValueError('Validation size should be between 0 and {}. Received: {}.' + .format(len(train_images), validation_size)) validation_images = train_images[:validation_size] validation_labels = train_labels[:validation_size] train_images = train_images[validation_size:] train_labels = train_labels[validation_size:] - options = dict(dtype=dtype, reshape=reshape, seed=seed) train = DataSet(train_images, train_labels, **options) diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py index 907dc0f3dfced7e55c5f46711fbe93f6400e1de7..9a843168c27d9cae3f55efe4fe4c688d86c745f3 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py +++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Synthetic dataset generators.""" from __future__ import absolute_import @@ -23,18 +22,27 @@ import numpy as np from tensorflow.contrib.learn.python.learn.datasets.base import Dataset -def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args, **kwargs): + +def circles(n_samples=100, + noise=None, + seed=None, + factor=0.8, + n_classes=2, + *args, + **kwargs): """Create circles separated by some value Args: n_samples: int, number of datapoints to generate noise: float or None, standard deviation of the Gaussian noise added seed: int or None, seed for the noise - factor: float, size factor of the inner circles with respect to the outer ones + factor: float, size factor of the inner circles with respect to the outer + ones n_classes: int, number of classes to generate Returns: - Shuffled features and labels for 'circles' synthetic dataset of type `base.Dataset` + Shuffled features and labels for 'circles' synthetic dataset of type + `base.Dataset` Note: The multi-class support might not work as expected if `noise` is enabled @@ -54,7 +62,7 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args if seed is not None: np.random.seed(seed) # Algo: 1) Generate initial circle, 2) For ever class generate a smaller radius circle - linspace = np.linspace(0, 2*np.pi, n_samples // n_classes) + linspace = np.linspace(0, 2 * np.pi, n_samples // n_classes) circ_x = np.empty(0, dtype=np.int32) circ_y = np.empty(0, dtype=np.int32) base_cos = np.cos(linspace) @@ -66,12 +74,12 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args circ_y = np.append(circ_y, base_sin) base_cos *= factor base_sin *= factor - y = np.append(y, label*np.ones(n_samples // n_classes, dtype=np.int32)) + y = np.append(y, label * np.ones(n_samples // n_classes, dtype=np.int32)) # Add more points if n_samples is not divisible by n_classes (unbalanced!) extras = n_samples % n_classes - circ_x = np.append(circ_x, np.cos(np.random.rand(extras)*2*np.pi)) - circ_y = np.append(circ_y, np.sin(np.random.rand(extras)*2*np.pi)) + circ_x = np.append(circ_x, np.cos(np.random.rand(extras) * 2 * np.pi)) + circ_y = np.append(circ_y, np.sin(np.random.rand(extras) * 2 * np.pi)) y = np.append(y, np.zeros(extras, dtype=np.int32)) # Reshape the features/labels @@ -85,10 +93,13 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args return Dataset(data=X[indices], target=y[indices]) -def spirals(n_samples=100, noise=None, seed=None, - mode = 'archimedes', - n_loops = 2, - *args, **kwargs): +def spirals(n_samples=100, + noise=None, + seed=None, + mode='archimedes', + n_loops=2, + *args, + **kwargs): """Create spirals Currently only binary classification is supported for spiral generation @@ -104,7 +115,8 @@ def spirals(n_samples=100, noise=None, seed=None, 'fermat': a spiral with branch distances decreasing (sqrt) Returns: - Shuffled features and labels for 'spirals' synthetic dataset of type `base.Dataset` + Shuffled features and labels for 'spirals' synthetic dataset of type + `base.Dataset` Raises: ValueError: If the generation `mode` is not valid @@ -112,34 +124,35 @@ def spirals(n_samples=100, noise=None, seed=None, TODO: - Generation of unbalanced data """ - n_classes = 2 # I am not sure how to make it multiclass + n_classes = 2 # I am not sure how to make it multiclass _modes = { - 'archimedes': _archimedes_spiral, - 'bernoulli': _bernoulli_spiral, - 'fermat': _fermat_spiral + 'archimedes': _archimedes_spiral, + 'bernoulli': _bernoulli_spiral, + 'fermat': _fermat_spiral } if mode is None or mode not in _modes: - raise ValueError("Cannot generate spiral with mode %s"%mode) + raise ValueError('Cannot generate spiral with mode %s' % mode) if seed is not None: np.random.seed(seed) - linspace = np.linspace(0, 2*n_loops*np.pi, n_samples // n_classes) + linspace = np.linspace(0, 2 * n_loops * np.pi, n_samples // n_classes) spir_x = np.empty(0, dtype=np.int32) spir_y = np.empty(0, dtype=np.int32) y = np.empty(0, dtype=np.int32) for label in range(n_classes): - base_cos, base_sin = _modes[mode](linspace, label*np.pi, *args, **kwargs) + base_cos, base_sin = _modes[mode](linspace, label * np.pi, *args, **kwargs) spir_x = np.append(spir_x, base_cos) spir_y = np.append(spir_y, base_sin) - y = np.append(y, label*np.ones(n_samples // n_classes, dtype=np.int32)) + y = np.append(y, label * np.ones(n_samples // n_classes, dtype=np.int32)) # Add more points if n_samples is not divisible by n_classes (unbalanced!) extras = n_samples % n_classes if extras > 0: - x_exrta, y_extra = _modes[mode](np.random.rand(extras)*2*np.pi, *args, **kwargs) + x_extra, y_extra = _modes[mode](np.random.rand(extras) * 2 * np.pi, *args, + **kwargs) spir_x = np.append(spir_x, x_extra) spir_y = np.append(spir_y, y_extra) y = np.append(y, np.zeros(extras, dtype=np.int32)) @@ -162,7 +175,8 @@ def _archimedes_spiral(theta, theta_offset=0., *args, **kwargs): theta: array-like, angles from polar coordinates to be converted theta_offset: float, angle offset in radians (2*pi = 0) """ - x, y = theta*np.cos(theta + theta_offset), theta*np.sin(theta + theta_offset) + x, y = theta * np.cos(theta + theta_offset), theta * np.sin( + theta + theta_offset) x_norm = np.max(np.abs(x)) y_norm = np.max(np.abs(y)) x, y = x / x_norm, y / y_norm @@ -181,7 +195,8 @@ def _bernoulli_spiral(theta, theta_offset=0., *args, **kwargs): """ exp_scale = kwargs.pop('exp_scale', 0.1) - x, y = np.exp(exp_scale*theta)*np.cos(theta + theta_offset), np.exp(exp_scale*theta)*np.sin(theta + theta_offset) + x, y = np.exp(exp_scale * theta) * np.cos(theta + theta_offset), np.exp( + exp_scale * theta) * np.sin(theta + theta_offset) x_norm = np.max(np.abs(x)) y_norm = np.max(np.abs(y)) x, y = x / x_norm, y / y_norm @@ -195,7 +210,8 @@ def _fermat_spiral(theta, theta_offset=0., *args, **kwargs): theta: array-like, angles from polar coordinates to be converted theta_offset: float, angle offset in radians (2*pi = 0) """ - x, y = np.sqrt(theta)*np.cos(theta + theta_offset), np.sqrt(theta)*np.sin(theta + theta_offset) + x, y = np.sqrt(theta) * np.cos(theta + theta_offset), np.sqrt(theta) * np.sin( + theta + theta_offset) x_norm = np.max(np.abs(x)) y_norm = np.max(np.abs(y)) x, y = x / x_norm, y / y_norm diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py index 5340afab46eba957d6d612bb583983b627537547..5809995c8c7d8e72eb47ee88a72547bae7fd3594 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py +++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py @@ -24,12 +24,14 @@ from tensorflow.python.platform import test from tensorflow.contrib.learn.python.learn import datasets from tensorflow.contrib.learn.python.learn.datasets import synthetic + class SyntheticTest(test.TestCase): """Test synthetic dataset generation""" def test_make_dataset(self): """Test if the synthetic routine wrapper complains about the name""" - self.assertRaises(ValueError, datasets.make_dataset, name='_non_existing_name') + self.assertRaises( + ValueError, datasets.make_dataset, name='_non_existing_name') def test_all_datasets_callable(self): """Test if all methods inside the `SYNTHETIC` are callable""" @@ -52,9 +54,10 @@ class SyntheticTest(test.TestCase): """ n_samples = 100 n_classes = 2 - circ = synthetic.circles(n_samples = n_samples, noise = None, n_classes = n_classes) + circ = synthetic.circles( + n_samples=n_samples, noise=None, n_classes=n_classes) self.assertIsInstance(circ, datasets.base.Dataset) - self.assertTupleEqual(circ.data.shape, (n_samples,2)) + self.assertTupleEqual(circ.data.shape, (n_samples, 2)) self.assertTupleEqual(circ.target.shape, (n_samples,)) self.assertSetEqual(set(circ.target), set(range(n_classes))) @@ -67,17 +70,24 @@ class SyntheticTest(test.TestCase): """ seed = 42 noise = 0.1 - circ0 = synthetic.circles(n_samples = 100, noise = noise, n_classes = 2, seed = seed) - circ1 = synthetic.circles(n_samples = 100, noise = noise, n_classes = 2, seed = seed) + circ0 = synthetic.circles( + n_samples=100, noise=noise, n_classes=2, seed=seed) + circ1 = synthetic.circles( + n_samples=100, noise=noise, n_classes=2, seed=seed) np.testing.assert_array_equal(circ0.data, circ1.data) np.testing.assert_array_equal(circ0.target, circ1.target) - circ1 = synthetic.circles(n_samples = 100, noise = noise, n_classes = 2, seed = seed+1) - self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data, circ1.data) - self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.target, circ1.target) + circ1 = synthetic.circles( + n_samples=100, noise=noise, n_classes=2, seed=seed + 1) + self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data, + circ1.data) + self.assertRaises(AssertionError, np.testing.assert_array_equal, + circ0.target, circ1.target) - circ1 = synthetic.circles(n_samples = 100, noise = noise/2., n_classes = 2, seed = seed) - self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data, circ1.data) + circ1 = synthetic.circles( + n_samples=100, noise=noise / 2., n_classes=2, seed=seed) + self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data, + circ1.data) def test_spirals(self): """Test if the circles are generated correctly @@ -89,13 +99,14 @@ class SyntheticTest(test.TestCase): - returned `target` shape is (n_samples,) - set of unique classes range is [0, n_classes) """ - self.assertRaises(ValueError, synthetic.spirals, mode='_unknown_mode_spiral_') + self.assertRaises( + ValueError, synthetic.spirals, mode='_unknown_mode_spiral_') n_samples = 100 modes = ('archimedes', 'bernoulli', 'fermat') for mode in modes: - spir = synthetic.spirals(n_samples = n_samples, noise = None, mode = mode) + spir = synthetic.spirals(n_samples=n_samples, noise=None, mode=mode) self.assertIsInstance(spir, datasets.base.Dataset) - self.assertTupleEqual(spir.data.shape, (n_samples,2)) + self.assertTupleEqual(spir.data.shape, (n_samples, 2)) self.assertTupleEqual(spir.target.shape, (n_samples,)) self.assertSetEqual(set(spir.target), set(range(2))) @@ -110,18 +121,24 @@ class SyntheticTest(test.TestCase): noise = 0.1 modes = ('archimedes', 'bernoulli', 'fermat') for mode in modes: - spir0 = synthetic.spirals(n_samples = 1000, noise = noise, seed = seed) - spir1 = synthetic.spirals(n_samples = 1000, noise = noise, seed = seed) + spir0 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed) + spir1 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed) np.testing.assert_array_equal(spir0.data, spir1.data) np.testing.assert_array_equal(spir0.target, spir1.target) - spir1 = synthetic.spirals(n_samples = 1000, noise = noise, seed = seed+1) - self.assertRaises(AssertionError, np.testing.assert_array_equal, spir0.data, spir1.data) - self.assertRaises(AssertionError, np.testing.assert_array_equal, spir0.target, spir1.target) + spir1 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed + 1) + self.assertRaises(AssertionError, np.testing.assert_array_equal, + spir0.data, spir1.data) + self.assertRaises(AssertionError, np.testing.assert_array_equal, + spir0.target, spir1.target) + + spir1 = synthetic.spirals(n_samples=1000, noise=noise / 2., seed=seed) + self.assertRaises(AssertionError, np.testing.assert_array_equal, + spir0.data, spir1.data) - spir1 = synthetic.spirals(n_samples = 1000, noise = noise/2., seed = seed) - self.assertRaises(AssertionError, np.testing.assert_array_equal, spir0.data, spir1.data) + def test_spirals_synthetic(self): + synthetic.spirals(3) -if __name__ == "__main__": +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/debug_test.py b/tensorflow/contrib/learn/python/learn/estimators/debug_test.py index 6b125534a42c5cdde69773d99cefd6e7b2d60c9c..b968aeed1b7a11d522b531783f04f0104b37904f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/debug_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/debug_test.py @@ -44,7 +44,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.training import input as input_lib - NUM_EXAMPLES = 100 N_CLASSES = 5 # Cardinality of multiclass labels. LABEL_DIMENSION = 3 # Dimensionality of regression labels. @@ -52,8 +51,10 @@ LABEL_DIMENSION = 3 # Dimensionality of regression labels. def _train_test_split(features_and_labels): features, labels = features_and_labels - train_set = (features[:int(len(features) / 2)], labels[:int(len(features) / 2)]) - test_set = (features[int(len(features) / 2):], labels[int(len(features) / 2):]) + train_set = (features[:int(len(features) / 2)], + labels[:int(len(features) / 2)]) + test_set = (features[int(len(features) / 2):], + labels[int(len(features) / 2):]) return train_set, test_set @@ -86,17 +87,17 @@ class DebugClassifierTest(test.TestCase): (train_features, train_labels), (test_features, test_labels) = _train_test_split( [self.features, self.labels]) - majority_class, _ = max(collections.Counter(train_labels).items(), - key=operator.itemgetter(1)) + majority_class, _ = max( + collections.Counter(train_labels).items(), key=operator.itemgetter(1)) expected_prediction = np.vstack( [[majority_class] for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=N_CLASSES) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) - pred = classifier.predict_classes(input_fn=_input_fn_builder(test_features, - None)) + pred = classifier.predict_classes( + input_fn=_input_fn_builder(test_features, None)) self.assertAllEqual(expected_prediction, np.vstack(pred)) def testPredictBinary(self): @@ -105,34 +106,34 @@ class DebugClassifierTest(test.TestCase): test_labels) = _train_test_split( [self.features, self.binary_labels]) - majority_class, _ = max(collections.Counter(train_labels).items(), - key=operator.itemgetter(1)) + majority_class, _ = max( + collections.Counter(train_labels).items(), key=operator.itemgetter(1)) expected_prediction = np.vstack( [[majority_class] for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=2) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) - pred = classifier.predict_classes(input_fn=_input_fn_builder(test_features, - None)) + pred = classifier.predict_classes( + input_fn=_input_fn_builder(test_features, None)) self.assertAllEqual(expected_prediction, np.vstack(pred)) - (train_features, train_labels), ( - test_features, test_labels) = _train_test_split( - [self.features, self.binary_float_labels]) + (train_features, + train_labels), (test_features, test_labels) = _train_test_split( + [self.features, self.binary_float_labels]) - majority_class, _ = max(collections.Counter(train_labels).items(), - key=operator.itemgetter(1)) + majority_class, _ = max( + collections.Counter(train_labels).items(), key=operator.itemgetter(1)) expected_prediction = np.vstack( [[majority_class] for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=2) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) - pred = classifier.predict_classes(input_fn=_input_fn_builder(test_features, - None)) + pred = classifier.predict_classes( + input_fn=_input_fn_builder(test_features, None)) self.assertAllEqual(expected_prediction, np.vstack(pred)) def testPredictProba(self): @@ -150,8 +151,8 @@ class DebugClassifierTest(test.TestCase): [class_distribution for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=N_CLASSES) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) pred = classifier.predict_proba( input_fn=_input_fn_builder(test_features, None)) @@ -173,17 +174,17 @@ class DebugClassifierTest(test.TestCase): [class_distribution for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=2) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) pred = classifier.predict_proba( input_fn=_input_fn_builder(test_features, None)) self.assertAllClose(expected_prediction, np.vstack(pred), atol=0.1) - (train_features, train_labels), ( - test_features, test_labels) = _train_test_split( - [self.features, self.binary_float_labels]) + (train_features, + train_labels), (test_features, test_labels) = _train_test_split( + [self.features, self.binary_float_labels]) class_distribution = np.zeros((1, 2)) for label in train_labels: @@ -194,8 +195,8 @@ class DebugClassifierTest(test.TestCase): [class_distribution for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=2) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) pred = classifier.predict_proba( input_fn=_input_fn_builder(test_features, None)) @@ -232,13 +233,12 @@ class DebugClassifierTest(test.TestCase): def _input_fn(): iris = test_data.prepare_iris_data_for_logistic_regression() return { - 'feature': constant_op.constant( - iris.data, dtype=dtypes.float32) + 'feature': constant_op.constant(iris.data, dtype=dtypes.float32) }, constant_op.constant( iris.target, shape=[100], dtype=dtypes.int32) - classifier = debug.DebugClassifier(config=run_config.RunConfig( - tf_random_seed=1)) + classifier = debug.DebugClassifier( + config=run_config.RunConfig(tf_random_seed=1)) classifier.fit(input_fn=_input_fn, steps=5) scores = classifier.evaluate(input_fn=_input_fn, steps=1) self.assertIn('loss', scores) @@ -342,8 +342,7 @@ class DebugClassifierTest(test.TestCase): def _input_fn(): iris = base.load_iris() return { - 'feature': constant_op.constant( - iris.data, dtype=dtypes.float32) + 'feature': constant_op.constant(iris.data, dtype=dtypes.float32) }, constant_op.constant( iris.target, shape=[150], dtype=dtypes.int32) @@ -387,7 +386,9 @@ class DebugClassifierTest(test.TestCase): # Create 4 rows, one of them (y = x), three of them (y=Not(x)) # The logistic prediction should be (y = 0.25). labels = constant_op.constant([[1], [0], [0], [0]]) - features = {'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32),} + features = { + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), + } return features, labels classifier = debug.DebugClassifier(n_classes=2) @@ -404,8 +405,7 @@ class DebugClassifierTest(test.TestCase): # The logistic prediction should be (y = 0.25). labels = constant_op.constant([[1.], [0.], [0.], [0.]]) features = { - 'x': array_ops.ones( - shape=[4, 1], dtype=dtypes.float32), + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), 'w': constant_op.constant([[1.], [1.], [1.], [1.]]) } return features, labels @@ -414,8 +414,7 @@ class DebugClassifierTest(test.TestCase): # 4 rows, with different weights. labels = constant_op.constant([[1.], [0.], [0.], [0.]]) features = { - 'x': array_ops.ones( - shape=[4, 1], dtype=dtypes.float32), + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), 'w': constant_op.constant([[7.], [1.], [1.], [1.]]) } return features, labels @@ -438,8 +437,7 @@ class DebugClassifierTest(test.TestCase): # than (y=Not(x)) due to the relative higher weight of the first row. labels = constant_op.constant([[1], [0], [0], [0]]) features = { - 'x': array_ops.ones( - shape=[4, 1], dtype=dtypes.float32), + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), 'w': constant_op.constant([[100.], [3.], [2.], [2.]]) } return features, labels @@ -448,8 +446,7 @@ class DebugClassifierTest(test.TestCase): # Create 4 rows (y = x) labels = constant_op.constant([[1], [1], [1], [1]]) features = { - 'x': array_ops.ones( - shape=[4, 1], dtype=dtypes.float32), + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), 'w': constant_op.constant([[1.], [1.], [1.], [1.]]) } return features, labels @@ -469,8 +466,7 @@ class DebugClassifierTest(test.TestCase): features = { 'x': input_lib.limit_epochs( - array_ops.ones( - shape=[4, 1], dtype=dtypes.float32), + array_ops.ones(shape=[4, 1], dtype=dtypes.float32), num_epochs=num_epochs), } return features, labels @@ -578,12 +574,11 @@ class DebugClassifierTest(test.TestCase): language = feature_column.sparse_column_with_hash_bucket('language', 100) feature_columns = [ feature_column.real_valued_column('age'), - feature_column.embedding_column( - language, dimension=1) + feature_column.embedding_column(language, dimension=1) ] - classifier = debug.DebugClassifier(config=run_config.RunConfig( - tf_random_seed=1)) + classifier = debug.DebugClassifier( + config=run_config.RunConfig(tf_random_seed=1)) classifier.fit(input_fn=input_fn, steps=5) def default_input_fn(unused_estimator, examples): @@ -614,8 +609,8 @@ class DebugRegressorTest(test.TestCase): classifier.fit( input_fn=_input_fn_builder(train_features, train_labels), steps=50) - pred = classifier.predict_scores(input_fn=_input_fn_builder(test_features, - None)) + pred = classifier.predict_scores( + input_fn=_input_fn_builder(test_features, None)) self.assertAllClose(expected_prediction, np.vstack(pred), atol=0.1) def testExperimentIntegration(self): @@ -698,7 +693,9 @@ class DebugRegressorTest(test.TestCase): # Create 4 rows, one of them (y = x), three of them (y=Not(x)) # The algorithm should learn (y = 0.25). labels = constant_op.constant([[1.], [0.], [0.], [0.]]) - features = {'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32),} + features = { + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), + } return features, labels regressor = debug.DebugRegressor( @@ -853,5 +850,6 @@ class DebugRegressorTest(test.TestCase): predictions2 = list(regressor2.predict_scores(input_fn=predict_input_fn)) self.assertAllClose(predictions, predictions2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 12f9bba531a296a00d17956b8ce32e5d7dead380..2bd57597c2e9444b51b1dacfbe4180b443c95a3d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -1224,7 +1224,7 @@ class DNNRegressorTest(test.TestCase): self, predictions, expected_shape): predictions_nparray = np.array(predictions) self.assertAllEqual(expected_shape, predictions_nparray.shape) - self.assertTrue(np.issubdtype(predictions_nparray.dtype, np.float)) + self.assertTrue(np.issubdtype(predictions_nparray.dtype, np.floating)) def testPredict_AsIterableFalse(self): """Tests predict method with as_iterable=False.""" diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 50c74add86fcf62c738e81426bfaf842fbac2b4e..4b63e08ab3372849309ee5d28d754de82e9632f4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Base Estimator class.""" from __future__ import absolute_import @@ -76,7 +75,6 @@ from tensorflow.python.util import compat from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect - AS_ITERABLE_DATE = '2016-09-15' AS_ITERABLE_INSTRUCTIONS = ( 'The default behavior of predict() is changing. The default value for\n' @@ -213,7 +211,7 @@ def _get_replica_device_setter(config): 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', 'MutableHashTableV2', 'MutableHashTableOfTensors', 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable', - 'MutableDenseHashTableV2' + 'MutableDenseHashTableV2', 'VarHandleOp' ] if config.task_type: @@ -223,8 +221,11 @@ def _get_replica_device_setter(config): if config.num_ps_replicas > 0: return device_setter.replica_device_setter( - ps_tasks=config.num_ps_replicas, worker_device=worker_device, - merge_devices=True, ps_ops=ps_ops, cluster=config.cluster_spec) + ps_tasks=config.num_ps_replicas, + worker_device=worker_device, + merge_devices=True, + ps_ops=ps_ops, + cluster=config.cluster_spec) else: return None @@ -284,10 +285,10 @@ def _make_metrics_ops(metrics, features, labels, predictions): raise ValueError('Invalid metric for {}. It returned a tuple with ' 'len {}, expected 2.'.format(name, len(name))) if not isinstance(predictions, dict): - raise ValueError( - 'Metrics passed provide (name, prediction), ' - 'but predictions are not dict. ' - 'Metrics: %s, Predictions: %s.' % (metrics, predictions)) + raise ValueError('Metrics passed provide (name, prediction), ' + 'but predictions are not dict. ' + 'Metrics: %s, Predictions: %s.' % (metrics, + predictions)) # Here are two options: labels are single Tensor or a dict. if isinstance(labels, dict) and name[1] in labels: # If labels are dict and the prediction name is in it, apply metric. @@ -298,10 +299,10 @@ def _make_metrics_ops(metrics, features, labels, predictions): else: # Single head metrics. if isinstance(predictions, dict): - raise ValueError( - 'Metrics passed provide only name, no prediction, ' - 'but predictions are dict. ' - 'Metrics: %s, Labels: %s.' % (metrics, labels_tensor_or_dict)) + raise ValueError('Metrics passed provide only name, no prediction, ' + 'but predictions are dict. ' + 'Metrics: %s, Labels: %s.' % (metrics, + labels_tensor_or_dict)) result[name] = metric(predictions, labels_tensor_or_dict) return result @@ -369,9 +370,8 @@ def _write_dict_to_summary(output_dir, dictionary, current_global_step): logging.info( 'Summary for np.ndarray is not visible in Tensorboard by default. ' 'Consider using a Tensorboard plugin for visualization (see ' - 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md ' # pylint:disable=line-too-long - 'for more information).' - ) + 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md' + ' for more information).') else: logging.warn( 'Skipping summary for %s, must be a float, np.float32, np.int64, ' @@ -385,8 +385,8 @@ GraphRewriteSpec = collections.namedtuple('GraphRewriteSpec', ['tags', 'transforms']) -class BaseEstimator( - sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable): +class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, + trainable.Trainable): """Abstract BaseEstimator class to train and evaluate TensorFlow models. Users should not instantiate or subclass this class. Instead, use an @@ -428,7 +428,7 @@ class BaseEstimator( # necessary. # pylint: disable=g-doc-exception raise ValueError( - "model_dir are set both in constructor and RunConfig, but with " + 'model_dir are set both in constructor and RunConfig, but with ' "different values. In constructor: '{}', in RunConfig: " "'{}' ".format(model_dir, self._config.model_dir)) # pylint: enable=g-doc-exception @@ -457,12 +457,16 @@ class BaseEstimator( # TODO(wicke): make RunConfig immutable, and then return it without a copy. return copy.deepcopy(self._config) - @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), - ('y', None), ('batch_size', None) - ) - def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, - monitors=None, max_steps=None): + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, + ('x', None), ('y', None), ('batch_size', None)) + def fit(self, + x=None, + y=None, + input_fn=None, + steps=None, + batch_size=None, + monitors=None, + max_steps=None): # pylint: disable=g-doc-args,g-doc-return-or-yield """See `Trainable`. @@ -494,13 +498,15 @@ class BaseEstimator( logging.info('Loss for final step: %s.', loss) return self - @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), - ('y', None), ('batch_size', None) - ) - def partial_fit( - self, x=None, y=None, input_fn=None, steps=1, batch_size=None, - monitors=None): + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, + ('x', None), ('y', None), ('batch_size', None)) + def partial_fit(self, + x=None, + y=None, + input_fn=None, + steps=1, + batch_size=None, + monitors=None): """Incremental fit on a batch of samples. This method is expected to be called several times consecutively @@ -536,13 +542,16 @@ class BaseEstimator( """ logging.warning('The current implementation of partial_fit is not optimized' ' for use in a loop. Consider using fit() instead.') - return self.fit(x=x, y=y, input_fn=input_fn, steps=steps, - batch_size=batch_size, monitors=monitors) + return self.fit( + x=x, + y=y, + input_fn=input_fn, + steps=steps, + batch_size=batch_size, + monitors=monitors) - @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), - ('y', None), ('batch_size', None) - ) + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, + ('x', None), ('y', None), ('batch_size', None)) def evaluate(self, x=None, y=None, @@ -584,13 +593,15 @@ class BaseEstimator( eval_results.update({'global_step': global_step}) return eval_results - @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), - ('batch_size', None), ('as_iterable', True) - ) - def predict( - self, x=None, input_fn=None, batch_size=None, outputs=None, - as_iterable=True): + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, + ('x', None), ('batch_size', None), ('as_iterable', True)) + def predict(self, + x=None, + input_fn=None, + batch_size=None, + outputs=None, + as_iterable=True, + iterate_batches=False): """Returns predictions for given features. Args: @@ -606,6 +617,9 @@ class BaseEstimator( for each example until inputs are exhausted. Note: The inputs must terminate if you want the iterable to terminate (e.g. be sure to pass num_epochs=1 if you are using something like read_batch_features). + iterate_batches: If True, yield the whole batch at once instead of + decomposing the batch into individual samples. Only relevant when + as_iterable is True. Returns: A numpy array of predicted classes or regression values if the @@ -625,7 +639,8 @@ class BaseEstimator( input_fn=input_fn, feed_fn=feed_fn, outputs=outputs, - as_iterable=as_iterable) + as_iterable=as_iterable, + iterate_batches=iterate_batches) def get_variable_value(self, name): """Returns value of the variable given by name. @@ -651,16 +666,17 @@ class BaseEstimator( return self._model_dir @deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') - def export(self, - export_dir, - input_fn=export._default_input_fn, # pylint: disable=protected-access - input_feature_key=None, - use_deprecated_input_fn=True, - signature_fn=None, - prediction_key=None, - default_batch_size=1, - exports_to_keep=None, - checkpoint_path=None): + def export( + self, + export_dir, + input_fn=export._default_input_fn, # pylint: disable=protected-access + input_feature_key=None, + use_deprecated_input_fn=True, + signature_fn=None, + prediction_key=None, + default_batch_size=1, + exports_to_keep=None, + checkpoint_path=None): """Exports inference graph into given dir. Args: @@ -798,8 +814,8 @@ class BaseEstimator( logging.debug('Setting feature info to %s.', str(self._features_info)) if labels is not None: if self._labels_info is not None: - logging.debug('Given labels: %s, required signatures: %s.', - str(labels), str(self._labels_info)) + logging.debug('Given labels: %s, required signatures: %s.', str(labels), + str(self._labels_info)) if not tensor_signature.tensors_compatible(labels, self._labels_info): raise ValueError('Labels are incompatible with given information. ' 'Given labels: %s, required signatures: %s.' % @@ -850,13 +866,13 @@ class BaseEstimator( if not checkpoint_path: latest_path = saver.latest_checkpoint(self._model_dir) if not latest_path: - raise NotFittedError("Couldn't find trained model at %s." - % self._model_dir) + raise NotFittedError( + "Couldn't find trained model at %s." % self._model_dir) checkpoint_path = latest_path # Setup output directory. - eval_dir = os.path.join(self._model_dir, 'eval' if not name else - 'eval_' + name) + eval_dir = os.path.join(self._model_dir, 'eval' + if not name else 'eval_' + name) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) @@ -879,8 +895,7 @@ class BaseEstimator( 'Use steps=None if intended.') if steps: hooks.append( - evaluation.StopAfterNEvalsHook( - steps, log_progress=log_progress)) + evaluation.StopAfterNEvalsHook(steps, log_progress=log_progress)) global_step_key = 'global_step' while global_step_key in eval_dict: @@ -916,8 +931,8 @@ class BaseEstimator( # Check that model has been trained. checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: - raise NotFittedError("Couldn't find trained model at %s." - % self._model_dir) + raise NotFittedError( + "Couldn't find trained model at %s." % self._model_dir) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) @@ -979,7 +994,8 @@ class BaseEstimator( existing_keys = predictions.keys() predictions = { key: value - for key, value in six.iteritems(predictions) if key in outputs + for key, value in six.iteritems(predictions) + if key in outputs } if not predictions: raise ValueError('Expected to run at least one output from %s, ' @@ -1045,8 +1061,7 @@ class BaseEstimator( chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, - config=self._session_config - ) as mon_sess: + config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss]) @@ -1137,8 +1152,7 @@ class Estimator(BaseEstimator): if params is not None and 'params' not in model_fn_args: raise ValueError('Estimator\'s model_fn (%s) does not have a params ' 'argument, but params (%s) were passed to the ' - 'Estimator\'s constructor.' % - (model_fn, params)) + 'Estimator\'s constructor.' % (model_fn, params)) if params is None and 'params' in model_fn_args: logging.warning('Estimator\'s model_fn (%s) includes params ' 'argument, but params are not passed to Estimator.', @@ -1192,8 +1206,9 @@ class Estimator(BaseEstimator): # Custom metrics should overwrite defaults. if metrics: - model_fn_ops.eval_metric_ops.update(_make_metrics_ops( - metrics, features, labels, model_fn_ops.predictions)) + model_fn_ops.eval_metric_ops.update( + _make_metrics_ops(metrics, features, labels, + model_fn_ops.predictions)) return model_fn_ops @@ -1238,8 +1253,8 @@ class Estimator(BaseEstimator): Raises: ValueError: if `metrics` don't match `labels`. """ - model_fn_ops = self._call_model_fn( - features, labels, model_fn_lib.ModeKeys.EVAL, metrics) + model_fn_ops = self._call_model_fn(features, labels, + model_fn_lib.ModeKeys.EVAL, metrics) if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops: model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = ( @@ -1263,14 +1278,16 @@ class Estimator(BaseEstimator): self._labels_info) return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER) - def export_savedmodel( - self, export_dir_base, serving_input_fn, - default_output_alternative_key=None, - assets_extra=None, - as_text=False, - checkpoint_path=None, - graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),), - strip_default_attrs=False): + def export_savedmodel(self, + export_dir_base, + serving_input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + checkpoint_path=None, + graph_rewrite_specs=(GraphRewriteSpec( + (tag_constants.SERVING,), ()),), + strip_default_attrs=False): # pylint: disable=line-too-long """Exports inference graph as a SavedModel into given dir. @@ -1297,7 +1314,8 @@ class Estimator(BaseEstimator): default serving tag ("serve") and no rewriting. strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see - [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + [Stripping Default-Valued + Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: The string path to the exported directory. @@ -1313,8 +1331,8 @@ class Estimator(BaseEstimator): # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: - raise NotFittedError("Couldn't find trained model at %s." - % self._model_dir) + raise NotFittedError( + "Couldn't find trained model at %s." % self._model_dir) export_dir = saved_model_export_utils.get_timestamped_export_dir( export_dir_base) @@ -1348,10 +1366,10 @@ class Estimator(BaseEstimator): saved_model_export_utils.get_output_alternatives( model_fn_ops, default_output_alternative_key)) - init_op = control_flow_ops.group( - variables.local_variables_initializer(), - resources.initialize_resources(resources.shared_resources()), - lookup_ops.tables_initializer()) + init_op = control_flow_ops.group(variables.local_variables_initializer(), + resources.initialize_resources( + resources.shared_resources()), + lookup_ops.tables_initializer()) # Build the SignatureDefs from all pairs of input and output alternatives signature_def_map = saved_model_export_utils.build_all_signature_defs( @@ -1381,10 +1399,10 @@ class Estimator(BaseEstimator): # TODO(soergel): switch to main_op or otherwise update when dust settles builder.add_meta_graph_and_variables( - session, untransformed_tags, + session, + untransformed_tags, signature_def_map=signature_def_map, - assets_collection=ops.get_collection( - ops.GraphKeys.ASSET_FILEPATHS), + assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=init_op, strip_default_attrs=strip_default_attrs) @@ -1395,12 +1413,16 @@ class Estimator(BaseEstimator): if graph_rewrite_specs[1:]: # Prepare the input_names and output_names needed for the # meta_graph_transform call below. - input_names = [tensor.name - for input_dict in input_alternatives.values() - for tensor in input_dict.values()] - output_names = [tensor.name - for output_alternative in output_alternatives.values() - for tensor in output_alternative[1].values()] + input_names = [ + tensor.name + for input_dict in input_alternatives.values() + for tensor in input_dict.values() + ] + output_names = [ + tensor.name + for output_alternative in output_alternatives.values() + for tensor in output_alternative[1].values() + ] # Write the additional MetaGraphDefs for graph_rewrite_spec in graph_rewrite_specs[1:]: @@ -1419,11 +1441,11 @@ class Estimator(BaseEstimator): # Add the extra assets if assets_extra: - assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), - compat.as_bytes('assets.extra')) + assets_extra_path = os.path.join( + compat.as_bytes(temp_export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): - dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), - compat.as_bytes(dest_relative)) + dest_absolute = os.path.join( + compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) gfile.MakeDirs(dest_path) gfile.Copy(source, dest_absolute) @@ -1443,25 +1465,36 @@ class SKCompat(sklearn.BaseEstimator): def fit(self, x, y, batch_size=128, steps=None, max_steps=None, monitors=None): - input_fn, feed_fn = _get_input_fn(x, y, input_fn=None, feed_fn=None, - batch_size=batch_size, shuffle=True, - epochs=None) + input_fn, feed_fn = _get_input_fn( + x, + y, + input_fn=None, + feed_fn=None, + batch_size=batch_size, + shuffle=True, + epochs=None) all_monitors = [] if feed_fn: all_monitors = [basic_session_run_hooks.FeedFnHook(feed_fn)] if monitors: all_monitors.extend(monitors) - self._estimator.fit(input_fn=input_fn, - steps=steps, - max_steps=max_steps, - monitors=all_monitors) + self._estimator.fit( + input_fn=input_fn, + steps=steps, + max_steps=max_steps, + monitors=all_monitors) return self def score(self, x, y, batch_size=128, steps=None, metrics=None, name=None): - input_fn, feed_fn = _get_input_fn(x, y, input_fn=None, - feed_fn=None, batch_size=batch_size, - shuffle=False, epochs=1) + input_fn, feed_fn = _get_input_fn( + x, + y, + input_fn=None, + feed_fn=None, + batch_size=batch_size, + shuffle=False, + epochs=1) if metrics is not None and not isinstance(metrics, dict): raise ValueError('Metrics argument should be None or dict. ' 'Got %s.' % metrics) @@ -1477,8 +1510,13 @@ class SKCompat(sklearn.BaseEstimator): def predict(self, x, batch_size=128, outputs=None): input_fn, feed_fn = _get_input_fn( - x, None, input_fn=None, feed_fn=None, batch_size=batch_size, - shuffle=False, epochs=1) + x, + None, + input_fn=None, + feed_fn=None, + batch_size=batch_size, + shuffle=False, + epochs=1) results = list( self._estimator._infer_model( input_fn=input_fn, @@ -1489,7 +1527,6 @@ class SKCompat(sklearn.BaseEstimator): if not isinstance(results[0], dict): return np.concatenate([output for output in results], axis=0) return { - key: np.concatenate( - [output[key] for output in results], axis=0) + key: np.concatenate([output[key] for output in results], axis=0) for key in results[0] } diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py index 9d7c1a099aa4be64ca0296fa5b870597dabec7b4..d4a46b41d0c93ef58d5db8c433cbf348fec10f5e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py @@ -41,7 +41,6 @@ from tensorflow.python.platform import test from tensorflow.python.training import input as input_lib from tensorflow.python.training import queue_runner_impl - _BOSTON_INPUT_DIM = 13 _IRIS_INPUT_DIM = 4 @@ -93,8 +92,8 @@ def boston_eval_fn(): constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM]) labels = array_ops.reshape( constant_op.constant(boston.target), [n_examples, 1]) - return array_ops.concat([features, features], 0), array_ops.concat( - [labels, labels], 0) + return array_ops.concat([features, features], + 0), array_ops.concat([labels, labels], 0) def extract(data, key): @@ -129,7 +128,10 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return prediction, loss, train_op @@ -139,7 +141,10 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -150,7 +155,10 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction @@ -173,7 +181,9 @@ class EstimatorInputTest(test.TestCase): scores = est.evaluate( x=boston_input, y=float64_target, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) del est # Create another estimator object with the same output dir. est2 = estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir) @@ -182,7 +192,9 @@ class EstimatorInputTest(test.TestCase): scores2 = est2.evaluate( x=boston_input, y=float64_target, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) self.assertAllClose(scores2['MSE'], scores['MSE']) predictions = np.array(list(est2.predict(x=boston_input))) other_score = _sklearn.mean_squared_error(predictions, @@ -197,7 +209,9 @@ class EstimatorInputTest(test.TestCase): scores = est.score( x=boston.data, y=float64_labels, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) predictions = np.array(list(est.predict(x=boston.data))) other_score = _sklearn.mean_squared_error(predictions, boston.target) self.assertAllClose(scores['MSE'], other_score) @@ -213,7 +227,9 @@ class EstimatorInputTest(test.TestCase): scores = est.evaluate( x=boston_input, y=float64_target, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) predictions = np.array(list(est.predict(x=boston_input))) other_score = _sklearn.mean_squared_error(predictions, boston.target) self.assertAllClose(other_score, scores['MSE']) @@ -228,14 +244,15 @@ class EstimatorInputTest(test.TestCase): scores = est.score( x=iris.data, y=iris.target, - metrics={('accuracy', 'class'): metric_ops.streaming_accuracy}) + metrics={ + ('accuracy', 'class'): metric_ops.streaming_accuracy + }) predictions = est.predict(x=iris.data) predictions_class = est.predict(x=iris.data, outputs=['class'])['class'] self.assertEqual(predictions['prob'].shape[0], iris.target.shape[0]) self.assertAllClose(predictions['class'], predictions_class) - self.assertAllClose( - predictions['class'], np.argmax( - predictions['prob'], axis=1)) + self.assertAllClose(predictions['class'], + np.argmax(predictions['prob'], axis=1)) other_score = _sklearn.accuracy_score(iris.target, predictions['class']) self.assertAllClose(scores['accuracy'], other_score) self.assertTrue('global_step' in scores) @@ -250,17 +267,18 @@ class EstimatorInputTest(test.TestCase): scores = est.evaluate( x=iris_data, y=iris_target, - metrics={('accuracy', 'class'): metric_ops.streaming_accuracy}) + metrics={ + ('accuracy', 'class'): metric_ops.streaming_accuracy + }) predictions = list(est.predict(x=iris_data)) predictions_class = list(est.predict(x=iris_data, outputs=['class'])) self.assertEqual(len(predictions), iris.target.shape[0]) classes_batch = np.array([p['class'] for p in predictions]) self.assertAllClose(classes_batch, np.array([p['class'] for p in predictions_class])) - self.assertAllClose( - classes_batch, - np.argmax( - np.array([p['prob'] for p in predictions]), axis=1)) + self.assertAllClose(classes_batch, + np.argmax( + np.array([p['prob'] for p in predictions]), axis=1)) other_score = _sklearn.accuracy_score(iris.target, classes_batch) self.assertAllClose(other_score, scores['accuracy']) self.assertTrue('global_step' in scores) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 5f682838b7afadec7a54df782cb5b89ac6746659..d81a534b79bc90fe91ffd3cb97a7865a7cb4c2a9 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -111,8 +111,8 @@ def boston_eval_fn(): constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM]) labels = array_ops.reshape( constant_op.constant(boston.target), [n_examples, 1]) - return array_ops.concat([features, features], 0), array_ops.concat( - [labels, labels], 0) + return array_ops.concat([features, features], + 0), array_ops.concat([labels, labels], 0) def extract(data, key): @@ -147,7 +147,10 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return prediction, loss, train_op @@ -157,7 +160,10 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -168,7 +174,10 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction @@ -184,14 +193,12 @@ def _build_estimator_for_export_tests(tmpdir): def _input_fn(): iris = base.load_iris() return { - 'feature': constant_op.constant( - iris.data, dtype=dtypes.float32) + 'feature': constant_op.constant(iris.data, dtype=dtypes.float32) }, constant_op.constant( iris.target, shape=[150], dtype=dtypes.int32) feature_columns = [ - feature_column_lib.real_valued_column( - 'feature', dimension=4) + feature_column_lib.real_valued_column('feature', dimension=4) ] est = linear.LinearRegressor(feature_columns) @@ -291,8 +298,8 @@ class CheckCallsMonitor(monitors_lib.BaseMonitor): self.begin_calls == self.expect_calls) -def _model_fn_ops( - expected_features, expected_labels, actual_features, actual_labels, mode): +def _model_fn_ops(expected_features, expected_labels, actual_features, + actual_labels, mode): assert_ops = tuple([ check_ops.assert_equal( expected_features[k], actual_features[k], name='assert_%s' % k) @@ -310,11 +317,11 @@ def _model_fn_ops( def _make_input_fn(features, labels): + def _input_fn(): - return { - k: constant_op.constant(v) - for k, v in six.iteritems(features) - }, constant_op.constant(labels) + return {k: constant_op.constant(v) + for k, v in six.iteritems(features)}, constant_op.constant(labels) + return _input_fn @@ -369,11 +376,13 @@ class EstimatorModelFnTest(test.TestCase): self.assertEqual(expected_params, params) self.assertTrue(config.i_am_test) return _model_fn_ops(features, labels, arg0, arg1, mode) + partial_model_fn = functools.partial( _model_fn, foo=expected_foo, bar=expected_bar) est = estimator.Estimator( - model_fn=partial_model_fn, params=expected_params, + model_fn=partial_model_fn, + params=expected_params, config=expected_config) self.assertEqual(0, model_fn_call_count[0]) est.fit(input_fn=_make_input_fn(features, labels), steps=1) @@ -382,7 +391,12 @@ class EstimatorModelFnTest(test.TestCase): def testModelFnWithModelDir(self): expected_param = {'some_param': 'some_value'} expected_model_dir = tempfile.mkdtemp() - def _argument_checker(features, labels, mode, params, config=None, + + def _argument_checker(features, + labels, + mode, + params, + config=None, model_dir=None): _, _, _ = features, labels, config self.assertEqual(model_fn.ModeKeys.TRAIN, mode) @@ -390,9 +404,11 @@ class EstimatorModelFnTest(test.TestCase): self.assertEqual(model_dir, expected_model_dir) return (constant_op.constant(0.), constant_op.constant(0.), training_util.get_global_step().assign_add(1)) - est = estimator.Estimator(model_fn=_argument_checker, - params=expected_param, - model_dir=expected_model_dir) + + est = estimator.Estimator( + model_fn=_argument_checker, + params=expected_param, + model_dir=expected_model_dir) est.fit(input_fn=boston_input_fn, steps=1) def testInvalidModelFn_no_train_op(self): @@ -447,8 +463,7 @@ class EstimatorModelFnTest(test.TestCase): est.predict(input_fn=boston_input_fn) with self.assertRaisesRegexp(ValueError, 'Missing prediction'): est.predict( - input_fn=functools.partial( - boston_input_fn, num_epochs=1), + input_fn=functools.partial(boston_input_fn, num_epochs=1), as_iterable=True) def testModelFnScaffoldInTraining(self): @@ -498,15 +513,17 @@ class EstimatorModelFnTest(test.TestCase): self.assertTrue(self.mock_saver.restore.called) est.predict(input_fn=input_fn) self.assertTrue(self.mock_saver.restore.called) + def serving_input_fn(): - serialized_tf_example = array_ops.placeholder(dtype=dtypes.string, - shape=[None], - name='input_example_tensor') + serialized_tf_example = array_ops.placeholder( + dtype=dtypes.string, shape=[None], name='input_example_tensor') features, labels = input_fn() - return input_fn_utils.InputFnOps( - features, labels, {'examples': serialized_tf_example}) + return input_fn_utils.InputFnOps(features, labels, { + 'examples': serialized_tf_example + }) - est.export_savedmodel(os.path.join(est.model_dir, 'export'), serving_input_fn) + est.export_savedmodel( + os.path.join(est.model_dir, 'export'), serving_input_fn) self.assertTrue(self.mock_saver.restore.called) @@ -550,33 +567,28 @@ class EstimatorTest(test.TestCase): def testRunConfigModelDir(self): config = run_config.RunConfig(model_dir='test_dir') - est = estimator.Estimator(model_fn=linear_model_fn, - config=config) + est = estimator.Estimator(model_fn=linear_model_fn, config=config) self.assertEqual('test_dir', est.config.model_dir) self.assertEqual('test_dir', est.model_dir) def testModelDirAndRunConfigModelDir(self): config = run_config.RunConfig(model_dir='test_dir') - est = estimator.Estimator(model_fn=linear_model_fn, - config=config, - model_dir='test_dir') + est = estimator.Estimator( + model_fn=linear_model_fn, config=config, model_dir='test_dir') self.assertEqual('test_dir', est.config.model_dir) with self.assertRaisesRegexp( - ValueError, - 'model_dir are set both in constructor and RunConfig, ' + ValueError, 'model_dir are set both in constructor and RunConfig, ' 'but with different'): - estimator.Estimator(model_fn=linear_model_fn, - config=config, - model_dir='different_dir') + estimator.Estimator( + model_fn=linear_model_fn, config=config, model_dir='different_dir') def testModelDirIsCopiedToRunConfig(self): config = run_config.RunConfig() self.assertIsNone(config.model_dir) - est = estimator.Estimator(model_fn=linear_model_fn, - model_dir='test_dir', - config=config) + est = estimator.Estimator( + model_fn=linear_model_fn, model_dir='test_dir', config=config) self.assertEqual('test_dir', est.config.model_dir) self.assertEqual('test_dir', est.model_dir) @@ -656,25 +668,27 @@ class EstimatorTest(test.TestCase): boston = base.load_boston() output_dir = tempfile.mkdtemp() est = estimator.SKCompat( - estimator.Estimator( - model_fn=linear_model_fn, model_dir=output_dir)) + estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir)) float64_labels = boston.target.astype(np.float64) est.fit(x=boston.data, y=float64_labels, steps=50) scores = est.score( x=boston.data, y=float64_labels, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) del est # Create another estimator object with the same output dir. est2 = estimator.SKCompat( - estimator.Estimator( - model_fn=linear_model_fn, model_dir=output_dir)) + estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir)) # Check we can evaluate and predict. scores2 = est2.score( x=boston.data, y=float64_labels, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) self.assertAllClose(scores['MSE'], scores2['MSE']) predictions = np.array(list(est2.predict(x=boston.data))) other_score = _sklearn.mean_squared_error(predictions, float64_labels) @@ -685,14 +699,15 @@ class EstimatorTest(test.TestCase): scores3 = est2.score( x=boston.data, y=float64_labels, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) self.assertLess(scores3['MSE'], scores['MSE']) def test_checkpoint_contains_relative_paths(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator( - model_dir=tmpdir, - model_fn=linear_model_fn_with_model_fn_ops) + model_dir=tmpdir, model_fn=linear_model_fn_with_model_fn_ops) est.fit(input_fn=boston_input_fn, steps=5) checkpoint_file_content = file_io.read_file_to_string( @@ -700,22 +715,20 @@ class EstimatorTest(test.TestCase): ckpt = checkpoint_state_pb2.CheckpointState() text_format.Merge(checkpoint_file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') - self.assertAllEqual( - ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) + self.assertAllEqual(['model.ckpt-1', 'model.ckpt-5'], + ckpt.all_model_checkpoint_paths) def test_train_save_copy_reload(self): tmpdir = tempfile.mkdtemp() model_dir1 = os.path.join(tmpdir, 'model_dir1') est1 = estimator.Estimator( - model_dir=model_dir1, - model_fn=linear_model_fn_with_model_fn_ops) + model_dir=model_dir1, model_fn=linear_model_fn_with_model_fn_ops) est1.fit(input_fn=boston_input_fn, steps=5) model_dir2 = os.path.join(tmpdir, 'model_dir2') os.renames(model_dir1, model_dir2) est2 = estimator.Estimator( - model_dir=model_dir2, - model_fn=linear_model_fn_with_model_fn_ops) + model_dir=model_dir2, model_fn=linear_model_fn_with_model_fn_ops) self.assertEqual(5, est2.get_variable_value('global_step')) est2.fit(input_fn=boston_input_fn, steps=5) self.assertEqual(10, est2.get_variable_value('global_step')) @@ -724,7 +737,9 @@ class EstimatorTest(test.TestCase): boston = base.load_boston() est = estimator.SKCompat( estimator.Estimator( - model_fn=linear_model_params_fn, params={'learning_rate': 0.01})) + model_fn=linear_model_params_fn, params={ + 'learning_rate': 0.01 + })) est.fit(x=boston.data, y=boston.target, steps=100) def testHooksNotChanged(self): @@ -824,11 +839,13 @@ class EstimatorTest(test.TestCase): def testMonitorsForFit(self): est = estimator.Estimator(model_fn=linear_model_fn) - est.fit(input_fn=boston_input_fn, - steps=21, - monitors=[CheckCallsMonitor(expect_calls=21)]) + est.fit( + input_fn=boston_input_fn, + steps=21, + monitors=[CheckCallsMonitor(expect_calls=21)]) def testHooksForEvaluate(self): + class CheckCallHook(session_run_hook.SessionRunHook): def __init__(self): @@ -874,7 +891,9 @@ class EstimatorTest(test.TestCase): est.evaluate( input_fn=boston_input_fn, steps=200, - metrics={'MSE': _streaming_mean_squared_error_histogram}) + metrics={ + 'MSE': _streaming_mean_squared_error_histogram + }) events = util_test.latest_events(est.model_dir + '/eval') output_values = {} for e in events: @@ -903,7 +922,9 @@ class EstimatorTest(test.TestCase): est.evaluate( input_fn=boston_input_fn, steps=200, - metrics={'PMT': _streaming_precition_mean_tensor}) + metrics={ + 'PMT': _streaming_precition_mean_tensor + }) events = util_test.latest_events(est.model_dir + '/eval') output_values = {} for e in events: @@ -956,8 +977,8 @@ class EstimatorTest(test.TestCase): self.assertTrue( gfile.Exists( os.path.join( - compat.as_bytes(export_dir), compat.as_bytes( - 'saved_model.pb')))) + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) self.assertTrue( gfile.Exists( os.path.join( @@ -1017,11 +1038,11 @@ class EstimatorTest(test.TestCase): self.assertTrue('input_example_tensor' in graph_ops) self.assertTrue('ParseExample/ParseExample' in graph_ops) self.assertTrue('linear/linear/feature/matmul' in graph_ops) - self.assertItemsEqual( - ['bogus_lookup', 'feature'], - [compat.as_str_any(x) for x in graph.get_collection( - constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)]) - + self.assertItemsEqual(['bogus_lookup', 'feature'], [ + compat.as_str_any(x) + for x in graph.get_collection( + constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS) + ]) # cleanup gfile.DeleteRecursively(tmpdir) @@ -1039,8 +1060,8 @@ class EstimatorTest(test.TestCase): self.assertTrue( gfile.Exists( os.path.join( - compat.as_bytes(export_dir), compat.as_bytes( - 'saved_model.pb')))) + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) self.assertTrue( gfile.Exists( os.path.join( @@ -1083,19 +1104,22 @@ class EstimatorTest(test.TestCase): export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) export_dir = est.export_savedmodel( - export_dir_base, serving_input_fn, assets_extra=assets_extra, + export_dir_base, + serving_input_fn, + assets_extra=assets_extra, graph_rewrite_specs=[ estimator.GraphRewriteSpec(['tag_1'], []), estimator.GraphRewriteSpec(['tag_2', 'tag_3'], - ['strip_unused_nodes'])]) + ['strip_unused_nodes']) + ]) self.assertTrue(gfile.Exists(export_dir_base)) self.assertTrue(gfile.Exists(export_dir)) self.assertTrue( gfile.Exists( os.path.join( - compat.as_bytes(export_dir), compat.as_bytes( - 'saved_model.pb')))) + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) self.assertTrue( gfile.Exists( os.path.join( @@ -1208,18 +1232,15 @@ class InferRealValuedColumnsTest(test.TestCase): self.assertEqual(1, len(feature_columns)) feature_column = feature_columns[0] self.assertEqual('', feature_column.name) - self.assertEqual( - { - '': - parsing_ops.FixedLenFeature( - shape=expected_shape, dtype=expected_dtype) - }, - feature_column.config) + self.assertEqual({ + '': + parsing_ops.FixedLenFeature( + shape=expected_shape, dtype=expected_dtype) + }, feature_column.config) def testInt32Input(self): feature_columns = estimator.infer_real_valued_columns_from_input( - np.ones( - shape=[7, 8], dtype=np.int32)) + np.ones(shape=[7, 8], dtype=np.int32)) self._assert_single_feature_column([8], dtypes.int32, feature_columns) def testInt32InputFn(self): @@ -1229,8 +1250,7 @@ class InferRealValuedColumnsTest(test.TestCase): def testInt64Input(self): feature_columns = estimator.infer_real_valued_columns_from_input( - np.ones( - shape=[7, 8], dtype=np.int64)) + np.ones(shape=[7, 8], dtype=np.int64)) self._assert_single_feature_column([8], dtypes.int64, feature_columns) def testInt64InputFn(self): @@ -1240,8 +1260,7 @@ class InferRealValuedColumnsTest(test.TestCase): def testFloat32Input(self): feature_columns = estimator.infer_real_valued_columns_from_input( - np.ones( - shape=[7, 8], dtype=np.float32)) + np.ones(shape=[7, 8], dtype=np.float32)) self._assert_single_feature_column([8], dtypes.float32, feature_columns) def testFloat32InputFn(self): @@ -1251,8 +1270,7 @@ class InferRealValuedColumnsTest(test.TestCase): def testFloat64Input(self): feature_columns = estimator.infer_real_valued_columns_from_input( - np.ones( - shape=[7, 8], dtype=np.float64)) + np.ones(shape=[7, 8], dtype=np.float64)) self._assert_single_feature_column([8], dtypes.float64, feature_columns) def testFloat64InputFn(self): @@ -1271,8 +1289,8 @@ class InferRealValuedColumnsTest(test.TestCase): ValueError, 'on integer or non floating types are not supported'): # pylint: disable=g-long-lambda estimator.infer_real_valued_columns_from_input_fn( - lambda: (constant_op.constant(False, shape=[7, 8], dtype=dtypes.bool), - None)) + lambda: (constant_op.constant(False, shape=[7, 8], dtype=dtypes.bool), None) + ) def testStringInput(self): with self.assertRaisesRegexp( @@ -1309,8 +1327,9 @@ class ReplicaDeviceSetterTest(test.TestCase): def testVariablesAreOnPs(self): tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}} - with test.mock.patch.dict('os.environ', - {'TF_CONFIG': json.dumps(tf_config)}): + with test.mock.patch.dict('os.environ', { + 'TF_CONFIG': json.dumps(tf_config) + }): config = run_config.RunConfig() with ops.device(estimator._get_replica_device_setter(config)): @@ -1337,14 +1356,14 @@ class ReplicaDeviceSetterTest(test.TestCase): def testMutableHashTableIsOnPs(self): tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}} - with test.mock.patch.dict('os.environ', - {'TF_CONFIG': json.dumps(tf_config)}): + with test.mock.patch.dict('os.environ', { + 'TF_CONFIG': json.dumps(tf_config) + }): config = run_config.RunConfig() with ops.device(estimator._get_replica_device_setter(config)): default_val = constant_op.constant([-1, -1], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) input_string = constant_op.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) self.assertDeviceEqual('/job:ps/task:0', table._table_ref.device) @@ -1354,8 +1373,7 @@ class ReplicaDeviceSetterTest(test.TestCase): with ops.device( estimator._get_replica_device_setter(run_config.RunConfig())): default_val = constant_op.constant([-1, -1], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) input_string = constant_op.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) self.assertDeviceEqual('', table._table_ref.device) @@ -1371,8 +1389,9 @@ class ReplicaDeviceSetterTest(test.TestCase): 'index': 3 } } - with test.mock.patch.dict('os.environ', - {'TF_CONFIG': json.dumps(tf_config)}): + with test.mock.patch.dict('os.environ', { + 'TF_CONFIG': json.dumps(tf_config) + }): config = run_config.RunConfig() with ops.device(estimator._get_replica_device_setter(config)): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py index 8131e0fde6fea5501cacc4714f53ed8d867ca70f..2113fae3940f14c8ca07e5f76986408ae8a33831 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py @@ -72,9 +72,11 @@ class FeatureEngineeringFunctionTest(test.TestCase): # predictions = transformed_x (9) self.assertEqual(9., prediction) metrics = estimator.evaluate( - input_fn=input_fn, steps=1, - metrics={"label": - metric_spec.MetricSpec(lambda predictions, labels: labels)}) + input_fn=input_fn, + steps=1, + metrics={ + "label": metric_spec.MetricSpec(lambda predictions, labels: labels) + }) # labels = transformed_y (99) self.assertEqual(99., metrics["label"]) @@ -82,10 +84,10 @@ class FeatureEngineeringFunctionTest(test.TestCase): def input_fn(): return { - "x": constant_op.constant(["9."]) - }, { - "y": constant_op.constant(["99."]) - } + "x": constant_op.constant(["9."]) + }, { + "y": constant_op.constant(["99."]) + } def feature_engineering_fn(features, labels): # Github #12205: raise a TypeError if called twice. @@ -104,15 +106,17 @@ class FeatureEngineeringFunctionTest(test.TestCase): return predictions, loss, update_global_step estimator = estimator_lib.Estimator( - model_fn=model_fn, feature_engineering_fn=feature_engineering_fn) + model_fn=model_fn, feature_engineering_fn=feature_engineering_fn) estimator.fit(input_fn=input_fn, steps=1) prediction = next(estimator.predict(input_fn=input_fn, as_iterable=True)) # predictions = transformed_x (9) self.assertEqual(9., prediction) metrics = estimator.evaluate( - input_fn=input_fn, steps=1, - metrics={"label": - metric_spec.MetricSpec(lambda predictions, labels: labels)}) + input_fn=input_fn, + steps=1, + metrics={ + "label": metric_spec.MetricSpec(lambda predictions, labels: labels) + }) # labels = transformed_y (99) self.assertEqual(99., metrics["label"]) @@ -150,12 +154,10 @@ class FeatureEngineeringFunctionTest(test.TestCase): # predictions = x prediction_with_fe_fn = next( - estimator_with_fe_fn.predict( - input_fn=input_fn, as_iterable=True)) + estimator_with_fe_fn.predict(input_fn=input_fn, as_iterable=True)) self.assertEqual(9., prediction_with_fe_fn) prediction_without_fe_fn = next( - estimator_without_fe_fn.predict( - input_fn=input_fn, as_iterable=True)) + estimator_without_fe_fn.predict(input_fn=input_fn, as_iterable=True)) self.assertEqual(1., prediction_without_fe_fn) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index bc0e6fc0091c9b5419ab526855b404eb4a927e97..9b124b2c19f16bbc9b2afeadb82a32006e1a0ae9 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -181,7 +181,8 @@ def regression_head(label_name=None, weight_column_name=None, label_dimension=1, enable_centered_bias=False, - head_name=None): + head_name=None, + link_fn=None): """Creates a `Head` for linear regression. Args: @@ -199,6 +200,8 @@ def regression_head(label_name=None, head_name: name of the head. If provided, predictions, summary and metrics keys will be suffixed by `"/" + head_name` and the default variable scope will be `head_name`. + link_fn: link function to convert logits to predictions. If provided, + this link function will be used instead of identity. Returns: An instance of `Head` for linear regression. @@ -210,7 +213,7 @@ def regression_head(label_name=None, enable_centered_bias=enable_centered_bias, head_name=head_name, loss_fn=_mean_squared_loss, - link_fn=array_ops.identity) + link_fn=(link_fn if link_fn is not None else array_ops.identity)) def poisson_regression_head(label_name=None, diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 3881bf533d642bef68fa9ab4ba908bbb8f7f8091..6d5da81b4c2087fb9c5307902e452a6220a17cd0 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -33,6 +33,7 @@ from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses as losses_lib from tensorflow.python.platform import test @@ -153,6 +154,25 @@ class RegressionHeadTest(test.TestCase): _assert_no_variables(self) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) + def testRegressionWithLogitFn(self): + head = head_lib.regression_head(link_fn=math_ops.square) + def _assert_preditions(test_case, expected_predictions, model_fn_ops): + variables.initialize_local_variables().run() + test_case.assertAllClose(expected_predictions, + model_fn_ops.predictions["scores"].eval()) + with ops.Graph().as_default(), session.Session(): + model_fn_ops = head.create_model_fn_ops( + {}, + labels=((0.,), (1.,), (1.,)), + mode=model_fn.ModeKeys.TRAIN, + train_op_fn=head_lib.no_op_train_fn, + logits=((1.,), (1.,), (3.,))) + self._assert_output_alternatives(model_fn_ops) + _assert_summary_tags(self, ["loss"]) + _assert_no_variables(self) + _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) + _assert_preditions(self, ([1.0, 1.0, 9.0]), model_fn_ops) + def testRegressionWithInvalidLogits(self): head = head_lib.regression_head() with ops.Graph().as_default(), session.Session(): @@ -342,7 +362,7 @@ class MultiLabelHeadTest(test.TestCase): "auc_precision_recall": 0.166667, "auc_precision_recall/class0": 0, "auc_precision_recall/class1": 0., - "auc_precision_recall/class2": 1., + "auc_precision_recall/class2": 0.49999, "labels/actual_label_mean/class0": self._labels[0][0], "labels/actual_label_mean/class1": self._labels[0][1], "labels/actual_label_mean/class2": self._labels[0][2], @@ -728,7 +748,7 @@ class BinaryClassificationHeadTest(test.TestCase): "accuracy/baseline_label_mean": label_mean, "accuracy/threshold_0.500000_mean": 1. / 2, "auc": 1. / 2, - "auc_precision_recall": 0.749999, + "auc_precision_recall": 0.25, "labels/actual_label_mean": label_mean, "labels/prediction_mean": .731059, # softmax "loss": expected_loss, diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py index 656d68b76888d9319c0b9be481f9b0478ac4314c..ac2d10011e222eb9c534d7fbae3c0cb5f4820945 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py @@ -57,7 +57,10 @@ def _logistic_regression_model_fn(features, labels, mode): predictions = math_ops.sigmoid(logits) loss = losses.sigmoid_cross_entropy(labels, logits) train_op = optimizers.optimize_loss( - loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return predictions, loss, train_op diff --git a/tensorflow/contrib/learn/python/learn/evaluable.py b/tensorflow/contrib/learn/python/learn/evaluable.py index 66e15265171679dcd710fdf05bed3105de6bab99..8f6cd39864b437f163dd7c1140dc88755ce98529 100644 --- a/tensorflow/contrib/learn/python/learn/evaluable.py +++ b/tensorflow/contrib/learn/python/learn/evaluable.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """`Evaluable` interface.""" from __future__ import absolute_import @@ -59,9 +58,12 @@ class Evaluable(object): for which this evaluation was performed. Args: - x: Matrix of shape [n_samples, n_features...] or dictionary of many matrices - containing the input samples for fitting the model. Can be iterator that returns - arrays of features or dictionary of array of features. If set, `input_fn` must + x: Matrix of shape [n_samples, n_features...] or dictionary of many + matrices + containing the input samples for fitting the model. Can be iterator that + returns + arrays of features or dictionary of array of features. If set, + `input_fn` must be `None`. y: Vector or matrix [n_samples] or [n_samples, n_outputs] containing the label values (class labels in classification, real numbers in diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 9576ff21c243022276bb0641882dfaf0decf05c0..bec976afd2719138117976381669ca3292360480 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Experiment class collecting information needed for a single training run.""" from __future__ import absolute_import @@ -43,7 +42,6 @@ from tensorflow.python.training import saver from tensorflow.python.training import server_lib from tensorflow.python.util import compat - __all__ = ["Experiment"] @@ -278,8 +276,7 @@ class Experiment(object): self._train_steps_per_iteration = train_steps_per_iteration if (self._train_steps_per_iteration is not None and not isinstance(self._train_steps_per_iteration, int)): - raise ValueError( - "`train_steps_per_iteration` must be an integer.") + raise ValueError("`train_steps_per_iteration` must be an integer.") @property def estimator(self): @@ -359,9 +356,10 @@ class Experiment(object): config.cluster_spec and config.master): self._start_server() elif config.cluster_spec and config.master: - raise ValueError('For distributed runtime, Experiment class only works with' - 'tf.contrib.learn.RunConfig for now, but provided {}' - .format(type(config))) + raise ValueError( + "For distributed runtime, Experiment class only works with" + "tf.contrib.learn.RunConfig for now, but provided {}".format( + type(config))) extra_hooks = [] if delay_secs is None: @@ -414,11 +412,12 @@ class Experiment(object): logging.info("Waiting %d secs before starting eval.", delay_secs) time.sleep(delay_secs) - return self._call_evaluate(input_fn=self._eval_input_fn, - steps=self._eval_steps, - metrics=self._eval_metrics, - name=(name or "one_pass"), - hooks=self._eval_hooks) + return self._call_evaluate( + input_fn=self._eval_input_fn, + steps=self._eval_steps, + metrics=self._eval_metrics, + name=(name or "one_pass"), + hooks=self._eval_hooks) @deprecated( "2016-10-23", @@ -499,15 +498,12 @@ class Experiment(object): previous_path = None eval_result = None last_warning_time = 0 - while (not predicate_fn or - predicate_fn( - eval_result, - checkpoint_path=previous_path if eval_result else None)): + while (not predicate_fn or predicate_fn( + eval_result, checkpoint_path=previous_path if eval_result else None)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " - "train_step=%s", - eval_result[ops.GraphKeys.GLOBAL_STEP], + "train_step=%s", eval_result[ops.GraphKeys.GLOBAL_STEP], self._train_steps) return @@ -528,12 +524,13 @@ class Experiment(object): logging.warning(error_msg) last_warning_time = time.time() else: - eval_result = self._call_evaluate(input_fn=input_fn, - steps=self._eval_steps, - metrics=self._eval_metrics, - name=name, - checkpoint_path=latest_path, - hooks=self._eval_hooks) + eval_result = self._call_evaluate( + input_fn=input_fn, + steps=self._eval_steps, + metrics=self._eval_metrics, + name=name, + checkpoint_path=latest_path, + hooks=self._eval_hooks) # Ensure eval result is not None for next round of evaluation. if not eval_result: eval_result = {} @@ -558,8 +555,8 @@ class Experiment(object): return False global_step = eval_result.get(ops.GraphKeys.GLOBAL_STEP) - return global_step and self._train_steps and ( - global_step >= self._train_steps) + return global_step and self._train_steps and (global_step >= + self._train_steps) def continuous_eval(self, delay_secs=None, @@ -678,8 +675,7 @@ class Experiment(object): return eval_result, export_results @experimental - def continuous_train_and_eval(self, - continuous_eval_predicate_fn=None): + def continuous_train_and_eval(self, continuous_eval_predicate_fn=None): """Interleaves training and evaluation. The frequency of evaluation is controlled by the `train_steps_per_iteration` @@ -752,10 +748,9 @@ class Experiment(object): elif self._train_steps is not None: train_steps_per_iteration = int(self._train_steps / 10) - while (not predicate_fn or - predicate_fn( - eval_result, - checkpoint_path=latest_checkpoint if eval_result else None)): + while (not predicate_fn or predicate_fn( + eval_result, checkpoint_path=latest_checkpoint + if eval_result else None)): if self._has_training_stopped(eval_result): # Exits once max steps of training is satisfied. @@ -785,8 +780,7 @@ class Experiment(object): def _maybe_export(self, eval_result, checkpoint_path=None): """Export the Estimator using export_fn, if defined.""" export_dir_base = os.path.join( - compat.as_bytes(self._estimator.model_dir), - compat.as_bytes("export")) + compat.as_bytes(self._estimator.model_dir), compat.as_bytes("export")) export_results = [] for strategy in self._export_strategies: @@ -824,10 +818,11 @@ class Experiment(object): hooks=self._train_monitors, saving_listeners=self._saving_listeners) - eval_result = self._call_evaluate(input_fn=self._eval_input_fn, - steps=1, - metrics=self._eval_metrics, - name="one_pass") + eval_result = self._call_evaluate( + input_fn=self._eval_input_fn, + steps=1, + metrics=self._eval_metrics, + name="one_pass") _ = self._maybe_export(eval_result) return eval_result @@ -849,9 +844,14 @@ class Experiment(object): server.start() return server - def _call_train(self, _sentinel=None, # pylint: disable=invalid-name, - input_fn=None, steps=None, hooks=None, max_steps=None, - saving_listeners=None): + def _call_train( + self, + _sentinel=None, # pylint: disable=invalid-name, + input_fn=None, + steps=None, + hooks=None, + max_steps=None, + saving_listeners=None): if _sentinel is not None: raise ValueError("_call_train should be called with keyword args only") @@ -867,14 +867,18 @@ class Experiment(object): hooks=hooks, saving_listeners=saving_listeners) else: - return self._estimator.fit(input_fn=input_fn, - steps=steps, - max_steps=max_steps, - monitors=hooks) - - def _call_evaluate(self, _sentinel=None, # pylint: disable=invalid-name, - input_fn=None, steps=None, metrics=None, name=None, - checkpoint_path=None, hooks=None): + return self._estimator.fit( + input_fn=input_fn, steps=steps, max_steps=max_steps, monitors=hooks) + + def _call_evaluate( + self, + _sentinel=None, # pylint: disable=invalid-name, + input_fn=None, + steps=None, + metrics=None, + name=None, + checkpoint_path=None, + hooks=None): if _sentinel is not None: raise ValueError("_call_evaluate should be called with keyword args only") @@ -882,18 +886,20 @@ class Experiment(object): if metrics is not None: raise ValueError( "`eval_metrics` must be `None` with `tf.estimator.Estimator`") - return self._estimator.evaluate(input_fn=input_fn, - steps=steps, - name=name, - checkpoint_path=checkpoint_path, - hooks=hooks) + return self._estimator.evaluate( + input_fn=input_fn, + steps=steps, + name=name, + checkpoint_path=checkpoint_path, + hooks=hooks) else: - return self._estimator.evaluate(input_fn=input_fn, - steps=steps, - metrics=metrics, - name=name, - checkpoint_path=checkpoint_path, - hooks=hooks) + return self._estimator.evaluate( + input_fn=input_fn, + steps=steps, + metrics=metrics, + name=name, + checkpoint_path=checkpoint_path, + hooks=hooks) @contextlib.contextmanager diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index f36a778b529a83f158241ddb060959c4b33e2e95..96be8b1bc402479d5611965f27abb197363cb939 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -35,6 +35,7 @@ from tensorflow.python.platform import tf_logging as logging # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels + # pylint: enable=g-multiple-import,g-bad-import-order @@ -74,11 +75,11 @@ def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None): if not y_is_dict: output_shape = out_el_shape(y_shape, n_classes) else: - output_shape = dict([ - (k, out_el_shape(v, n_classes[k] - if n_classes is not None and k in n_classes else None)) - for k, v in list(y_shape.items()) - ]) + output_shape = dict([(k, + out_el_shape(v, n_classes[k] + if n_classes is not None and + k in n_classes else None)) + for k, v in list(y_shape.items())]) return input_shape, output_shape, batch_size @@ -314,23 +315,23 @@ class DataFeeder(object): input_dtype: DType of input (or dictionary of shapes). output_dtype: DType of output (or dictionary of shapes. """ - x_is_dict, y_is_dict = isinstance(x, dict), y is not None and isinstance( - y, dict) + x_is_dict, y_is_dict = isinstance( + x, dict), y is not None and isinstance(y, dict) if isinstance(y, list): y = np.array(y) self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items()) ]) if x_is_dict else check_array(x, x.dtype) - self._y = None if y is None else ( - dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) - if y_is_dict else check_array(y, y.dtype)) + self._y = None if y is None else (dict( + [(k, check_array(v, v.dtype)) for k, v in list(y.items())]) + if y_is_dict else check_array(y, y.dtype)) # self.n_classes is not None means we're converting raw target indices # to one-hot. if n_classes is not None: if not y_is_dict: - y_dtype = (np.int64 - if n_classes is not None and n_classes > 1 else np.float32) + y_dtype = ( + np.int64 if n_classes is not None and n_classes > 1 else np.float32) self._y = (None if y is None else check_array(y, dtype=y_dtype)) self.n_classes = n_classes @@ -352,8 +353,8 @@ class DataFeeder(object): # self._output_dtype == np.float32 when y is None self._output_dtype = ( dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) - if y_is_dict else ( - _check_dtype(self._y.dtype) if y is not None else np.float32)) + if y_is_dict else (_check_dtype(self._y.dtype) + if y is not None else np.float32)) # self.n_classes is None means we're passing in raw target indices if n_classes is not None and y_is_dict: @@ -478,8 +479,8 @@ class DataFeeder(object): # Assign input features from random indices. def extract(data, indices): - return (np.array(_access(data, indices)).reshape((indices.shape[0], 1)) if - len(data.shape) == 1 else _access(data, indices)) + return (np.array(_access(data, indices)).reshape((indices.shape[0], 1)) + if len(data.shape) == 1 else _access(data, indices)) # assign labels from random indices def assign_label(data, shape, dtype, n_classes, indices): @@ -511,16 +512,18 @@ class DataFeeder(object): feed_dict[self._epoch_placeholder.name] = [self.epoch] # Take next batch of indices. - x_len = list(self._x.values())[0].shape[ - 0] if x_is_dict else self._x.shape[0] + x_len = list( + self._x.values())[0].shape[0] if x_is_dict else self._x.shape[0] end = min(x_len, self.offset + self._batch_size) batch_indices = self.indices[self.offset:end] # adding input placeholder feed_dict.update( dict([(self._input_placeholder[k].name, extract(v, batch_indices)) - for k, v in list(self._x.items())]) if x_is_dict else - {self._input_placeholder.name: extract(self._x, batch_indices)}) + for k, v in list(self._x.items())]) if x_is_dict else { + self._input_placeholder.name: + extract(self._x, batch_indices) + }) # move offset and reset it if necessary self.offset += self._batch_size @@ -545,7 +548,8 @@ class DataFeeder(object): assign_label(v, shape, dtype, n_classes, batch_indices) }) else: - shape, dtype, n_classes = self.output_shape, self._output_dtype, self.n_classes + shape, dtype, n_classes = (self.output_shape, self._output_dtype, + self.n_classes) feed_dict.update({ self._output_placeholder.name: assign_label(self._y, shape, dtype, n_classes, batch_indices) @@ -621,8 +625,9 @@ class StreamingDataFeeder(DataFeeder): elif y is None: y_first_el_shape = None else: - y_first_el_shape = ([1] + list(y_first_el[0].shape if isinstance( - y_first_el, list) else y_first_el.shape)) + y_first_el_shape = ( + [1] + list(y_first_el[0].shape + if isinstance(y_first_el, list) else y_first_el.shape)) self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape( x_first_el_shape, y_first_el_shape, n_classes, batch_size) @@ -683,8 +688,8 @@ class StreamingDataFeeder(DataFeeder): if shape is None: return None elif isinstance(shape, dict): - return dict([(k, np.zeros(shape[k], dtype[k])) - for k in list(shape.keys())]) + return dict( + [(k, np.zeros(shape[k], dtype[k])) for k in list(shape.keys())]) else: return np.zeros(shape, dtype=dtype) diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 3e0b1ad21a9a4a08fa94c8e9796f2b0dd5f8d622..51381a7427c919592b8e818c4b46dba974992610 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Monitors instrument the training process. @@get_default_monitors @@ -151,8 +150,8 @@ class BaseMonitor(object): ValueError: if we've not begun an epoch, or `epoch` number does not match. """ if self._current_epoch != epoch: - raise ValueError( - "epoch_end expected %s but got %s.", self._current_epoch, epoch) + raise ValueError("epoch_end expected %s but got %s.", self._current_epoch, + epoch) self._current_epoch = None def step_begin(self, step): @@ -171,8 +170,8 @@ class BaseMonitor(object): ValueError: if we've already begun a step, or `step` < 0, or `step` > `max_steps`. """ - if (step < 0) or ( - (self._max_steps is not None) and (step > self._max_steps)): + if (step < 0) or ((self._max_steps is not None) and + (step > self._max_steps)): raise ValueError("Invalid step %s." % step) self._current_step = step return [] @@ -203,8 +202,8 @@ class BaseMonitor(object): ValueError: if we've not begun a step, or `step` number does not match. """ if self._current_step != step: - raise ValueError( - "step_end expected %s but got %s.", self._current_step, step) + raise ValueError("step_end expected %s but got %s.", self._current_step, + step) self._current_step = None return False @@ -253,6 +252,7 @@ class EveryN(BaseMonitor): treatment. """ + # TODO(ipolosukhin): Add also every n seconds. def __init__(self, every_n_steps=100, first_n_steps=1): @@ -475,8 +475,8 @@ class LoggingTrainable(EveryN): super(LoggingTrainable, self).every_n_step_begin(step) # Get a list of trainable variables at the beginning of every N steps. # We cannot get this in __init__ because train_op has not been generated. - trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, - scope=self._scope) + trainables = ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, scope=self._scope) self._names = {} for var in trainables: self._names[var.name] = var.value().name @@ -561,12 +561,19 @@ class ValidationMonitor(EveryN): provided. """ - def __init__(self, x=None, y=None, input_fn=None, batch_size=None, + def __init__(self, + x=None, + y=None, + input_fn=None, + batch_size=None, eval_steps=None, - every_n_steps=100, metrics=None, hooks=None, + every_n_steps=100, + metrics=None, + hooks=None, early_stopping_rounds=None, early_stopping_metric="loss", - early_stopping_metric_minimize=True, name=None): + early_stopping_metric_minimize=True, + name=None): """Initializes a ValidationMonitor. Args: @@ -597,8 +604,8 @@ class ValidationMonitor(EveryN): Raises: ValueError: If both x and input_fn are provided. """ - super(ValidationMonitor, self).__init__(every_n_steps=every_n_steps, - first_n_steps=-1) + super(ValidationMonitor, self).__init__( + every_n_steps=every_n_steps, first_n_steps=-1) # TODO(mdan): Checks like this are already done by evaluate. if x is None and input_fn is None: raise ValueError("Either x or input_fn should be provided.") @@ -654,20 +661,27 @@ class ValidationMonitor(EveryN): def _evaluate_estimator(self): if isinstance(self._estimator, core_estimator.Estimator): - if any((x is not None for x in - [self.x, self.y, self.batch_size, self.metrics])): + if any((x is not None + for x in [self.x, self.y, self.batch_size, self.metrics])): raise ValueError( "tf.estimator.Estimator does not support following " "arguments: x, y, batch_size, metrics. Should set as `None` " "in ValidationMonitor") return self._estimator.evaluate( - input_fn=self.input_fn, steps=self.eval_steps, hooks=self.hooks, + input_fn=self.input_fn, + steps=self.eval_steps, + hooks=self.hooks, name=self.name) else: return self._estimator.evaluate( - x=self.x, y=self.y, input_fn=self.input_fn, - batch_size=self.batch_size, steps=self.eval_steps, - metrics=self.metrics, hooks=self.hooks, name=self.name) + x=self.x, + y=self.y, + input_fn=self.input_fn, + batch_size=self.batch_size, + steps=self.eval_steps, + metrics=self.metrics, + hooks=self.hooks, + name=self.name) def every_n_step_end(self, step, outputs): super(ValidationMonitor, self).every_n_step_end(step, outputs) @@ -700,8 +714,9 @@ class ValidationMonitor(EveryN): # Early stopping logic. if self.early_stopping_rounds is not None: if self.early_stopping_metric not in validation_outputs: - raise ValueError("Metric %s missing from outputs %s." % ( - self.early_stopping_metric, set(validation_outputs.keys()))) + raise ValueError("Metric %s missing from outputs %s." % + (self.early_stopping_metric, + set(validation_outputs.keys()))) current_value = validation_outputs[self.early_stopping_metric] if (self._best_value is None or (self.early_stopping_metric_minimize and (current_value < self._best_value)) or @@ -712,9 +727,9 @@ class ValidationMonitor(EveryN): self._best_value_step = step stop_now = (step - self._best_value_step >= self.early_stopping_rounds) if stop_now: - logging.info("Stopping. Best step: {} with {} = {}." - .format(self._best_value_step, - self.early_stopping_metric, self._best_value)) + logging.info("Stopping. Best step: {} with {} = {}.".format( + self._best_value_step, self.early_stopping_metric, + self._best_value)) self._early_stopped = True return True return False @@ -763,8 +778,11 @@ class CaptureVariable(EveryN): self._var_values[step] = _extract_output(outputs, self._var_name) -def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100, - output_dir=None, summary_writer=None): +def get_default_monitors(loss_op=None, + summary_op=None, + save_summary_steps=100, + output_dir=None, + summary_writer=None): """Returns a default set of typically-used monitors. Args: @@ -782,9 +800,12 @@ def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100, if loss_op is not None: monitors.append(PrintTensor(tensor_names={"loss": loss_op.name})) if summary_op is not None: - monitors.append(SummarySaver(summary_op, save_steps=save_summary_steps, - output_dir=output_dir, - summary_writer=summary_writer)) + monitors.append( + SummarySaver( + summary_op, + save_steps=save_summary_steps, + output_dir=output_dir, + summary_writer=summary_writer)) return monitors @@ -794,8 +815,10 @@ class GraphDump(BaseMonitor): Note, this is very expensive, prefer `PrintTensor` in production. """ - IGNORE_OPS = ["Const", "Assign", "Identity", "Placeholder", - "RandomUniform", "Cast", "RestoreSlice"] + IGNORE_OPS = [ + "Const", "Assign", "Identity", "Placeholder", "RandomUniform", "Cast", + "RestoreSlice" + ] def __init__(self, ignore_ops=None): """Initializes GraphDump monitor. @@ -856,7 +879,7 @@ class GraphDump(BaseMonitor): this_output = self.data[step] if step in self.data else {} other_output = other_dump.data[step] if step in other_dump.data else {} for key in this_output: - if not isinstance(key, str) and not isinstance(key, unicode): + if not isinstance(key, six.string_types): continue if key not in other_output: raise ValueError("%s missing at step %s.", (key, step)) @@ -881,8 +904,8 @@ class ExportMonitor(EveryN): """Monitor that exports Estimator every N steps.""" @deprecation.deprecated("2017-03-25", - "ExportMonitor is deprecated. Please pass an " - "ExportStrategy to Experiment instead.") + "ExportMonitor is deprecated. Please pass an " + "ExportStrategy to Experiment instead.") def __init__(self, every_n_steps, export_dir, @@ -1088,8 +1111,7 @@ class CheckpointSaver(BaseMonitor): class StepCounter(EveryN): """Steps per second monitor.""" - def __init__(self, every_n_steps=100, output_dir=None, - summary_writer=None): + def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None): super(StepCounter, self).__init__(every_n_steps=every_n_steps) self._summary_tag = "global_step/sec" self._last_reported_step = None @@ -1101,7 +1123,8 @@ class StepCounter(EveryN): def set_estimator(self, estimator): super(StepCounter, self).set_estimator(estimator) if self._summary_writer is None: - self._summary_writer = core_summary.FileWriterCache.get(estimator.model_dir) + self._summary_writer = core_summary.FileWriterCache.get( + estimator.model_dir) def every_n_step_end(self, current_step, outputs): current_time = time.time() @@ -1109,8 +1132,9 @@ class StepCounter(EveryN): added_steps = current_step - self._last_reported_step elapsed_time = current_time - self._last_reported_time steps_per_sec = added_steps / elapsed_time - summary = Summary(value=[Summary.Value(tag=self._summary_tag, - simple_value=steps_per_sec)]) + summary = Summary(value=[ + Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) + ]) self._summary_writer.add_summary(summary, current_step) self._last_reported_step = current_step self._last_reported_time = current_time diff --git a/tensorflow/contrib/learn/python/learn/ops/ops_test.py b/tensorflow/contrib/learn/python/learn/ops/ops_test.py index d0b9eb8abcbee187b6c53b7b419882f0a1e7da51..80d4923db37feb2a1304218f501ab51f9e0d9a14 100644 --- a/tensorflow/contrib/learn/python/learn/ops/ops_test.py +++ b/tensorflow/contrib/learn/python/learn/ops/ops_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.layers import conv2d from tensorflow.contrib.learn.python.learn import ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/learn/python/learn/trainable.py b/tensorflow/contrib/learn/python/learn/trainable.py index 972fec026f25d39dca75e8c5bafffb57fcd323fa..429b6040be21d8cbe1f2bba58090366552fdfbe7 100644 --- a/tensorflow/contrib/learn/python/learn/trainable.py +++ b/tensorflow/contrib/learn/python/learn/trainable.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """`Trainable` interface.""" from __future__ import absolute_import @@ -28,18 +27,31 @@ class Trainable(object): __metaclass__ = abc.ABCMeta @abc.abstractmethod - def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, - monitors=None, max_steps=None): + def fit(self, + x=None, + y=None, + input_fn=None, + steps=None, + batch_size=None, + monitors=None, + max_steps=None): """Trains a model given training data `x` predictions and `y` labels. Args: - x: Matrix of shape [n_samples, n_features...] or the dictionary of Matrices. - Can be iterator that returns arrays of features or dictionary of arrays of features. - The training input samples for fitting the model. If set, `input_fn` must be `None`. - y: Vector or matrix [n_samples] or [n_samples, n_outputs] or the dictionary of same. - Can be iterator that returns array of labels or dictionary of array of labels. - The training label values (class labels in classification, real numbers in regression). - If set, `input_fn` must be `None`. Note: For classification, label values must + x: Matrix of shape [n_samples, n_features...] or the dictionary of + Matrices. + Can be iterator that returns arrays of features or dictionary of arrays + of features. + The training input samples for fitting the model. If set, `input_fn` + must be `None`. + y: Vector or matrix [n_samples] or [n_samples, n_outputs] or the + dictionary of same. + Can be iterator that returns array of labels or dictionary of array of + labels. + The training label values (class labels in classification, real numbers + in regression). + If set, `input_fn` must be `None`. Note: For classification, label + values must be integers representing the class index (i.e. values from 0 to n_classes-1). input_fn: Input function returning a tuple of: diff --git a/tensorflow/contrib/learn/python/learn/utils/export_test.py b/tensorflow/contrib/learn/python/learn/utils/export_test.py index 95070ada3b9d3ccb00009bd9b885e8163d7fbed4..9bfb1fc952c07bd6c09d1f1074e8dc5539dc0529 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/export_test.py @@ -50,6 +50,7 @@ def _training_input_fn(): class ExportTest(test.TestCase): + def _get_default_signature(self, export_meta_filename): """ Gets the default signature from the export.meta file. """ with session.Session(): @@ -69,18 +70,18 @@ class ExportTest(test.TestCase): # Only the written checkpoints are exported. self.assertTrue( saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')), - 'Exported checkpoint expected but not found: %s' % - os.path.join(export_dir, '00000001', 'export')) + 'Exported checkpoint expected but not found: %s' % os.path.join( + export_dir, '00000001', 'export')) self.assertTrue( saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')), - 'Exported checkpoint expected but not found: %s' % - os.path.join(export_dir, '00000010', 'export')) + 'Exported checkpoint expected but not found: %s' % os.path.join( + export_dir, '00000010', 'export')) self.assertEquals( six.b(os.path.join(export_dir, '00000010')), export_monitor.last_export_dir) # Validate the signature signature = self._get_default_signature( - os.path.join(export_dir, '00000010', 'export.meta')) + os.path.join(export_dir, '00000010', 'export.meta')) self.assertTrue(signature.HasField(expected_signature)) def testExportMonitor_EstimatorProvidesSignature(self): @@ -116,8 +117,7 @@ class ExportTest(test.TestCase): def _serving_input_fn(): return { _X_KEY: - random_ops.random_uniform( - shape=(1,), minval=0.0, maxval=1000.0) + random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0) }, None input_feature_key = 'my_example_key' @@ -160,8 +160,7 @@ class ExportTest(test.TestCase): input_feature_key: None, _X_KEY: - random_ops.random_uniform( - shape=(1,), minval=0.0, maxval=1000.0) + random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0) }, None monitor = learn.monitors.ExportMonitor( @@ -182,8 +181,7 @@ class ExportTest(test.TestCase): def _serving_input_fn(): return { input_feature_key: - array_ops.placeholder( - dtype=dtypes.string, shape=(1,)) + array_ops.placeholder(dtype=dtypes.string, shape=(1,)) }, None monitor = learn.monitors.ExportMonitor( @@ -204,11 +202,9 @@ class ExportTest(test.TestCase): def _serving_input_fn(): return { input_feature_key: - array_ops.placeholder( - dtype=dtypes.string, shape=(1,)), + array_ops.placeholder(dtype=dtypes.string, shape=(1,)), _X_KEY: - random_ops.random_uniform( - shape=(1,), minval=0.0, maxval=1000.0) + random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0) }, None export_dir = os.path.join(tempfile.mkdtemp(), 'export') @@ -227,8 +223,8 @@ class ExportTest(test.TestCase): def _regression_signature(examples, unused_features, predictions): signatures = {} - signatures['regression'] = (exporter.regression_signature(examples, - predictions)) + signatures['regression'] = ( + exporter.regression_signature(examples, predictions)) return signatures['regression'], signatures random.seed(42) @@ -248,10 +244,10 @@ class ExportTest(test.TestCase): with self.assertRaises(errors.NotFoundError): saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export')) self.assertTrue( - saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export'))) + saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export'))) # Validate the signature signature = self._get_default_signature( - os.path.join(export_dir, '00000010', 'export.meta')) + os.path.join(export_dir, '00000010', 'export.meta')) self.assertTrue(signature.HasField('regression_signature')) diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py index 76cfd88e1d68856907131f7e2bae65d4c9fcc4b1..e7d091e18a8f186f89f5217442c24fb106c5cdab 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py @@ -34,12 +34,13 @@ def _create_parser(base_dir): # create a simple parser that pulls the export_version from the directory. def parser(path): # Modify the path object for RegEx match for Windows Paths - if os.name == 'nt': - match = re.match("^" + compat.as_str_any(base_dir).replace('\\','/') + "/(\\d+)$", - compat.as_str_any(path.path).replace('\\','/')) + if os.name == "nt": + match = re.match( + "^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$", + compat.as_str_any(path.path).replace("\\", "/")) else: match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$", - compat.as_str_any(path.path)) + compat.as_str_any(path.path)) if not match: return None return path._replace(export_version=int(match.group(1))) @@ -63,7 +64,9 @@ class GcTest(test_util.TensorFlowTestCase): def testModExportVersion(self): paths = [ - gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 4), + gc.Path("/foo", 5), + gc.Path("/foo", 6), gc.Path("/foo", 9) ] mod = gc.mod_export_version(2) @@ -73,14 +76,21 @@ class GcTest(test_util.TensorFlowTestCase): def testOneOfEveryNExportVersions(self): paths = [ - gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3), - gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7), - gc.Path("/foo", 8), gc.Path("/foo", 33) + gc.Path("/foo", 0), + gc.Path("/foo", 1), + gc.Path("/foo", 3), + gc.Path("/foo", 5), + gc.Path("/foo", 6), + gc.Path("/foo", 7), + gc.Path("/foo", 8), + gc.Path("/foo", 33) ] one_of = gc.one_of_every_n_export_versions(3) self.assertEqual( one_of(paths), [ - gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8), + gc.Path("/foo", 3), + gc.Path("/foo", 6), + gc.Path("/foo", 8), gc.Path("/foo", 33) ]) @@ -98,13 +108,19 @@ class GcTest(test_util.TensorFlowTestCase): f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) self.assertEqual( f(paths), [ - gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6), - gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9) + gc.Path("/foo", 0), + gc.Path("/foo", 3), + gc.Path("/foo", 6), + gc.Path("/foo", 7), + gc.Path("/foo", 8), + gc.Path("/foo", 9) ]) def testNegation(self): paths = [ - gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 4), + gc.Path("/foo", 5), + gc.Path("/foo", 6), gc.Path("/foo", 9) ] mod = gc.negation(gc.mod_export_version(2)) @@ -121,8 +137,7 @@ class GcTest(test_util.TensorFlowTestCase): gfile.MakeDirs(os.path.join(base_dir, "ignore")) self.assertEqual( - gc.get_paths(base_dir, _create_parser(base_dir)), - [ + gc.get_paths(base_dir, _create_parser(base_dir)), [ gc.Path(os.path.join(base_dir, "0"), 0), gc.Path(os.path.join(base_dir, "1"), 1), gc.Path(os.path.join(base_dir, "2"), 2) @@ -131,10 +146,10 @@ class GcTest(test_util.TensorFlowTestCase): def testMixedStrTypes(self): temp_dir = compat.as_bytes(test.get_temp_dir()) - for sub_dir in ['str', b'bytes', u'unicode']: + for sub_dir in ["str", b"bytes", u"unicode"]: base_dir = os.path.join( - (temp_dir if isinstance(sub_dir, bytes) else temp_dir.decode()), - sub_dir) + (temp_dir + if isinstance(sub_dir, bytes) else temp_dir.decode()), sub_dir) self.assertFalse(gfile.Exists(base_dir)) gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42")) gc.get_paths(base_dir, _create_parser(base_dir)) diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index fe2f183ac970cef4ebf6ca1a927b5a48eefb7d7b..cea3627ed565f0de86d8d9bb6b45c4b19c5b5558 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -126,6 +126,7 @@ py_library( py_test( name = "sdca_estimator_test", srcs = ["python/sdca_estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":sdca_estimator_py", diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 7526f3ae0dbdb3d6827e9d7f690090b8438e4f6e..3f5fdc18bb8f47cceee8f81dd5ded02059344b8b 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -211,9 +211,8 @@ class SdcaModel(object): sums.append( math_ops.reduce_sum( math_ops.abs(math_ops.cast(weights, dtypes.float64)))) - sum = math_ops.add_n(sums) # SDCA L1 regularization cost is: l1 * sum(|weights|) - return self._options['symmetric_l1_regularization'] * sum + return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums) def _l2_loss(self, l2): """Computes the (un-normalized) l2 loss of the model.""" @@ -225,9 +224,8 @@ class SdcaModel(object): sums.append( math_ops.reduce_sum( math_ops.square(math_ops.cast(weights, dtypes.float64)))) - sum = math_ops.add_n(sums) # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2 - return l2 * sum / 2.0 + return l2 * math_ops.add_n(sums) / 2.0 def _convert_n_to_tensor(self, input_list, as_ref=False): """Converts input list to a set of tensors.""" diff --git a/tensorflow/contrib/lite/Android.bp b/tensorflow/contrib/lite/Android.bp index 8a5d9d5df29cc98f3aeab0632295288c4c196c9d..bb1ea46b61bb7330c688dc42cd8a32f1dfeeb0b0 100644 --- a/tensorflow/contrib/lite/Android.bp +++ b/tensorflow/contrib/lite/Android.bp @@ -29,6 +29,9 @@ cc_library_static { name: "libtflite_context", defaults: ["tflite_defaults"], srcs: ["context.c"], + cflags: [ + "-Wno-typedef-redefinition", + ], } cc_library_static { @@ -39,6 +42,7 @@ cc_library_static { "allocation.cc", "arena_planner.cc", "error_reporter.cc", + "graph_info.cc", "interpreter.cc", "model.cc", "nnapi_delegate.cc", diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 13350c5a438b75fe14e8753e5bb1bb77ec8f655b..44c4a7e2ca8d019ca602c7f2b492cd1e70b17561 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -6,8 +6,11 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") +exports_files(["LICENSE"]) + exports_files(glob([ "testdata/*.bin", + "testdata/*.pb", "models/testdata/*", ])) @@ -25,11 +28,6 @@ config_setting( }, ) -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", -) - cc_library( name = "schema_fbs_version", hdrs = ["version.h"], @@ -53,6 +51,8 @@ cc_test( srcs = ["arena_planner_test.cc"], deps = [ ":arena_planner", + "//tensorflow/contrib/lite/testing:util", + "//tensorflow/core:lib", "@com_google_googletest//:gtest", ], ) @@ -107,6 +107,7 @@ cc_library( srcs = [ "allocation.cc", "error_reporter.cc", + "graph_info.cc", "interpreter.cc", "model.cc", "nnapi_delegate.cc", @@ -116,6 +117,7 @@ cc_library( "allocation.h", "context.h", "error_reporter.h", + "graph_info.h", "interpreter.h", "model.h", "nnapi_delegate.h", @@ -167,6 +169,22 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test graph utils +cc_test( + name = "graph_info_test", + size = "small", + srcs = ["graph_info_test.cc"], + deps = [ + ":framework", + ":string_util", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -218,18 +236,18 @@ cc_test( # Model tests -cc_library( - name = "models_test_utils", - testonly = 1, - hdrs = ["models/test_utils.h"], - deps = select({ - "//tensorflow:android": [], - "//conditions:default": [ - "@com_google_absl//absl/strings", - "//tensorflow/core:test", - ], - }), -) +#cc_library( +# name = "models_test_utils", +# testonly = 1, +# hdrs = ["models/test_utils.h"], +# deps = select({ +# "//tensorflow:android": [], +# "//conditions:default": [ +# "@com_google_absl//absl/strings", +# "//tensorflow/core:test", +# ], +# }), +#) filegroup( name = "all_files", diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index 55a524b207b258e794f97e68a96cf01dc60efb7f..00e93d2c4f3ab27057b855fba6fccf2ec8d7a1c1 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -6,7 +6,7 @@ TensorFlow Lite uses many techniques for achieving low latency like optimizing t ![image](g3doc/TFLite-Architecture.jpg) # Getting Started with an Android Demo App -This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo. +This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. A device running Android 5.0 ( API 21) or higher is required to run the demo. There are 3 ways to get the demo app to your device - Download the prebuilt binary or @@ -29,9 +29,16 @@ The simplest way to compile the demo app, and try out changes to the project cod - Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings). - Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project. - Click through installing all the Gradle extensions it requests. - - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) - - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: - `tensorflow/contrib/lite/java/demo/app/src/main/assets/` + - Either + - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) + - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: + `tensorflow/contrib/lite/java/demo/app/src/main/assets/` + - Or download the floating point Inception-v3 model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) + - unzip and copy inceptionv3_non_slim_2015.tflite to the assets directory + - change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java) from + `classifier = new ImageClassifierQuantizedMobileNet(getActivity());` + to + `classifier = new ImageClassifierFloatInception(getActivity());` - Build and run the demo app ## Building TensorFlow Lite and the demo app from source @@ -84,7 +91,7 @@ Currently, we only support building the Android demo app within a Python 2 environment (due to a Bazel bug). ### More about the demo -The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app. +The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used (229 * 229 for Inception-v3). The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch. 224 * 224 (299 * 299) is the width and height of the image. 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. Both models have 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The model file must be downloaded and bundled within the assets directory of the app. # iOS Demo App @@ -92,7 +99,7 @@ Similar to the Android demo app, there's an iOS camera app that uses exactly the This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: -1. Follow the Building section [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md#building) to build the universal iOS library for TensorFlow Lite. +1. Run `third_party/tensorflow/contrib/lite/examples/ios/download_models.sh` to download the model files used by the demo app. 1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. 1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. 1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. @@ -142,7 +149,7 @@ Since we employ several formats, the following definitions may be useful: - SavedModel - A collection of GraphDef and CheckPoint together with a signature that labels input and output arguments to a model. A GraphDef and Checkpoint can be extracted from a saved model. - - TensorFlow lite model (.lite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. + - TensorFlow lite model (.tflite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. ### Freeze Graph To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as "freezing" the graph. @@ -164,18 +171,18 @@ bazel-bin/tensorflow/python/tools/freeze_graph\ The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3). -This frozen Graphdef is now ready to be converted to flatbuffer format (.lite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. +This frozen Graphdef is now ready to be converted to flatbuffer format (.tflite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. -Here is a sample command line to convert the frozen Graphdef to '.lite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. +Here is a sample command line to convert the frozen Graphdef to '.tflite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. (Here is a link to the pb [file](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)). ``` bazel build tensorflow/contrib/lite/toco:toco -bazel-bin/tensorflow/contrib/lite/toco/toco -- \ +bazel-bin/tensorflow/contrib/lite/toco/toco \ --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \ --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ - --output_file=/tmp/mobilenet_v1_1.0_224.lite --inference_type=FLOAT \ + --output_file=/tmp/mobilenet_v1_1.0_224.tflite --inference_type=FLOAT \ --input_type=FLOAT --input_arrays=input \ --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3 ``` @@ -211,7 +218,7 @@ and then visualize the resulting HTML file in a browser. ## Step 3. Use the TensorFlow Lite model for inference in a mobile app -After completion of Step 2 the developer should have a .lite model. +After completion of Step 2 the developer should have a .tflite model. ### For Android Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index bf1bcdd1a7a7d3395c45ae95abd5980e9ffc0fc6..87b17c338e7afc33d32dd9688cc0825ac319dd19 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -185,8 +185,12 @@ TfLiteStatus ArenaPlanner::CalculateAllocations(int first_node, int last_node) { TfLiteStatus ArenaPlanner::ResolveTensorAllocation(int tensor_index) { TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_STATUS( - arena_.ResolveAlloc(context_, allocs_[tensor_index], &tensor.data.raw)); + // Skip resolution if the size of the tensor is zero, leaving it as a + // nullptr. + if (allocs_[tensor_index].size != 0) { + TF_LITE_ENSURE_STATUS(arena_.ResolveAlloc(context_, allocs_[tensor_index], + &tensor.data.raw)); + } } if (tensor.allocation_type == kTfLiteArenaRwPersistent) { TF_LITE_ENSURE_STATUS(persistent_arena_.ResolveAlloc( diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index c27c327abc63d7bd1e3912d368a1dacb62c50ca8..a8a8755e2c9e81474f2ff9cd2b85c0eb3d5c3441 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/core/platform/logging.h" namespace tflite { namespace { @@ -191,8 +193,8 @@ TEST_F(ArenaPlannerTest, GraphWithNoOps) { EXPECT_EQ(GetOffset(10), GetOffsetAfter(0)); // The outputs are never allocated because they are not connected to any // inputs. - EXPECT_EQ(GetOffset(5), 0); - EXPECT_EQ(GetOffset(11), 0); + EXPECT_TRUE((*graph.tensors())[5].data.raw == nullptr); + EXPECT_TRUE((*graph.tensors())[11].data.raw == nullptr); } TEST_F(ArenaPlannerTest, GraphWithOneOp) { @@ -371,11 +373,7 @@ TEST_F(ArenaPlannerTest, LargerGraphAndStepwiseAllocation) { SetGraph(&graph); auto is_unallocated = [&](int tensor_index) { - // TODO(ahentz): We'd to use nullptr to represent unallocated tensors, but - // the current code still points them all to the beginning fo the alloc - // (that is, zero offset). - // return (*graph.tensors())[tensor_index].data.raw == nullptr; - return GetOffset(tensor_index) == 0; + return (*graph.tensors())[tensor_index].data.raw == nullptr; }; // The allocation plan is made at the beginning and is independent of @@ -464,9 +462,7 @@ TEST_F(ArenaPlannerTest, LargerGraphAndStepwiseAllocation) { } // namespace tflite int main(int argc, char** argv) { - // ::tflite::LogToStderr(); - FLAGS_logtostderr = true; - + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 0b48ef4741ac921e34dd56930783499c5040d581..5fc8954743e5b3b458e5c2004f4378cbad6056c0 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -116,25 +116,9 @@ typedef struct { } TfLiteAddParams; typedef struct { - // Number of spatial dimensions. - // For now only NHWC is supported, and the value should always be 2. - int num_spatial_dimensions; - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int block_shape[2]; - int before_paddings[2]; - int after_paddings[2]; } TfLiteSpaceToBatchNDParams; typedef struct { - // Number of spatial dimensions. - // For now only NHWC is supported, and the value should always be 2. - int num_spatial_dimensions; - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int block_shape[2]; - int before_crops[2]; - int after_crops[2]; } TfLiteBatchToSpaceNDParams; typedef struct { @@ -167,8 +151,7 @@ typedef struct { } TfLiteLSTMParams; typedef struct { - int new_height; - int new_width; + bool align_corners; } TfLiteResizeBilinearParams; typedef struct { @@ -206,20 +189,16 @@ typedef struct { } TfLiteGatherParams; typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int perm[8]; - int num_dimensions; } TfLiteTransposeParams; typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int axis[8]; - int num_axis_dimensions; bool keep_dims; } TfLiteMeanParams; +typedef struct { + int num_splits; +} TfLiteSplitParams; + typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..4ebd1586de791eecf0304637bde76232d9f0a11d --- /dev/null +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ + +// DO NOT EDIT MANUALLY: This file is automatically generated by +// `schema_builtin_ops_header_generator.py`. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { + kTfLiteBuiltinAdd = 0, + kTfLiteBuiltinAveragePool2d = 1, + kTfLiteBuiltinConcatenation = 2, + kTfLiteBuiltinConv2d = 3, + kTfLiteBuiltinDepthwiseConv2d = 4, + kTfLiteBuiltinEmbeddingLookup = 7, + kTfLiteBuiltinFullyConnected = 9, + kTfLiteBuiltinHashtableLookup = 10, + kTfLiteBuiltinL2Normalization = 11, + kTfLiteBuiltinL2Pool2d = 12, + kTfLiteBuiltinLocalResponseNormalization = 13, + kTfLiteBuiltinLogistic = 14, + kTfLiteBuiltinLshProjection = 15, + kTfLiteBuiltinLstm = 16, + kTfLiteBuiltinMaxPool2d = 17, + kTfLiteBuiltinMul = 18, + kTfLiteBuiltinRelu = 19, + kTfLiteBuiltinReluN1To1 = 20, + kTfLiteBuiltinRelu6 = 21, + kTfLiteBuiltinReshape = 22, + kTfLiteBuiltinResizeBilinear = 23, + kTfLiteBuiltinRnn = 24, + kTfLiteBuiltinSoftmax = 25, + kTfLiteBuiltinSpaceToDepth = 26, + kTfLiteBuiltinSvdf = 27, + kTfLiteBuiltinTanh = 28, + kTfLiteBuiltinConcatEmbeddings = 29, + kTfLiteBuiltinSkipGram = 30, + kTfLiteBuiltinCall = 31, + kTfLiteBuiltinCustom = 32, + kTfLiteBuiltinEmbeddingLookupSparse = 33, + kTfLiteBuiltinPad = 34, + kTfLiteBuiltinUnidirectionalSequenceRnn = 35, + kTfLiteBuiltinGather = 36, + kTfLiteBuiltinBatchToSpaceNd = 37, + kTfLiteBuiltinSpaceToBatchNd = 38, + kTfLiteBuiltinTranspose = 39, + kTfLiteBuiltinMean = 40, + kTfLiteBuiltinSub = 41, + kTfLiteBuiltinDiv = 42, + kTfLiteBuiltinSqueeze = 43, + kTfLiteBuiltinUnidirectionalSequenceLstm = 44, + kTfLiteBuiltinStridedSlice = 45, + kTfLiteBuiltinBidirectionalSequenceRnn = 46, + kTfLiteBuiltinExp = 47, + kTfLiteBuiltinTopkV2 = 48, + kTfLiteBuiltinSplit = 49, +} TfLiteBuiltinOperator; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +} diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index d6dfc20ae829b13e9cb45efcf9e14af5d4b69b48..b0c4d3431f9a67bc87d51ada91ed73f1661023a2 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -38,6 +38,9 @@ extern "C" { typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; +// Forward declare so GetNode can use this is in Context. +typedef struct _TfLiteRegistration TfLiteRegistration; + #define kOptionalTensor (-1) // Fixed size list of integers. Used for dimensions and inputs/outputs tensor @@ -205,9 +208,56 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, // Resize the allocated data of a (dynamic) tensor. void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); +// A structure representing an instance of a node. +// This structure only exhibits the inputs, outputs and user defined data, not +// other features like the type. +typedef struct { + // Inputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* inputs; + + // Outputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* outputs; + + // Temporary tensors uses during the computations. This usually contains no + // tensors, but ops are allowed to change that if they need scratch space of + // any sort. + TfLiteIntArray* temporaries; + + // Opaque data provided by the node implementer through `Registration.init`. + void* user_data; + + // Opaque data provided to the node if the node is a builtin. This is usually + // a structure defined in builtin_op_data.h + void* builtin_data; + + // Custom initial data. This is the opaque data provided in the flatbuffer. + // WARNING: This is an experimental interface that is subject to change. + const void* custom_initial_data; + int custom_initial_data_size; +} TfLiteNode; + typedef struct TfLiteContext { // Number of tensors in the context. int tensors_size; + + // The execution plan contains a list of the node indices in execution + // order. execution_plan->size is the current number of nodes. And, + // execution_plan->data[0] is the first node that needs to be run. + // TfLiteDelegates can traverse the current execution plan by iterating + // through each member of this array and using GetNodeAndRegistration() to + // access details about a node. i.e. + // TfLiteIntArray* execution_plan; + // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan)); + // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) { + // int node_index = execution_plan->data[exec_index]; + // TfLiteNode* node; + // TfLiteRegistration* reg; + // context->GetNodeAndRegistration(context, node_index, &node, ®); + // } + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, + TfLiteIntArray** execution_plan); + // An tensor of tensors in the interpreter context (of length `tensors_size`) TfLiteTensor* tensors; @@ -227,34 +277,23 @@ typedef struct TfLiteContext { TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, int* first_new_tensor_index); + // Get a Tensor node by node_index. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index, + TfLiteNode** node, + TfLiteRegistration** registration); + + // Replace ops with delegate. + TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( + struct TfLiteContext*, TfLiteRegistration registration, + const TfLiteIntArray* nodes_to_replace); + // TODO(ahentz): we should create a more general mechanism for this sort of // library-global objects. void* gemm_context; } TfLiteContext; -// A structure representing an instance of a node. -// This structure only exhibits the inputs, outputs and user defined data, not -// other features like the type. -typedef struct { - // Inputs to this node expressed as indices into the simulator's tensors. - TfLiteIntArray* inputs; - - // Outputs to this node expressed as indices into the simulator's tensors. - TfLiteIntArray* outputs; - - // Temporary tensors uses during the computations. This usually contains no - // tensors, but ops are allowed to change that if they need scratch space of - // any sort. - TfLiteIntArray* temporaries; - - // Opaque data provided by the node implementer through `Registration.init`. - void* user_data; - - // Opaque data provided to the node if the node is a builtin. - void* builtin_data; -} TfLiteNode; - -typedef struct { +typedef struct _TfLiteRegistration { // Initializes the op from serialized data. // If a built-in op: // `buffer` is the op's params data (TfLiteLSTMParams*). @@ -291,8 +330,26 @@ typedef struct { // NN API. Note, it is the responsibility of the registration binder to // set this properly. int32_t builtin_code; + + // Custom op name. If the op is a builtin, this will be null. + // WARNING: This is an experimental interface that is subject to change. + const char* custom_name; } TfLiteRegistration; +// WARNING: This is an experimental interface that is subject to change. +typedef struct { + // Data that delegate needs to identify itself. This data is owned by the + // delegate. The delegate is owned in the user code, so the delegate is + // responsible for doing this when it is destroyed. + void* data_; + // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the + // delegate a view of the current graph through TfLiteContext*. It typically + // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() + // to ask the TensorFlow lite runtime to create macro-nodes to represent + // delegated subgraphs of the original graph. + TfLiteStatus (*Prepare)(TfLiteContext* context, void* data); +} TfLiteDelegate; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index e1b7b3613a041287ff3cc4eeff8afd7cfcede174..a93ed201d647ddf2359a57254a959871c13fb94f 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -36,8 +36,6 @@ ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_ NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip" -MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_ios_lite_float_2017_11_08.zip" -QUANTIZED_MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. @@ -93,8 +91,6 @@ download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl" download_and_extract "${NEON_2_SSE_URL}" "${DOWNLOADS_DIR}/neon_2_sse" download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" -download_and_extract "${MODELS_URL}" "${DOWNLOADS_DIR}/models" -download_and_extract "${QUANTIZED_MODELS_URL}" "${DOWNLOADS_DIR}/quantized_models" replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" @@ -103,7 +99,4 @@ replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#s replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA );#static uint64x2_t p2ul_CONJ_XOR;// = vld1q_u64( p2ul_conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" -cp ${DOWNLOADS_DIR}/models/models/* tensorflow/contrib/lite/examples/ios/simple/data/ -cp ${DOWNLOADS_DIR}/quantized_models/* tensorflow/contrib/lite/examples/ios/camera/data/ - echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index 10f31bb6f17242c9f7f70f0648ec643f99c5ac86..d74e275f0439b1ce56b29e0eadff5f211f6a4faa 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -225,14 +225,8 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const assert(pixelBuffer != NULL); OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); - int doReverseChannels; - if (kCVPixelFormatType_32ARGB == sourcePixelFormat) { - doReverseChannels = 1; - } else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) { - doReverseChannels = 0; - } else { - assert(false); // Unknown source format - } + assert(sourcePixelFormat == kCVPixelFormatType_32ARGB || + sourcePixelFormat == kCVPixelFormatType_32BGRA); const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer); const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer); diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile index 4ae6fb6b94e4489f63506b05a2f348b7daafd3b7..c7d3b1c966eaa0de71f5c37a6a77b3881e30ddd7 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/Podfile +++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_camera_example' - pod 'TensorFlow-experimental' + pod 'TensorFlowLite' diff --git a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj index c98183276bd60d2a0ad023ba26aad12572a02786..b0236e9c608ec35437bcfe79c51149a76f9f416e 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj +++ b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj @@ -16,7 +16,6 @@ 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 1CDB2D4D1ED3AA35007929E9 /* Info.plist */; }; 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */; }; AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = AC1F82641FBA3CBD0052BA77 /* labels.txt */; }; - AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */; }; ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */; }; /* End PBXBuildFile section */ @@ -38,7 +37,6 @@ 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.debug.xcconfig"; sourceTree = ""; }; 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.release.xcconfig"; sourceTree = ""; }; AC1F82641FBA3CBD0052BA77 /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = ""; }; - AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = ""; }; ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_quant_v1_224.tflite; sourceTree = ""; }; /* End PBXFileReference section */ @@ -47,7 +45,6 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */, 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */, 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */, 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */, @@ -60,7 +57,6 @@ 24D7686C331131624F4454A0 /* Frameworks */ = { isa = PBXGroup; children = ( - AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */, 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */, 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */, 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, @@ -336,7 +332,6 @@ ../../../downloads/, ); IPHONEOS_DEPLOYMENT_TARGET = 8.0; - LIBRARY_SEARCH_PATHS = ../../../gen/lib/; MTL_ENABLE_DEBUG_INFO = YES; ONLY_ACTIVE_ARCH = YES; SDKROOT = iphoneos; @@ -384,7 +379,6 @@ ../../../downloads/, ); IPHONEOS_DEPLOYMENT_TARGET = 8.0; - LIBRARY_SEARCH_PATHS = ../../../gen/lib/; MTL_ENABLE_DEBUG_INFO = NO; SDKROOT = iphoneos; TARGETED_DEVICE_FAMILY = "1,2"; diff --git a/tensorflow/contrib/lite/examples/ios/download_models.sh b/tensorflow/contrib/lite/examples/ios/download_models.sh new file mode 100755 index 0000000000000000000000000000000000000000..ccd163758c5830dc9367e023dcb3a604e07ca5db --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/download_models.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -ex + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_ios_lite_float_2017_11_08.zip" +QUANTIZED_MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" +DOWNLOADS_DIR=$(mktemp -d) + +cd $SCRIPT_DIR + +download_and_extract() { + local usage="Usage: download_and_extract URL DIR" + local url="${1:?${usage}}" + local dir="${2:?${usage}}" + echo "downloading ${url}" >&2 + mkdir -p "${dir}" + tempdir=$(mktemp -d) + tempdir2=$(mktemp -d) + + curl -L ${url} > ${tempdir}/zipped.zip + unzip ${tempdir}/zipped.zip -d ${tempdir2} + + # If the zip file contains nested directories, extract the files from the + # inner directory. + if ls ${tempdir2}/*/* 1> /dev/null 2>&1; then + # unzip has no strip components, so unzip to a temp dir, and move the + # files we want from the tempdir to destination. + cp -R ${tempdir2}/*/* ${dir}/ + else + cp -R ${tempdir2}/* ${dir}/ + fi + rm -rf ${tempdir2} ${tempdir} +} + +download_and_extract "${MODELS_URL}" "${DOWNLOADS_DIR}/models" +download_and_extract "${QUANTIZED_MODELS_URL}" "${DOWNLOADS_DIR}/quantized_models" + +file ${DOWNLOADS_DIR}/models + +cp ${DOWNLOADS_DIR}/models/models/* simple/data/ +cp ${DOWNLOADS_DIR}/quantized_models/* camera/data/ + diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile index 1740ad64573a84fae6de0fcf284eb06afec67e25..e4aca2be82d437a0225d2c15d3e486b0344aa978 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/Podfile +++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile @@ -1,5 +1,5 @@ platform :ios, '8.0' inhibit_all_warnings! -target 'tf_simple_example' - pod 'TensorFlow-experimental' +target 'tflite_simple_example' + pod 'TensorFlowLite' diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist b/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist index 1a3eaa8a2c18d1cd24dfd475d396b00ec4d86c9d..a19a43a7541e3d751116e868dbcbdd607d15ab4a 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist @@ -7,7 +7,7 @@ CFBundleDisplayName tflite-simple-example CFBundleExecutable - tf_simple_example + tflite_simple_example CFBundleIdentifier $(PRODUCT_BUNDLE_IDENTIFIER) CFBundleInfoDictionaryVersion diff --git a/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj index 9277c230b8cce1b5673a50d32d7640d52e2e8f9d..f5b8382d5ae4ac80a7edb52c34ebaf12ad65f4db 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj +++ b/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj @@ -9,7 +9,7 @@ /* Begin PBXBuildFile section */ 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */; }; 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */; }; - 594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */; }; + 1E6F42DBB39A4A3871D4F848 /* libPods-tflite_simple_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 73DBC33C5DD9A526EE6D1EF2 /* libPods-tflite_simple_example.a */; }; 594C14B11FB9037100EE8BFE /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = 594C14AF1FB9037100EE8BFE /* labels.txt */; }; 594C14B21FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */; }; 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; @@ -24,8 +24,7 @@ 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; - 5911579B1CF4011C00C31E3A /* tf_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; - 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = ""; }; + 5911579B1CF4011C00C31E3A /* tflite_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tflite_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; 594C14AF1FB9037100EE8BFE /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = ""; }; 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = ""; }; 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; @@ -38,7 +37,9 @@ 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RunModelViewController.h; sourceTree = ""; }; 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RunModelViewController.mm; sourceTree = ""; }; 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = RunModelViewController.xib; sourceTree = ""; }; - 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; + 5D6203B9FAEEB9824194DBE8 /* Pods-tflite_simple_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_simple_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_simple_example/Pods-tflite_simple_example.release.xcconfig"; sourceTree = ""; }; + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tflite_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tflite_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; + 987DD5BCAB2DD8B682674E20 /* Pods-tflite_simple_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_simple_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_simple_example/Pods-tflite_simple_example.debug.xcconfig"; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -46,9 +47,9 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - 594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */, 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */, 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */, + 1E6F42DBB39A4A3871D4F848 /* libPods-tflite_simple_example.a in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -58,11 +59,10 @@ 24D7686C331131624F4454A0 /* Frameworks */ = { isa = PBXGroup; children = ( - 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */, 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, - 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */, + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tflite_simple_example.a */, ); name = Frameworks; sourceTree = ""; @@ -82,13 +82,14 @@ 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */, 5911579C1CF4011C00C31E3A /* Products */, 24D7686C331131624F4454A0 /* Frameworks */, + 5CE7E4179B26BF77944D8637 /* Pods */, ); sourceTree = ""; }; 5911579C1CF4011C00C31E3A /* Products */ = { isa = PBXGroup; children = ( - 5911579B1CF4011C00C31E3A /* tf_simple_example.app */, + 5911579B1CF4011C00C31E3A /* tflite_simple_example.app */, ); name = Products; sourceTree = ""; @@ -103,24 +104,36 @@ path = data; sourceTree = ""; }; + 5CE7E4179B26BF77944D8637 /* Pods */ = { + isa = PBXGroup; + children = ( + 987DD5BCAB2DD8B682674E20 /* Pods-tflite_simple_example.debug.xcconfig */, + 5D6203B9FAEEB9824194DBE8 /* Pods-tflite_simple_example.release.xcconfig */, + ); + name = Pods; + sourceTree = ""; + }; /* End PBXGroup section */ /* Begin PBXNativeTarget section */ - 5911579A1CF4011C00C31E3A /* tf_simple_example */ = { + 5911579A1CF4011C00C31E3A /* tflite_simple_example */ = { isa = PBXNativeTarget; - buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */; + buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tflite_simple_example" */; buildPhases = ( + A507411BCC70190B9ABD2721 /* [CP] Check Pods Manifest.lock */, 591157971CF4011C00C31E3A /* Sources */, 591157981CF4011C00C31E3A /* Frameworks */, 591157991CF4011C00C31E3A /* Resources */, + 25E1671BDC7334C678FB5DFB /* [CP] Embed Pods Frameworks */, + 10976C49D86B7F8A59157601 /* [CP] Copy Pods Resources */, ); buildRules = ( ); dependencies = ( ); - name = tf_simple_example; + name = tflite_simple_example; productName = tf_ios_makefile_example; - productReference = 5911579B1CF4011C00C31E3A /* tf_simple_example.app */; + productReference = 5911579B1CF4011C00C31E3A /* tflite_simple_example.app */; productType = "com.apple.product-type.application"; }; /* End PBXNativeTarget section */ @@ -152,7 +165,7 @@ projectDirPath = ""; projectRoot = ""; targets = ( - 5911579A1CF4011C00C31E3A /* tf_simple_example */, + 5911579A1CF4011C00C31E3A /* tflite_simple_example */, ); }; /* End PBXProject section */ @@ -171,6 +184,57 @@ }; /* End PBXResourcesBuildPhase section */ +/* Begin PBXShellScriptBuildPhase section */ + 10976C49D86B7F8A59157601 /* [CP] Copy Pods Resources */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Copy Pods Resources"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_simple_example/Pods-tflite_simple_example-resources.sh\"\n"; + showEnvVarsInLog = 0; + }; + 25E1671BDC7334C678FB5DFB /* [CP] Embed Pods Frameworks */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Embed Pods Frameworks"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_simple_example/Pods-tflite_simple_example-frameworks.sh\"\n"; + showEnvVarsInLog = 0; + }; + A507411BCC70190B9ABD2721 /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-tflite_simple_example-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; +/* End PBXShellScriptBuildPhase section */ + /* Begin PBXSourcesBuildPhase section */ 591157971CF4011C00C31E3A /* Sources */ = { isa = PBXSourcesBuildPhase; @@ -274,6 +338,7 @@ }; 591157B31CF4011D00C31E3A /* Debug */ = { isa = XCBuildConfiguration; + baseConfigurationReference = 987DD5BCAB2DD8B682674E20 /* Pods-tflite_simple_example.debug.xcconfig */; buildSettings = { CLANG_DEBUG_INFORMATION_LEVEL = default; CODE_SIGN_IDENTITY = "iPhone Developer"; @@ -283,15 +348,10 @@ GCC_ENABLE_CPP_RTTI = YES; HEADER_SEARCH_PATHS = ( "$(inherited)", - ../../../../../../, - ../../../downloads/flatbuffers/include/, - ../../../downloads/eigen/, - ../../../downloads/, ); INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; IPHONEOS_DEPLOYMENT_TARGET = 9.2; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ../../../gen/lib/; OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; OTHER_LDFLAGS = "$(inherited)"; PRODUCT_BUNDLE_IDENTIFIER = "com.google.tflite-simple-example"; @@ -304,6 +364,7 @@ }; 591157B41CF4011D00C31E3A /* Release */ = { isa = XCBuildConfiguration; + baseConfigurationReference = 5D6203B9FAEEB9824194DBE8 /* Pods-tflite_simple_example.release.xcconfig */; buildSettings = { CLANG_DEBUG_INFORMATION_LEVEL = default; CODE_SIGN_IDENTITY = "iPhone Developer"; @@ -313,15 +374,10 @@ GCC_ENABLE_CPP_RTTI = YES; HEADER_SEARCH_PATHS = ( "$(inherited)", - ../../../../../../, - ../../../downloads/flatbuffers/include/, - ../../../downloads/eigen/, - ../../../downloads/, ); INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; IPHONEOS_DEPLOYMENT_TARGET = 9.2; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ../../../gen/lib/; ONLY_ACTIVE_ARCH = YES; OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; OTHER_LDFLAGS = "$(inherited)"; @@ -344,7 +400,7 @@ defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; - 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */ = { + 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tflite_simple_example" */ = { isa = XCConfigurationList; buildConfigurations = ( 591157B31CF4011D00C31E3A /* Debug */, diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD index 476d85c0314e331d6d3bad382c331a8458fd01a1..959347b5491514ddc13af57ea6f7385a0d39e418 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -42,7 +42,15 @@ cc_library( "bitmap_helpers_impl.h", "label_image.h", ], - deps = ["//tensorflow/contrib/lite:string"], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/schema:schema_fbs", + ], ) # TODO(ahentz): Test disabled as it has a memory leek from read_bmp diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h index 860e27e5ba9cc9fe23d2a7f9f65dd53bbf76f7a3..97343dde6b31694e5b2de20b35a7083fb8fe4a0e 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H -#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ #include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h" #include "tensorflow/contrib/lite/examples/label_image/label_image.h" @@ -26,15 +26,15 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, int* channels, Settings* s); template -void downsize(T* out, uint8_t* in, int image_height, int image_width, - int image_channels, int wanted_height, int wanted_width, - int wanted_channels, Settings* s); +void resize(T* out, uint8_t* in, int image_height, int image_width, + int image_channels, int wanted_height, int wanted_width, + int wanted_channels, Settings* s); // explicit instantiation -template void downsize(uint8_t*, unsigned char*, int, int, int, int, - int, int, Settings*); -template void downsize(float*, unsigned char*, int, int, int, int, int, +template void resize(uint8_t*, unsigned char*, int, int, int, int, int, int, Settings*); +template void resize(float*, unsigned char*, int, int, int, int, int, + int, Settings*); } // namespace label_image } // namespace tflite diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index 64a931082b0cbb4632ec3a814ce654d4f9106bc1..2a64c1de725b601e9b6e9325d9faacb37df0e626 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -13,8 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H -#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/version.h" + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/version.h" #include "tensorflow/contrib/lite/examples/label_image/label_image.h" @@ -22,28 +34,70 @@ namespace tflite { namespace label_image { template -void downsize(T* out, uint8_t* in, int image_height, int image_width, - int image_channels, int wanted_height, int wanted_width, - int wanted_channels, Settings* s) { - for (int y = 0; y < wanted_height; ++y) { - const int in_y = (y * image_height) / wanted_height; - uint8_t* in_row = in + (in_y * image_width * image_channels); - T* out_row = out + (y * wanted_width * wanted_channels); - for (int x = 0; x < wanted_width; ++x) { - const int in_x = (x * image_width) / wanted_width; - uint8_t* in_pixel = in_row + (in_x * image_channels); - T* out_pixel = out_row + (x * wanted_channels); - for (int c = 0; c < wanted_channels; ++c) { - if (s->input_floating) - out_pixel[c] = (in_pixel[c] - s->input_mean) / s->input_std; - else - out_pixel[c] = in_pixel[c]; - } - } +void resize(T* out, uint8_t* in, int image_height, int image_width, + int image_channels, int wanted_height, int wanted_width, + int wanted_channels, Settings* s) { + int number_of_pixels = image_height * image_width * image_channels; + std::unique_ptr interpreter(new Interpreter); + + int base_index = 0; + + // two inputs: input and new_sizes + interpreter->AddTensors(2, &base_index); + // one output + interpreter->AddTensors(1, &base_index); + // set input and output tensors + interpreter->SetInputs({0, 1}); + interpreter->SetOutputs({2}); + + // set parameters of tensors + TfLiteQuantizationParams quant; + interpreter->SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "input", + {1, image_height, image_width, image_channels}, quant); + interpreter->SetTensorParametersReadWrite(1, kTfLiteInt32, "new_size", {2}, + quant); + interpreter->SetTensorParametersReadWrite( + 2, kTfLiteFloat32, "output", + {1, wanted_height, wanted_width, wanted_channels}, quant); + + ops::builtin::BuiltinOpResolver resolver; + TfLiteRegistration* resize_op = + resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR); + auto* params = reinterpret_cast( + malloc(sizeof(TfLiteResizeBilinearParams))); + params->align_corners = false; + interpreter->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params, resize_op, + nullptr); + + interpreter->AllocateTensors(); + + // fill input image + // in[] are integers, cannot do memcpy() directly + auto input = interpreter->typed_tensor(0); + for (int i = 0; i < number_of_pixels; i++) { + input[i] = in[i]; + } + + // fill new_sizes + interpreter->typed_tensor(1)[0] = wanted_height; + interpreter->typed_tensor(1)[1] = wanted_width; + + interpreter->Invoke(); + + auto output = interpreter->typed_tensor(2); + auto output_number_of_pixels = + wanted_height * wanted_height * wanted_channels; + + for (int i = 0; i < output_number_of_pixels; i++) { + if (s->input_floating) + out[i] = (output[i] - s->input_mean) / s->input_std; + else + out[i] = (uint8_t)output[i]; } } } // namespace label_image } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index d7f49ad8757e8899fe9c23b985edff6ba7f68750..a91467d345fdce1268635a69a96939921dc170e8 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -151,14 +151,14 @@ void RunInference(Settings* s) { switch (interpreter->tensor(input)->type) { case kTfLiteFloat32: s->input_floating = true; - downsize(interpreter->typed_tensor(input), in, - image_height, image_width, image_channels, - wanted_height, wanted_width, wanted_channels, s); + resize(interpreter->typed_tensor(input), in, image_height, + image_width, image_channels, wanted_height, wanted_width, + wanted_channels, s); break; case kTfLiteUInt8: - downsize(interpreter->typed_tensor(input), in, - image_height, image_width, image_channels, - wanted_height, wanted_width, wanted_channels, s); + resize(interpreter->typed_tensor(input), in, + image_height, image_width, image_channels, wanted_height, + wanted_width, wanted_channels, s); break; default: LOG(FATAL) << "cannot handle input type " @@ -188,9 +188,8 @@ void RunInference(Settings* s) { int output = interpreter->outputs()[0]; switch (interpreter->tensor(output)->type) { case kTfLiteFloat32: - get_top_n(interpreter->typed_output_tensor(0), - output_size, num_results, threshold, &top_results, - true); + get_top_n(interpreter->typed_output_tensor(0), output_size, + num_results, threshold, &top_results, true); break; case kTfLiteUInt8: get_top_n(interpreter->typed_output_tensor(0), diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index 204a489a93519309bb09238f1b2c8bbd4f1f19e4..d7cc854ebac08e79d346df0aca6e1fa56b490156 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -73,7 +73,7 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, SinResize, SinEval}; + static TfLiteRegistration r = {nullptr, nullptr, SinPrepare, SinEval}; return &r; } ``` diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 8e5e694a5cbe7f908572114db33c8257db6151f0..b1bbb7c67013acfb575cc1e9f9390ba191cbd08e 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -1,4 +1,4 @@ -# TensorFlow Compatibility Guide +# TensorFlow Lite & TensorFlow Compatibility Guide TensorFlow Lite supports a number of TensorFlow operations used in common inference models. As they are processed by the TensorFlow Lite Optimizing diff --git a/tensorflow/contrib/lite/graph_info.cc b/tensorflow/contrib/lite/graph_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..e60ed2c2463cb621015ba725ca030e8d8c02f3c7 --- /dev/null +++ b/tensorflow/contrib/lite/graph_info.cc @@ -0,0 +1,224 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/graph_info.h" +#include + +namespace tflite { + +namespace { + +// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite +// C api uses. Can't use the google array_view, since we can't depend on even +// absl for embedded device reasons. +// TODO(aselle): Move this into central utilities. +class TfLiteIntArrayView { + public: + // Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null + // and this view does not take ownership of it. + explicit TfLiteIntArrayView(const TfLiteIntArray* int_array) + : int_array_(int_array) {} + + typedef const int* const_iterator; + const_iterator begin() const { return int_array_->data; } + const_iterator end() const { return &int_array_->data[int_array_->size]; } + + TfLiteIntArrayView(const TfLiteIntArrayView&) = default; + TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default; + + private: + const TfLiteIntArray* int_array_; +}; + +// Helper class that actually performs partitioning by subgraph. +// Outputs to a provided `subgraphs` structure. +// +// Example usage: +// PartitionGraphIntoIndependentSubgraphsImpl partitioner( +// info, nodes_to_part, subgraphs); +// partitioner.Partition(); +class PartitionGraphIntoIndependentSubgraphsImpl { + public: + PartitionGraphIntoIndependentSubgraphsImpl( + const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, + std::vector* subgraphs) + : info_(info), + subgraphs_(subgraphs), + node_type_(info->num_nodes(), Subgraph::kTfNonPartition) { + // Populate the node_type_ map. + for (auto node_index : TfLiteIntArrayView(nodes_to_partition)) { + node_type_[node_index] = Subgraph::kTfPartition; + } + } + + // Actually partition the graph. + void Partition() { + // Initialize here to make Partition() re-entrant. + subgraphs_->clear(); + tensor_epochs_.clear(); + tensor_epochs_.resize(info_->num_tensors(), kEpochAlwaysReady); + node_epochs_.clear(); + node_epochs_.resize(info_->num_nodes(), kEpochNotReady); + // Set computed tensors to be kEpochNotReady (initializer set everything to + // AlwaysReady). + for (int node_index = 0; node_index < info_->num_nodes(); node_index++) { + const TfLiteNode& node = info_->node(node_index); + for (int output_tensor_index : TfLiteIntArrayView(node.outputs)) { + tensor_epochs_[output_tensor_index] = kEpochNotReady; + } + } + + // Do a graph traversal where each iteration in the loop is an epoch + // that corresponds to a subgraph that only contains nodes that are of + // the same node_type_. + while (true) { + BuildSubgraph(); + if (subgraphs_->back().nodes.empty()) { + subgraphs_->pop_back(); + break; + } + } + + // Mark model outputs as subgraph outputs. All the rest have already been + // identified. + for (int output_index : info_->outputs()) { + int output_epoch = tensor_epochs_[output_index]; + Subgraph& output_subgraph = (*subgraphs_)[output_epoch]; + output_subgraph.output_tensors.push_back(output_index); + } + // Make sure every subgraph's inputs and outputs are unique. Since the + // list of inputs and outputs is generated in a way that produces + // duplicates. + for (Subgraph& subgraph : *subgraphs_) { + // Sort and uniquefy using standard library algorithms. + auto uniquefy = [](std::vector* items) { + std::sort(items->begin(), items->end()); + auto last = std::unique(items->begin(), items->end()); + items->erase(last, items->end()); + }; + uniquefy(&subgraph.input_tensors); + uniquefy(&subgraph.output_tensors); + } + } + + private: + // Special integer values needed for tensor_epochs_ and node_epochs_. + enum { + // The node or tensor is not ready to be assigned an epoch. e.g. a node's + // inputs have not all been assigned epochs. + kEpochNotReady = -1, + // Used for tensor_epochs_. This means that the tensor is always ready. + // e.g. an input to the whole model or a constant that has no dependencies. + kEpochAlwaysReady = -2 + }; + + // Updates the node `node_index` and returns true if it is assigned to an + // epoch. False is returned if the node is already set to an epoch, its inputs + // are not all assigned to epochs, or if it cannot be assigned to the current + // epoch since the epoch's node_type doesn't match. + bool UpdateNode(int node_index) { + const TfLiteNode& node = info_->node(node_index); + Subgraph& current_subgraph = subgraphs_->back(); + int current_epoch = subgraphs_->size() - 1; + // Check if node is already done. + if (node_epochs_[node_index] != kEpochNotReady) { + return false; + } + // See if all dependencies of this node are already assigned to a + // subgraph. + for (int input_tensor_index : TfLiteIntArrayView(node.inputs)) { + if (tensor_epochs_[input_tensor_index] == kEpochNotReady) { + return false; + } + } + // When we are starting a new epoch, the first ready node defines + // the type of that epoch. + if (current_subgraph.type == Subgraph::kTfUnexplored) { + current_subgraph.type = node_type_[node_index]; + } + // The node gets assigned to this epoch if it is the same type as + // the epoch's assigned type. Note, if this is the current ready + // node encountered during this epoch, this condition will be + // automatically true. + if (current_subgraph.type == node_type_[node_index]) { + node_epochs_[node_index] = current_epoch; + current_subgraph.nodes.push_back(node_index); + // All outputs of this node now are assigned to this epoch as + // well. + for (int output_tensor_index : TfLiteIntArrayView(node.outputs)) { + tensor_epochs_[output_tensor_index] = current_epoch; + } + // Look at our inputs one more time to update that tensor's + // epochs' outputs + for (int input_tensor_index : TfLiteIntArrayView(node.inputs)) { + int input_epoch = tensor_epochs_[input_tensor_index]; + int node_epoch = current_epoch; + if (input_epoch != node_epoch) { + current_subgraph.input_tensors.push_back(input_tensor_index); + // Set inputs to be outputs of the subgraph where they reside. + // the if condition makes sure inputs to the whole computation + // are not included (i.e. those initialized to -2 above). + if (input_epoch >= 0) { + Subgraph& input_subgraph = (*subgraphs_)[input_epoch]; + input_subgraph.output_tensors.push_back(input_tensor_index); + } + } + } + return true; + } else { + return false; + } + } + + // Completely populates the current subgraph by doing graph traversal + void BuildSubgraph() { + subgraphs_->emplace_back(Subgraph()); + // loop until no more nodes can be updated. + while (true) { + bool did_something = false; + for (int node_index = 0; node_index < info_->num_nodes(); node_index++) { + if (UpdateNode(node_index)) { + did_something = true; + } + } + if (!did_something) return; + } + } + + // Temporary data needed for partitioning. + const GraphInfo* info_; + // List of subgraphs to populate + std::vector* subgraphs_; + std::vector node_type_; + // Maps from tensor index to the epoch in which it is assigned. Also special + // negative values of kEpochNotAssigned if not assigned, kEpochNotReady if it + // is an input or constant. + std::vector tensor_epochs_; + // Maps from tensor index to the epoch in which it is assigned. Also special + // negative values of kEpochNotAssigned if not assigned. + std::vector node_epochs_; +}; + +} // namespace + +TfLiteStatus PartitionGraphIntoIndependentSubgraphs( + const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, + std::vector* subgraphs) { + PartitionGraphIntoIndependentSubgraphsImpl(info, nodes_to_partition, + subgraphs) + .Partition(); + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h index 57690058c4630f75f8b23073f4ab44f27090c51b..313af5fb7574b42bcdd53b4baad06e4ccfb34053 100644 --- a/tensorflow/contrib/lite/graph_info.h +++ b/tensorflow/contrib/lite/graph_info.h @@ -48,6 +48,32 @@ class GraphInfo { virtual const std::vector& outputs() const = 0; }; +// Represents a subgraph of a TensorFlow Lite graph. +struct Subgraph { + enum Type { + kTfUnexplored = 0, // temporarily used during creation + kTfPartition, + kTfNonPartition + }; + Type type = kTfUnexplored; + // Nodes within the subgraph + std::vector nodes; + // Tensors that stride output from another subgraph that this depends on, + // or global inputs to the TensorFlow Lite full graph. + std::vector input_tensors; + // Outputs that are consumed by other subgraphs or are global output tensors. + // All output tensors of the nodes in the subgraph that do not appear in this + // list are intermediate results that can be potentially elided. + std::vector output_tensors; +}; + +// Partitions a list of node indices `nodes_to_partition` into subgraphs. +// Each subgraph is in dependency order (i.e. all members of the subgraph). +// `subgraphs` is assumed to be empty. +TfLiteStatus PartitionGraphIntoIndependentSubgraphs( + const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, + std::vector* subgraphs); + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ diff --git a/tensorflow/contrib/lite/graph_info_test.cc b/tensorflow/contrib/lite/graph_info_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea38b43993fef71c6820c7a978351d92d5420287 --- /dev/null +++ b/tensorflow/contrib/lite/graph_info_test.cc @@ -0,0 +1,270 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/contrib/lite/graph_info.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace { + +// Makes a TfLiteIntArray* from std::vector, must free with TfLiteIntFree(). +TfLiteIntArray* ConvertVector(const std::vector& x) { + TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); + for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i]; + return lite; +} + +// A very simple test graph that supports setting in/out tensors on nodes. +class SimpleTestGraph : public GraphInfo { + public: + ~SimpleTestGraph() override { + for (auto& node : nodes_) { + TfLiteIntArrayFree(node.inputs); + TfLiteIntArrayFree(node.outputs); + } + } + + size_t num_tensors() const override { return tensors_.size(); } + size_t num_nodes() const override { return nodes_.size(); } + const TfLiteNode& node(size_t index) const override { return nodes_[index]; } + TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; } + const std::vector& inputs() const override { return inputs_; } + const std::vector& outputs() const override { return outputs_; } + + void AddNode(const std::vector& inputs, + const std::vector& outputs) { + nodes_.push_back(TfLiteNode()); + TfLiteNode& node = nodes_.back(); + node.inputs = ConvertVector(inputs); + node.outputs = ConvertVector(outputs); + } + + void AddTensors(int count) { tensors_.resize(count + tensors_.size()); } + + void SetInputsAndOutputs(const std::vector& inputs, + const std::vector& outputs) { + inputs_ = inputs; + outputs_ = outputs; + } + + private: + std::vector nodes_; + std::vector tensors_; + std::vector inputs_; + std::vector outputs_; +}; + +// Partition a graph to generate a list of subgraphs. This wraps the API call +// we are testing and handles memory management and conversion to +// TfLiteIntArray. Populates `subgraphs` with resulting generated subgraphs. +void PartitionGraph(const SimpleTestGraph& graph, + const std::vector& nodes_to_partition, + std::vector* subgraphs) { + TfLiteIntArray* nodes_to_partition_int_array = + ConvertVector(nodes_to_partition); + PartitionGraphIntoIndependentSubgraphs(&graph, nodes_to_partition_int_array, + subgraphs); + TfLiteIntArrayFree(nodes_to_partition_int_array); +} + +// Check a generated list of subgraphs against the expected list of subgraphs. +void CheckPartitionSubgraphs(const std::vector& generated_subgraphs, + const std::vector& expected_subgraphs) { + ASSERT_EQ(generated_subgraphs.size(), expected_subgraphs.size()); + for (int subgraph_index = 0; subgraph_index < generated_subgraphs.size(); + subgraph_index++) { + EXPECT_EQ(generated_subgraphs[subgraph_index].nodes, + expected_subgraphs[subgraph_index].nodes); + EXPECT_EQ(generated_subgraphs[subgraph_index].input_tensors, + expected_subgraphs[subgraph_index].input_tensors); + EXPECT_EQ(generated_subgraphs[subgraph_index].output_tensors, + expected_subgraphs[subgraph_index].output_tensors); + } +} + +// Test an empty trivial graph with no partitions. +TEST(PartitionTest, Nodes0_PartitionNodes0) { + SimpleTestGraph graph; + std::vector nodes_to_partition = {}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + CheckPartitionSubgraphs(generated_subgraphs, {}); +} + +// Test a 1 node graph with no partitions. +// Input: tensor(0) -> node(0) -> tensor(1), nodes_to_partition=[] +// Output: [kTfNoPartition, tensor(0) -> node(0) -> tensor(1)] +TEST(PartitionTest, Nodes1PartitionNodes0) { + SimpleTestGraph graph; + graph.AddTensors(2); + graph.AddNode({0}, {1}); + graph.SetInputsAndOutputs({0}, {1}); + std::vector nodes_to_partition = {}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph; + expected_subgraph.type = Subgraph::kTfNonPartition; + expected_subgraph.nodes = {0}; + expected_subgraph.input_tensors = {0}; + expected_subgraph.output_tensors = {1}; + CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph}); +} + +// Test a 1 node graph with no inputs that is fully partitioned. +// Input: node(0) -> tensor(1), nodes_to_partition=[node0] +// Output: [kTfPartition, node(0) -> tensor(1)] +TEST(PartitionTest, Nodes1PartitionNodes0Inputs0) { + SimpleTestGraph graph; + graph.AddTensors(1); + graph.AddNode({}, {0}); + graph.SetInputsAndOutputs({}, {0}); + std::vector generated_subgraphs; + std::vector nodes_to_partition = {0}; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph; + expected_subgraph.type = Subgraph::kTfPartition; + expected_subgraph.nodes = {0}; + expected_subgraph.input_tensors = {}; + expected_subgraph.output_tensors = {0}; + CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph}); +} + +// Test a 1 node graph that is partitioned completely. +// Input: tensor(0) -> node(0) -> tensor(1), nodes_to_partition=[node0] +// Output: [kTfPartition, tensor(0) -> node(0) -> tensor(1)] +TEST(PartitionTest, Nodes1PartitionNodes1) { + SimpleTestGraph graph; + graph.AddTensors(2); + graph.AddNode({0}, {1}); + graph.SetInputsAndOutputs({0}, {1}); + std::vector nodes_to_partition = {0}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph; + expected_subgraph.type = Subgraph::kTfPartition; + expected_subgraph.nodes = {0}; + expected_subgraph.input_tensors = {0}; + expected_subgraph.output_tensors = {1}; + CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph}); +} + +// Test a 2 node graph where 1 node is partitioned and the other is not. +// Input: tensor(0) -> node(0) -> tensor(1) -> node(1) -> tensor(2), +// nodes_to_partition = [1] +// Output: [kTfNonPartition, tensor(0) -> node(0) -> tensor(1), +// kTfPartition, tensor(1) -> node(1), tensor(2)] +TEST(PartitionTest, Nodes2PartitionNodes1) { + SimpleTestGraph graph; + graph.AddTensors(3); + graph.AddNode({0}, {1}); + graph.AddNode({1}, {2}); + graph.SetInputsAndOutputs({0}, {2}); + std::vector nodes_to_partition = {1}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph0; + expected_subgraph0.type = Subgraph::kTfPartition; + expected_subgraph0.nodes = {0}; + expected_subgraph0.input_tensors = {0}; + expected_subgraph0.output_tensors = {1}; + Subgraph expected_subgraph1; + expected_subgraph1.type = Subgraph::kTfPartition; + expected_subgraph1.nodes = {1}; + expected_subgraph1.input_tensors = {1}; + expected_subgraph1.output_tensors = {2}; + CheckPartitionSubgraphs(generated_subgraphs, + {expected_subgraph0, expected_subgraph1}); +} + +// Test a 2 node graph where both nodes are fully partitioned. +// Input: tensor(0) -> node(0) -> tensor(1) -> node(1) -> tensor(2), +// nodes_to_partition = [0, 1] +// Output: [kTfPartition, tensor(0) -> node(0) -> node(1) -> tensor(1)] +TEST(PartitionTest, Nodes2PartitionNodes2) { + SimpleTestGraph graph; + graph.AddTensors(3); + graph.AddNode({0}, {1}); + graph.AddNode({1}, {2}); + graph.SetInputsAndOutputs({0}, {2}); + std::vector nodes_to_partition = {0, 1}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph0; + expected_subgraph0.type = Subgraph::kTfPartition; + expected_subgraph0.nodes = {0, 1}; + expected_subgraph0.input_tensors = {0}; + expected_subgraph0.output_tensors = {2}; + CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph0}); +} + +// Test a three node model where we want to partition nodes 0 and nodes +// 2, but nodes 0 and nodes 2 cannot be in the same subgraph since node 2 +// depends on node 1 which depends on node 0. Thus, we need to produce three +// subgraphs. +// +// Input: tensor(0) -> node(0) -> tensor(1) +// tensor(1) -> node(1) -> tensor(2) +// [tensor(2), tensor(1)] -> node(2) -> tensor(3) +// nodes_to_partition = [0, 2] +// Output: [[kTfPartition, tensor(0) -> node(0) -> tensor(1), +// [kTfNonPartition, tensor(1) -> node(1) -> tensor(2)], +// [kTfPartition, [tensor(2), tensor(1)] -> node(2) -> node(3)] +TEST(PartitionTest, Nodes3PartitionNodes2) { + SimpleTestGraph graph; + graph.AddTensors(4); + graph.AddNode({0}, {1}); + graph.AddNode({1}, {2}); + graph.AddNode({1, 2}, {3}); + graph.SetInputsAndOutputs({0}, {3}); + std::vector nodes_to_partition = {0, 2}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph0; + expected_subgraph0.type = Subgraph::kTfPartition; + expected_subgraph0.nodes = {0}; + expected_subgraph0.input_tensors = {0}; + expected_subgraph0.output_tensors = {1}; + Subgraph expected_subgraph1; + expected_subgraph1.type = Subgraph::kTfNonPartition; + expected_subgraph1.nodes = {1}; + expected_subgraph1.input_tensors = {1}; + expected_subgraph1.output_tensors = {2}; + Subgraph expected_subgraph2; + expected_subgraph2.type = Subgraph::kTfPartition; + expected_subgraph2.nodes = {2}; + expected_subgraph2.input_tensors = {1, 2}; + expected_subgraph2.output_tensors = {3}; + CheckPartitionSubgraphs( + generated_subgraphs, + {expected_subgraph0, expected_subgraph1, expected_subgraph2}); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 69a597dc5a219b55eced6ec8da5b388caf372b8e..028449211b8108d004df4d1cd8a58b4a08df6604 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -36,6 +36,10 @@ constexpr const int kSlotsToReserve = 128; namespace tflite { // A trivial implementation of GraphInfo around the Interpreter. +// NOTE: this interpreter info represents the subset of the +// graph that is executed according to execution plan. Thus, +// the indices are execution plan indices rather than raw node +// indices. class InterpreterInfo : public GraphInfo { public: explicit InterpreterInfo(Interpreter* interpreter) @@ -45,9 +49,12 @@ class InterpreterInfo : public GraphInfo { TfLiteTensor* tensor(size_t index) override { return interpreter_->tensor(index); } - size_t num_nodes() const override { return interpreter_->nodes_size(); } + size_t num_nodes() const override { + return interpreter_->execution_plan().size(); + } const TfLiteNode& node(size_t index) const override { - return interpreter_->node_and_registration(index)->first; + int node_index = interpreter_->execution_plan()[index]; + return interpreter_->node_and_registration(node_index)->first; } const std::vector& inputs() const override { return interpreter_->inputs(); @@ -70,10 +77,16 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) context_.tensors = nullptr; context_.tensors_size = 0; context_.gemm_context = nullptr; + + // Invalid to call these these except from TfLiteDelegate + context_.GetNodeAndRegistration = nullptr; + context_.ReplaceSubgraphsWithDelegateKernels = nullptr; + context_.GetExecutionPlan = nullptr; + // Reserve some space for the tensors to avoid excessive resizing. tensors_.reserve(kSlotsToReserve); nodes_and_registration_.reserve(kSlotsToReserve); - next_node_to_prepare_ = 0; + next_execution_plan_index_to_prepare_ = 0; UseNNAPI(false); } @@ -93,6 +106,78 @@ Interpreter::~Interpreter() { } } +TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( + TfLiteContext* context, TfLiteRegistration registration, + const TfLiteIntArray* nodes_to_replace) { + return static_cast(context->impl_) + ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace); +} + +TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( + TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) { + // Analyze the graph to find all independent subgraphs that are either + // fully not-this-delegate or this-delegate computation. + InterpreterInfo info(this); + std::vector subgraphs; + PartitionGraphIntoIndependentSubgraphs(&info, nodes_to_replace, &subgraphs); + + execution_plan_.clear(); + for (auto& subgraph : subgraphs) { + // Turn subgraph.nodes into a TfLiteIntArray compatible data structure. + // TODO(aselle): Avoid this copy by constructing subgraph.nodes that way + // in the first place + subgraph.nodes.insert(subgraph.nodes.begin(), + static_cast(subgraph.nodes.size())); + // Subgraphs calimed by the delegate should have a "macro" op created, the + // other subgraphs (kTfNonPartition) just have their nodes added back to + // the execution plan. + switch (subgraph.type) { + case Subgraph::kTfNonPartition: + for (auto it = subgraph.nodes.begin() + 1; it != subgraph.nodes.end(); + ++it) { + execution_plan_.push_back(*it); + } + break; + case Subgraph::kTfPartition: { + void* builtin_data = nullptr; + int node_index; + // Create a node that represents computation of this subgraph. + AddNodeWithParameters( + subgraph.input_tensors, subgraph.output_tensors, + reinterpret_cast(subgraph.nodes.data()), + subgraph.nodes.size() * sizeof(subgraph.nodes[0]), builtin_data, + ®istration, &node_index); + } break; + case Subgraph::kTfUnexplored: + return kTfLiteError; + break; + } + } + return kTfLiteOk; +} + +// Gets an TfLiteIntArray* representing the execution plan. The interpreter owns +// this memory and it is only guaranteed to exist during the invocation of the +// delegate prepare. +TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) { + // TODO(aselle): Do not make a copy here + plan_cache_.reset(TfLiteIntArrayCreate(execution_plan_.size())); + *execution_plan = plan_cache_.get(); + static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]), + "TfLiteIntArray and execution_plan do not contain same type."); + memcpy(plan_cache_->data, execution_plan_.data(), + sizeof(plan_cache_->data[0]) * execution_plan_.size()); + return kTfLiteOk; +} + +// WARNING: This is an experimental interface that is subject to change. +// Entry point for C node plugin API to get the execution plan +TfLiteStatus Interpreter::GetExecutionPlan(struct TfLiteContext* context, + TfLiteIntArray** execution_plan) { + return static_cast(context->impl_) + ->GetExecutionPlan(execution_plan); +} + TfLiteStatus Interpreter::SetInputs(std::vector inputs) { TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("inputs", inputs.data(), inputs.size())); @@ -160,7 +245,7 @@ TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { } // namespace TfLiteStatus Interpreter::AllocateTensors() { - next_node_to_prepare_ = 0; + next_execution_plan_index_to_prepare_ = 0; if (memory_planner_) { TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); } @@ -190,8 +275,10 @@ TfLiteStatus Interpreter::AddNodeWithParameters( &context_, CheckTensorIndices("node outputs", outputs.data(), outputs.size())); - if (node_index) *node_index = nodes_and_registration_.size(); + int new_node_index = nodes_and_registration_.size(); + if (node_index) *node_index = new_node_index; nodes_and_registration_.resize(nodes_and_registration_.size() + 1); + auto& node_and_reg = nodes_and_registration_.back(); TfLiteNode& node = node_and_reg.first; if (node.inputs) TfLiteIntArrayFree(node.inputs); @@ -213,6 +300,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters( } node.builtin_data = builtin_data_deleter.release(); node_and_reg.second = *registration; + execution_plan_.push_back(new_node_index); return kTfLiteOk; } @@ -240,16 +328,19 @@ bool HasDynamicTensor(const TfLiteContext& context, return false; } -TfLiteStatus Interpreter::PrepareOpsStartingAt(int first_node, - int* last_node_prepared) { - for (int i = first_node; i < nodes_and_registration_.size(); i++) { - TfLiteNode& node = nodes_and_registration_[i].first; - const TfLiteRegistration& registration = nodes_and_registration_[i].second; +TfLiteStatus Interpreter::PrepareOpsStartingAt( + int first_execution_plan_index, int* last_execution_plan_index_prepared) { + for (int execution_plan_index = first_execution_plan_index; + execution_plan_index < execution_plan_.size(); execution_plan_index++) { + int node_index = execution_plan_[execution_plan_index]; + TfLiteNode& node = nodes_and_registration_[node_index].first; + const TfLiteRegistration& registration = + nodes_and_registration_[node_index].second; if (OpPrepare(registration, &node) == kTfLiteError) { return kTfLiteError; } - *last_node_prepared = i; + *last_execution_plan_index_prepared = execution_plan_index; // Discontinue if the node has dynamic outputs. Note that we don't // stop for dynamic temporary tensors since they won't affect the @@ -268,14 +359,14 @@ TfLiteStatus Interpreter::PrepareOpsAndTensors() { memory_planner_->PlanAllocations(); } - int last_node_prepared = 0; + int last_exec_plan_index_prepared = 0; - TF_LITE_ENSURE_STATUS( - PrepareOpsStartingAt(next_node_to_prepare_, &last_node_prepared)); + TF_LITE_ENSURE_STATUS(PrepareOpsStartingAt( + next_execution_plan_index_to_prepare_, &last_exec_plan_index_prepared)); TF_LITE_ENSURE_STATUS(memory_planner_->ExecuteAllocations( - next_node_to_prepare_, last_node_prepared)); + next_execution_plan_index_to_prepare_, last_exec_plan_index_prepared)); - next_node_to_prepare_ = last_node_prepared + 1; + next_execution_plan_index_to_prepare_ = last_exec_plan_index_prepared + 1; return kTfLiteOk; } @@ -291,7 +382,7 @@ TfLiteStatus Interpreter::Invoke() { TfLiteStatus status = kTfLiteOk; if (nnapi_delegate_) { - if (next_node_to_prepare_ == nodes_and_registration_.size()) { + if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) { TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); return kTfLiteOk; } else { @@ -311,13 +402,17 @@ TfLiteStatus Interpreter::Invoke() { // TODO(b/71913981): we should force recalculation in the presence of dynamic // tensors, because they may have new value which in turn may affect shapes // and allocations. - for (int i = 0; i < nodes_and_registration_.size(); i++) { - if (i == next_node_to_prepare_) { + for (int execution_plan_index = 0; + execution_plan_index < execution_plan_.size(); execution_plan_index++) { + if (execution_plan_index == next_execution_plan_index_to_prepare_) { TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); - TF_LITE_ENSURE(&context_, next_node_to_prepare_ >= i); + TF_LITE_ENSURE(&context_, next_execution_plan_index_to_prepare_ >= + execution_plan_index); } - TfLiteNode& node = nodes_and_registration_[i].first; - const TfLiteRegistration& registration = nodes_and_registration_[i].second; + int node_index = execution_plan_[execution_plan_index]; + TfLiteNode& node = nodes_and_registration_[node_index].first; + const TfLiteRegistration& registration = + nodes_and_registration_[node_index].second; if (OpInvoke(registration, &node) == kTfLiteError) { status = kTfLiteError; } @@ -372,6 +467,22 @@ TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add, ->AddTensors(tensors_to_add, first_new_tensor_index); } +TfLiteStatus Interpreter::GetNodeAndRegistration( + int node_index, TfLiteNode** node, TfLiteRegistration** registration) { + TF_LITE_ENSURE(&context_, node_index < nodes_size() && node_index >= 0); + TF_LITE_ENSURE(&context_, node != nullptr && registration != nullptr); + *node = &nodes_and_registration_[node_index].first; + *registration = &nodes_and_registration_[node_index].second; + return kTfLiteOk; +} + +TfLiteStatus Interpreter::GetNodeAndRegistration( + struct TfLiteContext* context, int node_index, TfLiteNode** node, + TfLiteRegistration** registration) { + return static_cast(context->impl_) + ->GetNodeAndRegistration(node_index, node, registration); +} + TfLiteStatus Interpreter::SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const std::vector& dims, TfLiteQuantizationParams quantization, @@ -421,6 +532,14 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite( return kTfLiteOk; } +TfLiteStatus Interpreter::SetExecutionPlan(const std::vector& new_plan) { + for (int node_index : new_plan) { + TF_LITE_ENSURE(&context_, node_index >= 0 && node_index < nodes_size()); + } + execution_plan_ = new_plan; + return kTfLiteOk; +} + TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size) { // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. @@ -434,6 +553,9 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArrayFree(new_size); return kTfLiteError; } + + // Realloc space for kTfLiteDynamic tensors. + TfLiteTensorRealloc(bytesRequired, tensor); tensor->bytes = bytesRequired; } if (tensor->dims) TfLiteIntArrayFree(tensor->dims); @@ -471,4 +593,20 @@ void Interpreter::SetNumThreads(int num_threads) { tflite::gemm_support::SetMaxNumThreads(&context_, num_threads); } +TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { + // TODO(aselle): Consider if it is worth storing pointers to delegates. + // Setup additional context interface + context_.GetNodeAndRegistration = GetNodeAndRegistration; + context_.ReplaceSubgraphsWithDelegateKernels = + ReplaceSubgraphsWithDelegateKernels; + context_.GetExecutionPlan = GetExecutionPlan; + + TfLiteStatus status = delegate->Prepare(&context_, delegate->data_); + // Remove additional context info. + context_.GetNodeAndRegistration = nullptr; + context_.ReplaceSubgraphsWithDelegateKernels = nullptr; + context_.GetExecutionPlan = nullptr; + return status; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 4f732769f9f921a9debd5213547d2baccfa69426..bab56a9d72f8992a9d8af23f92133c7c918fd46d 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -80,6 +80,12 @@ class NNAPIDelegate; // foo.Invoke(); // +struct TfLiteIntArrayDeleter { + void operator()(TfLiteIntArray* a) { + if (a) TfLiteIntArrayFree(a); + } +}; + class Interpreter { public: // Instantiate an interpreter. All errors associated with reading and @@ -108,7 +114,7 @@ class Interpreter { // Adds a node with the given parameters and returns the index of the new // node in `node_index` (optionally). Interpreter will take ownership of - // `builtin_data` and destroy it with `delete`. Ownership of 'init_data' + // `builtin_data` and destroy it with `free`. Ownership of 'init_data' // remains with the caller. TfLiteStatus AddNodeWithParameters(const std::vector& inputs, const std::vector& outputs, @@ -166,12 +172,19 @@ class Interpreter { // Return the number of ops in the model. int nodes_size() const { return nodes_and_registration_.size(); } + // WARNING: Experimental interface, subject to change + const std::vector& execution_plan() const { return execution_plan_; } + + // WARNING: Experimental interface, subject to change + // Overrides execution plan. This bounds checks indices sent in. + TfLiteStatus SetExecutionPlan(const std::vector& new_plan); + // Get a tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { if (tensor_index >= context_.tensors_size || tensor_index < 0) - return nullptr; + return nullptr; return &context_.tensors[tensor_index]; } @@ -240,6 +253,11 @@ class Interpreter { // Set the number of threads available to the interpreter. void SetNumThreads(int num_threads); + // Allow a delegate to look at the graph and modify the graph to handle + // parts of the graph themselves. After this is called, the graph may + // contain new nodes that replace 1 more nodes. + TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); + private: // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. @@ -279,7 +297,8 @@ class Interpreter { // dynamic tensors is found or all ops have been prepared. Fill // 'last_node_prepared' with the id of the op containing dynamic tensors, or // the last in the graph. - TfLiteStatus PrepareOpsStartingAt(int first_node, int* last_node_prepared); + TfLiteStatus PrepareOpsStartingAt(int first_execution_plan_index, + int* last_execution_plan_index_prepared); // Tensors needed by the interpreter. Use `AddTensors` to add more blank // tensor entries. Note, `tensors_.data()` needs to be synchronized to the @@ -299,7 +318,8 @@ class Interpreter { TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size, size_t* bytes); - // Request an tensor be resized implementation. + // Request an tensor be resized implementation. If the given tensor is of + // type kTfLiteDynamic it will also be allocated new memory. TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size); // Report a detailed error string (will be printed to stderr). @@ -316,6 +336,40 @@ class Interpreter { static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add, int* first_new_tensor_index); + // WARNING: This is an experimental API and subject to change. + // Entry point for C API ReplaceSubgraphsWithDelegateKernels + static TfLiteStatus ReplaceSubgraphsWithDelegateKernels( + TfLiteContext* context, TfLiteRegistration registration, + const TfLiteIntArray* nodes_to_replace); + + // Update the execution graph to replace some of the nodes with stub + // nodes. Specifically any node index that has `nodes[index]==1` will be + // slated for replacement with a delegate kernel specified by registration. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus ReplaceSubgraphsWithDelegateKernels( + TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace); + + // WARNING: This is an experimental interface that is subject to change. + // Gets the internal pointer to a TensorFlow lite node by node_index. + TfLiteStatus GetNodeAndRegistration(int node_index, TfLiteNode** node, + TfLiteRegistration** registration); + + // WARNING: This is an experimental interface that is subject to change. + // Entry point for C node plugin API to get a node by index. + static TfLiteStatus GetNodeAndRegistration(struct TfLiteContext*, + int node_index, TfLiteNode** node, + TfLiteRegistration** registration); + + // WARNING: This is an experimental interface that is subject to change. + // Gets an TfLiteIntArray* representing the execution plan. The caller owns + // this memory and must free it with TfLiteIntArrayFree(). + TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan); + + // WARNING: This is an experimental interface that is subject to change. + // Entry point for C node plugin API to get the execution plan + static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, + TfLiteIntArray** execution_plan); + // A pure C data structure used to communicate with the pure C plugin // interface. To avoid copying tensor metadata, this is also the definitive // structure to store tensors. @@ -354,7 +408,18 @@ class Interpreter { // node id, and execute the node to generate the output tensor before continue // to allocate successors. This process repeats until all nodes are executed. // NOTE: this relies on the order of nodes that is in topological order. - int next_node_to_prepare_; + int next_execution_plan_index_to_prepare_; + + // WARNING: This is an experimental interface that is subject to change. + // This is a list of node indices (to index into nodes_and_registration). + // This represents a valid topological sort (dependency ordered) execution + // plan. In particular, it is valid for this ordering to contain only a + // subset of the node indices. + std::vector execution_plan_; + + // In the future, we'd like a TfLiteIntArray compatible representation. + // TODO(aselle): replace execution_plan_ with this. + std::unique_ptr plan_cache_; // Whether to delegate to NN API std::unique_ptr nnapi_delegate_; diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index edff2109430c6e1ec6c481619ed7772237a3301d..28c96e5dde6ffa62bb073db9716a00f91c6e0bdf 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/contrib/lite/interpreter.h" #include #include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/string_util.h" - +#include "tensorflow/contrib/lite/testing/util.h" namespace tflite { namespace { @@ -282,6 +284,51 @@ TEST(BasicInterpreter, NoOpInterpreter) { ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); } +TEST(BasicInterpreter, ResizingTensors) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + int t = interpreter.inputs()[0]; + TfLiteTensor* tensor = interpreter.tensor(t); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + tensor->data.f[5] = 0.123f; + + // Changing from kTfLiteArenaRw to kTfLiteDynamic is quite complicate: we need + // to unset data.raw, otherwise Realloc will try to free that memory. + tensor->data.raw = nullptr; + tensor->allocation_type = kTfLiteDynamic; + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 4}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 8 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // TODO(ahentz): We shouldn't have to force reallocation, but + // ResizeInputTensor doesn't realloc dynamic tensors. Also note that + // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op. + TfLiteTensorRealloc(9 * sizeof(float), tensor); + tensor->data.f[7] = 0.123f; + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {2, 2, 4}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 16 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // TODO(ahentz): We shouldn't have to force reallocation, but + // ResizeInputTensor doesn't realloc dynamic tensors. Also note that + // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op. + TfLiteTensorRealloc(17 * sizeof(float), tensor); + tensor->data.f[15] = 0.123f; +} + TEST(BasicInterpreter, OneOpInterpreter) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); @@ -514,13 +561,283 @@ TEST(BasicInterpreter, TestCustomErrorReporter) { ASSERT_EQ(reporter.calls, 1); } +// Test fixture that allows playing with execution plans. It creates a two +// node graph that can be executed in either [0,1] order or [1,0] order. +// The CopyOp records when it is invoked in the class member run_order_ +// so we can test whether the execution plan was honored. +class TestExecutionPlan : public ::testing::Test { + // Encapsulates the node ids and provides them to a C primitive data type + // Allocatable with placement new, but never destructed, so make sure this + // doesn't own any heap allocated data. This is then is used as op local + // data to allow access to the test fixture data. + class CallReporting { + public: + CallReporting(int node_id, std::vector* run_order) + : node_id_(node_id), run_order_(run_order) {} + + void Record() { run_order_->push_back(node_id_); } + + private: + // The node id for this particular node + int node_id_; + // A pointer to the global run-order + std::vector* run_order_; + }; + + // Build a kernel registration for an op that copies its one input + // to an output + TfLiteRegistration CopyOpRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + // Set output size to input size + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + return context->ResizeTensor(context, tensor1, newSize); + }; + + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + CallReporting* call_reporting = + reinterpret_cast(node->builtin_data); + // Copy input data to output data. + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + a1->data.f[i] = a0->data.f[i]; + } + call_reporting->Record(); + return kTfLiteOk; + }; + return reg; + } + + // Adds a copy node going from tensor `input` to output tensor `output`. + // Note, input is used as the node_id. Inject run_order as op accessible + // data. Note: this is a little strange of a way to do this, but it is + // using op functionality to avoid static global variables. + void MakeCopyNode(int input, int output) { + // Ownership of call_reporting is taken by interpreter (malloc is used due + // to nodes being a C99 interface so free() is used). + TfLiteRegistration copy_op = CopyOpRegistration(); + CallReporting* call_reporting_1 = + reinterpret_cast(malloc(sizeof(CallReporting))); + new (call_reporting_1) CallReporting(input, &run_order_); + ASSERT_EQ(interpreter_.AddNodeWithParameters( + {0}, {2}, nullptr, 0, + reinterpret_cast(call_reporting_1), ©_op), + kTfLiteOk); + ASSERT_EQ(interpreter_.ResizeInputTensor(input, {3}), kTfLiteOk); + } + + void SetUp() final { + // Add two inputs and two outputs that don't depend on each other + ASSERT_EQ(interpreter_.AddTensors(4), kTfLiteOk); + interpreter_.SetInputs({0, 1}); + interpreter_.SetOutputs({2, 3}); + TfLiteQuantizationParams quantized; + for (int tensor_index = 0; tensor_index < 4; tensor_index++) { + ASSERT_EQ(interpreter_.SetTensorParametersReadWrite( + tensor_index, kTfLiteFloat32, "", {3}, quantized), + kTfLiteOk); + } + + // Define two copy functions that also use the user_data to report that + // they were called. + // i.e. tensor[2] = copy(tensor[0]); tensor[3] = copy(tensor[1]); + // thus we can reorder the two nodes arbitrary and still satisfy dependency + // order. + MakeCopyNode(0, 2); + MakeCopyNode(1, 3); + + ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk); + } + + protected: + Interpreter interpreter_; + + // list of node_ids that were run + std::vector run_order_; +}; + +TEST_F(TestExecutionPlan, DefaultExecutionPlan) { + // Check default order + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector({0, 1})); +} + +TEST_F(TestExecutionPlan, ReversedExecutionPlan) { + // Check reversed order + interpreter_.SetExecutionPlan({1, 0}); + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector({1, 0})); +} + +TEST_F(TestExecutionPlan, SubsetExecutionPlan) { + // Check running only node index 1 + interpreter_.SetExecutionPlan({1}); + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector({1})); +} + +TEST_F(TestExecutionPlan, NullExecutionPlan) { + // Check nothing executed. + interpreter_.SetExecutionPlan({}); + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector()); +} + +// Build a kernel registration for an op that copies its one input +// to an output +TfLiteRegistration AddOpRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + reg.custom_name = "my_add"; + reg.builtin_code = tflite::BuiltinOperator_CUSTOM; + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + // Set output size to input size + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* tensor2 = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + TfLiteIntArray* newSizeOther = TfLiteIntArrayCopy(tensor1->dims); + TF_LITE_ENSURE_EQ(context, newSize->size, newSizeOther->size); + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor2, newSize)); + return kTfLiteOk; + }; + + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + // Copy input data to output data. + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* out = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + out->data.f[i] = a0->data.f[i] + a1->data.f[i]; + } + return kTfLiteOk; + }; + return reg; +} + +class TestDelegate : public ::testing::Test { + public: + TestDelegate() { + interpreter_.AddTensors(5); + interpreter_.SetInputs({0, 1}); + interpreter_.SetOutputs({3, 4}); + TfLiteQuantizationParams quant; + interpreter_.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quant); + interpreter_.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quant); + interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, + quant); + interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, + quant); + TfLiteRegistration reg = AddOpRegistration(); + interpreter_.AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); + interpreter_.AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); + interpreter_.AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®); + } + + protected: + class SimpleDelegate { + public: + // Create a simple implementation of a TfLiteDelegate. We use the C++ class + // SimpleDelegate and it can produce a handle TfLiteDelegate that is + // value-copyable and compatible with TfLite. + explicit SimpleDelegate(const std::vector& nodes) : nodes_(nodes) { + delegate_.Prepare = [](TfLiteContext* context, + void* data) -> TfLiteStatus { + auto* simple = reinterpret_cast(data); + TfLiteIntArray* nodes_to_separate = + TfLiteIntArrayCreate(simple->nodes_.size()); + // Mark nodes that we want in TfLiteIntArray* structure. + int index = 0; + for (auto node_index : simple->nodes_) { + nodes_to_separate->data[index++] = node_index; + // make sure node is add + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM); + TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0); + } + // Check that all nodes are available + TfLiteIntArray* execution_plan; + TF_LITE_ENSURE_STATUS( + context->GetExecutionPlan(context, &execution_plan)); + for (int exec_index = 0; exec_index < execution_plan->size; + exec_index++) { + int node_index = execution_plan->data[exec_index]; + // Check that we are an identity map to start. + TFLITE_CHECK_EQ(exec_index, node_index); + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM); + TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0); + } + + context->ReplaceSubgraphsWithDelegateKernels( + context, FakeFusedRegistration(), nodes_to_separate); + TfLiteIntArrayFree(nodes_to_separate); + return kTfLiteOk; + }; + // Store type-punned data SimpleDelegate structure. + delegate_.data_ = reinterpret_cast(this); + } + + static TfLiteRegistration FakeFusedRegistration() { + TfLiteRegistration reg = {nullptr}; + reg.custom_name = "fake_fused_op"; + return reg; + } + + TfLiteDelegate* get_tf_lite_delegate() { return &delegate_; } + + private: + std::vector nodes_; + TfLiteDelegate delegate_; + }; + Interpreter interpreter_; +}; + +TEST_F(TestDelegate, BasicDelegate) { + interpreter_.Invoke(); + SimpleDelegate simple({0, 1, 2}); + interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate()); + + ASSERT_EQ(interpreter_.execution_plan().size(), 1); + int node = interpreter_.execution_plan()[0]; + const auto* node_and_reg = interpreter_.node_and_registration(node); + ASSERT_EQ(node_and_reg->second.custom_name, + SimpleDelegate::FakeFusedRegistration().custom_name); +} + +TEST_F(TestDelegate, ComplexDeligate) { + interpreter_.Invoke(); + SimpleDelegate simple({1, 2}); + interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate()); + + ASSERT_EQ(interpreter_.execution_plan().size(), 2); + // 0th should be a non-delegated original op + ASSERT_EQ(interpreter_.execution_plan()[0], 0); + // 1st should be a new macro op (3) which didn't exist) + ASSERT_EQ(interpreter_.execution_plan()[1], 3); + const auto* node_and_reg = interpreter_.node_and_registration(3); + ASSERT_EQ(node_and_reg->second.custom_name, + SimpleDelegate::FakeFusedRegistration().custom_name); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { -#ifdef OS_LINUX - FLAGS_logtostderr = true; -#endif + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/ios_makefile.inc index 26cfe6c3e286ed603c2183986c697562e846889c..fc6594c3a04ba6aabba99bb631f85737baf389f1 100644 --- a/tensorflow/contrib/lite/ios_makefile.inc +++ b/tensorflow/contrib/lite/ios_makefile.inc @@ -22,6 +22,7 @@ ifeq ($(TARGET), IOS) IOS_ARCH := x86_64 CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -DTFLITE_USE_APPLE_ACCELERATE_FOR_CONV \ -fembed-bitcode \ -Wno-c++11-narrowing \ -mno-thumb \ @@ -42,6 +43,7 @@ ifeq ($(TARGET), IOS) -O3 LDFLAGS := -fembed-bitcode \ -miphoneos-version-min=${MIN_SDK_VERSION} \ + -framework Accelerate \ -arch $(IOS_ARCH) OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/ LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/ diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 9a1a888b93ff981b1d14faa7e847e80be1f167f2..35aacb70002d1d454f675484e4398bcdffc4acf1 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -111,6 +111,26 @@ java_test( ], ) +# TODO: generate large models at runtime, instead of storing them. +java_test( + name = "InterpreterTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/InterpreterTest.java"], + data = [ + "src/testdata/add.bin", + "src/testdata/mobilenet.tflite.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.InterpreterTest", + visibility = ["//visibility:private"], + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + java_test( name = "TensorTest", size = "small", diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt new file mode 100644 index 0000000000000000000000000000000000000000..572eccf90087c1c19874e40b950c1610f59cc9c2 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt @@ -0,0 +1,1001 @@ +dummy +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt rename to tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java index 74737a8b883d23684220dd32bbd7a9e8ab4b2123..d2048b41b1e76fe42c919c9b889df5be8a94957f 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -296,7 +296,8 @@ public class Camera2BasicFragment extends Fragment public void onActivityCreated(Bundle savedInstanceState) { super.onActivityCreated(savedInstanceState); try { - classifier = new ImageClassifier(getActivity()); + // create either a new ImageClassifierQuantizedMobileNet or an ImageClassifierFloatInception + classifier = new ImageClassifierQuantizedMobileNet(getActivity()); } catch (IOException e) { Log.e(TAG, "Failed to initialize an image classifier."); } @@ -659,7 +660,7 @@ public class Camera2BasicFragment extends Fragment return; } Bitmap bitmap = - textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y); + textureView.getBitmap(classifier.getImageSizeX(), classifier.getImageSizeY()); String textToShow = classifier.classifyFrame(bitmap); bitmap.recycle(); showToast(textToShow); diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java index e44c5ae6b48eda187079dd3a0a1bc563276d816e..c319bff9f11546ac4d49c6b34c5ecdbc41547d58 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -20,6 +20,9 @@ import android.content.res.AssetFileDescriptor; import android.graphics.Bitmap; import android.os.SystemClock; import android.util.Log; + +import org.tensorflow.lite.Interpreter; + import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; @@ -34,20 +37,15 @@ import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.PriorityQueue; -import org.tensorflow.lite.Interpreter; -/** Classifies images with Tensorflow Lite. */ -public class ImageClassifier { +/** + * Classifies images with Tensorflow Lite. + */ +public abstract class ImageClassifier { /** Tag for the {@link Log}. */ private static final String TAG = "TfLiteCameraDemo"; - /** Name of the model file stored in Assets. */ - private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite"; - - /** Name of the label file stored in Assets. */ - private static final String LABEL_PATH = "labels.txt"; - /** Number of results to show in the UI. */ private static final int RESULTS_TO_SHOW = 3; @@ -56,23 +54,18 @@ public class ImageClassifier { private static final int DIM_PIXEL_SIZE = 3; - static final int DIM_IMG_SIZE_X = 224; - static final int DIM_IMG_SIZE_Y = 224; - /* Preallocated buffers for storing image data in. */ - private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y]; + private int[] intValues = new int[getImageSizeX() * getImageSizeY()]; /** An instance of the driver class to run model inference with Tensorflow Lite. */ - private Interpreter tflite; + protected Interpreter tflite; /** Labels corresponding to the output of the vision model. */ private List labelList; /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */ - private ByteBuffer imgData = null; + protected ByteBuffer imgData = null; - /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ - private byte[][] labelProbArray = null; /** multi-stage low pass filter * */ private float[][] filterLabelProbArray = null; @@ -95,10 +88,10 @@ public class ImageClassifier { labelList = loadLabelList(activity); imgData = ByteBuffer.allocateDirect( - DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE); + DIM_BATCH_SIZE * getImageSizeX() * getImageSizeY() * DIM_PIXEL_SIZE * + getNumBytesPerChannel()); imgData.order(ByteOrder.nativeOrder()); - labelProbArray = new byte[1][labelList.size()]; - filterLabelProbArray = new float[FILTER_STAGES][labelList.size()]; + filterLabelProbArray = new float[FILTER_STAGES][getNumLabels()]; Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); } @@ -111,7 +104,7 @@ public class ImageClassifier { convertBitmapToByteBuffer(bitmap); // Here's where the magic happens!!! long startTime = SystemClock.uptimeMillis(); - tflite.run(imgData, labelProbArray); + runInference(); long endTime = SystemClock.uptimeMillis(); Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime)); @@ -125,12 +118,12 @@ public class ImageClassifier { } void applyFilter() { - int numLabels = labelList.size(); + int numLabels = getNumLabels(); // Low pass filter `labelProbArray` into the first stage of the filter. for (int j = 0; j < numLabels; ++j) { filterLabelProbArray[0][j] += - FILTER_FACTOR * (labelProbArray[0][j] - filterLabelProbArray[0][j]); + FILTER_FACTOR * (getProbability(j) - filterLabelProbArray[0][j]); } // Low pass filter each stage into the next. for (int i = 1; i < FILTER_STAGES; ++i) { @@ -142,7 +135,7 @@ public class ImageClassifier { // Copy the last stage filter output back to `labelProbArray`. for (int j = 0; j < numLabels; ++j) { - labelProbArray[0][j] = (byte)filterLabelProbArray[FILTER_STAGES - 1][j]; + setProbability(j, filterLabelProbArray[FILTER_STAGES - 1][j]); } } @@ -156,7 +149,7 @@ public class ImageClassifier { private List loadLabelList(Activity activity) throws IOException { List labelList = new ArrayList(); BufferedReader reader = - new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH))); + new BufferedReader(new InputStreamReader(activity.getAssets().open(getLabelPath()))); String line; while ((line = reader.readLine()) != null) { labelList.add(line); @@ -167,7 +160,7 @@ public class ImageClassifier { /** Memory-map the model file in Assets. */ private MappedByteBuffer loadModelFile(Activity activity) throws IOException { - AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH); + AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath()); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); @@ -185,12 +178,10 @@ public class ImageClassifier { // Convert the image to floating point. int pixel = 0; long startTime = SystemClock.uptimeMillis(); - for (int i = 0; i < DIM_IMG_SIZE_X; ++i) { - for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) { + for (int i = 0; i < getImageSizeX(); ++i) { + for (int j = 0; j < getImageSizeY(); ++j) { final int val = intValues[pixel++]; - imgData.put((byte) ((val >> 16) & 0xFF)); - imgData.put((byte) ((val >> 8) & 0xFF)); - imgData.put((byte) (val & 0xFF)); + addPixelValue(val); } } long endTime = SystemClock.uptimeMillis(); @@ -199,9 +190,9 @@ public class ImageClassifier { /** Prints top-K labels, to be shown in UI as the results. */ private String printTopKLabels() { - for (int i = 0; i < labelList.size(); ++i) { + for (int i = 0; i < getNumLabels(); ++i) { sortedLabels.add( - new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f)); + new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i))); if (sortedLabels.size() > RESULTS_TO_SHOW) { sortedLabels.poll(); } @@ -214,4 +205,80 @@ public class ImageClassifier { } return textToShow; } + + /** + * Get the name of the model file stored in Assets. + * @return + */ + protected abstract String getModelPath(); + + /** + * Get the name of the label file stored in Assets. + * @return + */ + protected abstract String getLabelPath(); + + /** + * Get the image size along the x axis. + * @return + */ + protected abstract int getImageSizeX(); + + /** + * Get the image size along the y axis. + * @return + */ + protected abstract int getImageSizeY(); + + /** + * Get the number of bytes that is used to store a single color channel value. + * @return + */ + protected abstract int getNumBytesPerChannel(); + + /** + * Add pixelValue to byteBuffer. + * @param pixelValue + */ + protected abstract void addPixelValue(int pixelValue); + + /** + * Read the probability value for the specified label + * This is either the original value as it was read from the net's output or the updated value + * after the filter was applied. + * @param labelIndex + * @return + */ + protected abstract float getProbability(int labelIndex); + + /** + * Set the probability value for the specified label. + * @param labelIndex + * @param value + */ + protected abstract void setProbability(int labelIndex, Number value); + + /** + * Get the normalized probability value for the specified label. + * This is the final value as it will be shown to the user. + * @return + */ + protected abstract float getNormalizedProbability(int labelIndex); + + /** + * Run inference using the prepared input in {@link #imgData}. + * Afterwards, the result will be provided by getProbability(). + * + * This additional method is necessary, because we don't have a common base for different + * primitive data types. + */ + protected abstract void runInference(); + + /** + * Get the total number of labels. + * @return + */ + protected int getNumLabels() { + return labelList.size(); + } } diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java new file mode 100644 index 0000000000000000000000000000000000000000..be17b85e0cd93778fd123663595c43b730fb44f7 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; + +import java.io.IOException; + +/** + * This classifier works with the Inception-v3 slim model. + * It applies floating point inference rather than using a quantized model. + */ +public class ImageClassifierFloatInception extends ImageClassifier { + + /** + * The inception net requires additional normalization of the used input. + */ + private static final int IMAGE_MEAN = 128; + private static final float IMAGE_STD = 128.0f; + + /** + * An array to hold inference results, to be feed into Tensorflow Lite as outputs. + * This isn't part of the super class, because we need a primitive array here. + */ + private float[][] labelProbArray = null; + + /** + * Initializes an {@code ImageClassifier}. + * + * @param activity + */ + ImageClassifierFloatInception(Activity activity) throws IOException { + super(activity); + labelProbArray = new float[1][getNumLabels()]; + } + + @Override + protected String getModelPath() { + // you can download this file from + // https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip + return "inceptionv3_slim_2016.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_imagenet_slim.txt"; + } + + @Override + protected int getImageSizeX() { + return 299; + } + + @Override + protected int getImageSizeY() { + return 299; + } + + @Override + protected int getNumBytesPerChannel() { + // a 32bit float value requires 4 bytes + return 4; + } + + @Override + protected void addPixelValue(int pixelValue) { + imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + } + + @Override + protected float getProbability(int labelIndex) { + return labelProbArray[0][labelIndex]; + } + + @Override + protected void setProbability(int labelIndex, Number value) { + labelProbArray[0][labelIndex] = value.floatValue(); + } + + @Override + protected float getNormalizedProbability(int labelIndex) { + // TODO the following value isn't in [0,1] yet, but may be greater. Why? + return getProbability(labelIndex); + } + + @Override + protected void runInference() { + tflite.run(imgData, labelProbArray); + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..156c895146940adfe71f111be6e354e02b75ea48 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; + +import java.io.IOException; + +/** + * This classifier works with the quantized MobileNet model. + */ +public class ImageClassifierQuantizedMobileNet extends ImageClassifier { + + /** + * An array to hold inference results, to be feed into Tensorflow Lite as outputs. + * This isn't part of the super class, because we need a primitive array here. + */ + private byte[][] labelProbArray = null; + + /** + * Initializes an {@code ImageClassifier}. + * + * @param activity + */ + ImageClassifierQuantizedMobileNet(Activity activity) throws IOException { + super(activity); + labelProbArray = new byte[1][getNumLabels()]; + } + + @Override + protected String getModelPath() { + // you can download this file from + // https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip + return "mobilenet_quant_v1_224.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_mobilenet_quant_v1_224.txt"; + } + + @Override + protected int getImageSizeX() { + return 224; + } + + @Override + protected int getImageSizeY() { + return 224; + } + + @Override + protected int getNumBytesPerChannel() { + // the quantized model uses a single byte only + return 1; + } + + @Override + protected void addPixelValue(int pixelValue) { + imgData.put((byte) ((pixelValue >> 16) & 0xFF)); + imgData.put((byte) ((pixelValue >> 8) & 0xFF)); + imgData.put((byte) (pixelValue & 0xFF)); + } + + @Override + protected float getProbability(int labelIndex) { + return labelProbArray[0][labelIndex]; + } + + @Override + protected void setProbability(int labelIndex, Number value) { + labelProbArray[0][labelIndex] = value.byteValue(); + } + + @Override + protected float getNormalizedProbability(int labelIndex) { + return (labelProbArray[0][labelIndex] & 0xff) / 255.0f; + } + + @Override + protected void runInference() { + tflite.run(imgData, labelProbArray); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index f3f51b668f068ffcd02862a79b72dbae31d31c02..c346f9f92e360c0722ebac440d790da6441ceecf 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -200,6 +200,12 @@ TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, return kTfLiteOk; } +// TODO(yichengfan): evaluate the benefit to use tflite verifier. +bool VerifyModel(const void* buf, size_t len) { + flatbuffers::Verifier verifier(static_cast(buf), len); + return tflite::VerifyModelBuffer(verifier); +} + } // namespace JNIEXPORT jobjectArray JNICALL @@ -271,6 +277,17 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( convertLongToErrorReporter(env, error_handle); if (error_reporter == nullptr) return 0; const char* path = env->GetStringUTFChars(model_file, nullptr); + + { + tflite::FileCopyAllocation allocation(path, nullptr); + if (!VerifyModel(allocation.base(), allocation.bytes())) { + throwException(env, kIllegalArgumentException, + "Contents of %s is not a valid flatbuffer model", path); + env->ReleaseStringUTFChars(model_file, path); + return 0; + } + } + auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter); if (!model) { throwException(env, kIllegalArgumentException, @@ -293,6 +310,12 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( const char* buf = static_cast(env->GetDirectBufferAddress(model_buffer)); jlong capacity = env->GetDirectBufferCapacity(model_buffer); + if (!VerifyModel(buf, capacity)) { + throwException(env, kIllegalArgumentException, + "MappedByteBuffer is not a valid flatbuffer model"); + return 0; + } + auto model = tflite::FlatBufferModel::BuildFromBuffer( buf, static_cast(capacity), error_reporter); if (!model) { diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 473f73816fd3c0a414a2c2e232dec299579fcbb6..90323555d88419d837a76bca7de6d9998e388fca 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -60,9 +60,7 @@ public final class NativeInterpreterWrapperTest { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); fail(); } catch (IllegalArgumentException e) { - assertThat(e) - .hasMessageThat() - .contains("Model provided has model identifier ' is ', should be 'TFL3'"); + assertThat(e).hasMessageThat().contains("is not a valid flatbuffer model"); } } diff --git a/tensorflow/contrib/lite/kernels/Android.bp b/tensorflow/contrib/lite/kernels/Android.bp index a171eba48d2a5314b3c9b3c553f89725e79d83a8..f2a51c0270862833c1f66894370e311e9c82f094 100644 --- a/tensorflow/contrib/lite/kernels/Android.bp +++ b/tensorflow/contrib/lite/kernels/Android.bp @@ -23,6 +23,9 @@ cc_library_static { "internal/reference/portable_tensor_utils.cc", "internal/optimized/neon_tensor_utils.cc", ], + header_libs: [ + "gemmlowp_headers", + ], cflags: [ "-Wno-extern-c-compat", ] @@ -36,12 +39,14 @@ cc_library_static { "add.cc", "basic_rnn.cc", "batch_to_space_nd.cc", + "bidirectional_sequence_rnn.cc", "concatenation.cc", "conv.cc", "depthwise_conv.cc", "div.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", + "exp.cc", "fully_connected.cc", "gather.cc", "hashtable_lookup.cc", @@ -60,13 +65,16 @@ cc_library_static { "skip_gram.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "split.cc", "squeeze.cc", "strided_slice.cc", "sub.cc", "svdf.cc", + "topk_v2.cc", "transpose.cc", "unidirectional_sequence_lstm.cc", "unidirectional_sequence_rnn.cc", + "internal/kernel_utils.cc", "internal/tensor_utils.cc", "internal/quantization_util.cc", "internal/reference/portable_tensor_utils.cc", diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 4195e7553c48028d56e80db0d204ef5656be874d..b59dc5ffb339caade28626d1954d41bc821fae41 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -71,6 +71,32 @@ cc_library( ], ) +cc_library( + name = "kernel_util", + srcs = [ + "kernel_util.cc", + ], + hdrs = [ + "kernel_util.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/kernels/internal:round", + ], +) + +tf_cc_test( + name = "kernel_util_test", + size = "small", + srcs = ["kernel_util_test.cc"], + deps = [ + ":kernel_util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "builtin_ops", srcs = [ @@ -78,16 +104,17 @@ cc_library( "add.cc", "basic_rnn.cc", "batch_to_space_nd.cc", + "bidirectional_sequence_rnn.cc", "concatenation.cc", "conv.cc", "depthwise_conv.cc", "div.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", + "exp.cc", "fully_connected.cc", "gather.cc", "hashtable_lookup.cc", - "kernel_util.cc", "l2norm.cc", "local_response_norm.cc", "lsh_projection.cc", @@ -102,16 +129,17 @@ cc_library( "skip_gram.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "split.cc", "squeeze.cc", "strided_slice.cc", "sub.cc", "svdf.cc", + "topk_v2.cc", "transpose.cc", "unidirectional_sequence_lstm.cc", "unidirectional_sequence_rnn.cc", ], hdrs = [ - "kernel_util.h", "padding.h", "register.h", ], @@ -125,11 +153,13 @@ cc_library( }), deps = [ ":activation_functor", + ":kernel_util", ":op_macros", "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels/internal:kernel_utils", "//tensorflow/contrib/lite/kernels/internal:optimized", "//tensorflow/contrib/lite/kernels/internal:optimized_base", "//tensorflow/contrib/lite/kernels/internal:quantization_util", @@ -223,6 +253,7 @@ tf_cc_test( ":builtin_ops", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], ) @@ -263,6 +294,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "bidirectional_sequence_rnn_test", + size = "small", + srcs = ["bidirectional_sequence_rnn_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "unidirectional_sequence_rnn_test", size = "small", @@ -287,6 +330,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "exp_test", + size = "small", + srcs = ["exp_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "mean_test", size = "small", @@ -348,6 +403,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "topk_v2_test", + size = "small", + srcs = ["topk_v2_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "resize_bilinear_test", size = "small", @@ -507,6 +575,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "split_test", + size = "small", + srcs = ["split_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "squeeze_test", size = "small", diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 8ac93bc8c8dcfc66d3822e01b6f9b29a3e49c446..3c5c77815d0f2592ab549152b4d77f45b967a660 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -15,8 +15,8 @@ limitations under the License. #include #include #include -#include #include +#include #include #include @@ -134,8 +134,7 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { float* out = output->data.f; for (; in < in_end; in++, out++) *out = std::max(0.f, *in); return kTfLiteOk; - } - break; + } break; default: context->ReportError(context, "Only float32 supported currently."); return kTfLiteError; @@ -173,8 +172,7 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { float* out = output->data.f; for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f); return kTfLiteOk; - } - break; + } break; default: context->ReportError(context, "Only float32 supported currently."); return kTfLiteError; @@ -192,8 +190,7 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { float* out = output->data.f; for (; in < in_end; in++, out++) *out = std::tanh(*in); return kTfLiteOk; - } - break; + } break; default: context->ReportError(context, "Only float32 supported currently."); return kTfLiteError; diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index 0e10a249abac3ba19cf107e055aa71d1eee00122..63ea89df56bafa995950afec3a58267681af304f 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -37,7 +37,23 @@ constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; constexpr int kOutputTensor = 0; +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -45,43 +61,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); - for (int i = 0; i < NumDimensions(input1); ++i) { - TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), - SizeOfDimension(input2, i)); - } + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + output->type = input2->type; - TF_LITE_ENSURE_EQ(context, input1->type, output->type); - TF_LITE_ENSURE_EQ(context, input2->type, output->type); + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } - TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); return context->ResizeTensor(context, output, output_size); } template void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteAddParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); -#define TF_LITE_ADD(type) \ - type::Add(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops); +#define TF_LITE_ADD(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd); } else { - TF_LITE_ADD(optimized_ops); + TF_LITE_ADD(reference_ops, Add); + } + } else { + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd); + } else { + TF_LITE_ADD(optimized_ops, Add); + } } #undef TF_LITE_ADD } template void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteAddParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { auto input1_offset = -input1->params.zero_point; auto input2_offset = -input2->params.zero_point; auto output_offset = output->params.zero_point; @@ -112,19 +141,20 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeUint8(params->activation, output, &output_activation_min, &output_activation_max); -#define TF_LITE_ADD(type) \ - type::BroadcastAdd( \ - left_shift, GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, input1_multiplier, input1_shift, \ - GetTensorData(input2), GetTensorDims(input2), input2_offset, \ - input2_multiplier, input2_shift, output_offset, output_multiplier, \ - output_shift, output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); - +#define TF_LITE_ADD(type, opname) \ + type::opname(left_shift, GetTensorData(input1), \ + GetTensorDims(input1), input1_offset, input1_multiplier, \ + input1_shift, GetTensorData(input2), \ + GetTensorDims(input2), input2_offset, input2_multiplier, \ + input2_shift, output_offset, output_multiplier, output_shift, \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops); + TF_LITE_ADD(reference_ops, BroadcastAdd); } else { - TF_LITE_ADD(optimized_ops); + TF_LITE_ADD(optimized_ops, BroadcastAdd); } #undef TF_LITE_ADD } @@ -132,15 +162,17 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { - EvalAddFloat(context, node, params, input1, input2, output); + EvalAddFloat(context, node, params, data, input1, input2, + output); } else if (output->type == kTfLiteUInt8) { - EvalAddQuantized(context, node, params, input1, input2, + EvalAddQuantized(context, node, params, data, input1, input2, output); } else { context->ReportError(context, @@ -154,19 +186,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace add TfLiteRegistration* Register_ADD_REF() { - static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval}; return &r; } TfLiteRegistration* Register_ADD_GENERIC_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval}; return &r; } TfLiteRegistration* Register_ADD_NEON_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc index 306dfc3e803d3df34061767ba9ced032299bfa26..956d05bed5162f6ce59705d59aad77ff056dda77 100644 --- a/tensorflow/contrib/lite/kernels/add_test.cc +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -25,10 +25,11 @@ using ::testing::ElementsAreArray; class BaseAddOpModel : public SingleOpModel { public: - BaseAddOpModel(const TensorData& input, const TensorData& output, + BaseAddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, ActivationFunctionType activation_type) { - input1_ = AddInput(input); - input2_ = AddInput(input); + input1_ = AddInput(input1); + input2_ = AddInput(input2); output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, CreateAddOptions(builder_, activation_type).Union()); @@ -70,6 +71,7 @@ float GetTolerance(int min, int max) { TEST(FloatAddOpModel, NoActivation) { FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); @@ -78,9 +80,9 @@ TEST(FloatAddOpModel, NoActivation) { } TEST(FloatAddOpModel, ActivationRELU_N1_TO_1) { - FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {}}, - ActivationFunctionType_RELU_N1_TO_1); + FloatAddOpModel m( + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU_N1_TO_1); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); m.Invoke(); @@ -92,6 +94,7 @@ TEST(FloatAddOpModel, VariousInputShapes) { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, test_shapes[i]}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); @@ -102,6 +105,23 @@ TEST(FloatAddOpModel, VariousInputShapes) { } } +TEST(FloatAddOpModel, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, // always a scalar + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}))) + << "With shape number " << i; + } +} + TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { @@ -112,6 +132,7 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; for (int i = 0; i < inputs1.size(); ++i) { QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}, ActivationFunctionType_NONE); m.QuantizeAndPopulate(m.input1(), inputs1[i]); @@ -133,6 +154,7 @@ TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { {-0.2, 0.6, -0.1, 0.8}}; for (int i = 0; i < inputs1.size(); ++i) { QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}, ActivationFunctionType_RELU_N1_TO_1); m.QuantizeAndPopulate(m.input1(), inputs1[i]); @@ -150,6 +172,7 @@ TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, test_shapes[i], -3.0, 3.0}, {TensorType_UINT8, {}, -3.0, 3.0}, ActivationFunctionType_NONE); m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); @@ -162,6 +185,25 @@ TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { } } +TEST(QuantizedAddOpModel, QuantizedWithBroadcast) { + float kQuantizedTolerance = GetTolerance(-3.0, 3.0); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}, + kQuantizedTolerance))) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index 3cee43c68b2a0af5a3fd84b33a980b74bb8f0cb4..2c5074eca3176c7f33a6f051b492dc41333257ed 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -15,14 +15,15 @@ limitations under the License. #include #include #include -#include #include +#include #include #include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -76,8 +77,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); output_size_array->data[0] = batch_size; output_size_array->data[1] = num_units; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, - output_size_array)); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); return kTfLiteOk; } @@ -101,50 +102,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int batch_size = input->dims->data[0]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[1]; - const int input_weights_stride = input_weights->dims->data[1]; - const int recurrent_weights_stride = recurrent_weights->dims->data[1]; - - // For each batch - for (int b = 0; b < batch_size; b++) { - // Initialize the pointer to input, output and bias. - const float* input_ptr_batch = input->data.f + b * input_size; - float* output_ptr_batch = output->data.f + b * num_units; - float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; - - // Initialize input_weights and recurrent_weights. - const float* input_weights_ptr = input_weights->data.f; - const float* recurrent_weights_ptr = recurrent_weights->data.f; - - // Output = bias - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = bias_ptr[o]; - } - - // Output += input * input_weights - for (int o = 0; o < num_units; o++) { - for (int i = 0; i < input_size; i++) { - output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; - } - input_weights_ptr += input_weights_stride; - } - - // Output += recurrent_weights * hidden_state - for (int o = 0; o < num_units; o++) { - for (int h = 0; h < num_units; h++) { - output_ptr_batch[o] += - hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; - } - recurrent_weights_ptr += recurrent_weights_stride; - } - - // Output = activation(Output) and update hidden_state - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = - (ActivationFunctor(params->activation))(output_ptr_batch[o]); - hidden_state_ptr_batch[o] = output_ptr_batch[o]; - } - } + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Initialize the pointer to input and output. + const float* input_ptr_batch = input->data.f; + float* output_ptr_batch = output->data.f; + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, + recurrent_weights_ptr, bias_ptr, input_size, + num_units, batch_size, params->activation, + hidden_state_ptr_batch, output_ptr_batch); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc index 5ecccb985e91238f1183c8f94a2b5f468758ce55..fa7ef525db47c93f98951604cd04da66196422d7 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite RNN op. -#include #include +#include #include #include @@ -120,8 +120,7 @@ static float rnn_golden_output[] = { 0.415153, 0.210318, 0, 0, 0, 0, 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, - 0.628881, 3.58099, 1.49974, 0 -}; + 0.628881, 3.58099, 1.49974, 0}; class RNNOpModel : public SingleOpModel { public: diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc index 0eed680fdcc2afc4bc72be55a5e7722310fa4538..bc438f99c6a72fdbc2794dee03524db6a7523834 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -35,12 +35,14 @@ enum KernelType { struct BatchToSpaceNDContext { BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); + block_shape = GetInput(context, node, 1); + crops = GetInput(context, node, 2); output = GetOutput(context, node, 0); } - TfLiteBatchToSpaceNDParams* params; TfLiteTensor* input; + TfLiteTensor* block_shape; + TfLiteTensor* crops; TfLiteTensor* output; }; @@ -48,23 +50,28 @@ struct BatchToSpaceNDContext { // The 4D array need to have exactly 2 spatial dimensions. // TODO(ycling): Support arbitrary dimension in BatchToSpaceND. const int kInputDimensionNum = 4; -const int kOutputDimensionNum = 4; +const int kBlockSizeDimensionNum = 1; const int kSpatialDimensionNum = 2; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - // The 2nd tensor (block_shape) and the 3rd tensor (crops) are ignored now. - TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + BatchToSpaceNDContext* op_context) { + TfLiteIntArray* input_size = op_context->input->dims; + const int* block_shape = GetTensorData(op_context->block_shape); + const int* crops = GetTensorData(op_context->crops); - BatchToSpaceNDContext op_context(context, node); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), - kInputDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions, + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), + kBlockSizeDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0], + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), kSpatialDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - const TfLiteIntArray* input_size = op_context.input->dims; - const int* block_shape = op_context.params->block_shape; + // TODO(ycling): Add crops as part of calculation. Remove check for a crops + // containing all zeroes. + TF_LITE_ENSURE_EQ(context, crops[0], 0); + TF_LITE_ENSURE_EQ(context, crops[1], 0); + TF_LITE_ENSURE_EQ(context, crops[2], 0); + TF_LITE_ENSURE_EQ(context, crops[3], 0); // Number of batch must be multiple of (block_shape[0] * block_shape[1]). TF_LITE_ENSURE_EQ(context, @@ -76,27 +83,47 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int output_width = input_size->data[2] * block_shape[1]; const int output_channel_size = input_size->data[3]; - TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum); + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); output_size->data[0] = output_batch_size; output_size->data[1] = output_height; output_size->data[2] = output_width; output_size->data[3] = output_channel_size; - return context->ResizeTensor(context, op_context.output, output_size); + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + BatchToSpaceNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + if (!IsConstantTensor(op_context.block_shape) || + !IsConstantTensor(op_context.crops)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { BatchToSpaceNDContext op_context(context, node); - int block_shape_dims_array[1] = {kSpatialDimensionNum}; - Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } -#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ - type::BatchToSpaceND(GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), \ - op_context.params->block_shape, block_shape_dims, \ - GetTensorData(op_context.output), \ +#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ + type::BatchToSpaceND(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + GetTensorData(op_context.block_shape), \ + GetTensorDims(op_context.block_shape), \ + GetTensorData(op_context.output), \ GetTensorDims(op_context.output)) switch (op_context.input->type) { // Already know in/out types are same. case kTfLiteFloat32: diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc index 3ec4efbebcef9d55d0042d93007018c9f6ee3b58..8485cde1b40066f2070855bca91ea78a9f80e83c 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc @@ -26,36 +26,76 @@ using ::testing::ElementsAreArray; class BatchToSpaceNDOpModel : public SingleOpModel { public: - BatchToSpaceNDOpModel(std::initializer_list input_shape, - std::initializer_list block_shape, - std::initializer_list before_crops, - std::initializer_list after_crops) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, - BuiltinOptions_BatchToSpaceNDOptions, - CreateBatchToSpaceNDOptions( - builder_, builder_.CreateVector(block_shape), - builder_.CreateVector(before_crops), - builder_.CreateVector(after_crops)) - .Union()); - BuildInterpreter({input_shape}); - } - void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetBlockShape(std::initializer_list data) { + PopulateTensor(block_shape_, data); + } + + void SetCrops(std::initializer_list data) { + PopulateTensor(crops_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; + int block_shape_; + int crops_; int output_; }; -TEST(BatchToSpaceNDOpTest, SimpleTest) { - BatchToSpaceNDOpModel m({4, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}); +// Tests case where block_shape and crops are const tensors. +// +// Example usage is as follows: +// BatchToSpaceNDOpConstModel m(input_shape, block_shape, crops); +// m.SetInput(input_data); +// m.Invoke(); +class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel { + public: + BatchToSpaceNDOpConstModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list crops) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + crops_ = AddConstInput(TensorType_INT32, crops, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Tests case where block_shape and crops are non-const tensors. +// +// Example usage is as follows: +// BatchToSpaceNDOpDynamicModel m(input_shape); +// m.SetInput(input_data); +// m.SetBlockShape(block_shape); +// m.SetPaddings(crops); +// m.Invoke(); +class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel { + public: + BatchToSpaceNDOpDynamicModel(std::initializer_list input_shape) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddInput(TensorType_INT32); + crops_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions(builder_).Union()); + BuildInterpreter({input_shape, {2}, {2, 2}}); + } +}; + +TEST(BatchToSpaceNDOpTest, SimpleConstTest) { + BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); @@ -63,11 +103,35 @@ TEST(BatchToSpaceNDOpTest, SimpleTest) { 4, 8, 11, 15, 12, 16})); } +TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) { + BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetCrops({0, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, + 4, 8, 11, 15, 12, 16})); +} + TEST(BatchToSpaceNDOpTest, InvalidShapeTest) { - EXPECT_DEATH(BatchToSpaceNDOpModel({3, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}), + EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}), "Cannot allocate tensors"); } +TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) { + EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 1}), + "1 != 0"); +} + +TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) { + BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetCrops({0, 0, 1, 0}); + EXPECT_DEATH(m.Invoke(), "1 != 0"); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa24c1f34cd1e8c02a6a75b62fbe5f3c629498ca --- /dev/null +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -0,0 +1,205 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace bidirectional_sequence_rnn { + +constexpr int kInputTensor = 0; +// Forward and backward cell tensors. +constexpr int kFwWeightsTensor = 1; +constexpr int kFwRecurrentWeightsTensor = 2; +constexpr int kFwBiasTensor = 3; +constexpr int kBwWeightsTensor = 4; +constexpr int kBwRecurrentWeightsTensor = 5; +constexpr int kBwBiasTensor = 6; +// State and output tensors. +constexpr int kFwHiddenStateTensor = 0; +constexpr int kFwOutputTensor = 1; +constexpr int kBwHiddenStateTensor = 2; +constexpr int kBwOutputTensor = 3; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 7); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* fw_input_weights = + &context->tensors[node->inputs->data[kFwWeightsTensor]]; + TfLiteTensor* fw_recurrent_weights = + &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]]; + TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]]; + TfLiteTensor* bw_input_weights = + &context->tensors[node->inputs->data[kBwWeightsTensor]]; + TfLiteTensor* bw_recurrent_weights = + &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]]; + TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int batch_size = input->dims->data[0]; + const int max_time = input->dims->data[1]; + const int fw_num_units = fw_input_weights->dims->data[0]; + const int bw_num_units = bw_input_weights->dims->data[0]; + TF_LITE_ASSERT_EQ(input->dims->data[2], fw_input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(input->dims->data[2], bw_input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(fw_input_weights->dims->data[0], fw_bias->dims->data[0]); + TF_LITE_ASSERT_EQ(bw_input_weights->dims->data[0], bw_bias->dims->data[0]); + TF_LITE_ASSERT_EQ(fw_recurrent_weights->dims->data[0], + fw_bias->dims->data[0]); + TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1], + bw_bias->dims->data[0]); + + TfLiteTensor* fw_output = + &context->tensors[node->outputs->data[kFwOutputTensor]]; + TfLiteTensor* bw_output = + &context->tensors[node->outputs->data[kBwOutputTensor]]; + + // Resize hidden states. + TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2); + fw_hidden_state_size_array->data[0] = batch_size; + fw_hidden_state_size_array->data[1] = fw_num_units; + TfLiteTensor* fw_hidden_state = + &context->tensors[node->outputs->data[kFwHiddenStateTensor]]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state, + fw_hidden_state_size_array)); + + TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2); + bw_hidden_state_size_array->data[0] = batch_size; + bw_hidden_state_size_array->data[1] = fw_num_units; + TfLiteTensor* bw_hidden_state = + &context->tensors[node->outputs->data[kBwHiddenStateTensor]]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state, + bw_hidden_state_size_array)); + + // Mark hidden states as a persistent tensor. + fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent; + bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize outputs. + TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); + fw_output_size_array->data[0] = batch_size; + fw_output_size_array->data[1] = max_time; + fw_output_size_array->data[2] = fw_num_units; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, fw_output, fw_output_size_array)); + TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3); + bw_output_size_array->data[0] = batch_size; + bw_output_size_array->data[1] = max_time; + bw_output_size_array->data[2] = bw_num_units; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, bw_output, bw_output_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* fw_input_weights = + &context->tensors[node->inputs->data[kFwWeightsTensor]]; + TfLiteTensor* fw_recurrent_weights = + &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]]; + TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]]; + TfLiteTensor* fw_hidden_state = + &context->tensors[node->outputs->data[kFwHiddenStateTensor]]; + TfLiteTensor* fw_output = + &context->tensors[node->outputs->data[kFwOutputTensor]]; + + TfLiteTensor* bw_input_weights = + &context->tensors[node->inputs->data[kBwWeightsTensor]]; + TfLiteTensor* bw_recurrent_weights = + &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]]; + TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]]; + TfLiteTensor* bw_hidden_state = + &context->tensors[node->outputs->data[kBwHiddenStateTensor]]; + TfLiteTensor* bw_output = + &context->tensors[node->outputs->data[kBwOutputTensor]]; + + const int batch_size = input->dims->data[0]; + const int max_time = input->dims->data[1]; + const int input_size = input->dims->data[2]; + + const int fw_num_units = fw_input_weights->dims->data[0]; + const float* fw_bias_ptr = fw_bias->data.f; + const float* fw_input_weights_ptr = fw_input_weights->data.f; + const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f; + + const int bw_num_units = bw_input_weights->dims->data[0]; + const float* bw_bias_ptr = bw_bias->data.f; + const float* bw_input_weights_ptr = bw_input_weights->data.f; + const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f; + + for (int b = 0; b < batch_size; b++) { + // Forward cell. + float* fw_hidden_state_ptr_batch = + fw_hidden_state->data.f + b * fw_num_units; + for (int s = 0; s < max_time; s++) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units; + + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr, + fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1, + params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); + } + // Backward cell. + float* bw_hidden_state_ptr_batch = + bw_hidden_state->data.f + b * bw_num_units; + for (int s = max_time - 1; s >= 0; s--) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units; + + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr, + bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1, + params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); + } + } + return kTfLiteOk; +} + +} // namespace bidirectional_sequence_rnn + +TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + bidirectional_sequence_rnn::Prepare, + bidirectional_sequence_rnn::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..12f4ff97cfd90e3a6894a24d15fcbc356f96cde2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -0,0 +1,931 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Bidirectional RNN op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_fw_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0}; + +static float rnn_golden_bw_output[] = { + 0.496726, 0, 1.00883, 0, 0.0584256, 0, 0, + 0.236412, 0, 0, 0.612267, 0.487726, 0, 0.54883, + 1.16099, 0.0291233, 0, 0, 0.428302, 0, 0, + 0, 0, 1.13262, 0, 1.64415, 0, 0.311249, + 0.570804, 0.259696, 0, 0, 0, 0, 0, + 0.262334, 0, 0, 0, 1.23781, 0, 2.86532, + 0, 0, 1.34389, 2.76409, 0, 0, 1.03969, + 0, 0.00410865, 0, 0.0470295, 0, 0, 0, + 0.371556, 0.27175, 1.36614, 1.63956, 0.683887, 1.06176, 0.719552, + 0.301314, 0.971195, 0, 0.697143, 0, 0.215219, 0.210693, + 0.363027, 0, 0.501283, 0, 1.13399, 0.623774, 0, + 1.09851, 1.33313, 0.470441, 0.210965, 0, 0.664178, 0, + 0.839686, 0, 0, 0.147834, 0, 0, 0, + 0.58786, 0.490128, 0, 0.905806, 0, 0.932134, 0.424257, + 0, 0, 0.860629, 0, 0, 0, 0.476425, + 0, 0.566017, 0.513721, 0.207341, 1.09508, 1.08385, 0, + 0.973787, 0, 0, 0, 0, 0, 0, + 1.20698, 0, 0, 0, 1.56135, 1.12369, 0.99588, + 0.459803, 0, 0.915854, 0, 0, 0, 0, + 0, 0, 2.03206, 0, 0.773264, 0.267228, 1.55012, + 1.202, 1.51611, 0.701202, 0, 0.725088, 0, 0.509069, + 0, 0.671349, 0.581129, 0.343447, 0, 0.107755, 0.611838, + 1.4331, 1.55871, 0.015242, 0.140624, 0.492562, 0.395095, 0.147722, + 0, 0.784925, 0, 1.65477, 0.715257, 0, 0, + 0, 0.685024, 0, 1.89505, 1.00037, 0, 0, + 0, 0, 0, 1.52659, 0, 0, 0, + 0, 0.618583, 0, 0.11115, 0, 1.37194, 0.630225, + 0, 0, 0, 0, 0, 0.0322124, 0, + 0, 0, 0, 0.430834, 0.252786, 0, 0, + 0, 0.991297, 1.98451, 0, 0, 0.111511, 0, + 1.05513, 0, 0, 0, 0, 0, 0, + 0.290445, 0.412559, 0.0429958, 0.256564, 1.27858, 0.289948, 0, + 1.01693, 0.327141, 0, 0, 0, 0, 0, + 1.83508, 0.346248, 0, 0.961535, 0.790026, 0.552203, 2.13457, + 2.19233, 0.333244, 0.316526, 0.179398, 0, 0, 0, + 0, 0, 1.86126, 0, 0.728256, 0.750013, 0.011861, + 0.576383, 3.38891, 1.29273, 0}; + +constexpr std::initializer_list weights = { + 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}; + +static float endtoend_input[] = { + 0.996808, 0.060710, 0.981855, 0.570017, 0.525164, 0.796859, 0.696547, + 0.505925, 0.991844, 0.461208, 0.949371, 0.027624, 0.539236, 0.841854, + 0.915222, 0.538569, 0.069375, 0.237905, 0.903700, 0.441703, 0.536196, + 0.402724, 0.761635, 0.025063, 0.082592, 0.688245, 0.239310, 0.256931, + 0.658900, 0.105695, 0.301983, 0.655708, 0.166405, 0.283837, 0.225725, + 0.691569, 0.080696, 0.922272, 0.197494, 0.072540, 0.383481, 0.146865, + 0.100163, 0.922717, 0.988720, 0.015386, 0.461286, 0.058095, 0.253290, + 0.364986, 0.499797, 0.789487, 0.767709, 0.261433, 0.814549, 0.850302, + 0.949678, 0.053859, 0.107233, 0.608577, 0.159554, 0.409215, 0.264285, + 0.325960, 0.693053, 0.490011, 0.017529, 0.773749, 0.412283, 0.215023, + 0.846288, 0.795764, 0.361889, 0.946452, 0.718481, 0.350608, 0.961837, + 0.179767, 0.408703, 0.215128, 0.544753, 0.908500, 0.004614, 0.312462, + 0.169933, 0.819163, 0.162764, 0.119611, 0.873022, 0.269997, 0.728188, + 0.032576, 0.679212, 0.992474, 0.358536, 0.372265, 0.482484, 0.376065, + 0.146014, 0.894767, 0.591088, 0.992302, 0.690531, 0.952977, 0.938754, + 0.409012, 0.303585, 0.900591, 0.588780, 0.712287, 0.115719, 0.133533, + 0.620788, 0.120334, 0.445995, 0.790720, 0.939497, 0.608759, 0.910331, + 0.812519, 0.878756, 0.638519, 0.845096, 0.557968, 0.630993, 0.203632, + 0.930233, 0.113477, 0.579697, 0.076247, 0.008244, 0.170785, 0.068549, + 0.698776, 0.123761, 0.007303, 0.107788, 0.427346, 0.907894, 0.696568, + 0.139633, 0.023613, 0.830100, 0.760421, 0.143947, 0.276096, 0.551141, + 0.083444, 0.884855, 0.461472, 0.895963, 0.763611, 0.099992, 0.741059, + 0.321579, 0.730984, 0.944691, 0.251812, 0.844461, 0.524388, 0.328059, + 0.852706, 0.695172, 0.396607, 0.551482, 0.818934, 0.403910, 0.659270, + 0.246280, 0.311804, 0.355838, 0.385913, 0.335418, 0.185938, 0.146334, + 0.479364, 0.462034, 0.697475, 0.562808, 0.346888, 0.158948, 0.458771, + 0.110499, 0.258939, 0.199830, 0.432078, 0.989924, 0.144521, 0.683890, + 0.834385, 0.668908, 0.011949, 0.687091, 0.364081, 0.408556, 0.238572, + 0.183015, 0.812466, 0.897842, 0.429294, 0.124271, 0.253680, 0.815207, + 0.459688, 0.439618, 0.961541, 0.939053, 0.901651, 0.659016, 0.501861, + 0.248539, 0.817964, 0.960632, 0.359038, 0.076903, 0.160462, 0.791117, + 0.066826, 0.304983, 0.475007, 0.901211, 0.973891, 0.486955, 0.588302, + 0.337972, 0.895512, 0.826874, 0.520987, 0.707978, 0.724716, 0.950281, + 0.832249, 0.978396, 0.765488, 0.291937, 0.418014, 0.727029, 0.230990, + 0.319665, 0.386045, 0.732850, 0.568204, 0.204009, 0.693482, 0.927242, + 0.280912, 0.853944, 0.718359, 0.347738, 0.158927, 0.193366, 0.248950, + 0.132818, 0.680321, 0.837252, 0.470790, 0.575833, 0.664126, 0.991777, + 0.283811, 0.388843, 0.942058, 0.116060, 0.367239, 0.707546, 0.407997, + 0.785253, 0.434575, 0.638986, 0.104917, 0.820620, 0.371837, 0.673121, + 0.024629, 0.065319, 0.600363, 0.305541, 0.919263, 0.318722, 0.653279, + 0.078190, 0.512088, 0.902229, 0.211009, 0.192409, 0.739480, 0.681799, + 0.768242, 0.403607, 0.673576, 0.052052, 0.792450, 0.615634, 0.168112, + 0.159689, 0.323180, 0.576109, 0.944941, 0.757755, 0.215095, 0.049858, + 0.578375, 0.586932, 0.722979, 0.603003, 0.652251, 0.323343, 0.908544, + 0.571514, 0.642065, 0.561823, 0.649704, 0.154153, 0.464051, 0.860713, + 0.346562, 0.203532, 0.542512, 0.114804, 0.607139, 0.216088, 0.166856, + 0.399588, 0.831722, 0.334968, 0.559277, 0.154902, 0.911077, 0.504218, + 0.912656, 0.126172, 0.554076, 0.491031, 0.713104, 0.277055, 0.094034, + 0.365355, 0.600398, 0.002578, 0.936869, 0.242463, 0.564401, 0.586574, + 0.396616, 0.028452, 0.447287, 0.743178, 0.231984, 0.989799, 0.857982, + 0.839122, 0.205887, 0.024838, 0.238711, 0.037608, 0.359806, 0.797987, + 0.192510, 0.270883, 0.302205, 0.105166, 0.397055, 0.856281, 0.596197, + 0.110160, 0.133336, 0.690231, 0.475515, 0.733734, 0.692809, 0.412384, + 0.976196, 0.257209, 0.998958, 0.372812, 0.285661, 0.446245, 0.115990, + 0.517645, 0.436044, 0.973972, 0.356767, 0.641930, 0.998810, 0.595478, + 0.679539, 0.358617, 0.393465, 0.872049, 0.629500, 0.695670, 0.977215, + 0.026555, 0.551951, 0.573412, 0.136715, 0.685287, 0.263643, 0.612229, + 0.419020, 0.956451, 0.024613, 0.395216, 0.213661, 0.023572, 0.768029, + 0.499322, 0.469816, 0.884019, 0.016967, 0.905860, 0.857991, 0.373734, + 0.547791, 0.856802, 0.969211, 0.227330, 0.215418, 0.362676, 0.099378, + 0.844918, 0.058346, 0.076594, 0.871473, 0.610297, 0.650006, 0.008188, + 0.295583, 0.913648, 0.620417, 0.714603, 0.870100, 0.645031, 0.109820, + 0.083760, 0.668602, 0.877849, 0.583082, 0.138419, 0.761868, 0.600049, + 0.044279, 0.619859, 0.973783, 0.592069, 0.476661, 0.942994, 0.819399, + 0.692079, 0.305670, 0.918778, 0.536997, 0.364016, 0.995371, 0.408470, + 0.974313, 0.645377, 0.416658, 0.269896, 0.559025, 0.037075, 0.984499, + 0.429125, 0.682105, 0.094319, 0.512885, 0.350707, 0.972168, 0.095967, + 0.489126, 0.734035, 0.696016, 0.533405, 0.353894, 0.669799, 0.125474, + 0.830555, 0.612793, 0.944873, 0.522634, 0.918463, 0.863651, 0.059631, + 0.282479, 0.859022, 0.468101, 0.256791, 0.504398, 0.884758, 0.526687, + 0.063423, 0.921833, 0.511186, 0.492548, 0.603939, 0.605505, 0.005433, + 0.954646, 0.577673, 0.101400, 0.443772, 0.311708, 0.797417, 0.977176, + 0.665602, 0.467216, 0.102650, 0.496157, 0.080009, 0.047524, 0.018791, + 0.998471, 0.911174, 0.078422, 0.280950, 0.770196, 0.546523, 0.537741, + 0.274594, 0.431281, 0.064428, 0.338017, 0.353115, 0.575615, 0.830565, + 0.957053, 0.181120, 0.835998, 0.911699, 0.758793, 0.937398, 0.355471, + 0.070501, 0.734815, 0.332647, 0.736103, 0.202031, 0.435297, 0.232261, + 0.282039, 0.482821, 0.251052, 0.280511, 0.393995, 0.329474, 0.561460, + 0.164191, 0.875997, 0.099202, 0.438785, 0.307278, 0.163630, 0.776802, + 0.660393, 0.739244, 0.607367, 0.617446, 0.920364, 0.443365, 0.529145, + 0.679157, 0.380763, 0.884616, 0.749658, 0.115578, 0.217263, 0.485761, + 0.317609, 0.652560, 0.718021, 0.599648, 0.135381, 0.969073, 0.880159, + 0.529376, 0.298547, 0.441619, 0.693567, 0.174544, 0.540821, 0.132351, + 0.481822, 0.704450, 0.909153, 0.142215, 0.443695, 0.516520, 0.759661, + 0.364059, 0.959885, 0.288806, 0.043216, 0.340648, 0.173422, 0.792874, + 0.456226, 0.390685, 0.278634, 0.773834, 0.043245, 0.996656, 0.373483, + 0.178625, 0.965729, 0.253641, 0.708001, 0.264276, 0.695260, 0.401568, + 0.438820, 0.236081, 0.533919, 0.920642, 0.940531, 0.443072, 0.062857, + 0.384226, 0.959592, 0.822518, 0.748285, 0.919477, 0.111325, 0.791501, + 0.260124, 0.284747, 0.584375, 0.716350, 0.675431, 0.863009, 0.490184, + 0.718676, 0.859665, 0.863666, 0.897301, 0.825393, 0.117308, 0.605302, + 0.089669, 0.812568, 0.006870, 0.528489, 0.048649, 0.540788, 0.449131, + 0.989180, 0.983860, 0.511988, 0.373407, 0.943452, 0.334506, 0.121692, + 0.862929, 0.445831, 0.913193, 0.123053, 0.730578, 0.497568, 0.839402, + 0.406009, 0.360577, 0.329586, 0.124685, 0.220241, 0.193253, 0.021986, + 0.045634, 0.310560, 0.627288, 0.135303, 0.123128, 0.634158, 0.663792, + 0.171777, 0.174946, 0.112923, 0.160958, 0.158806, 0.624911, 0.534364, + 0.102259, 0.959418, 0.656056, 0.965187, 0.405249, 0.569249, 0.088240, + 0.135827, 0.066817, 0.927642, 0.541836, 0.427393, 0.257229, 0.666520, + 0.647634, 0.450481, 0.688506, 0.693269, 0.761042, 0.315794, 0.828572, + 0.884170, 0.949952, 0.492364, 0.055947, 0.124898, 0.605288, 0.216905, + 0.283705, 0.230199, 0.751269, 0.385963, 0.189616, 0.407326, 0.351151, + 0.594865, 0.976575, 0.439391, 0.730692, 0.043392, 0.367033, 0.272527, + 0.470785, 0.624261, 0.939048, 0.118419, 0.074743, 0.627554, 0.811688, + 0.835784, 0.943348, 0.640260, 0.719954, 0.893300, 0.132625, 0.775901, + 0.018199, 0.737913, 0.992806, 0.301903, 0.968111, 0.744076, 0.687867, + 0.157728, 0.151401, 0.039017, 0.752593, 0.127976, 0.478408, 0.483284, + 0.171368, 0.845441, 0.755811, 0.642153, 0.469702, 0.694859, 0.760572, + 0.544445, 0.322413, 0.572260, 0.380229, 0.265761, 0.212521, 0.100183, + 0.159062, 0.345146, 0.876084, 0.177261, 0.083058, 0.868891, 0.479164, + 0.051169, 0.612966, 0.167030, 0.208897, 0.764367, 0.206048, 0.961490, + 0.892343, 0.684456, 0.444774, 0.063711, 0.529896, 0.200585, 0.705863, + 0.999598, 0.895444, 0.466435, 0.544043, 0.217857, 0.038696, 0.924272, + 0.483618, 0.251217, 0.024455, 0.642680, 0.596362, 0.900539, 0.819941, + 0.679420, 0.769430, 0.299105, 0.730590, 0.382396, 0.466135, 0.939487, + 0.146763, 0.672183, 0.900977, 0.039106, 0.356638, 0.345750, 0.102817, + 0.886535, 0.546336, 0.808681, 0.886133, 0.441780, 0.275116, 0.430176, + 0.659637, 0.313812, 0.354448, 0.143255, 0.565028, 0.378903, 0.785935, + 0.161391, 0.279443, 0.605876, 0.840811, 0.048873, 0.904980, 0.571401, + 0.431269, 0.371115, 0.510887, 0.578032, 0.043298, 0.411864, 0.617138, + 0.399936, 0.757614, 0.719955, 0.286471, 0.303950, 0.528636, 0.172604, + 0.745730, 0.803752, 0.602780, 0.405367, 0.117564, 0.957228, 0.548622, + 0.682592, 0.336131, 0.334557, 0.843983, 0.615574, 0.940433, 0.684794, + 0.664447, 0.845413, 0.256194, 0.095715, 0.216529, 0.767082, 0.673747, + 0.259827, 0.178946, 0.290885, 0.659763, 0.936560, 0.010840, 0.946234, + 0.240510, 0.539476, 0.118838, 0.986240, 0.343228, 0.721618, 0.391606, + 0.460792, 0.678846, 0.940228, 0.143384, 0.014977, 0.274785, 0.987367, + 0.630551, 0.215218, 0.672161, 0.294998, 0.060631, 0.928355, 0.390713, + 0.277160, 0.695436, 0.064460, 0.536987, 0.874382, 0.355345, 0.196751, + 0.810942, 0.366185, 0.142985, 0.051452, 0.905661, 0.261823, 0.037691, + 0.248889, 0.983441, 0.429297, 0.709681, 0.662286, 0.369525, 0.853066, + 0.677263, 0.644310, 0.840433, 0.307814, 0.859528, 0.512593, 0.602812, + 0.920160, 0.440948, 0.993525, 0.197320, 0.136384, 0.057984, 0.734307, + 0.010766, 0.413329, 0.931058, 0.821707, 0.779514, 0.074043, 0.873159, + 0.685175, 0.335865, 0.910850, 0.934065, 0.319306, 0.340147, 0.643746, + 0.981592, 0.709673, 0.496812, 0.658856, 0.353983, 0.337245, 0.966670, + 0.213511, 0.849838, 0.569482, 0.133671, 0.290786, 0.563007, 0.330991, + 0.427170, 0.620991, 0.065299, 0.437936, 0.034320, 0.996356, 0.259643, + 0.813834, 0.070399, 0.132802, 0.499009, 0.406265, 0.043652, 0.433074, + 0.725570, 0.383800, 0.076820, 0.707163, 0.093473, 0.573632, 0.366018, + 0.447456, 0.910877, 0.332688, 0.660967, 0.760714, 0.902170, 0.794638, + 0.051500, 0.465177, 0.125630, 0.478670, 0.086168, 0.190928, 0.916605, + 0.120488, 0.187285, 0.176248, 0.934322, 0.257684, 0.309050, 0.433331, + 0.663949, 0.352703, 0.866405, 0.389519, 0.736502, 0.943226, 0.096682, + 0.829975, 0.516858, 0.462700, 0.277430, 0.427734, 0.795388, 0.938398, + 0.188449, 0.697558, 0.733036, 0.239948, 0.162735, 0.858666, 0.718618, + 0.248903, 0.049594, 0.635223, 0.369391, 0.236879, 0.811472, 0.303713, + 0.494563, 0.120522, 0.737044, 0.158511, 0.473225, 0.603450, 0.548030, + 0.209727, 0.546675, 0.644712, 0.039702, 0.063533, 0.107412, 0.317132, + 0.491267, 0.902800, 0.255530, 0.679716, 0.600359, 0.988566, 0.919664, + 0.763094, 0.847232, 0.638283, 0.011997, 0.896825, 0.273506, 0.381388, + 0.133704, 0.084978, 0.685101, 0.628267, 0.205500, 0.422145, 0.786778, + 0.678725, 0.025595, 0.334808, 0.888452, 0.572271, 0.979520, 0.928154, + 0.635804, 0.086932, 0.245286, 0.127071, 0.989732, 0.500816, 0.806787, + 0.590091, 0.489382, 0.726451, 0.353185, 0.336614, 0.364734, 0.365182, + 0.233439, 0.638240, 0.746570, 0.367143, 0.723218, 0.431671, 0.995410, + 0.928718, 0.853816, 0.782188, 0.607442, 0.879411, 0.116995, 0.495894, + 0.451682, 0.096515, 0.424048, 0.087485, 0.183447, 0.669334, 0.214556, + 0.173179, 0.170151, 0.021343, 0.763269, 0.659533, 0.747794, 0.116454, + 0.996147, 0.112528, 0.481635, 0.229586, 0.750768, 0.228205, 0.596730, + 0.473985, 0.659876, 0.592139, 0.402703, 0.513692, 0.374327, 0.010145, + 0.393103, 0.491322, 0.506039, 0.844785, 0.587837, 0.930088, 0.932270, + 0.771284, 0.599422, 0.146826, 0.944463, 0.769573, 0.168169, 0.707732, + 0.429106, 0.915964, 0.824186, 0.425253, 0.028492, 0.305821, 0.654839, + 0.779259, 0.534026, 0.251569, 0.253245, 0.193901, 0.843708, 0.655947, + 0.707593, 0.218035, 0.666093, 0.100696, 0.709357, 0.172132, 0.945481, + 0.297195, 0.102220, 0.877751, 0.068479, 0.701642, 0.024577, 0.012941, + 0.471215, 0.192747, 0.720673, 0.900321, 0.108710, 0.544859, 0.325574, + 0.137202, 0.850679, 0.980413, 0.916462, 0.384705, 0.231982, 0.169706, + 0.578607, 0.075690, 0.825654, 0.286200, 0.293725, 0.491746, 0.386896, + 0.003083, 0.663878, 0.332377, 0.300278, 0.766098, 0.210128, 0.368756, + 0.467740, 0.234705, 0.381697, 0.938955, 0.427451, 0.102370, 0.839275, + 0.536162, 0.647229, 0.164849, 0.673364, 0.497908, 0.145262, 0.589825, + 0.882613, 0.377244, 0.759532, 0.461220, 0.452934, 0.585185, 0.747420, + 0.746660, 0.076932, 0.134316, 0.749743, 0.740810, 0.466692, 0.050020, + 0.506908, 0.676820, 0.418776, 0.974648, 0.911525, 0.800474, 0.913602, + 0.338976, 0.902844, 0.752878, 0.875138, 0.550072, 0.917727, 0.548502, + 0.047981, 0.062989, 0.138327, 0.930594, 0.440233, 0.897859, 0.391814, + 0.893168, 0.483044, 0.139234, 0.639828, 0.559975, 0.273549, 0.389570, + 0.300785, 0.740242, 0.439590, 0.807693, 0.417062, 0.858367, 0.782341, + 0.328586, 0.658840, 0.695943, 0.667562, 0.561684, 0.448821, 0.542700, + 0.111756, 0.366548, 0.091202, 0.159737, 0.429537, 0.229529, 0.090331, + 0.869770, 0.127388, 0.482145, 0.762938, 0.610432, 0.621379, 0.402765, + 0.170407, 0.894928, 0.792336, 0.471192, 0.635170, 0.231926, 0.278886, + 0.052232, 0.090293, 0.061226, 0.380818, 0.749133, 0.757170, 0.048380, + 0.310817, 0.205990, 0.591080, 0.422573, 0.572538, 0.682282, 0.582310, + 0.002075, 0.911812, 0.672641, 0.871845, 0.039199, 0.154786, 0.634783, + 0.649631, 0.776165, 0.037548, 0.820038, 0.671093, 0.829884, 0.291231, + 0.306263, 0.061810, 0.570116, 0.358495, 0.152103, 0.631343, 0.739313, + 0.901236, 0.388512, 0.787693, 0.212053, 0.594503, 0.378773, 0.634626, + 0.167040, 0.061056, 0.216937, 0.169115, 0.972867, 0.889578, 0.040960, + 0.012067, 0.044364, 0.675743, 0.661698, 0.820529, 0.713291, 0.481736, + 0.491623, 0.543175, 0.772966, 0.797886, 0.604985, 0.343083, 0.156380, + 0.757088, 0.974425, 0.895693, 0.658324, 0.362938, 0.683386, 0.870376, + 0.957440, 0.062159, 0.505002, 0.124481, 0.123215, 0.721939, 0.293596, + 0.096082, 0.611517, 0.334556, 0.108149, 0.655881, 0.010299, 0.769846, + 0.476411, 0.723590, 0.251582, 0.968033, 0.266765, 0.024548, 0.765919, + 0.871750, 0.367631, 0.922299, 0.628838, 0.342056, 0.817992, 0.287162, + 0.704994, 0.501378, 0.157538, 0.662434, 0.563537, 0.662541, 0.786915, + 0.686752, 0.384480, 0.080511, 0.782834, 0.995997, 0.415067, 0.890983, + 0.651878, 0.425365, 0.660829, 0.128289, 0.148956, 0.912411, 0.096322, + 0.415721, 0.936959, 0.862241, 0.287471, 0.304590, 0.784540, 0.916309, + 0.646646, 0.602533, 0.203471, 0.351640, 0.103911, 0.361009, 0.014074, + 0.667448, 0.023550, 0.800989, 0.354200, 0.408030, 0.881500, 0.137034, + 0.404026, 0.296566, 0.028017, 0.055904, 0.721932, 0.688846, 0.184193, + 0.870887, 0.601257, 0.280515, 0.286608, 0.538216, 0.142755, 0.574079, + 0.842806, 0.927296, 0.490388, 0.489452, 0.529828, 0.693859, 0.841092, + 0.633739, 0.054869, 0.855167, 0.301187, 0.078419, 0.656156, 0.655388, + 0.486448, 0.537656, 0.792422, 0.890475, 0.834222, 0.820439, 0.946379, + 0.556153, 0.509285, 0.130571, 0.427041, 0.110542, 0.411086, 0.713648, + 0.648758, 0.553842, 0.287727, 0.491563, 0.481137, 0.778116, 0.981015, + 0.010966, 0.471975, 0.822107, 0.644705, 0.526844, 0.677274, 0.945892, + 0.605263, 0.333430, 0.601280, 0.091711, 0.871086, 0.393702, 0.982186, + 0.705307, 0.214141, 0.928564, 0.261461, 0.723426, 0.059136, 0.688501, + 0.833968, 0.470222, 0.402150, 0.482725, 0.024063, 0.689877, 0.974289, + 0.505201, 0.467993, 0.955304, 0.516166, 0.939968, 0.777411, 0.160871, + 0.466812, 0.454685, 0.106763, 0.072075, 0.788115, 0.708043, 0.163786, + 0.659201, 0.101744, 0.145971, 0.364508, 0.315885, 0.074536, 0.625969, + 0.039311, 0.133672, 0.314471, 0.873279, 0.603893, 0.716620, 0.356004, + 0.627957, 0.406498, 0.330292, 0.133157, 0.874490, 0.285596, 0.649324, + 0.814458, 0.063007, 0.810195, 0.281270, 0.517693, 0.916958, 0.353345, + 0.305808, 0.625000, 0.517131, 0.965009, 0.726745, 0.663102, 0.329518, + 0.042630, 0.737638, 0.955487, 0.081940, 0.871310, 0.269957, 0.955219, + 0.475203, 0.986578, 0.311223, 0.103160, 0.393075, 0.641515, 0.236317, + 0.267566, 0.927112, 0.885641, 0.082024, 0.990119, 0.695835, 0.363295, + 0.507812, 0.612793, 0.716640, 0.813620, 0.237793, 0.233770, 0.778629, + 0.964538, 0.896872, 0.108147, 0.007167, 0.634510, 0.063633, 0.089108, + 0.505820, 0.333591, 0.044327, 0.981023, 0.320168, 0.355550, 0.084182, + 0.713244, 0.997065, 0.320499, 0.980810, 0.924177, 0.206140, 0.062834, + 0.914296, 0.901975, 0.426129, 0.422107, 0.514768, 0.142768, 0.235727, + 0.752561, 0.376539, 0.014356, 0.717099, 0.273411, 0.122502, 0.724266, + 0.907921, 0.186136, 0.813374, 0.413741, 0.519726, 0.857701, 0.394764, + 0.839895, 0.213251, 0.478946, 0.553139, 0.210317, 0.799446, 0.533948, + 0.134493, 0.005586, 0.596782, 0.048789, 0.907561, 0.022911, 0.470896, + 0.422329, 0.165679, 0.706623, 0.174890, 0.542218, 0.720979, 0.891989, + 0.815629, 0.843481, 0.616255, 0.723551, 0.029617, 0.429630, 0.137292, + 0.549343, 0.287331, 0.532056, 0.389238, 0.500583, 0.011002, 0.942377, + 0.710899, 0.810448, 0.476326, 0.845392, 0.816033, 0.073108, 0.894181, + 0.723594, 0.096019, 0.365077, 0.145923, 0.261699, 0.071700, 0.320813, + 0.803917, 0.792679, 0.212802, 0.619546, 0.636160, 0.829057, 0.343096, + 0.665777, 0.258687, 0.480388, 0.215121, 0.546018, 0.012444, 0.604359, + 0.046601, 0.023446, 0.546736, 0.757500, 0.833893, 0.023062, 0.602892, + 0.649927, 0.096170, 0.497074, 0.373521, 0.192189, 0.862151, 0.519444, + 0.453887, 0.933851, 0.840257, 0.257804, 0.726531, 0.053058, 0.877350, + 0.362691, 0.882115, 0.220446, 0.028468, 0.140802, 0.700834, 0.243589, + 0.686821, 0.713278, 0.847948, 0.733421, 0.736723, 0.394684, 0.490921, + 0.570617, 0.417746, 0.093813, 0.220543, 0.513916, 0.590887, 0.594064, + 0.706105, 0.453038, 0.113508, 0.159992, 0.386889, 0.953765, 0.417796, + 0.113420, 0.006823, 0.295146, 0.476111, 0.888938, 0.515592, 0.504579, + 0.029741, 0.216426, 0.748168, 0.716561, 0.929703, 0.596117, 0.449982, + 0.666427, 0.990801, 0.940903, 0.237043, 0.408547, 0.034717, 0.457587, + 0.922463, 0.625603, 0.051651, 0.628568, 0.078641, 0.165159, 0.788560, + 0.465530, 0.118923, 0.206356, 0.578950, 0.125746, 0.501502, 0.055060, + 0.014685, 0.017094, 0.559640, 0.044425, 0.233519, 0.307808, 0.760986, + 0.163223, 0.903925, 0.210969, 0.829650, 0.894726, 0.151872, 0.066693, + 0.303273, 0.186589, 0.524279, 0.225736, 0.812192, 0.575930, 0.854304, + 0.890833, 0.741089, 0.642864, 0.356363, 0.860012, 0.849220, 0.935313, + 0.985758, 0.350722, 0.990373, 0.000443, 0.367815, 0.550013, 0.044868, + 0.601335, 0.857820, 0.805855, 0.764557, 0.761745, 0.016823, 0.594207, + 0.656471, 0.168696, 0.660900, 0.959744, 0.355284, 0.185179, 0.185480, + 0.167477, 0.761110, 0.039784, 0.058310, 0.502199, 0.682648, 0.414673, + 0.362211, 0.531868, 0.349985, 0.347969, 0.882589, 0.340358, 0.348412, + 0.250404, 0.890371, 0.393280, 0.851739, 0.748191, 0.199135, 0.616297, + 0.509936, 0.215958, 0.210504, 0.166407, 0.384654, 0.871404, 0.126151, + 0.739938, 0.056583, 0.311631, 0.907415, 0.817693, 0.351415, 0.965724, + 0.319891, 0.034062, 0.380397, 0.682102, 0.565930, 0.730382, 0.030072, + 0.448519, 0.070741, 0.378484, 0.698924, 0.961112, 0.771764, 0.550663, + 0.709303, 0.970899, 0.166959, 0.219239, 0.186857, 0.377463, 0.385647, + 0.571511, 0.248867, 0.511798, 0.311449, 0.305450, 0.823429, 0.218864, + 0.123142, 0.174844, 0.184588, 0.443034, 0.208906, 0.564986, 0.125136, + 0.774836, 0.295368, 0.155207, 0.223355, 0.366109, 0.533691, 0.922279, + 0.327221, 0.305455, 0.472942, 0.036524, 0.276354, 0.639901, 0.255763, + 0.463211, 0.017364, 0.641410, 0.034722, 0.266231, 0.153207, 0.346171, + 0.571680, 0.976636, 0.565036, 0.694822, 0.151480, 0.749624, 0.137856, + 0.360386, 0.314610, 0.262992, 0.135222, 0.609978, 0.418200, 0.358578, + 0.976087, 0.951891, 0.280856, 0.303307, 0.257346, 0.753798, 0.339831, + 0.533700, 0.393699, 0.595594, 0.996911, 0.411063, 0.237003, 0.031634, + 0.677294, 0.390211, 0.377805, 0.248974, 0.366847, 0.942841, 0.943796, + 0.518327, 0.692465, 0.081653, 0.878713, 0.007074, 0.344645, 0.013936, + 0.617052, 0.762845, 0.372513, 0.593138, 0.714736, 0.653370, 0.896446, + 0.972082, 0.407168, 0.236276, 0.505782, 0.800867, 0.831870, 0.502693, + 0.211930, 0.068873, 0.534327, 0.889224, 0.459084, 0.912132, 0.138197, + 0.825931, 0.854972, 0.081994, 0.344259, 0.547437, 0.163646, 0.222972, + 0.554511, 0.508291, 0.236908, 0.171563, 0.271135, 0.609421, 0.764701, + 0.985871, 0.262790, 0.661147, 0.957953, 0.669958, 0.897423, 0.463734, + 0.470825, 0.729293, 0.966427, 0.682755, 0.798166, 0.500754, 0.571978, + 0.257251, 0.412886, 0.710176, 0.083182, 0.267858, 0.792169, 0.427441, + 0.815295, 0.955815, 0.650413, 0.369805, 0.464106, 0.887320, 0.541368, + 0.735242, 0.496741, 0.306069, 0.721113, 0.759531, 0.967216, 0.679065, + 0.429489, 0.864639, 0.142799, 0.900314, 0.593932, 0.109227, 0.583069, + 0.392098, 0.609981, 0.155047, 0.649349, 0.022867, 0.865222, 0.732531, + 0.290725, 0.657392, 0.159972, 0.106019, 0.613207, 0.810384, 0.475824, + 0.077313, 0.697704, 0.017192, 0.812555}; + +static float golden_endtoend_output[] = { + -1.881211, -0.028385, -3.585066, 1.939770, -3.461155, 1.280415, -4.408978, + 0.608663, -2.704937, 1.859742, -5.777429, 2.691839, -1.049012, 1.640870, + -4.856245, 1.604236, 0.992707, 0.422858, -4.307465, 1.887332, -0.884831, + -0.154277, -2.634801, 0.586827, -1.849960, 1.399608, -4.531559, 1.943591, + 0.271676, -2.893054, -2.066826, 0.235467, -1.248263, -1.164534, -2.640174, + -0.112878, -4.386484, 1.253024, -4.135623, 1.068984, -0.043579, -0.832957, + -3.257258, -0.514396, -1.651174, 0.638630, -4.364372, 1.548441, -0.289455, + 0.539845, -4.097627, 0.635001, -0.465071, -0.927701, -2.481498, 0.356616, + -2.355012, 0.728806, -3.340283, 1.609038, -4.786268, -0.532272, -1.886150, + 0.254797, 0.746620, -1.657134, -3.264265, 0.525551, -1.756837, 0.845446, + -5.572190, 1.715797, -2.856942, 3.394245, -5.803662, 2.281806, -3.014739, + 2.616136, -4.728482, 1.659984, -2.106307, 2.711709, -6.173832, 1.352869, + -0.038035, 0.107619, -4.279774, 2.341930, -0.980413, -0.119538, -4.049717, + 1.172128, -3.477744, 2.602274, -6.231380, 2.537300, -0.862214, 0.568722, + -3.858362, 0.197867, -1.725885, 3.687312, -7.067363, 2.403544, -0.944963, + 0.235639, -3.250094, 0.659117, -1.459576, 0.426128, -3.637207, 1.030386, + -4.224351, 3.516220, -6.053367, 0.993473, -2.182416, -0.762625, -1.884405, + -0.113736, -2.572602, 0.329290, -1.913233, 0.517418, -0.019757, 0.203176, + -3.715881, 0.482136, -1.912823, 1.357907, -5.473043, 1.714658, -3.177160, + 0.089285, -3.127669, 1.268076, 0.772498, -1.622712, -3.850314, 0.436124, + -1.495983, 3.439982, -7.623405, 1.726721, -0.423979, 0.180201, -2.902406, + 0.986457, -1.845638, 0.460903, -5.359343, -1.133931, -1.074456, 0.717304, + -3.519856, 1.012126, -0.562301, 1.881967, -6.716627, 2.525036, 0.945480, + 0.337081, -5.210562, 2.572035, -0.943370, 0.442026, -2.666313, 0.411296, + 0.002787, -0.000735, -2.498933, 0.771719, -3.568153, 3.833721, -6.617026, + 2.813922, -0.573970, 1.025208, -3.909923, 1.722648, -1.406849, 0.719783, + -5.207438, 1.819442, -0.530895, -0.010887, -2.939614, 0.971225, -1.660297, + 1.345243, -4.454571, 2.244876, -2.021213, 1.756090, -4.880947, 0.364597, + -2.380270, 2.763117, -5.613013, 2.137534, 0.289101, -2.279400, -3.365582, + 0.170028, -1.142254, -0.709604, -3.656223, 1.804870, -0.854690, 0.592102, + -5.010415, 2.462687, -1.474710, 0.566002, -3.621819, -0.391946, -0.423524, + -0.631428, -3.513310, 0.962825, -1.480262, 0.319791, -3.610137, 1.842339, + -0.250073, 1.182022, -6.249267, 1.604172, 1.153759, -0.734054, -4.620415, + -0.030858, 0.050911, 1.524406, -4.724010, 1.451846, -3.277104, 2.414182, + -4.605285, 1.846092, -1.503047, -0.618200, -2.746546, -0.459332, -0.980326, + -1.199977, -2.043865, -0.165793, -2.214698, 3.108281, -7.127830, -0.123065, + 1.244948, -3.039923, -4.660061, -0.225957, -0.307210, -1.513205, -2.456005, + 0.840048, -0.741445, 2.328635, -6.015267, 2.723240, -1.381171, -0.728878, + -5.114925, -0.362034, -0.574923, 0.518080, -3.892457, 1.798948, 0.435119, + -0.371696, -2.807571, 1.302864, -2.063052, 1.036388, -4.232038, 1.397059, + -1.615668, -1.511019, -3.095508, 1.290955, -3.428723, 2.000287, -4.196487, + 1.566983, 0.196957, 0.224343, -4.926359, -0.691975, -0.214941, 1.546821, + -5.384868, 2.290820, -1.878865, 0.493692, -4.129823, 2.112036, 0.516558, + -2.553077, -2.717338, 0.017146, -2.016057, 1.628995, -4.240602, 1.189533, + -5.460220, 1.254738, -4.214903, 0.755659, -2.893235, 2.937762, -6.169453, + 2.035456, -5.613212, -0.122254, -1.973646, -0.060619, -2.119598, 1.413512, + -4.938738, 1.890244, 0.544169, -2.062413, -3.329637, -0.062515, -1.855805, + -0.791297, -2.570353, 0.607615, 0.305812, 0.338930, -4.150270, 2.274937, + 0.042653, 0.133825, -3.538155, 1.523639, -3.173690, -1.496599, -2.414655, + 0.464687, -1.448998, -0.368907, -3.520129, 0.203382, -2.443626, 1.266233, + -3.393848, 0.605911, -0.015353, 1.402006, -4.441003, 1.419281, 0.603587, + 0.434146, -4.966566, 2.171872, -0.688264, -0.009981, -4.461103, 1.538354, + -5.029816, -0.264424, -1.713510, -0.315258, -1.891606, 0.252074, -2.419428, + 0.043970, -1.291143, 2.048704, -4.590105, 0.524734, -1.889576, 0.134836, + -3.462745, 1.390663, -0.112773, 0.402735, -4.203784, 1.381043, -1.201634, + -1.968277, -1.425637, -0.181725, -1.250742, -2.102041, -3.925464, -1.256797, + -3.701354, -1.754610, -1.917231, -1.455910, -1.838006, 2.041781, -5.666212, + 2.752957, -2.659553, 2.553637, -4.872212, 1.443437, -2.081846, 3.311263, + -5.912457, 1.871049, 0.196148, -0.307044, -4.024967, 2.149149, 0.361809, + 0.620415, -5.939984, 0.180672, -1.209180, -0.269122, -3.240285, 1.460315, + -1.040803, 1.125700, -6.060366, 0.887767, -3.214111, 1.314368, -3.026808, + 1.023640, -3.815175, 1.795642, -4.355603, 1.064454, -0.046472, 0.618463, + -5.941646, 2.861891, -2.852155, -0.990457, -2.624445, 1.794494, -1.176747, + -0.358159, -3.206776, 1.138721, -2.819523, -1.825522, -1.450902, -0.187312, + -0.808727, 0.636872, -4.120567, 1.192623, 0.810731, -1.768519, -3.699450, + 1.527116, -2.772720, 3.012835, -5.912736, 1.599365, -4.696381, 2.234591, + -4.139552, 1.061768, -1.880089, 3.596274, -7.006379, 2.382152, -3.158115, + 3.844430, -7.044156, 2.307596, -2.473970, 1.312644, -5.467269, 0.197154, + -1.530040, 1.762275, -5.550757, 0.630276, -3.048947, 1.043777, -3.096658, + 1.345893, -1.329494, 2.065748, -4.711032, 2.227600, -0.413321, -0.032428, + -4.599650, 1.668734, -4.351490, -0.200022, -2.359903, 0.021997, 0.116028, + 1.159718, -5.093972, -0.142951, -2.409895, 0.906133, -2.728812, 0.809932, + -2.597363, 0.494130, -2.357861, 0.369825, -2.165235, 1.148522, -3.130562, + 0.759034, 0.646335, -1.463660, -3.508299, 1.059679, -1.485465, 1.007319, + -4.340716, 1.789864, -1.590654, 1.612324, -4.452007, 2.389805, -5.200148, + -1.068398, -1.306923, -0.472408, -0.392165, -0.524996, -2.933478, 1.518430, + -1.287781, 0.113422, -3.020525, 1.338359, -0.105982, 0.936014, -4.132197, + 1.836807, -0.616589, -1.029716, -3.271347, 0.284889, -2.653359, 2.135829, + -4.643613, 1.627981, 0.287733, -2.017263, -2.776574, 1.184792, 1.004161, + -1.483019, -4.339290, -0.787322, 0.582420, 1.137839, -5.673941, -0.001862, + -1.219142, 0.532561, -4.457245, 1.826807, -3.343291, 3.034610, -6.179855, + 2.235917, -4.369989, 4.018128, -6.632714, 0.926585, -0.485469, 0.536073, + -4.179557, 1.489637, -0.521762, 1.636089, -6.137912, 1.500867, -4.086009, + 1.961372, -3.688977, 1.358220, -1.544034, 1.763837, -4.357567, 1.852201, + -2.018725, 1.046264, -6.211127, 1.609419, -0.118441, 1.602284, -6.242423, + 1.518578, -0.604078, 1.106613, -5.393445, 2.595629, 0.142712, -1.903953, + -2.821177, 0.032758, -0.009152, 0.184628, -4.227636, 2.046843, -2.240138, + 1.256176, -5.108516, -0.308447, -2.998571, 4.657396, -7.582112, 2.510951, + -3.535784, 1.704560, -5.068484, 1.318466, -3.058265, 3.073172, -6.998089, + 3.178849, -2.420286, 2.277806, -4.999528, 1.423890, -1.672914, 0.447460, + -4.088940, 1.351087, -1.051546, -0.417955, -4.042147, 1.604102, -1.700931, + 2.796663, -6.497579, 2.857974, -0.240828, 0.858001, -5.778933, 2.778508, + -0.406211, 1.300766, -5.073671, 2.089362, -0.201673, 1.588396, -6.000150, + 2.185055, -2.332125, 0.768216, -2.609184, 0.327277, -3.358943, -1.020736, + -2.389984, 0.315512, -0.561905, 1.948740, -6.408485, 2.231985, -0.603652, + 0.661829, -5.070386, -1.063058, -0.624796, 1.375772, -4.379606, 1.929358, + -1.047263, 0.739100, -5.217857, 2.127625, -5.025338, 0.650344, -2.068460, + 0.076936, -0.457505, -1.050984, -1.917765, 1.150908, 0.782625, 0.855595, + -5.321719, 0.787209, -0.460232, 1.106736, -5.552326, 2.801043, -0.360217, + -0.434432, -4.273378, 0.967556, -0.972652, 0.874811, -5.429918, -0.331039, + 0.115477, 0.111883, -5.418786, 1.240546, -1.842794, 0.505880, -3.676064, + -0.682369, 1.858984, -0.742566, -5.784060, 0.673239, -1.280398, 0.280842, + -4.848077, 2.214860, -0.785100, -0.588488, -2.438206, 0.786651, -1.568752, + 1.935400, -6.320256, 2.125338, -1.476457, -1.651941, -2.695734, 0.007338, + -3.280860, 2.310385, -5.319578, 1.890123, -0.775723, 0.630606, -4.321582, + 1.085521, -1.847371, 1.188521, -4.596577, 2.056443, -2.340172, -0.108501, + -3.156392, 0.933279, -0.495331, 0.122405, -5.171133, 1.763245, -0.796913, + 2.310487, -7.247197, 2.401678, -1.908860, 0.043798, -2.393796, 0.573806, + -0.608531, 0.154710, -4.669001, 0.750680, 0.468380, 0.392591, -4.755001, + 2.615217, -1.957774, 1.153513, -4.530099, 1.124362, -3.569415, 1.697154, + -3.536335, 0.910758, -2.976264, 1.833129, -4.287203, -0.547050, -2.409768, + 0.061585, -1.324116, 0.268497, -2.962222, -1.524245, -2.063413, 0.442058, + -4.292337, 3.538863, -6.699603, 1.718664, -2.290363, 1.994596, -6.245037, + -0.433084, -0.367059, 1.020297, -4.940721, 2.902264, -0.577056, -0.709887, + -5.001413, -0.268316, -1.112048, -1.083307, -1.753492, 0.209973, 0.139540, + 0.917602, -5.232745, 2.538467, -2.139234, -0.187388, -1.837249, -0.478582, + -0.731653, -0.481550, -2.531261, 1.044770, 0.707750, 0.279971, -3.221119, + 1.552074, -2.373144, 0.859518, -3.665156, 1.620278, -1.440871, -0.525581, + -2.758271, 1.491873, -2.302013, 1.119935, -5.257080, 2.627170, -3.174739, + 1.363282, -4.831639, 1.101076, -4.337008, 2.689639, -5.165915, 1.069201, + -1.882078, -0.120370, -2.287967, 1.147619, -1.403616, 1.077150, -5.084296, + 1.658236, -0.919642, 0.487423, -3.001075, 0.741268, 0.107300, 0.943556, + -3.544311, 1.000239, -1.627171, 2.871253, -5.179172, 1.429893, -0.826040, + 0.188670, -4.499894, 1.013447, -2.101299, 0.317516, -3.452141, -0.833776, + -1.362144, 1.272437, -4.449355, 1.613591, -2.039873, 2.613175, -6.229640, + 1.659790, -1.595520, -0.237462, -2.744997, 0.337841, 0.148981, -1.703771, + -2.388023, 1.276469, 1.058508, -0.401642, -4.680769, 0.861881, -1.336381, + 1.153080, -2.834378, 0.721075, 0.900115, 1.360511, -5.573611, 0.949182, + -2.970844, 2.017563, -5.186108, -0.201038, -1.192824, 0.610142, -4.450919, + -0.897114, -1.812093, 0.422310, -5.245487, 0.256549, 0.320275, -2.324150, + -2.967040, -0.260536, -0.721467, 0.454148, -5.058031, 0.526370, -0.895656, + 0.732240, -3.327363, 1.353953, -1.277912, -0.483171, -1.926713, 0.065044, + -2.167506, -0.196606, -1.923437, 0.604962, -2.088319, 1.406834, -5.227296, + 2.247351, -4.421744, 1.729791, -5.007922, 1.264769, -0.897019, 0.922902, + -3.887108, 2.087432, -1.310226, -0.101938, -3.359082, -0.079662, -0.514988, + -0.963179, -4.038209, 2.223278, -0.590083, -2.310458, -1.748338, 0.363406, + -0.540731, -0.885913, -4.179595, 2.216781, -3.044339, -0.447100, -2.446098, + 0.931101, -1.676190, 2.096175, -4.980755, 2.262151, -1.095047, 1.897516, + -5.996138, 2.191038, 0.297128, -0.780974, -2.884299, 1.195408, -0.521065, + -1.955837, -3.091064, -0.404183, -1.961519, 4.076096, -7.521851, 2.242064, + -1.988043, 0.303300, -2.422585, 0.322230, -3.377634, 3.499955, -7.084434, + 2.375587, -0.718851, 2.150076, -5.412241, 2.374280, -2.006088, 2.229828, + -5.848188, 2.543077, -2.171042, 2.096026, -5.300007, 0.141405, -1.187745, + 0.105340, -4.003816, 1.034281, -3.980804, 1.856709, -5.103042, 0.623737, + -2.080307, 0.896140, -3.104050, 0.983158, -0.424898, -1.154270, -3.805728, + 1.978917, -1.314387, 1.235096, -3.148906, 1.113173, 0.111713, 2.055213, + -7.565283, 2.100342}; +constexpr std::initializer_list biases = { + 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568, + -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, + 0.37197268, 0.61957061, 0.3956964, -0.37609905}; + +constexpr std::initializer_list recurrent_weights = { + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}; + +class BidirectionalRNNOpModel : public SingleOpModel { + public: + BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, + int bw_units, int input_size) + : batches_(batches), + sequence_len_(sequence_len), + fw_units_(fw_units), + bw_units_(bw_units), + input_size_(input_size) { + input_ = AddInput(TensorType_FLOAT32); + fw_weights_ = AddInput(TensorType_FLOAT32); + fw_recurrent_weights_ = AddInput(TensorType_FLOAT32); + fw_bias_ = AddInput(TensorType_FLOAT32); + fw_hidden_state_ = AddOutput(TensorType_FLOAT32); + fw_output_ = AddOutput(TensorType_FLOAT32); + bw_weights_ = AddInput(TensorType_FLOAT32); + bw_recurrent_weights_ = AddInput(TensorType_FLOAT32); + bw_bias_ = AddInput(TensorType_FLOAT32); + bw_hidden_state_ = AddOutput(TensorType_FLOAT32); + bw_output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOptions_SequenceRNNOptions, + CreateSequenceRNNOptions(builder_, /*time_major=*/false, + ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({ + {batches_, sequence_len_, input_size_}, // input + {fw_units_, input_size_}, // fw_weights + {fw_units_, fw_units_}, // fw_recurrent_weights + {fw_units_}, // fw_bias + {bw_units_, input_size_}, // bw_weights + {bw_units_, bw_units_}, // bw_recurrent_weights + {bw_units_} // bw_bias + }); + } + + void SetFwBias(std::initializer_list f) { + PopulateTensor(fw_bias_, f); + } + + void SetBwBias(std::initializer_list f) { + PopulateTensor(bw_bias_, f); + } + + void SetFwWeights(std::initializer_list f) { + PopulateTensor(fw_weights_, f); + } + + void SetBwWeights(std::initializer_list f) { + PopulateTensor(bw_weights_, f); + } + + void SetFwRecurrentWeights(std::initializer_list f) { + PopulateTensor(fw_recurrent_weights_, f); + } + + void SetBwRecurrentWeights(std::initializer_list f) { + PopulateTensor(bw_recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenStates() { + const int fw_zero_buffer_size = fw_units_ * batches_; + std::unique_ptr fw_zero_buffer(new float[fw_zero_buffer_size]); + memset(fw_zero_buffer.get(), 0, fw_zero_buffer_size * sizeof(float)); + PopulateTensor(fw_hidden_state_, 0, fw_zero_buffer.get(), + fw_zero_buffer.get() + fw_zero_buffer_size); + const int bw_zero_buffer_size = bw_units_ * batches_; + std::unique_ptr bw_zero_buffer(new float[bw_zero_buffer_size]); + memset(bw_zero_buffer.get(), 0, bw_zero_buffer_size * sizeof(float)); + PopulateTensor(bw_hidden_state_, 0, bw_zero_buffer.get(), + bw_zero_buffer.get() + bw_zero_buffer_size); + } + + std::vector GetFwOutput() { return ExtractVector(fw_output_); } + std::vector GetBwOutput() { return ExtractVector(bw_output_); } + + int input_size() { return input_size_; } + int num_fw_units() { return fw_units_; } + int num_bw_units() { return bw_units_; } + int num_batches() { return batches_; } + int sequence_len() { return sequence_len_; } + + private: + int input_; + int fw_weights_; + int fw_recurrent_weights_; + int fw_bias_; + int fw_hidden_state_; + int fw_output_; + int bw_weights_; + int bw_recurrent_weights_; + int bw_bias_; + int bw_hidden_state_; + int bw_output_; + + int batches_; + int sequence_len_; + int fw_units_; + int bw_units_; + int input_size_; +}; + +// TODO(mirkov): add another test which directly compares to TF once TOCO +// supports the conversion from dynamic_rnn with BasicRNNCell. +TEST(BidirectionalRNNOpTest, BlackBoxTest) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + rnn.ResetHiddenStates(); + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + float* batch_start = rnn_input; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(input_sequence_size, batch_start, batch_end); + + rnn.Invoke(); + + float* golden_fw_start = rnn_golden_fw_output; + float* golden_fw_end = + golden_fw_start + rnn.num_fw_units() * rnn.sequence_len(); + std::vector fw_expected; + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); + + float* golden_bw_start = rnn_golden_bw_output; + float* golden_bw_end = + golden_bw_start + rnn.num_bw_units() * rnn.sequence_len(); + std::vector bw_expected; + bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); + bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); + EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); +} + +// Check that if the input sequence is reversed the outputs are the same just +// forward and backward are swapped (and reversed). +TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + rnn.ResetHiddenStates(); + + // Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the + // following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1]. + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + const int reverse_idx = rnn.sequence_len() - i - 1; + rnn.SetInput(reverse_idx * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((rnn.sequence_len() + reverse_idx) * rnn.input_size(), + batch_start, batch_end); + } + + rnn.Invoke(); + + // The forward and backward outputs are swapped. + std::vector fw_expected; // consider using std::deque instead. + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_fw_start = rnn_golden_bw_output + i * rnn.num_fw_units(); + float* golden_fw_end = golden_fw_start + rnn.num_fw_units(); + fw_expected.insert(fw_expected.begin(), golden_fw_start, golden_fw_end); + } + fw_expected.insert(fw_expected.end(), fw_expected.begin(), fw_expected.end()); + EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); + + std::vector bw_expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_bw_start = rnn_golden_fw_output + i * rnn.num_bw_units(); + float* golden_bw_end = golden_bw_start + rnn.num_bw_units(); + bw_expected.insert(bw_expected.begin(), golden_bw_start, golden_bw_end); + } + bw_expected.insert(bw_expected.end(), bw_expected.begin(), bw_expected.end()); + EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); +} + +// Tests an end-to-end neural network with a Bidirectional RNN followed by a +// DNN that aggregates the outputs from the two sequences. +TEST(BidirectionalRNNOpTest, EndToEndTest) { + BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8); + const int output_size = 4; + float dnn_weights[] = { + -0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139, + -0.23420811, -0.39647382, 0.31423986, 0.61819065, -0.73659575, + -0.89698344, -0.8931554, -0.0845688, 0.5617367, 0.38415289, + -0.11487955, -0.7617774, 0.17927337, 0.15726972, 0.059798479, + 0.19009054, -0.27616632, -0.39142907, 0.77744663, -0.046830714, + -0.6603595, 0.21945822, 0.051494241, 0.23785079, 0.19239247, + -0.53268754, 0.65961659, -0.85981959, -0.80232513, 0.84745562, + -0.66070104, -0.036533296, -0.54901814, 0.65353882, -0.41834265, + -0.28561389, 0.75655544, -0.31149811, 0.62981737, 0.31829214, + -0.92734522, -0.48506218, 0.55651462, 0.25192821, 0.67220747, + -0.3836869, -0.55798125, -0.60395885, 0.22488403, -0.78053463, + 0.3492105, 0.56452453, 0.4389236, -0.59929526, -0.19762468, + -0.36868393, -0.13198286, -0.53800809, -0.22850353}; + + std::initializer_list dnn_biases = { + 0.29177809, -0.98799044, 0.065919638, 0.68781924}; + + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + rnn.ResetHiddenStates(); + + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + const int output_sequence_size = output_size * rnn.sequence_len(); + const int num_examples = 64; + for (int k = 0; k < num_examples; k++) { + float* batch_start = endtoend_input + k * input_sequence_size; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + + rnn.Invoke(); + + std::vector fw_output = rnn.GetFwOutput(); + std::vector bw_output = rnn.GetBwOutput(); + EXPECT_EQ(fw_output.size(), bw_output.size()); + + std::transform(fw_output.begin(), fw_output.end(), bw_output.begin(), + fw_output.begin(), std::plus()); + + std::vector sequence_result; + for (int s = 0; s < rnn.sequence_len(); s++) { + const float* rnn_output = fw_output.data() + s * rnn.num_fw_units(); + std::vector results(dnn_biases); + for (int i = 0; i < output_size; i++) { + for (int j = 0; j < rnn.num_fw_units(); j++) { + results[i] += *(rnn_output + j) * dnn_weights[output_size * j + i]; + } + } + sequence_result.insert(sequence_result.end(), results.begin(), + results.end()); + } + + float* golden_start = golden_endtoend_output + k * output_sequence_size; + float* golden_end = golden_start + output_sequence_size; + + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(sequence_result, ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index 9e7a1233dac0f3cd02dc386f9d194597f38ca3b8..a619ada86af64c299f8e518a7493db20f1011a50 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -49,6 +49,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // dimensions except 'axis' must be equal. TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; TfLiteType input_type = t0->type; + if (axis < 0) axis += t0->dims->size; TF_LITE_ENSURE(context, axis >= 0); TF_LITE_ENSURE(context, axis < t0->dims->size); @@ -95,53 +96,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } -template -class VectorOfInputs { - public: - VectorOfInputs(const TfLiteContext& context, const TfLiteIntArray& inputs) { - int num_inputs = inputs.size; - - all_data_.reserve(num_inputs); - all_dims_.reserve(num_inputs); - all_dims_ptr_.reserve(num_inputs); - - for (int i = 0; i < num_inputs; ++i) { - TfLiteTensor* input = &context.tensors[inputs.data[i]]; - all_data_.push_back(GetTensorData(input)); - all_dims_.push_back(GetTensorDims(input)); - } - - // Taking the pointer from inside a std::vector is only OK if the vector is - // never modified, so we populate all_dims in the previous loop and then we - // are free to grab iterators here. - for (int i = 0; i < num_inputs; ++i) { - all_dims_ptr_.push_back(&all_dims_[i]); - } - } - const T* const* data() const { return all_data_.data(); } - const Dims<4>* const* dims() const { return all_dims_ptr_.data(); } - - private: - std::vector all_data_; - std::vector> all_dims_; - std::vector*> all_dims_ptr_; -}; - template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - + int axis = params->axis; TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + if (axis < 0) axis += output->dims->size; // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should // allocate and populate these during Prepare(). // TODO(ycling): Activation function parameter is ignored. For now we dont have // a model with a Concatenation with fused activation function. #define TF_LITE_CONCATENATION(type, scalar) \ - VectorOfInputs all_inputs(*context, *node->inputs); \ + VectorOfTensors all_inputs(*context, *node->inputs); \ type::Concatenation( \ - RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \ + RemapDim(NumDimensions(output), axis), all_inputs.data(), \ all_inputs.dims(), node->inputs->size, GetTensorData(output), \ GetTensorDims(output)) diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc index 499856a93cbbfbf9aa1a326912e52ce32bbbdf83..ba1ffc5f8423b9626c9c8e2a1086ea0dcca43f50 100644 --- a/tensorflow/contrib/lite/kernels/concatenation_test.cc +++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc @@ -94,7 +94,7 @@ TEST(ConcatenationOpTest, TwoDimensionalOneInput) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(ConcatenationOpTest, TwoInputsTwoAxis) { +TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxes) { // We will concatenate two tensors along different dimensions. auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; @@ -107,6 +107,14 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxis) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + ConcatenationOpModel m0_negative({TensorType_FLOAT32, {2, 3}}, /*axis=*/-2, + /*num_inputs=*/2); + m0_negative.SetInput(0, tensor0); + m0_negative.SetInput(1, tensor1); + m0_negative.Invoke(); + EXPECT_THAT(m0_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1, /*num_inputs=*/2); m1.SetInput(0, tensor0); @@ -114,6 +122,14 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxis) { m1.Invoke(); EXPECT_THAT(m1.GetOutput(), ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); + + ConcatenationOpModel m1_negative({TensorType_FLOAT32, {2, 3}}, /*axis=*/-1, + /*num_inputs=*/2); + m1_negative.SetInput(0, tensor0); + m1_negative.SetInput(1, tensor1); + m1_negative.Invoke(); + EXPECT_THAT(m1_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); } TEST(ConcatenationOpTest, FourInputs) { diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 37f499a4d09a38765aa4b8db8aa91b708edd7823..66d2c04bba4a164bbcdcf4b1a097d9aac0b3aeeb 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" @@ -38,11 +39,16 @@ namespace ops { namespace builtin { namespace conv { -// This file has three implementation of Conv. +// This file has 4 implementation of Conv. enum KernelType { kReference, kGenericOptimized, // Neon-free - kNeonOptimized, + kMultithreadOptimized, + // The kernel uses use CBLAS interface for matrix multiplication. + // It's fast when an optimized CBLAS implementation is available (e.g. Apple + // Accelerate Framework), and it's slow when falling back to naive + // implementation. + kCblasOptimized, }; struct OpData { @@ -265,10 +271,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { free(hwcn_weights->data.raw); hwcn_weights->data.raw = nullptr; } + + // Note that hwcn_weights_status is a kTfLiteDynamic tensor, and + // ResizeTensor will actually allocate space for it. The would be more + // efficient if we placed hwcn_weights_status in the persistent arena. auto hwcn_weights_status = context->ResizeTensor(context, hwcn_weights, hwcn_weights_size); if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status; - hwcn_weights->data.raw = static_cast(malloc(hwcn_weights->bytes)); // TODO(petewarden): If Resize() is called when the size hasn't actually // changed, this will do extra redundant work. @@ -290,26 +299,34 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, auto filter_offset = -filter->params.zero_point; auto output_offset = output->params.zero_point; - if (kernel_type == kReference) { - reference_ops::Conv( - GetTensorData(input), GetTensorDims(input), input_offset, - GetTensorData(filter), GetTensorDims(filter), filter_offset, - GetTensorData(bias), GetTensorDims(bias), params->stride_width, - params->stride_height, data->padding.width, data->padding.height, - output_offset, data->output_multiplier, data->output_shift, - data->output_activation_min, data->output_activation_max, - GetTensorData(output), GetTensorDims(output), - GetTensorData(im2col), GetTensorDims(im2col), gemm_context); - } else { - optimized_ops::Conv( - GetTensorData(input), GetTensorDims(input), input_offset, - GetTensorData(filter), GetTensorDims(filter), filter_offset, - GetTensorData(bias), GetTensorDims(bias), params->stride_width, - params->stride_height, data->padding.width, data->padding.height, - output_offset, data->output_multiplier, data->output_shift, - data->output_activation_min, data->output_activation_max, - GetTensorData(output), GetTensorDims(output), - GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + switch (kernel_type) { + case kReference: + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, output_offset, data->output_multiplier, + data->output_shift, data->output_activation_min, + data->output_activation_max, GetTensorData(output), + GetTensorDims(output), GetTensorData(im2col), + GetTensorDims(im2col), gemm_context); + break; + case kGenericOptimized: + case kMultithreadOptimized: + case kCblasOptimized: + // There is only one optimized implementation for Quantized Conv. + optimized_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, output_offset, data->output_multiplier, + data->output_shift, data->output_activation_min, + data->output_activation_max, GetTensorData(output), + GetTensorDims(output), GetTensorData(im2col), + GetTensorDims(im2col), gemm_context); + break; } } @@ -322,31 +339,57 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); - if (kernel_type == kReference) { - reference_ops::Conv(GetTensorData(input), GetTensorDims(input), - GetTensorData(filter), GetTensorDims(filter), - GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, - data->padding.width, data->padding.height, - output_activation_min, output_activation_max, - GetTensorData(output), GetTensorDims(output), - GetTensorData(im2col), GetTensorDims(im2col)); - } else { - const float* filter_data; - if (data->need_hwcn_weights) { - filter_data = GetTensorData(hwcn_weights); - } else { - filter_data = GetTensorData(filter); + switch (kernel_type) { + case kReference: { + reference_ops::Conv(GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, + data->padding.width, data->padding.height, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; + } + case kGenericOptimized: { + optimized_ops::Conv(GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, + data->padding.width, data->padding.height, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; + } + case kMultithreadOptimized: { + const float* filter_data; + if (data->need_hwcn_weights) { + filter_data = GetTensorData(hwcn_weights); + } else { + filter_data = GetTensorData(filter); + } + multithreaded_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), + GetTensorDims(bias), params->stride_width, params->stride_height, + data->padding.width, data->padding.height, params->padding, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; + } + case kCblasOptimized: { + cblas_ops::Conv(GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, + data->padding.width, data->padding.height, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; } - - multithreaded_ops::Conv( - GetTensorData(input), GetTensorDims(input), filter_data, - GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, data->padding.width, - data->padding.height, params->padding, output_activation_min, - output_activation_max, GetTensorData(output), - GetTensorDims(output), GetTensorData(im2col), - GetTensorDims(im2col)); } } @@ -407,17 +450,23 @@ TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() { return &r; } -TfLiteRegistration* Register_CONVOLUTION_NEON_OPT() { +TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT() { static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, - conv::Eval}; + conv::Eval}; return &r; } TfLiteRegistration* Register_CONV_2D() { -#ifdef USE_NEON - return Register_CONVOLUTION_NEON_OPT(); +#ifdef TFLITE_USE_APPLE_ACCELERATE_FOR_CONV + return Register_CONVOLUTION_CBLAS_OPT(); #else - return Register_CONVOLUTION_GENERIC_OPT(); + return Register_CONVOLUTION_MULTITHREADED_OPT(); #endif } diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 1d0a81c3135625c07a3566f5f9a8e5401f0d4db7..d2393c3c97bb9516e2b8a6c8ae037dc0dfdfe64b 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -15,12 +15,25 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" #include "tensorflow/contrib/lite/model.h" namespace tflite { + +namespace ops { +namespace builtin { + +TfLiteRegistration* Register_CONVOLUTION_REF(); +TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT(); +TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT(); +TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT(); + +} // namespace builtin +} // namespace ops + namespace { using ::testing::ElementsAreArray; @@ -30,9 +43,9 @@ class BaseConvolutionOpModel : public SingleOpModel { // TODO(ahentz): Also test different activation types, bias, padding types, // stride values. BaseConvolutionOpModel( - const TensorData& input, const TensorData& filter, - const TensorData& output, int stride_width = 2, int stride_height = 2, - enum Padding padding = Padding_VALID, + TfLiteRegistration* registration, const TensorData& input, + const TensorData& filter, const TensorData& output, int stride_width = 2, + int stride_height = 2, enum Padding padding = Padding_VALID, enum ActivationFunctionType activation = ActivationFunctionType_NONE) { input_ = AddInput(input); filter_ = AddInput(filter); @@ -62,6 +75,8 @@ class BaseConvolutionOpModel : public SingleOpModel { stride_height, activation) .Union()); + resolver_ = absl::make_unique(BuiltinOperator_CONV_2D, + registration); BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); } @@ -83,12 +98,26 @@ class ConvolutionOpModel : public BaseConvolutionOpModel { void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } }; -TEST(ConvolutionOpTest, SimpleTestFloat32) { - ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, +const auto kKernelMap = new std::map({ + {"Reference", ops::builtin::Register_CONVOLUTION_REF()}, + {"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()}, + {"MultithreadedOptimized", + ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()}, + {"CblasOptimized", ops::builtin::Register_CONVOLUTION_CBLAS_OPT()}, +}); + +class ConvolutionOpTest : public SingleOpTest { + protected: + const std::map& GetKernelMap() override { + return *kKernelMap; + } +}; + +TEST_P(ConvolutionOpTest, SimpleTestFloat32) { + ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, {TensorType_FLOAT32, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}}); @@ -117,8 +146,8 @@ TEST(ConvolutionOpTest, SimpleTestFloat32) { })); } -TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { - ConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 6, 1}}, +TEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { + ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, /*stride_width=*/3, /*stride_height=*/1); @@ -139,7 +168,7 @@ TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { })); } -TEST(ConvolutionOpTest, HandCalculatedFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -150,6 +179,7 @@ TEST(ConvolutionOpTest, HandCalculatedFloat32) { const int stride_height = 1; const Padding padding = Padding_SAME; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -192,7 +222,7 @@ TEST(ConvolutionOpTest, HandCalculatedFloat32) { 178, 187, 234, 261, 121})); } -TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -203,6 +233,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { const int stride_height = 1; const Padding padding = Padding_SAME; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -245,7 +276,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { 367, 188, 197, 244, 271, 131})); } -TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedWithReluFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -256,6 +287,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { const int stride_height = 1; const Padding padding = Padding_SAME; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -300,7 +332,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0})); } -TEST(ConvolutionOpTest, HandCalculatedValidFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedValidFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -311,6 +343,7 @@ TEST(ConvolutionOpTest, HandCalculatedValidFloat32) { const int stride_height = 1; const Padding padding = Padding_VALID; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -366,8 +399,9 @@ class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { // In this tests we set the input and output scales so that the results // match exactly the 'non-quantized' version. -TEST(ConvolutionOpTest, SimpleTestQuantized) { - QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, +TEST_P(ConvolutionOpTest, SimpleTestQuantized) { + QuantizedConvolutionOpModel m(GetRegistration(), + {TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, {TensorType_UINT8, {}, -127, 128}); m.SetInput({ @@ -405,8 +439,9 @@ TEST(ConvolutionOpTest, SimpleTestQuantized) { })); } -TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { - QuantizedConvolutionOpModel m({TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, +TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { + QuantizedConvolutionOpModel m(GetRegistration(), + {TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64}, {TensorType_UINT8, {}, -127, 128}, /*stride_width=*/3, /*stride_height=*/1); @@ -430,6 +465,11 @@ TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { 167, 93, // })); } + +INSTANTIATE_TEST_CASE_P( + ConvolutionOpTest, ConvolutionOpTest, + ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc index dcdc5fffad9ceac1a9d23a4e91637a9ff92a8dda..ef2b5422253ea880a9ded4d3c0efc5cec07178a9 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc @@ -123,18 +123,16 @@ TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) { [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); m.Invoke(); - EXPECT_THAT( - m.GetOutput(), - ElementsAreArray(ArrayFloatNear({ - 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 - 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - - 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), - 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), - 7.20f / std::sqrt(20.0f), - 7.26f / - std::sqrt( - 20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * Row 3 + 4 * Row 0 - }))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), + 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), + 7.20f / std::sqrt(20.0f), + 7.26f / std::sqrt(20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * + // Row 3 + 4 * Row 0 + }))); } TEST(EmbeddingLookupOpTest, Indices3DTest) { diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9e79b742dc2c80ce4ed9a3aa786814265dcb660 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/exp.cc @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace exp { + +// This file has reference implementation of Exp. +enum KernelType { + kReference, +}; + +struct ExpContext { + ExpContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + ExpContext op_context(context, node); + TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims); + op_context.output->type = op_context.input->type; + return context->ResizeTensor(context, op_context.output, output_dims); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + ExpContext op_context(context, node); + +#define TF_LITE_EXP(kernel_type, data_type) \ + kernel_type::Exp(GetTensorData(op_context.input), \ + NumElements(op_context.input), \ + GetTensorData(op_context.output)) + + // TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128. + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_EXP(reference_ops, float); + break; + default: + context->ReportError(context, + "Type %d is currently not supported by Exp.", + op_context.input->type); + return kTfLiteError; + } + } +#undef TF_LITE_EXP + return kTfLiteOk; +} + +} // namespace exp + +TfLiteRegistration* Register_EXP_REF() { + static TfLiteRegistration r = {nullptr, nullptr, exp::Prepare, + exp::Eval}; + return &r; +} + +// TODO(kanlig): add optimized implementation of Exp. +TfLiteRegistration* Register_EXP() { return Register_EXP_REF(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/exp_test.cc b/tensorflow/contrib/lite/kernels/exp_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..eed67369a1f30e57cd29a3975a899db41938def0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/exp_test.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ExpOpModel : public SingleOpModel { + public: + ExpOpModel(const TensorData& input, const TensorType& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_EXP, BuiltinOptions_ExpOptions, + CreateExpOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int output_; +}; + +TEST(ExpOpTest, FloatTest) { + std::initializer_list data = {1.0, 0.0, -1.0, 1.0, 1.0, -1.0}; + ExpOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {2.71828, 1, 0.367879, 2.71828, 2.71828, 0.367879}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc index 658d977b8dc7fffcdde69d74ba2564dfa1b5709e..cdadbeda1884ba0186846826dd16be6ff69878d9 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -81,10 +81,8 @@ TEST(GatherOpTest, Test0DIndex) { m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); m.SetPositions({1}); m.Invoke(); - EXPECT_THAT(m.GetOutputFloat(), - ElementsAreArray(ArrayFloatNear({0.7, 0.8}))); - EXPECT_THAT(m.GetOutputShape(), - ElementsAreArray({2})); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({0.7, 0.8}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); } TEST(GatherOpTest, Test0DIndexWith0DResult) { @@ -94,8 +92,7 @@ TEST(GatherOpTest, Test0DIndexWith0DResult) { m.SetInputFloat({1.0, 2.0, 3.0}); m.SetPositions({1}); m.Invoke(); - EXPECT_THAT(m.GetOutputFloat(), - ElementsAreArray(ArrayFloatNear({2.0}))); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({2.0}))); EXPECT_TRUE(m.GetOutputShape().empty()); } diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc index cb6038f9009a3865661e7b4f075c3033166d0f91..ba0ed5ce06392613238b757308dddc2b22e7eb30 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc @@ -116,7 +116,10 @@ TEST(HashtableLookupOpTest, Test2DInput) { 1.0, 1.1, // 1-st item }))); EXPECT_THAT(m.GetHit(), ElementsAreArray({ - 1, 0, 1, 1, + 1, + 0, + 1, + 1, })); } diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 21118fc96d804654a33d5c693d496b05e2dc59d2..f47fb04cbaa688b75e763ff9d3cb7df44ac3f166 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -124,6 +124,13 @@ config_setting( }, ) +config_setting( + name = "darwin_x86_64", + values = { + "cpu": "darwin_x86_64", + }, +) + config_setting( name = "freebsd", values = { @@ -154,6 +161,7 @@ cc_library( ":x86": tflite_deps_intel, ":x86_64": tflite_deps_intel, ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, ":freebsd": tflite_deps_intel, "//conditions:default": [], }), @@ -162,6 +170,8 @@ cc_library( cc_library( name = "optimized", hdrs = [ + "optimized/cblas_conv.h", + "optimized/cblas_reference.h", "optimized/eigen_spatial_convolutions.h", "optimized/eigen_tensor_reduced_instantiations_oss.h", "optimized/multithreaded_conv.h", @@ -232,6 +242,7 @@ cc_library( ":x86": tflite_deps_intel, ":x86_64": tflite_deps_intel, ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, ":freebsd": tflite_deps_intel, "//conditions:default": [], }), @@ -267,6 +278,8 @@ cc_library( "optimized/neon_tensor_utils.cc", ], hdrs = [ + "common.h", + "optimized/cpu_check.h", "optimized/neon_tensor_utils.h", "optimized/tensor_utils_impl.h", ], @@ -274,8 +287,21 @@ cc_library( deps = [ ":cpu_check", ":portable_tensor_utils", + ":types", "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite/kernels:activation_functor", + "@arm_neon_2_x86_sse", + "@gemmlowp", + ], +) + +cc_library( + name = "kernel_utils", + srcs = ["kernel_utils.cc"], + hdrs = ["kernel_utils.h"], + deps = [ + ":tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", ], ) @@ -285,14 +311,21 @@ cc_library( "tensor_utils.cc", ], hdrs = [ + "common.h", + "compatibility.h", + "optimized/cpu_check.h", + "optimized/neon_tensor_utils.h", "optimized/tensor_utils_impl.h", "reference/portable_tensor_utils.h", "tensor_utils.h", + "types.h", ], copts = NEON_FLAGS_IF_APPLICABLE, deps = [ "//tensorflow/contrib/lite/kernels:activation_functor", "//tensorflow/contrib/lite:builtin_op_data", + "@arm_neon_2_x86_sse", + "@gemmlowp", ] + select({ ":arm": [ ":neon_tensor_utils", @@ -312,6 +345,21 @@ cc_library( ":ios_arm64": [ ":neon_tensor_utils", ], + ":ios_x86_64": [ + ":neon_tensor_utils", + ], + ":x86_64": [ + ":neon_tensor_utils", + ], + ":x86": [ + ":neon_tensor_utils", + ], + ":k8": [ + ":neon_tensor_utils", + ], + ":darwin": [ + ":neon_tensor_utils", + ], "//conditions:default": [ ":portable_tensor_utils", ], diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index 4d7a3a4e98497653071f9bbee464bc05a8e821e5..e126774ebc866fee7476e2ba98a5863085965215 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -104,6 +104,17 @@ inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( quantized_multiplier); } +inline int32 MultiplyByQuantizedMultiplier(int32 x, int32 quantized_multiplier, + int shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + x * (1 << left_shift), quantized_multiplier), + right_shift); +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h index 1d963afb7e1ce414f251f090208923ca0c68cee1..51426bb1c584b82af7b1a2ffaf5a675a1dd9a6fd 100644 --- a/tensorflow/contrib/lite/kernels/internal/compatibility.h +++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h @@ -27,6 +27,10 @@ limitations under the License. #define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false) #endif +#ifndef TFLITE_DCHECK_NE +#define TFLITE_DCHECK_NE(x, y) ((x) != (y)) ? (void)0 : assert(false) +#endif + #ifndef TFLITE_DCHECK_GE #define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false) #endif @@ -52,6 +56,10 @@ limitations under the License. #define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort() #endif +#ifndef TFLITE_CHECK_NE +#define TFLITE_CHECK_NE(x, y) ((x) != (y)) ? (void)0 : abort() +#endif + #ifndef TFLITE_CHECK_GE #define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort() #endif diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..510395126ce3785b1d44fec1e0eb994c29ff0db7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" + +namespace tflite { +namespace kernel_utils { + +void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, + const float* recurrent_weights_ptr, const float* bias_ptr, + int input_size, int num_units, int batch_size, + TfLiteFusedActivation activation, + float* hidden_state_ptr_batch, float* output_ptr_batch) { + // Output = bias + tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, + output_ptr_batch); + // Output += input * input_weights + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size, + output_ptr_batch, /*result_stride=*/1); + // Output += recurrent_weights * hidden_state + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch, + batch_size, output_ptr_batch, /*result_stride=*/1); + // Output = activation(Output) and update hidden_state + tensor_utils::ApplyActivationToVector( + output_ptr_batch, num_units * batch_size, activation, output_ptr_batch); + tensor_utils::VectorBatchVectorAssign(output_ptr_batch, num_units, batch_size, + hidden_state_ptr_batch); +} + +} // namespace kernel_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9872d4500b862388ed4b96c97e3755f548e35d35 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace kernel_utils { + +// Performs an RNN batch inference step for inputs specified by input_ptr_batch. +// The RNN cell is specified by the pointers to its input and recurrent weights, +// and biases, along with the input size, number of units, activation. +// +// The pointers to the hidden state and the output are updated as a result. +// +// The pointers with the suffix "_batch" point to data aligned in batch_major +// order, and each step processes batch_size many inputs from input_ptr_batch, +// and updates batch_size many outputs and hidden states. +void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, + const float* recurrent_weights_ptr, const float* bias_ptr, + int input_size, int num_units, int batch_size, + TfLiteFusedActivation activation, + float* hidden_state_ptr_batch, float* output_ptr_batch); + +} // namespace kernel_utils +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..4a90e7e640ef29b675c236d8bbb479aa16560761 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ + +// The Conv implementation based on CBLAS interface. This is only used on iOS +// for now, utilizing Apple's Accelerate framework. + +#if TFLITE_USE_APPLE_ACCELERATE_FOR_CONV +#include +#else +#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h" +#endif + +#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" + +namespace tflite { +namespace cblas_ops { + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + gemmlowp::ScopedProfilingLabel label("Conv/cblas"); + + const float* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + optimized_ops::Im2col(input_data, input_dims, stride_width, stride_height, + pad_width, pad_height, filter_height, filter_width, 0, + im2col_data, im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + // The following code computes matrix multiplication c = a * transponse(b) + // with CBLAS, where: + // * `a` is a matrix with dimensions (m, k). + // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n). + // * `c` is a matrix with dimensions (m, n). + // The naming of variables are aligned with CBLAS specification here. + const float* a = gemm_input_data; + const float* b = filter_data; + float* c = output_data; + int m = gemm_input_dims->sizes[1] * gemm_input_dims->sizes[2] * + gemm_input_dims->sizes[3]; + int n = output_dims.sizes[0]; + int k = gemm_input_dims->sizes[0]; + // The stride of matrix a, b and c respectively. + int stride_a = k; + int stride_b = k; + int stride_c = n; + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a, + stride_a, b, stride_b, 0.0f, c, stride_c); + + optimized_ops::AddBiasAndEvalActivationFunction( + bias_data, bias_dims, output_data, output_dims, output_activation_min, + output_activation_max); +} + +} // namespace cblas_ops +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h new file mode 100644 index 0000000000000000000000000000000000000000..6acc513805c9398c304f3e24175d3bd6c96938f6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +// The reference implementation for a small subset of CBLAS interface. +// This is only used for testing CBLAS implementation, and should never be used +// in production code. + +namespace tflite { +namespace cblas_ops { + +// The following code follows the original CBLAS specification, and it might +// conflict with the TensorFlow naming convention. +// TODO(ycling): Find another way to test CBLAS with bazel, without writing +// a reference implementation by ourselves. +enum CBLAS_ORDER { CblasRowMajor = 0, CblasColMajor = 1 }; + +enum CBLAS_TRANSPOSE { CblasNoTrans = 0, CblasTrans = 1, CblasConjTrans = 2 }; + +// A reference implementation for matrix multiplication. +// The following code computes, c = a * transponse(b) matrix multiplication +// with CBLAS, where: +// * `a` is a matrix with dimensions (m, k). +// * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n). +// * `c` is a matrix with dimensions (m, n). +// The naming of variables is aligned with CBLAS specification here. +void cblas_sgemm(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE trans_a, + const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, + const int k, const float alpha, const float *a, + const int stride_a, const float *b, const int stride_b, + const float beta, float *c, const int stride_c) { + TFLITE_DCHECK(order == CblasRowMajor); + TFLITE_DCHECK(trans_a == CblasNoTrans); + TFLITE_DCHECK(trans_b == CblasTrans); + TFLITE_DCHECK(beta == 0.0f); + for (int row = 0; row < m; ++row) { + for (int col = 0; col < n; ++col) { + // If `beta` non-zero, multiple it with the original values in output. + // Otherwise, ignore the original value in output completely. + float value = 0.0f; + for (int idx = 0; idx < k; ++idx) { + value += alpha * a[stride_a * row + idx] * b[stride_b * col + idx]; + } + c[stride_c * row + col] = value; + } + } +} + +} // namespace cblas_ops +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h index 76687dce333e46d124e105b6e79b96cbd52ae1e2..11fffa7b2095e2355ef064ced7acb56d80b00f86 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h @@ -29,17 +29,13 @@ inline bool TestCPUFeatureNeon() { #endif } -#elif __ARM_NEON +#elif defined USE_NEON || defined __ARM_NEON -inline bool TestCPUFeatureNeon() { - return true; -} +inline bool TestCPUFeatureNeon() { return true; } #else -inline bool TestCPUFeatureNeon() { - return false; -} +inline bool TestCPUFeatureNeon() { return false; } #endif diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h index 81796e295d9c7ae1f04163467c8b2af851b632c2..7f6eea2d5d1cfd6f4e2a569760ecbe0d96f754c8 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -573,6 +573,46 @@ struct FloatDepthwiseConvKernel { } }; +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3); + float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4); + // Multiply-accumulate + acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val); + acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val); + acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val); + acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val); + acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4); + acc_buffer_ptr += 20; + } + } +}; + template <> struct FloatDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -926,6 +966,7 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 20) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 2) @@ -992,11 +1033,11 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, for (int k = 0; k < 4; k++) { acc[k] = vld1q_f32(acc_buffer + i + 4 * k); } - for (int k = 0; k < 4; k++) { - acc[k] = vmaxq_f32( - vdupq_n_f32(output_activation_min), - vminq_f32(vdupq_n_f32(output_activation_max), acc[k])); - } + for (int k = 0; k < 4; k++) { + acc[k] = vmaxq_f32( + vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc[k])); + } for (int k = 0; k < 4; k++) { vst1q_f32(output_ptr + 4 * k, acc[k]); } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h index f993fd6a00f054c670b247e886a1d9d2a34643e7..dbc4f0d6fdca8279072d6ea225334722d6a89eb2 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -1205,6 +1205,55 @@ struct QuantizedDepthwiseConvKernel { } }; +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + // NEON wants to load 8 bytes at a time, but 20 is not divisible by 8. + // We load the first 16 bytes into filter_u8_{0,1} as usual. + // Then we load the 8 last bytes into filter_u8_x (x for 'extra'). + // This is redundant: the first 4 bytes of filter_u8_x are the same + // as the last 4 bytes of filter_u8_x. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1); + uint8x8_t filter_u8_x = vld1_u8(filter_ptr + 8 * 1 + 4); + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + int16x8_t filter_x = vreinterpretq_s16_u16(vmovl_u8(filter_u8_x)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + filter_x = vaddq_s16(filter_x, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4); + // Multiply-accumulate + acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input); + acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input); + acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input); + acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input); + acc_4 = vmlal_n_s16(acc_4, vget_high_s16(filter_x), input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4); + acc_buffer_ptr += 20; + } + } +}; + template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -1504,7 +1553,7 @@ inline void QuantizedDepthwiseConvAccumRowGeneric( << "*\n" << "* If you would like to carry on with the slow code, compile\n" << "* with this preprocessor token defined:\n" - << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" + << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" << "*\n" << "* The right thing to do, if you care about performance, is to add\n" << "* a new DepthwiseConv kernel to tfmini to cover your case.\n" @@ -1691,6 +1740,7 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 20) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h index 86bd8eced5cddd85a7281e99edbdee8cb6ab85cb..5726a0dd5a1f44cbb517f0ddc6cbb97ecaf43578 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h @@ -39,7 +39,6 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #endif - namespace Eigen { /** SpatialConvolution @@ -215,13 +214,12 @@ EIGEN_DEVICE_FUNC } // TODO(yangke): choose() is defined in TensorContraction.h -- consider // moving it to somewhere more "common". - return - input - .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, - row_in_stride, col_in_stride, padding_type) - .reshape(pre_contract_dims) - .contract(kernel.reshape(kernel_dims), contract_dims) - .reshape(post_contract_dims); + return input + .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, padding_type) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims) + .reshape(post_contract_dims); } } // end namespace Eigen diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index bf0bdfb1fb875c4b54c55e25d4a17541507ecd4c..780401e052733cccae0cc34f495df090c1530624 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -15,12 +15,13 @@ limitations under the License. #include #include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" #ifdef USE_NEON -#include #define kFloatWeightsPerNeonLane 4 namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index f65ca6adad71b078a8c71880a1d620545206457d..d4f031e130f2277f199440bf9099ad4be8520e9a 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -2081,6 +2081,438 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, output_state_map.tanh(); } +#ifdef GEMMLOWP_NEON +// In the common case of batch size 1, a fully-connected node degenerates +// to a matrix*vector product. LSTM cells contain a fully-connected node; +// when quantized, this becomes a special type of GEMV operation where +// the output is 16bit-quantized, thus needs its own special path. +inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims, + const uint8* weights_data, + const Dims<4>& weights_dims, + uint8 weights_zero_point, const int32* bias_data, + const Dims<4>& bias_dims, int32 accum_multiplier, + int accum_shift, int16* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3), + 1); + const int input_size = input_dims.strides[3]; + const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0); + // This special fast path for quantized LSTM cells does not try to support + // odd sizes that we haven't encountered in any LSTM cell, that would + // require special code (that would go untested until any LSTM cell + // exercises it). We just guard our assumptions about size evenness with + // the following assertions. + TFLITE_DCHECK(!(output_size % 4)); + TFLITE_DCHECK(!(input_size % 8)); + const int32* bias_ptr = bias_data; + int16* output_ptr = output_data; + for (int out = 0; out < output_size; out += 4) { + int32x4_t acc_0 = vdupq_n_s32(0); + int32x4_t acc_1 = vdupq_n_s32(0); + int32x4_t acc_2 = vdupq_n_s32(0); + int32x4_t acc_3 = vdupq_n_s32(0); + const int16x8_t input_offset_vec = vdupq_n_s16(-128); + const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point); + int in = 0; + // Handle 16 levels of depth at a time. + for (; in <= input_size - 16; in += 16) { + const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); + const uint8* weights_ptr = weights_data + in + out * input_size; + uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size); + uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size); + uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size); + uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size); + int16x8_t input_val_0, input_val_1; + const uint8x8_t low = vget_low_u8(input_val_u8); + const uint8x8_t high = vget_high_u8(input_val_u8); + input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low)); + input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high)); + input_val_0 = vaddq_s16(input_val_0, input_offset_vec); + input_val_1 = vaddq_s16(input_val_1, input_offset_vec); + int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0, + weights_val_3_0; + int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1, + weights_val_3_1; + weights_val_0_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))), + weights_offset_vec); + weights_val_0_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))), + weights_offset_vec); + weights_val_1_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))), + weights_offset_vec); + weights_val_1_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))), + weights_offset_vec); + weights_val_2_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))), + weights_offset_vec); + weights_val_2_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))), + weights_offset_vec); + weights_val_3_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))), + weights_offset_vec); + weights_val_3_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))), + weights_offset_vec); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0), + vget_low_s16(input_val_0)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0), + vget_low_s16(input_val_0)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0), + vget_low_s16(input_val_0)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0), + vget_low_s16(input_val_0)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0), + vget_high_s16(input_val_0)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0), + vget_high_s16(input_val_0)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0), + vget_high_s16(input_val_0)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0), + vget_high_s16(input_val_0)); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1), + vget_low_s16(input_val_1)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1), + vget_low_s16(input_val_1)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1), + vget_low_s16(input_val_1)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1), + vget_low_s16(input_val_1)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1), + vget_high_s16(input_val_1)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1), + vget_high_s16(input_val_1)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1), + vget_high_s16(input_val_1)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1), + vget_high_s16(input_val_1)); + } + // Handle 8 levels of depth at a time. + for (; in < input_size; in += 8) { + const uint8x8_t input_val_u8 = vld1_u8(input_data + in); + const uint8* weights_ptr = weights_data + in + out * input_size; + uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size); + uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size); + uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size); + uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size); + int16x8_t input_val; + input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); + input_val = vaddq_s16(input_val, input_offset_vec); + int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3; + weights_val_0 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)), + weights_offset_vec); + weights_val_1 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)), + weights_offset_vec); + weights_val_2 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)), + weights_offset_vec); + weights_val_3 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)), + weights_offset_vec); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0), + vget_low_s16(input_val)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1), + vget_low_s16(input_val)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2), + vget_low_s16(input_val)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3), + vget_low_s16(input_val)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0), + vget_high_s16(input_val)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1), + vget_high_s16(input_val)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2), + vget_high_s16(input_val)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3), + vget_high_s16(input_val)); + } + // Horizontally reduce accumulators + int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, + pairwise_reduced_acc_2, pairwise_reduced_acc_3; + pairwise_reduced_acc_0 = + vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0)); + pairwise_reduced_acc_1 = + vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1)); + pairwise_reduced_acc_2 = + vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2)); + pairwise_reduced_acc_3 = + vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3)); + const int32x2_t reduced_lo = + vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); + const int32x2_t reduced_hi = + vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); + int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); + // Add bias values. + int32x4_t bias_vec = vld1q_s32(bias_ptr); + bias_ptr += 4; + reduced = vaddq_s32(reduced, bias_vec); + int left_shift = accum_shift > 0 ? accum_shift : 0; + int right_shift = accum_shift > 0 ? 0 : -accum_shift; + reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); + // Multiply by the fixed-point multiplier. + reduced = vqrdmulhq_n_s32(reduced, accum_multiplier); + // Rounding-shift-right. + using gemmlowp::RoundingDivideByPOT; + reduced = RoundingDivideByPOT(reduced, right_shift); + // Narrow values down to 16 bit signed. + const int16x4_t res16 = vqmovn_s32(reduced); + vst1_s16(output_ptr, res16); + output_ptr += 4; + } +} +#endif + +// Quantized LSTM cell. Currently just a copy of the reference impl in +// reference_ops.h. See the big function comment there, not replicating it +// here. +template +void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, + const uint8* prev_activ_data_uint8, + const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8, + const Dims<4>& weights_dims, const int32* bias_data_int32, + const Dims<4>& bias_dims, const int16* prev_state_data_int16, + const Dims<4>& prev_state_dims, int16* output_state_data_int16, + const Dims<4>& output_state_dims, uint8* output_activ_data_uint8, + const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, + const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, + const Dims<4>& activ_temp_dims, int32 weights_zero_point, + int32 accum_multiplier, int accum_shift, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label( + "LstmCell/quantized (8bit external, 16bit internal)"); + // Gather dimensions information, and perform consistency checks. + const int batches = + MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, + output_state_dims, 3, output_activ_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, + output_state_dims, 2, output_activ_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, + output_state_dims, 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + const int fc_batches = ArraySize(activ_temp_dims, 1) * + ArraySize(activ_temp_dims, 2) * + ArraySize(activ_temp_dims, 3); + const int fc_output_depth = + MatchingArraySize(weights_dims, 1, activ_temp_dims, 0); + const int fc_accum_depth = ArraySize(weights_dims, 0); + TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth); + + // Depth-concatenate prev_activ and input data together. + uint8 const* concat_input_arrays_data[2] = {input_data_uint8, + prev_activ_data_uint8}; + Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims}; + Concatenation( + 0, concat_input_arrays_data, concat_input_arrays_dims, 2, + concat_temp_data_uint8, concat_temp_dims); + + // Implementation of the fully connected node inside the LSTM cell. + // The operands are 8-bit integers, the accumulators are internally 32bit + // integers, and the output is 16-bit fixed-point with 3 integer bits so + // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that + // is explained in the function comment above. + bool gemm_already_performed = false; +#ifdef GEMMLOWP_NEON + if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) { + GEMVForLstmCell(concat_temp_data_uint8, concat_temp_dims, + weights_data_uint8, weights_dims, weights_zero_point, + bias_data_int32, bias_dims, accum_multiplier, accum_shift, + activ_temp_data_int16, activ_temp_dims); + gemm_already_performed = true; + } +#endif + if (!gemm_already_performed) { + gemmlowp::MatrixMap + weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth); + gemmlowp::MatrixMap input_matrix( + concat_temp_data_uint8, fc_accum_depth, fc_batches); + gemmlowp::MatrixMap output_matrix( + activ_temp_data_int16, fc_output_depth, fc_batches); + typedef gemmlowp::VectorMap + ColVectorMap; + ColVectorMap bias_vector(bias_data_int32, fc_output_depth); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage; + scale_stage.result_offset_after_shift = 0; + scale_stage.result_fixedpoint_multiplier = accum_multiplier; + scale_stage.result_exponent = accum_shift; + gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage; + auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage, + saturating_cast_int16_stage); + gemmlowp::GemmWithOutputPipeline< + uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + gemm_context, weights_matrix, input_matrix, &output_matrix, + -weights_zero_point, -128, output_pipeline); + } + + // Rest of the LSTM cell: tanh and logistic math functions, and some adds + // and muls, all done in 16-bit fixed-point. + const int outer_size = batches * width * height; + const int16* input_gate_input_ptr = activ_temp_data_int16; + const int16* input_modulation_gate_input_ptr = + activ_temp_data_int16 + output_depth; + const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth; + const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth; + const int16* prev_state_ptr = prev_state_data_int16; + int16* output_state_data_ptr = output_state_data_int16; + uint8* output_activ_data_ptr = output_activ_data_uint8; + + for (int b = 0; b < outer_size; ++b) { + int c = 0; +#ifdef GEMMLOWP_NEON + for (; c <= output_depth - 8; c += 8) { + // Define the fixed-point data types that we will use here. All use + // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // They only differ by the number of integral vs. fractional bits, + // determining the range of values that they can represent. + // + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8]. + // This is the range of the previous fully-connected node's output, + // which is our input here. + using F3 = gemmlowp::FixedPoint; + // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, + // 2^StateIntegerBits]. It's used to represent the internal state, whose + // number of integer bits is currently dictated by the model. See comment + // on the StateIntegerBits template parameter above. + using FS = gemmlowp::FixedPoint; + // Implementation of input gate, using fixed-point logistic function. + F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr)); + input_gate_input_ptr += 8; + F0 input_gate_output = gemmlowp::logistic(input_gate_input); + // Implementation of input modulation gate, using fixed-point tanh + // function. + F3 input_modulation_gate_input = + F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr)); + input_modulation_gate_input_ptr += 8; + F0 input_modulation_gate_output = + gemmlowp::tanh(input_modulation_gate_input); + // Implementation of forget gate, using fixed-point logistic function. + F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr)); + forget_gate_input_ptr += 8; + F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); + // Implementation of output gate, using fixed-point logistic function. + F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr)); + output_gate_input_ptr += 8; + F0 output_gate_output = gemmlowp::logistic(output_gate_input); + // Implementation of internal multiplication nodes, still in fixed-point. + F0 input_times_input_modulation = + input_gate_output * input_modulation_gate_output; + FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr)); + prev_state_ptr += 8; + FS prev_state_times_forget_state = forget_gate_output * prev_state; + // Implementation of internal addition node, saturating. + FS new_state = gemmlowp::SaturatingAdd( + gemmlowp::Rescale(input_times_input_modulation), + prev_state_times_forget_state); + // Implementation of last internal tanh node, still in fixed-point. + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state); + // Store the new internal state back to memory, as 16-bit integers. + vst1q_s16(output_state_data_ptr, new_state.raw()); + output_state_data_ptr += 8; + // Down-scale the output activations to 8-bit integers, saturating, + // and store back to memory. + int16x8_t rescaled_output_activ = + gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); + int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ); + uint8x8_t uint8_output_activ = + vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ)); + vst1_u8(output_activ_data_ptr, uint8_output_activ); + output_activ_data_ptr += 8; + } +#endif + for (; c < output_depth; ++c) { + // Define the fixed-point data types that we will use here. All use + // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // They only differ by the number of integral vs. fractional bits, + // determining the range of values that they can represent. + // + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8]. + // This is the range of the previous fully-connected node's output, + // which is our input here. + using F3 = gemmlowp::FixedPoint; + // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, + // 2^StateIntegerBits]. It's used to represent the internal state, whose + // number of integer bits is currently dictated by the model. See comment + // on the StateIntegerBits template parameter above. + using FS = gemmlowp::FixedPoint; + // Implementation of input gate, using fixed-point logistic function. + F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++); + F0 input_gate_output = gemmlowp::logistic(input_gate_input); + // Implementation of input modulation gate, using fixed-point tanh + // function. + F3 input_modulation_gate_input = + F3::FromRaw(*input_modulation_gate_input_ptr++); + F0 input_modulation_gate_output = + gemmlowp::tanh(input_modulation_gate_input); + // Implementation of forget gate, using fixed-point logistic function. + F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++); + F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); + // Implementation of output gate, using fixed-point logistic function. + F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++); + F0 output_gate_output = gemmlowp::logistic(output_gate_input); + // Implementation of internal multiplication nodes, still in fixed-point. + F0 input_times_input_modulation = + input_gate_output * input_modulation_gate_output; + FS prev_state = FS::FromRaw(*prev_state_ptr++); + FS prev_state_times_forget_state = forget_gate_output * prev_state; + // Implementation of internal addition node, saturating. + FS new_state = gemmlowp::SaturatingAdd( + gemmlowp::Rescale(input_times_input_modulation), + prev_state_times_forget_state); + // Implementation of last internal tanh node, still in fixed-point. + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state); + // Store the new internal state back to memory, as 16-bit integers. + *output_state_data_ptr++ = new_state.raw(); + // Down-scale the output activations to 8-bit integers, saturating, + // and store back to memory. + int16 rescaled_output_activ = + gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); + int16 clamped_output_activ = + std::max(-128, std::min(127, rescaled_output_activ)); + *output_activ_data_ptr++ = 128 + clamped_output_activ; + } + input_gate_input_ptr += 3 * output_depth; + input_modulation_gate_input_ptr += 3 * output_depth; + forget_gate_input_ptr += 3 * output_depth; + output_gate_input_ptr += 3 * output_depth; + } +} + template void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int outputs_count, Scalar* const* output_data, @@ -2706,74 +3138,194 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - gemmlowp::ScopedProfilingLabel label("Softmax"); + gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); const int height = MatchingArraySize(input_dims, 2, output_dims, 2); const int width = MatchingArraySize(input_dims, 1, output_dims, 1); const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int x = 0; x < width; ++x) { - for (int y = 0; y < height; ++y) { - uint8 max_in_row = 0; - for (int c = 0; c < depth; ++c) { - max_in_row = - std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]); - } + const int outer_size = batches * height * width; - FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); - for (int c = 0; c < depth; ++c) { - int32 input_diff = - static_cast(input_data[Offset(input_dims, c, x, y, b)]) - - max_in_row; - if (input_diff >= diff_min) { - const int32 input_diff_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_diff, input_beta_multiplier, input_beta_left_shift); - const FixedPointScaledDiff scaled_diff_f8 = - FixedPointScaledDiff::FromRaw(input_diff_rescaled); - sum_of_exps = - sum_of_exps + gemmlowp::Rescale( - exp_on_negative_values(scaled_diff_f8)); - } + for (int b = 0; b < outer_size; ++b) { + const uint8* input_data_ptr = input_data + b * depth; + uint8* output_data_ptr = output_data + b * depth; + + // Determine the largest entry in the current row + uint8 max_in_row = 0; + { + int c = 0; +#ifdef USE_NEON + uint8x16_t max16_0 = vdupq_n_u8(0); + uint8x16_t max16_1 = vdupq_n_u8(0); + for (; c <= depth - 32; c += 32) { + max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0)); + max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16)); + } + uint8x16_t max16 = vmaxq_u8(max16_0, max16_1); + if (c <= depth - 16) { + max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c)); + c += 16; + } + uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16)); + if (c <= depth - 8) { + max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c)); + c += 8; + } + uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4)); + uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2)); + uint8x8_t max1 = vpmax_u8(max2, max2); + max_in_row = vget_lane_u8(max1, 0); +#endif + for (; c < depth; ++c) { + max_in_row = std::max(max_in_row, input_data_ptr[c]); + } + } + +#ifdef USE_NEON + using FixedPointAccumInt32x4 = + gemmlowp::FixedPoint; + using FixedPointScaledDiffInt32x4 = + gemmlowp::FixedPoint; + using FixedPoint0Int32x4 = gemmlowp::FixedPoint; + FixedPoint0Int32x4 input_beta_multiplier_f0 = + FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier); + int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row); +#endif + + // Compute the sum of exponentials of the differences of entries in the + // current row from the largest entry in the current row. + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + { + int c = 0; +#ifdef USE_NEON + int32x4_t diff_min_s32 = vdupq_n_s32(diff_min); + FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero(); + FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero(); + FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero(); + for (; c <= depth - 8; c += 8) { + uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c)); + int16x8_t input_diff_s16 = + vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16); + int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); + int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); + int32x4_t mask_0 = + gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32); + int32x4_t mask_1 = + gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32); + FixedPointScaledDiffInt32x4 scaled_diff_0 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift)); + FixedPointScaledDiffInt32x4 scaled_diff_1 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift)); + FixedPointAccumInt32x4 exps_0 = + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_0)); + FixedPointAccumInt32x4 exps_1 = + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_1)); + FixedPointAccumInt32x4 masked_exps_0 = + SelectUsingMask(mask_0, exps_0, zeros); + FixedPointAccumInt32x4 masked_exps_1 = + SelectUsingMask(mask_1, exps_1, zeros); + sum_of_exps_0 = sum_of_exps_0 + masked_exps_0; + sum_of_exps_1 = sum_of_exps_1 + masked_exps_1; + } + int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw(); + int32x2_t sum_of_exps_reduced_2 = + vadd_s32(vget_low_s32(sum_of_exps_reduced_4), + vget_high_s32(sum_of_exps_reduced_4)); + int32x2_t sum_of_exps_reduced_1 = + vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2); + sum_of_exps = + FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0)); +#endif + for (; c < depth; ++c) { + int32 input_diff = static_cast(input_data_ptr[c]) - max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); } + } + } - int32 fixed_sum_of_exps = sum_of_exps.raw(); - // TODO(starka): Use a NEON intrinsic like vclzq_u32 instead. - int headroom_plus_one = - __builtin_clz(static_cast(fixed_sum_of_exps)); - // This is the number of bits to the left of the binary point above 1.0. - // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and - // no later adjustment will be needed. - int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; - int32 shifted_sum_minus_one = static_cast( - (static_cast(fixed_sum_of_exps) << headroom_plus_one) - - (static_cast(1) << 31)); - - FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( - FixedPoint0::FromRaw(shifted_sum_minus_one)); + // Compute the fixed-point multiplier and shift that we need to apply to + // perform a division by the above-computed sum-of-exponentials. + int32 fixed_sum_of_exps = sum_of_exps.raw(); + int headroom_plus_one = + __builtin_clz(static_cast(fixed_sum_of_exps)); + // This is the number of bits to the left of the binary point above 1.0. + // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and + // no later adjustment will be needed. + int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; + int32 shifted_sum_minus_one = static_cast( + (static_cast(fixed_sum_of_exps) << headroom_plus_one) - + (static_cast(1) << 31)); + FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixedPoint0::FromRaw(shifted_sum_minus_one)); + + // Compute the quotients of exponentials of differences of entries in the + // current row from the largest entry, over the previously-computed sum of + // exponentials. + { + int c = 0; +#ifdef USE_NEON + int16x8_t diff_min_s16 = vdupq_n_s16(diff_min); + for (; c <= depth - 8; c += 8) { + uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c)); + int16x8_t input_diff_s16 = + vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16); + int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); + int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); + uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16)); + FixedPointScaledDiffInt32x4 scaled_diff_0 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift)); + FixedPointScaledDiffInt32x4 scaled_diff_1 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift)); + FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0); + FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1); + int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT( + vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()), + num_bits_over_unit + 31 - 8); + int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT( + vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()), + num_bits_over_unit + 31 - 8); + int16x8_t output_s16 = + vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1)); + uint8x8_t output_u8 = vqmovun_s16(output_s16); + uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0)); + vst1_u8(output_data_ptr + c, masked_output); + } +#endif + for (; c < depth; ++c) { + int32 input_diff = static_cast(input_data_ptr[c]) - max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); - for (int c = 0; c < depth; ++c) { - int32 input_diff = - static_cast(input_data[Offset(input_dims, c, x, y, b)]) - - max_in_row; - if (input_diff >= diff_min) { - const int32 input_diff_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_diff, input_beta_multiplier, input_beta_left_shift); - const FixedPointScaledDiff scaled_diff_f8 = - FixedPointScaledDiff::FromRaw(input_diff_rescaled); - - FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); - int32 unsat_output = gemmlowp::RoundingDivideByPOT( - (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); - - output_data[Offset(output_dims, c, x, y, b)] = - std::max(std::min(unsat_output, 255), 0); - - } else { - output_data[Offset(output_dims, c, x, y, b)] = 0; - } + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0); + + } else { + output_data_ptr[c] = 0; } } } @@ -2938,6 +3490,156 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, output_map.array() = input_map.array().tanh(); } +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + // Note that this is almost the exact same code as in Logistic(). + gemmlowp::ScopedProfilingLabel label("Tanh"); + /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* height */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* width */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0); + const int size = RequiredBufferSizeForDims(input_dims); + + int c = 0; + int32_t output_zero_point = 128; +#ifdef USE_NEON + // Handle 16 values at a time + for (; c <= size - 16; c += 16) { + // Read input uint8 values, cast to int16 and subtract input_zero_point + uint8x16_t input_val_u8 = vld1q_u8(input_data + c); + int16x8_t input_val_centered_0 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + int16x8_t input_val_centered_1 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + + // Prepare the bit masks that we will use at the end to implement the logic + // that was expressed in the scalar code with branching: + // if (input_val_centered < -input_range_radius) { + // output_val = 0; + // } else if (input_val_centered > input_range_radius) { + // output_val = 255; + // } else { + // ... + uint16x8_t mask_rightclamp_0 = + vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_rightclamp_1 = + vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_leftclamp_0 = + vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius)); + uint16x8_t mask_leftclamp_1 = + vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius)); + uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), + vshrn_n_u16(mask_rightclamp_1, 8)); + uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), + vshrn_n_u16(mask_leftclamp_1, 8)); + + // This performs what is expressed in the scalar code as + // const int32 input_val_rescaled = + // MultiplyByQuantizedMultiplierGreaterThanOne( + // input_val_centered, input_multiplier, input_left_shift); + int32x4_t input_val_rescaled_0 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_1 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_2 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_3 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + input_val_rescaled_0 = + vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier); + input_val_rescaled_1 = + vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier); + input_val_rescaled_2 = + vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier); + input_val_rescaled_3 = + vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier); + + // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4_0 = + FixedPoint4::FromRaw(input_val_rescaled_0); + const FixedPoint4 input_val_f4_1 = + FixedPoint4::FromRaw(input_val_rescaled_1); + const FixedPoint4 input_val_f4_2 = + FixedPoint4::FromRaw(input_val_rescaled_2); + const FixedPoint4 input_val_f4_3 = + FixedPoint4::FromRaw(input_val_rescaled_3); + const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0); + const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1); + const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2); + const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3); + + // Divide by 2^24 as in the scalar code + using gemmlowp::RoundingDivideByPOT; + int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24); + int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24); + int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24); + int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24); + + // Add the output zero point + int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point); + output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32); + output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32); + output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32); + output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32); + + // Cast output values to uint8, saturating + int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), + vqmovn_s32(output_val_s32_1)); + int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), + vqmovn_s32(output_val_s32_3)); + uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0), + vqmovun_s16(output_val_s16_1)); + + // Perform the bit-masking with the bit masks computed at the beginning, + // see the comment there. + output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp); + output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp); + + // Store back to memory + vst1q_u8(output_data + c, output_val_u8); + } +#endif + // Leftover loop: handle one value at a time with scalar code. + for (; c < size; ++c) { + const uint8 input_val_u8 = input_data[c]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered < -input_range_radius) { + output_val = 0; + } else if (input_val_centered > input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); + output_val_s32 += output_zero_point; + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[c] = output_val; + } +} inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, int32 zero_point, double scale, float* output_data, const Dims<4>& output_dims) { @@ -3410,7 +4112,7 @@ inline void ResizeBilinearGeneric(const float* input_data, inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, bool align_corners) { gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); @@ -3425,13 +4127,20 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; // Specialize for 2x2 upsample. - if (output_height == 2 * input_height && output_width == 2 * input_width) { + if (!align_corners && output_height == 2 * input_height && + output_width == 2 * input_width) { ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches, input_height, input_width, depth, output_height, output_width); } else { float height_scale = static_cast(input_height) / output_height; float width_scale = static_cast(input_width) / output_width; + if (align_corners && output_height > 1) { + height_scale = static_cast(input_height - 1) / (output_height - 1); + } + if (align_corners && output_width > 1) { + width_scale = static_cast(input_width - 1) / (output_width - 1); + } ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims, batches, input_height, input_width, depth, @@ -3440,6 +4149,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); +} + template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index f8be99e82fb8721ced7a3e5da686b20ce241ea2d..4e324a5e107cf5a90c0042331899edab831c8e51 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ #define TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ -// TDOD(ghodrat): Remove this header file and the dependency to internal data +// TODO(ghodrat): Remove this header file and the dependency to internal data // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index 98f2e365c5249a6c28673fc185ebec34cc2105b2..18be6777a5caeb45a4ffabd8b7f1793de7b053f8 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -22,27 +22,20 @@ limitations under the License. namespace tflite { -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift) { - TFLITE_CHECK(double_multiplier >= 0.); - TFLITE_CHECK(double_multiplier < 1.); +void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, + int* shift) { if (double_multiplier == 0.) { *quantized_multiplier = 0; - *right_shift = 0; + *shift = 0; return; } - TFLITE_CHECK(double_multiplier > 0.); - const double q = std::frexp(double_multiplier, right_shift); - *right_shift *= -1; - + const double q = std::frexp(double_multiplier, shift); auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); TFLITE_CHECK(q_fixed <= (1ll << 31)); if (q_fixed == (1ll << 31)) { q_fixed /= 2; - --*right_shift; + ++*shift; } - TFLITE_CHECK_GE(*right_shift, 0); TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); *quantized_multiplier = static_cast(q_fixed); } @@ -50,17 +43,20 @@ void QuantizeMultiplierSmallerThanOne(double double_multiplier, void QuantizeMultiplierGreaterThanOne(double double_multiplier, int32_t* quantized_multiplier, int* left_shift) { - TFLITE_CHECK(double_multiplier > 1.); - const double q = std::frexp(double_multiplier, left_shift); - auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); - TFLITE_CHECK(q_fixed <= (1ll << 31)); - if (q_fixed == (1ll << 31)) { - q_fixed /= 2; - ++*left_shift; - } + TFLITE_CHECK_GT(double_multiplier, 1.); + QuantizeMultiplier(double_multiplier, quantized_multiplier, left_shift); TFLITE_CHECK_GE(*left_shift, 0); - TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); - *quantized_multiplier = static_cast(q_fixed); +} + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift) { + TFLITE_CHECK_LT(double_multiplier, 1.); + TFLITE_CHECK_GT(double_multiplier, 0.); + int shift; + QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); + TFLITE_CHECK_LE(shift, 0); + *right_shift = -shift; } void PreprocessSoftmaxScaling(double beta, double input_scale, diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index efb7191c8deb2a23ea5473ab131d2b6537202765..ba06bc0975b6847b24592daa60efe99983d03707 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -20,7 +20,8 @@ limitations under the License. namespace tflite { // Decompose a double multiplier into a Q0.31 int32 representation of its -// significand, and shift representation of its exponent. +// significand, and shift representation of NEGATIVE its exponent --- +// this is intended as a RIGHT-shift. // // Restricted to the case where the multiplier < 1 (and non-negative). void QuantizeMultiplierSmallerThanOne(double double_multiplier, @@ -35,6 +36,16 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier, int32_t* quantized_multiplier, int* left_shift); +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Handles an arbitrary positive multiplier. The 'shift' output-value is +// basically the 'floating-point exponent' of the multiplier: +// Negative for a right-shift (when the multiplier is <1), positive for a +// left-shift (when the multiplier is >1) +void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, + int* shift); + // This first creates a multiplier in a double equivalent of // Q(input_integer_bits).(31-input_integer_bits) representation, with extra // precision in the double's fractional bits. It then splits the result into diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index d6f306e2cbae3c780b3d773638ba46cd2abf02f5..19b1b408ec74b0939065b0ad10b91ecfc2cd4765 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -31,7 +31,7 @@ TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { }; EXPECT_DEATH(quantize(-0.1), ""); - EXPECT_THAT(quantize(0.0), Pair(0, 0)); + EXPECT_DEATH(quantize(0.0), ""); EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); // Around 0.5 we can see the change in exponent and how we try hard to diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index afc3e26e7988a369fb777ae99c08c4e98f26ebb8..c05c21b472b05f2cbe133adf94d91ab0c6d9ef40 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ -// TDOD(ghodrat): Remove this header file and the dependency to internal data +// TODO(ghodrat): Remove this header file and the dependency to internal data // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 5ad1178f8c473261a8024da1b91533095e82a2d4..0c6f198170384556a4ab0287512524eab17f064e 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -1358,6 +1358,278 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, } } +// Quantized LSTM cell implementation. +// The quantization of the input, output arrays is as follows: +// - The input activations are quantized as uint8 on the interval +// [-1, 127/128]. +// The rationale for that is that that is the natural interval for output +// activations (see next point) and these need to be concatenated together. +// We could accommodate different ranges by re-scaling, but we empirically +// found that setting the input activations range to be [-1, 127/128] in the +// first place, removing the need for re-scaling, greatly improves accuracy. +// - The output activations are quantized as uint8 on the interval +// [-1, 127/128]. +// The rationale for that is that the definition of a LSTM cell makes them +// intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128] +// makes for simpler, more accurate fixed-point arithmetic. +// - The output-at-previous-timestep state array is obviously quantized as +// the output activations. +// - The internal LSTM memory (not the output-at-previous-timestep, the other +// internal state array) is int16-quantized and may use any power-of-two, +// symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call +// StateIntegerBits below, see the below discussion of that template +// parameter ("The StateIntegerBits template parameter"). +// - The output of the internal fully-connected node is int16-quantized +// on the interval [-8, 8 * 32767/32768], the rationale for which is +// explained just below ("Why [-8, 8] for fully-connected output?"). +// +// +// === The StateIntegerBits template parameter === +// +// The StateIntegerBits template parameter controls the fixed-point format used +// to represent the internal memory of the LSTM cell (not the +// output-at-previous-timestep, the other internal state array). It's currently +// a template parameter so that the model can control that. The most typical +// value for StateIntegerBits is 4. Other plausible values are anywhere between +// 3 and 5. We might eventually standardize on a single supported value, e.g. 4, +// and drop that template parameter. The reason why it can't be a runtime +// parameter is that this controls the fixed-point format used, i.e. we need to +// generate actually different code based on it. In particular, we generate code +// for a fixed-point tanh() implementation for that format, which internally +// uses a fixed-point exp() implementation, which internally uses a +// barrel-shifter with a number of steps that depends on StateIntegerBits. +// Another consequence of that is that a higher value of StateIntegerBits +// results in a more expensive implementation (more barrel shifter steps +// needed). +// +// +// === Why [-8, 8] for fully-connected output? === +// +// This array is only fed to Logistic and Tanh functions, for which +// the quantized implementation will want to use fixed-point arithmetic, +// requiring a power-of-two representation interval. Thus, we should right +// away quantize this array to a power-of-two interval; otherwise, +// implementation will need to rescale that, losing any benefit that a tighter +// representation interval might otherwise yield, while introducting some +// numerical error and computational overhead. +// +// Now, Logistic and Tanh +// are nearly constant (nearly equal to their horizontal asymptotes) +// outside of a small bounded interval around 0: +// +// Logistic(4) = 1 - 1.8e-2 Tanh(4) = 1 - 6.7e-4 +// Logistic(8) = 1 - 3.4e-4 Tanh(8) = 1 - 2.3e-7 +// Logistic(16) = 1 - 1.1e-7 Tanh(16) = 1 - 2.5e-14 +// +// From this, we see that clamping to [-4, 4] would be too inaccurate +// (the error of 1.8e-2 on Logistic would be felt even in 8bit precision) +// while clamping to [-16, 16] would make no difference even in float32. +// However, for a fixed-point implementation in 16-bit integers, using 5 +// integer bits to represent the [-16, 16] range would leave only 11 +// fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive +// representable values. Notice that that is higher than the +// worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic. +// Using [-8, 8] thus seems like the better compromise overall, enjoying +// an increment of 2.4e-4 between representable values and a worst-case +// clamping error of 3.4e-4, both better than the increment of 4.9e-4 with +// [-16, 16]. +// +// Moreover, all other things being equal, it is nice to choose the narrower +// representation range, as that makes the implementation of fixed-point +// math functions a little cheaper (each integer bit requires an additional +// barrel-shifter atep in the implementation of exp(-x)). That is further +// reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make +// sense for 32-bit float or 32-bit fixed-point quantization, but we are +// aiming for 16-bit fixed-point quantization of these internal nodes here. +// +template +void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, + const uint8* prev_activ_data_uint8, + const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8, + const Dims<4>& weights_dims, const int32* bias_data_int32, + const Dims<4>& bias_dims, const int16* prev_state_data_int16, + const Dims<4>& prev_state_dims, int16* output_state_data_int16, + const Dims<4>& output_state_dims, uint8* output_activ_data_uint8, + const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, + const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, + const Dims<4>& activ_temp_dims, int32 weights_zero_point, + int32 accum_multiplier, int accum_shift, + gemmlowp::GemmContext* gemm_context) { + (void)gemm_context; // only used in optimized code. + + // Gather dimensions information, and perform consistency checks. + const int batches = + MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, + output_state_dims, 3, output_activ_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, + output_state_dims, 2, output_activ_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, + output_state_dims, 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + const int fc_batches = ArraySize(activ_temp_dims, 1) * + ArraySize(activ_temp_dims, 2) * + ArraySize(activ_temp_dims, 3); + const int fc_output_depth = + MatchingArraySize(weights_dims, 1, activ_temp_dims, 0); + const int fc_accum_depth = ArraySize(weights_dims, 0); + TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth); + + // Depth-concatenate prev_activ and input data together. + uint8 const* concat_input_arrays_data[2] = {input_data_uint8, + prev_activ_data_uint8}; + Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims}; + Concatenation( + 0, concat_input_arrays_data, concat_input_arrays_dims, 2, + concat_temp_data_uint8, concat_temp_dims); + + // Implementation of the fully connected node inside the LSTM cell. + // The operands are 8-bit integers, the accumulators are internally 32bit + // integers, and the output is 16-bit fixed-point with 3 integer bits so + // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that + // is explained in the function comment above. + for (int b = 0; b < fc_batches; ++b) { + for (int out_c = 0; out_c < fc_output_depth; ++out_c) { + // Internal accumulation. + // Initialize accumulator with the bias-value. + int32 accum = bias_data_int32[out_c]; + // Accumulation loop. + for (int d = 0; d < fc_accum_depth; ++d) { + int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128; + int16 weights_val = + weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point; + accum += input_val * weights_val; + } + // Down-scale the final int32 accumulator to the scale used by our + // (16-bit, using 3 integer bits) fixed-point format. The quantized + // multiplier and shift here have been pre-computed offline + // (e.g. by toco). + accum = + MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift); + // Saturate, cast to int16, and store to the temporary activations array. + accum = std::max(-32768, std::min(32767, accum)); + activ_temp_data_int16[out_c + fc_output_depth * b] = accum; + } + } + + // Rest of the LSTM cell: tanh and logistic math functions, and some adds + // and muls, all done in 16-bit fixed-point. + const int outer_size = batches * width * height; + for (int b = 0; b < outer_size; ++b) { + for (int c = 0; c < output_depth; ++c) { + // Define the fixed-point data types that we will use here. All use + // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // They only differ by the number of integral vs. fractional bits, + // determining the range of values that they can represent. + // + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8]. + // This is the range of the previous fully-connected node's output, + // which is our input here. + using F3 = gemmlowp::FixedPoint; + // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, + // 2^StateIntegerBits]. It's used to represent the internal state, whose + // number of integer bits is currently dictated by the model. See comment + // on the StateIntegerBits template parameter above. + using FS = gemmlowp::FixedPoint; + // Implementation of input gate, using fixed-point logistic function. + F3 input_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]); + F0 input_gate_output = gemmlowp::logistic(input_gate_input); + // Implementation of input modulation gate, using fixed-point tanh + // function. + F3 input_modulation_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]); + F0 input_modulation_gate_output = + gemmlowp::tanh(input_modulation_gate_input); + // Implementation of forget gate, using fixed-point logistic function. + F3 forget_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]); + F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); + // Implementation of output gate, using fixed-point logistic function. + F3 output_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]); + F0 output_gate_output = gemmlowp::logistic(output_gate_input); + // Implementation of internal multiplication nodes, still in fixed-point. + F0 input_times_input_modulation = + input_gate_output * input_modulation_gate_output; + FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]); + FS prev_state_times_forget_state = forget_gate_output * prev_state; + // Implementation of internal addition node, saturating. + FS new_state = gemmlowp::SaturatingAdd( + gemmlowp::Rescale(input_times_input_modulation), + prev_state_times_forget_state); + // Implementation of last internal Tanh node, still in fixed-point. + // Since a Tanh fixed-point implementation is specialized for a given + // number or integer bits, and each specialization can have a substantial + // code size, and we already used above a Tanh on an input with 3 integer + // bits, and per the table in the above function comment there is no + // significant accuracy to be lost by clamping to [-8, +8] for a + // 3-integer-bits representation, let us just do that. This helps people + // porting this to targets where code footprint must be minimized. + F3 new_state_f3 = gemmlowp::Rescale<3>(new_state); + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3); + // Store the new internal state back to memory, as 16-bit integers. + // Note: here we store the original value with StateIntegerBits, not + // the rescaled 3-integer-bits value fed to tanh. + output_state_data_int16[b * output_depth + c] = new_state.raw(); + // Down-scale the output activations to 8-bit integers, saturating, + // and store back to memory. + int16 rescaled_output_activ = + gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); + int16 clamped_output_activ = + std::max(-128, std::min(127, rescaled_output_activ)); + output_activ_data_uint8[b * output_depth + c] = + 128 + clamped_output_activ; + } + } +} + +template +void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, + int axis, int outputs_count, Scalar* const* output_data, + const Dims<4>* const* output_dims) { + const int batches = ArraySize(*output_dims[0], 3); + const int height = ArraySize(*output_dims[0], 2); + const int width = ArraySize(*output_dims[0], 1); + const int depth = ArraySize(*output_dims[0], 0); + + const int slice_size = ArraySize(*output_dims[0], axis); + + for (int i = 0; i < outputs_count; ++i) { + int offset = i * slice_size * input_dims.strides[axis]; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + auto out = Offset(*output_dims[i], c, x, y, b); + auto in = Offset(input_dims, c, x, y, b); + output_data[i][out] = input_data[offset + in]; + } + } + } + } + } +} + template void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int outputs_count, Scalar* const* output_data, @@ -1368,28 +1640,12 @@ void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); } - const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); - const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); - const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); // for now we dont have a model with a TensorFlowSplit // with fused activation function. TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - int in_c = 0; - for (int i = 0; i < outputs_count; ++i) { - const int depth = ArraySize(*output_dims[i], 0); - for (int c = 0; c < depth; ++c) { - output_data[i][Offset(*output_dims[i], c, x, y, b)] = - input_data[Offset(input_dims, in_c, x, y, b)]; - in_c++; - } - } - TFLITE_DCHECK(in_c == ArraySize(input_dims, 0)); - } - } - } + + TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count, + output_data, output_dims); } // TODO(benoitjacob) make this a proper reference impl without Eigen! @@ -2043,6 +2299,54 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, } } +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + const int32 output_zero_point = 128; + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered <= -input_range_radius) { + output_val = 0; + } else if (input_val_centered >= input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = + FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); + + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); + output_val_s32 += output_zero_point; + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[Offset(output_dims, c, x, y, b)] = output_val; + } + } + } + } +} + inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, int32 zero_point, double scale, float* output_data, const Dims<4>& output_dims) { @@ -2202,7 +2506,7 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, bool align_corners) { int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); int32 input_width = ArraySize(input_dims, 1); @@ -2216,6 +2520,12 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; float height_scale = static_cast(input_height) / output_height; float width_scale = static_cast(input_width) / output_width; + if (align_corners && output_height > 1) { + height_scale = static_cast(input_height - 1) / (output_height - 1); + } + if (align_corners && output_width > 1) { + width_scale = static_cast(input_width - 1) / (output_width - 1); + } for (int b = 0; b < batches; ++b) { for (int y = 0; y < output_height; ++y) { @@ -2243,6 +2553,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); +} + template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, @@ -2370,13 +2689,15 @@ inline int StartIndex(int start, int stride, int dim, bool masked) { return masked ? (stride > 0 ? 0 : dim - 1) : start; } -inline int StopIndex(int stop, int stride, int dim, bool masked) { - return masked ? (stride > 0 ? dim : -1) : stop; +inline int StopIndex(int start, int stop, int stride, int dim, bool masked, + bool shrink_axis_masked) { + return shrink_axis_masked ? stride > 0 ? start + 1 : start - 1 + : masked ? (stride > 0 ? dim : -1) : stop; } template inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, + int begin_mask, int end_mask, int shrink_axis_mask, const std::vector& starts, const std::vector& stops, const std::vector& strides, T* output_data, @@ -2387,19 +2708,23 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, const int start_b = StartIndex(starts[3], strides[3], input_dims.sizes[3], begin_mask & 8); const int stop_b = - StopIndex(stops[3], strides[3], input_dims.sizes[3], end_mask & 8); + StopIndex(start_b, stops[3], strides[3], input_dims.sizes[3], + end_mask & 8, shrink_axis_mask & 8); const int start_h = StartIndex(starts[2], strides[2], input_dims.sizes[2], begin_mask & 4); const int stop_h = - StopIndex(stops[2], strides[2], input_dims.sizes[2], end_mask & 4); + StopIndex(start_h, stops[2], strides[2], input_dims.sizes[2], + end_mask & 4, shrink_axis_mask & 4); const int start_w = StartIndex(starts[1], strides[1], input_dims.sizes[1], begin_mask & 2); const int stop_w = - StopIndex(stops[1], strides[1], input_dims.sizes[1], end_mask & 2); + StopIndex(start_w, stops[1], strides[1], input_dims.sizes[1], + end_mask & 2, shrink_axis_mask & 2); const int start_d = StartIndex(starts[0], strides[0], input_dims.sizes[0], begin_mask & 1); const int stop_d = - StopIndex(stops[0], strides[0], input_dims.sizes[0], end_mask & 1); + StopIndex(start_d, stops[0], strides[0], input_dims.sizes[0], + end_mask & 1, shrink_axis_mask & 1); T* out_ptr = output_data; for (int in_b = start_b; LoopCondition(in_b, stop_b, strides[3]); @@ -2417,6 +2742,18 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, } } +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + StridedSlice(input_data, input_dims, begin_mask, end_mask, + /*shrink_axis_mask=*/0, starts, stops, strides, output_data, + output_dims); +} + template inline void Slice(const T* input_data, const Dims<4>& input_dims, const std::vector& begin, const std::vector& size, @@ -2449,6 +2786,14 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } } +template +inline void Exp(const T* input_data, const size_t num_elements, + T* output_data) { + for (size_t idx = 0; idx < num_elements; ++idx) { + output_data[idx] = exp(input_data[idx]); + } +} + template inline void Mean(T* input_data, const int* input_dims, const int input_num_dims, T* output_data, const int* output_dims, diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index dfe76c2afd40c692063710a4d98464b55e40feb9..62e38e0d4c3e023d0ed2242fc9438b096b86dc59 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -81,6 +81,51 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { return GetTensorDims(dims->data, dims->size); } +// A list of tensors in a format that can be used by kernels like split and +// concatenation. +template +class VectorOfTensors { + public: + // Build with the tensors in 'tensor_list'. + VectorOfTensors(const TfLiteContext& context, + const TfLiteIntArray& tensor_list) { + int num_tensors = tensor_list.size; + + all_data_.reserve(num_tensors); + all_dims_.reserve(num_tensors); + all_dims_ptr_.reserve(num_tensors); + + for (int i = 0; i < num_tensors; ++i) { + TfLiteTensor* t = &context.tensors[tensor_list.data[i]]; + all_data_.push_back(GetTensorData(t)); + all_dims_.push_back(GetTensorDims(t)); + } + + // Taking the pointer from inside a std::vector is only OK if the vector is + // never modified, so we populate all_dims in the previous loop and then we + // are free to grab iterators here. + for (int i = 0; i < num_tensors; ++i) { + all_dims_ptr_.push_back(&all_dims_[i]); + } + } + // Return a pointer to the data pointers of all tensors in the list. For + // example: + // float* const* f = v.data(); + // f[0][1] is the second element of the first tensor. + T* const* data() const { return all_data_.data(); } + + // Return a pointer the dim pointers of all tensors in the list. For + // example: + // const Dims<4>* const* d = v.dims(); + // dims[1] are the dimensions of the second tensor in the list. + const Dims<4>* const* dims() const { return all_dims_ptr_.data(); } + + private: + std::vector all_data_; + std::vector> all_dims_; + std::vector*> all_dims_ptr_; +}; + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc index 904a97803a6a9ba369c1e64c711b12d19ffc10c4..f4181b18a8f46fd9bef4b81a210a6b8134a4e9d0 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" #ifndef USE_NEON #if defined(__ARM_NEON__) || defined(__ARM_NEON) diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc index b0546c00cf977af5f722a802866448b0cb293b8d..955e8c5764c6adad37a0009f4ddf8accb437b174 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.cc +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/kernel_util.h" + #include #include +#include + #include "tensorflow/contrib/lite/kernels/internal/round.h" namespace tflite { @@ -84,4 +87,27 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation, } } +bool HaveSameShapes(TfLiteTensor* input1, TfLiteTensor* input2) { + return TfLiteIntArrayEqual(input1->dims, input2->dims); +} + +TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, + TfLiteTensor* input1, + TfLiteTensor* input2, + TfLiteIntArray** output_shape) { + int64_t dims1 = NumDimensions(input1); + int64_t dims2 = NumDimensions(input2); + int64_t out_dims = std::max(dims1, dims2); + std::unique_ptr shape( + TfLiteIntArrayCreate(out_dims), TfLiteIntArrayFree); + for (int i = 0; i < out_dims; ++i) { + int64_t d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1); + int64_t d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1); + TF_LITE_ENSURE(context, d1 == d2 || d1 == 1 || d2 == 1); + shape->data[out_dims - i - 1] = std::max(d1, d2); + } + *output_shape = shape.release(); + return kTfLiteOk; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index 1cf30ecff9760d218d279cc6c7132589e11cc15c..28f53b9fbbc5620f2fab5c73e40bed8af4af5f1e 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -35,6 +35,14 @@ inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node, inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; } inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; } +inline int64_t NumElements(const TfLiteTensor* t) { + int64_t count = 1; + for (int i = 0; i < NumDimensions(t); ++i) { + count *= SizeOfDimension(t, i); + } + return count; +} + inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, const TfLiteNode* node, int index) { const bool use_tensor = node->inputs->data[index] != kOptionalTensor; @@ -44,6 +52,25 @@ inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, return nullptr; } +// Determines whether tensor is constant. +inline bool IsConstantTensor(TfLiteTensor* tensor) { + return tensor->allocation_type == kTfLiteMmapRo; +} + +// Determines whether tensor is dynamic. Note that a tensor can be non-const and +// not dynamic. This function specificially checks for a dynamic tensor. +inline bool IsDynamicTensor(TfLiteTensor* tensor) { + return tensor->allocation_type == kTfLiteDynamic; +} + +// Sets tensor to dynamic. +inline void SetTensorToDynamic(TfLiteTensor* tensor) { + if (tensor->allocation_type != kTfLiteDynamic) { + tensor->allocation_type = kTfLiteDynamic; + tensor->data.raw = nullptr; + } +} + // Calculates the multiplication factor for a quantized convolution (or // quantized depthwise convolution) involving the given tensors. Returns an // error if the scales of the tensors are not compatible. @@ -60,6 +87,15 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation, float* activation_min, float* activation_max); +// Return true if the given tensors have the same shape. +bool HaveSameShapes(TfLiteTensor* input1, TfLiteTensor* input2); + +// Calculate the output_shape that is necessary for element-wise operations +// with broadcasting involving the two input tensors. +TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, + TfLiteTensor* input1, + TfLiteTensor* input2, + TfLiteIntArray** output_shape); } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util_test.cc b/tensorflow/contrib/lite/kernels/kernel_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c65b68970f6853e17af3a70aad7a2bc982a1ee60 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/kernel_util_test.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace { + +void ReportError(TfLiteContext* context, const char* format, ...) {} + +class KernelUtilTest : public ::testing::Test { + public: + KernelUtilTest() { + context_.ReportError = ReportError; + + tensor1_.dims = nullptr; + tensor2_.dims = nullptr; + tensor1_.allocation_type = kTfLiteMmapRo; + tensor2_.allocation_type = kTfLiteMmapRo; + } + ~KernelUtilTest() { + TfLiteTensorFree(&tensor1_); + TfLiteTensorFree(&tensor2_); + } + + void SetShape(TfLiteTensor* tensor, std::initializer_list dims) { + TfLiteTensorFree(tensor); + tensor->dims = TfLiteIntArrayCreate(dims.size()); + int i = 0; + for (int d : dims) { + tensor->dims->data[i] = d; + ++i; + } + } + + std::vector GetShape(TfLiteIntArray* dims) { + std::vector result; + for (int i = 0; i < dims->size; ++i) { + result.push_back(dims->data[i]); + } + return result; + } + + protected: + TfLiteContext context_; + TfLiteTensor tensor1_; + TfLiteTensor tensor2_; +}; + +TEST_F(KernelUtilTest, SameShapeEmpty) { + EXPECT_TRUE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor1_, {1, 2, 3}); + EXPECT_FALSE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor2_, {1, 2}); + EXPECT_FALSE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor2_, {1, 2, 3, 4}); + EXPECT_FALSE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor2_, {1, 2, 3}); + EXPECT_TRUE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor2_, {}); + EXPECT_FALSE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor1_, {}); + EXPECT_TRUE(HaveSameShapes(&tensor1_, &tensor2_)); +} + +TEST_F(KernelUtilTest, BroadcastShapeIncompatibleDim) { + TfLiteIntArray* output = nullptr; + SetShape(&tensor1_, {1, 2}); + SetShape(&tensor2_, {1, 3}); + EXPECT_NE(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + EXPECT_EQ(output, nullptr); +} + +TEST_F(KernelUtilTest, BroadcastShapeOnes) { + TfLiteIntArray* output = nullptr; + SetShape(&tensor1_, {1, 1}); + SetShape(&tensor2_, {1, 3}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + TfLiteIntArrayFree(output); + + SetShape(&tensor1_, {1, 2}); + SetShape(&tensor2_, {1, 1}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + TfLiteIntArrayFree(output); +} + +TEST_F(KernelUtilTest, BroadcastShapeScalars) { + TfLiteIntArray* output = nullptr; + SetShape(&tensor1_, {1, 2}); + SetShape(&tensor2_, {}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + EXPECT_THAT(GetShape(output), ::testing::ElementsAre(1, 2)); + TfLiteIntArrayFree(output); + + SetShape(&tensor1_, {}); + SetShape(&tensor2_, {2}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + EXPECT_THAT(GetShape(output), ::testing::ElementsAre(2)); + TfLiteIntArrayFree(output); +} + +TEST_F(KernelUtilTest, BroadcastShapeDifferentSizes) { + TfLiteIntArray* output = nullptr; + SetShape(&tensor1_, {1, 2}); + SetShape(&tensor2_, {3, 1, 1}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + EXPECT_THAT(GetShape(output), ::testing::ElementsAre(3, 1, 2)); + TfLiteIntArrayFree(output); + + SetShape(&tensor1_, {1, 2, 3, 4}); + SetShape(&tensor2_, {1, 3, 1}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + EXPECT_THAT(GetShape(output), ::testing::ElementsAre(1, 2, 3, 4)); + TfLiteIntArrayFree(output); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/mean.cc index 540e5a364dd60a42c316199d0ebe878ae07e6756..aff19581ea56f94c08638b7b388ae181f566cf4f 100644 --- a/tensorflow/contrib/lite/kernels/mean.cc +++ b/tensorflow/contrib/lite/kernels/mean.cc @@ -35,10 +35,12 @@ struct MeanContext { MeanContext(TfLiteContext* context, TfLiteNode* node) { params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); + axis = GetInput(context, node, 1); output = GetOutput(context, node, 0); } TfLiteMeanParams* params; TfLiteTensor* input; + TfLiteTensor* axis; TfLiteTensor* output; }; @@ -54,45 +56,26 @@ void Free(TfLiteContext* context, void* buffer) { delete reinterpret_cast(buffer); } -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - - MeanContext op_context(context, node); - int input_num_dims = NumDimensions(op_context.input); - int axis_num_dims = op_context.params->num_axis_dimensions; - - // Creates a temp index to iterate through input data. - int* scratch_tensor_index = reinterpret_cast(node->user_data); - TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); - node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; - scratch_tensor->type = kTfLiteInt32; - scratch_tensor->allocation_type = kTfLiteArenaRw; - TfLiteIntArray* index_size = TfLiteIntArrayCreate(1); - index_size->data[0] = input_num_dims; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, scratch_tensor, index_size)); - - // Creates a temp tensor to store resolved axis given input data. - node->temporaries->data[1] = *scratch_tensor_index + 1; - TfLiteTensor* axis_tensor = &context->tensors[node->temporaries->data[1]]; - axis_tensor->type = kTfLiteInt32; - axis_tensor->allocation_type = kTfLiteArenaRw; +// Resizes the temp tensor that stores resolved axis. +TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, + TfLiteTensor* resolved_axis) { TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1); - axis_size->data[0] = op_context.params->num_axis_dimensions; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, axis_tensor, axis_size)); + axis_size->data[0] = static_cast(NumElements(op_context->axis)); + return context->ResizeTensor(context, resolved_axis, axis_size); +} - // Determines size of output tensor. - const TfLiteIntArray* input_dims = op_context.input->dims; - const int* axis = op_context.params->axis; - if (op_context.params->keep_dims) { +// Resizes output array based on the input size and resolved axis. +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + MeanContext* op_context) { + size_t num_axis = NumElements(op_context->axis); + const TfLiteIntArray* input_dims = op_context->input->dims; + int input_num_dims = NumDimensions(op_context->input); + const int* axis = GetTensorData(op_context->axis); + if (op_context->params->keep_dims) { TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims); for (int idx = 0; idx < input_num_dims; ++idx) { bool is_axis = false; - for (int axis_idx = 0; axis_idx < axis_num_dims; ++axis_idx) { + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) { is_axis = true; break; @@ -104,11 +87,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_dims->data[idx] = input_dims->data[idx]; } } - return context->ResizeTensor(context, op_context.output, output_dims); + return context->ResizeTensor(context, op_context->output, output_dims); } else { // Calculates size of reducing axis. - int num_reduce_axis = axis_num_dims; - for (int i = 0; i < axis_num_dims; ++i) { + int num_reduce_axis = num_axis; + for (int i = 0; i < num_axis; ++i) { int current = axis[i]; if (current < 0) { current += input_num_dims; @@ -131,7 +114,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { int num_skip_axis = 0; for (int idx = 0; idx < input_num_dims; ++idx) { bool is_axis = false; - for (int axis_idx = 0; axis_idx < axis_num_dims; ++axis_idx) { + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) { ++num_skip_axis; is_axis = true; @@ -142,24 +125,74 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_dims->data[idx - num_skip_axis] = input_dims->data[idx]; } } - return context->ResizeTensor(context, op_context.output, output_dims); + return context->ResizeTensor(context, op_context->output, output_dims); + } +} + +// Initializes temp tensors to store index and resolved axis. +TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, + MeanContext* op_context) { + // Creates a temp index to iterate through input data. + int* scratch_tensor_index = reinterpret_cast(node->user_data); + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries->data[0] = *scratch_tensor_index; + TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; + scratch_tensor->type = kTfLiteInt32; + scratch_tensor->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* index_size = TfLiteIntArrayCreate(1); + index_size->data[0] = NumDimensions(op_context->input); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, scratch_tensor, index_size)); + + // Creates a temp tensor to store resolved axis given input data. + node->temporaries->data[1] = *scratch_tensor_index + 1; + TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + resolved_axis->type = kTfLiteInt32; + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + MeanContext op_context(context, node); + TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); + + TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + // Leaves work to Eval if axis is not constant; else resizes output. + if (!IsConstantTensor(op_context.axis)) { + SetTensorToDynamic(op_context.output); + SetTensorToDynamic(resolved_axis); + return kTfLiteOk; } + resolved_axis->allocation_type = kTfLiteArenaRw; + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + return ResizeOutputTensor(context, &op_context); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { MeanContext op_context(context, node); + int num_axis = static_cast(NumElements(op_context.axis)); TfLiteTensor* temp_index = &context->tensors[node->temporaries->data[0]]; TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } -#define TF_LITE_MEAN(kernel_type, data_type) \ - kernel_type::Mean<>( \ - GetTensorData(op_context.input), \ - op_context.input->dims->data, op_context.input->dims->size, \ - GetTensorData(op_context.output), \ - op_context.output->dims->data, op_context.output->dims->size, \ - op_context.params->axis, op_context.params->num_axis_dimensions, \ - op_context.params->keep_dims, GetTensorData(temp_index), \ +#define TF_LITE_MEAN(kernel_type, data_type) \ + kernel_type::Mean<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ GetTensorData(resolved_axis)) if (kernel_type == kReference) { diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/mean_test.cc index 4305c0632f5a52b858a056109187ad4a0cc2e46e..c4c53c2ded351849e7c458fc754c36395a25ebd0 100644 --- a/tensorflow/contrib/lite/kernels/mean_test.cc +++ b/tensorflow/contrib/lite/kernels/mean_test.cc @@ -25,58 +25,108 @@ using ::testing::ElementsAreArray; class BaseMeanOpModel : public SingleOpModel { public: - BaseMeanOpModel(const TensorData& input, const TensorData& output, - std::initializer_list axis, bool keep_dims) { - input_ = AddInput(input); - output_ = AddOutput(output); - SetBuiltinOp( - BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, - CreateMeanOptions(builder_, builder_.CreateVector(axis), keep_dims) - .Union()); - BuildInterpreter({GetShape(input_)}); + void SetAxis(std::initializer_list data) { PopulateTensor(axis_, data); } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); } - int input() { return input_; } + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } protected: int input_; + int axis_; int output_; }; -class FloatMeanOpModel : public BaseMeanOpModel { +// Model for the tests case where axis is a const tensor. +class MeanOpConstModel : public BaseMeanOpModel { public: - using BaseMeanOpModel::BaseMeanOpModel; - - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); + MeanOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, + CreateMeanOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); } +}; - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } +// Model for the tests case where axis is a dynamic tensor. +class MeanOpDynamicModel : public BaseMeanOpModel { + public: + MeanOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, + CreateMeanOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } }; -TEST(FloatMeanOpTest, NotKeepDims) { +TEST(ConstMeanOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); +} + +TEST(ConstMeanOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); +} + +TEST(DynamicMeanOpTest, NotKeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; - FloatMeanOpModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, - {1, 0, -3, -3}, false); + MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::initializer_list axis = {1, 0, -3, -3}; + m.SetAxis(axis); m.SetInput(data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); } -TEST(FloatMeanOpTest, KeepDims) { +TEST(DynamicMeanOpTest, KeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; - FloatMeanOpModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, - {0, 2}, true); + MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, + true); + std::initializer_list axis = {0, 2}; + m.SetAxis(axis); m.SetInput(data); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); } diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 81c73f2523186c2d4072d56bdc8980fcdbb588a3..54575019de4c678ce25561cf2ac8dc80c9973363 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -37,7 +37,23 @@ constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; constexpr int kOutputTensor = 0; +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -45,43 +61,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); - for (int i = 0; i < NumDimensions(input1); ++i) { - TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), - SizeOfDimension(input2, i)); - } + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + output->type = input2->type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); - TF_LITE_ENSURE_EQ(context, input1->type, output->type); - TF_LITE_ENSURE_EQ(context, input2->type, output->type); + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } - TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); return context->ResizeTensor(context, output, output_size); } template void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteMulParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); -#define TF_LITE_MUL(type) \ - type::Mul(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops); + if (data->requires_broadcast) { + TF_LITE_MUL(reference_ops, BroadcastMul); + } else { + TF_LITE_MUL(reference_ops, Mul); + } } else { - TF_LITE_MUL(optimized_ops); + if (data->requires_broadcast) { + TF_LITE_MUL(optimized_ops, BroadcastMul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } } #undef TF_LITE_MUL } template void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteMulParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { auto input1_offset = -input1->params.zero_point; auto input2_offset = -input2->params.zero_point; auto output_offset = output->params.zero_point; @@ -98,17 +127,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeUint8(params->activation, output, &output_activation_min, &output_activation_max); -#define TF_LITE_MUL(type) \ - type::BroadcastMul(GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, output_offset, \ - output_multiplier, output_shift, output_activation_min, \ - output_activation_max, GetTensorData(output), \ - GetTensorDims(output)); +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, GetTensorData(input2), \ + GetTensorDims(input2), input2_offset, output_offset, \ + output_multiplier, output_shift, output_activation_min, \ + output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + // The quantized version of Mul doesn't support activations, so we + // always use BroadcastMul. if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops); + TF_LITE_MUL(reference_ops, BroadcastMul); } else { - TF_LITE_MUL(optimized_ops); + TF_LITE_MUL(optimized_ops, BroadcastMul); } #undef TF_LITE_MUL } @@ -116,15 +147,17 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { - EvalFloat(context, node, params, input1, input2, output); + EvalFloat(context, node, params, data, input1, input2, output); } else if (output->type == kTfLiteUInt8) { - EvalQuantized(context, node, params, input1, input2, output); + EvalQuantized(context, node, params, data, input1, input2, + output); } else { context->ReportError(context, "Mul only supports FLOAT32 and quantized UINT8 now."); @@ -137,19 +170,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace mul TfLiteRegistration* Register_MUL_REF() { - static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, mul::Eval}; return &r; } TfLiteRegistration* Register_MUL_GENERIC_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, mul::Eval}; return &r; } TfLiteRegistration* Register_MUL_NEON_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, mul::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc index 8838b300c0af167bf2ffcf944fc7c31d6173f462..f1a30f82634631ba8320421d5b36ffe446f443fa 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -25,10 +25,11 @@ using ::testing::ElementsAreArray; class BaseMulOpModel : public SingleOpModel { public: - BaseMulOpModel(TensorData input, TensorData output, + BaseMulOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, ActivationFunctionType activation_type) { - input1_ = AddInput(input); - input2_ = AddInput(input); + input1_ = AddInput(input1); + input2_ = AddInput(input2); output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, CreateMulOptions(builder_, activation_type).Union()); @@ -70,6 +71,7 @@ class QuantizedMulOpModel : public BaseMulOpModel { TEST(FloatMulOpTest, NoActivation) { FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); @@ -79,9 +81,9 @@ TEST(FloatMulOpTest, NoActivation) { } TEST(FloatMulOpTest, ActivationRELU_N1_TO_1) { - FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {}}, - ActivationFunctionType_RELU_N1_TO_1); + FloatMulOpModel m( + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU_N1_TO_1); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 5}); m.Invoke(); @@ -94,6 +96,7 @@ TEST(FloatMulOpTest, VariousInputShapes) { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, test_shapes[i]}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); @@ -105,8 +108,26 @@ TEST(FloatMulOpTest, VariousInputShapes) { } } +TEST(FloatMulOpTest, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, // always a scalar + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}))) + << "With shape number " << i; + } +} + TEST(QuantizedMulOpTest, NoActivation) { QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}, ActivationFunctionType_NONE); m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); @@ -117,6 +138,32 @@ TEST(QuantizedMulOpTest, NoActivation) { kQuantizedTolerance))); } +// for quantized Mul, the error shouldn't exceed 2*step +float GetTolerance(int min, int max) { + float kQuantizedStep = (max - min) / 255.0; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + +TEST(QuantizedMulOpTest, WithBroadcast) { + float kQuantizedTolerance = GetTolerance(-3.0, 3.0); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedMulOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, // always a scalar + {TensorType_UINT8, {}, -3.0, 3.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}, kQuantizedTolerance))) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index 17166715ca30ff3d8ba3d384110e403f8910e39d..cee3ec6197c698a11004d42dccdfe2bcca088015 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -243,7 +243,6 @@ class LSTMOpModel : public SingleOpModel { int n_output_; }; - TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { const int n_batch = 1; const int n_input = 2; @@ -282,7 +281,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, 0.04717243, 0.48944736, -0.38535351, -0.17212132}); diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 569bf0fe8fc9964a1299911d248d53862c99cbdf..c29da3862e84d6756bf5ef34b2ca06307b0a065d 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -51,17 +51,14 @@ struct PadContext { // paddings data is present. TfLiteStatus ResizeOutputTensor(TfLiteContext* context, PadContext* op_context) { - // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. - TF_LITE_ENSURE_EQ(context, op_context->dims, 4); - // Ensures the paddings array is dims x 2. TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0), op_context->dims); TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2); // Determines the size of the output tensor. - const TfLiteIntArray* input_size = op_context->input->dims; - TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context->dims); + TfLiteIntArray* input_size = op_context->input->dims; + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); const int32* paddings_data = GetTensorData(op_context->paddings); for (int idx = 0; idx < op_context->dims; ++idx) { @@ -85,11 +82,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { PadContext op_context(context, node); TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - // TODO(nupurgarg): Create wrapper functions for dynamic tensor logic. + // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, op_context.dims, 4); + // Exit early if paddings is a non-const tensor. Set output tensor to // dynamic so output size can be determined in Eval. - if (op_context.paddings->allocation_type != kTfLiteMmapRo) { - op_context.output->allocation_type = kTfLiteDynamic; + if (!IsConstantTensor(op_context.paddings)) { + SetTensorToDynamic(op_context.output); return kTfLiteOk; } return ResizeOutputTensor(context, &op_context); @@ -100,9 +99,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { PadContext op_context(context, node); // Resize the output tensor if the output tensor is dynamic. - if (op_context.output->allocation_type == kTfLiteDynamic) { + if (IsDynamicTensor(op_context.output)) { TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); - TfLiteTensorRealloc(op_context.output->bytes, op_context.output); } // TODO(nupurgarg): Change kernel implementation to take in int* instead of @@ -178,9 +176,7 @@ TfLiteRegistration* Register_PAD_GENERIC_OPT() { return &r; } -TfLiteRegistration* Register_PAD() { - return Register_PAD_GENERIC_OPT(); -} +TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index f605deaa5b4a3a8572c4be16cb1d301dbc49e5ba..edc4e26edbd44784f8604e7da32156a8e695d2e2 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -31,6 +31,7 @@ TfLiteRegistration* Register_CONV_2D(); TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); TfLiteRegistration* Register_SVDF(); TfLiteRegistration* Register_RNN(); +TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN(); TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN(); TfLiteRegistration* Register_EMBEDDING_LOOKUP(); TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE(); @@ -57,8 +58,11 @@ TfLiteRegistration* Register_SPACE_TO_DEPTH(); TfLiteRegistration* Register_GATHER(); TfLiteRegistration* Register_TRANSPOSE(); TfLiteRegistration* Register_MEAN(); +TfLiteRegistration* Register_SPLIT(); TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_STRIDED_SLICE(); +TfLiteRegistration* Register_EXP(); +TfLiteRegistration* Register_TOPK_V2(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -73,6 +77,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D()); AddBuiltin(BuiltinOperator_SVDF, Register_SVDF()); AddBuiltin(BuiltinOperator_RNN, Register_RNN()); + AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + Register_BIDIRECTIONAL_SEQUENCE_RNN()); AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, Register_UNIDIRECTIONAL_SEQUENCE_RNN()); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); @@ -103,8 +109,11 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_MEAN, Register_MEAN()); AddBuiltin(BuiltinOperator_DIV, Register_DIV()); AddBuiltin(BuiltinOperator_SUB, Register_SUB()); + AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT()); AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); + AddBuiltin(BuiltinOperator_EXP, Register_EXP()); + AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index 9a419af0238e1a25e4b9e81f109b54de6b49097b..9e3e19c09a4012ebdadbc2a7c2ba06c4bfefd206 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -36,6 +36,17 @@ constexpr int kInputTensor = 0; constexpr int kSizeTensor = 1; constexpr int kOutputTensor = 0; +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TfLiteTensor* input, + TfLiteTensor* size, TfLiteTensor* output) { + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + const int32* size_data = GetTensorData(size); + output_size->data[1] = size_data[0]; + output_size->data[2] = size_data[1]; + output_size->data[3] = input->dims->data[3]; + return context->ResizeTensor(context, output, output_size); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -55,32 +66,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // integers. output->type = kTfLiteFloat32; - // TODO(ahentz): if the input is constant, we can allocate here. - output->allocation_type = kTfLiteDynamic; - return kTfLiteOk; + if (!IsConstantTensor(size)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, input, size, output); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* size = GetInput(context, node, kSizeTensor); - // TODO(ahentz): we only need to do this here if it wasn't done in Eval(). - TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); - output_size->data[0] = input->dims->data[0]; - const int32* size_data = GetTensorData(size); - output_size->data[1] = size_data[0]; - output_size->data[2] = size_data[1]; - output_size->data[3] = input->dims->data[3]; - context->ResizeTensor(context, output, output_size); - TfLiteTensorRealloc(output->bytes, output); + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputTensor(context, input, size, output)); + } if (output->type == kTfLiteFloat32) { -#define TF_LITE_RESIZE_BILINEAR(type) \ - type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ - GetTensorData(size), GetTensorDims(size), \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_RESIZE_BILINEAR(type) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(size), GetTensorDims(size), \ + GetTensorData(output), GetTensorDims(output), \ + params->align_corners) if (kernel_type == kReference) { TF_LITE_RESIZE_BILINEAR(reference_ops); diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 2b1aaf654f87f435ec464b2cc1a63c77ba86ae5b..4e03f3820a5c14ee1692c553db61e385716b1723 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -25,14 +25,24 @@ using ::testing::ElementsAreArray; class ResizeBilinearOpModel : public SingleOpModel { public: - ResizeBilinearOpModel(std::initializer_list input_shape) { - input_ = AddInput(TensorType_FLOAT32); - size_ = AddInput(TensorType_INT32); - output_ = AddOutput(TensorType_FLOAT32); + ResizeBilinearOpModel(const TensorData& input, + std::initializer_list size_data = {}) { + bool const_size = size_data.size() != 0; + input_ = AddInput(input); + if (const_size) { + size_ = AddConstInput(TensorType_INT32, size_data, {2}); + } else { + size_ = AddInput({TensorType_INT32, {2}}); + } + output_ = AddOutput(TensorType_FLOAT32); // Always float. SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, CreateResizeBilinearOptions(builder_).Union()); - BuildInterpreter({input_shape, {2}}); + if (const_size) { + BuildInterpreter({GetShape(input_)}); + } else { + BuildInterpreter({GetShape(input_), GetShape(size_)}); + } } void SetInput(std::initializer_list data) { @@ -49,23 +59,33 @@ class ResizeBilinearOpModel : public SingleOpModel { }; TEST(ResizeBilinearOpTest, HorizontalResize) { - ResizeBilinearOpModel m({1, 1, 2, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } TEST(ResizeBilinearOpTest, VerticalResize) { - ResizeBilinearOpModel m({1, 2, 1, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } TEST(ResizeBilinearOpTest, TwoDimensionalResize) { - ResizeBilinearOpModel m({1, 2, 2, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); m.SetInput({ 3, 6, // 9, 12 // @@ -77,10 +97,22 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResize) { 7, 9, 10, // 9, 11, 12, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); } TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { - ResizeBilinearOpModel m({2, 2, 2, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); m.SetInput({ 3, 6, // 9, 12, // @@ -97,10 +129,27 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { 8, 12, 14, // 10, 14, 16, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { - ResizeBilinearOpModel m({1, 2, 2, 2}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // @@ -112,6 +161,18 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { 7, 8, 9, 12, 10, 14, // 9, 10, 11, 14, 12, 16, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc index 2e22d0db56a233bf554c57cf86275832ce941a18..d8c9e352f00627eee45ae836b720f2af77140538 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -33,17 +33,16 @@ enum KernelType { kGenericOptimized, }; -// Inputs specified in the 2nd tensor (block_shape) and 3rd tensor (paddings) -// are ignored. Only use the `block_shape` and `paddings` specified in params. -// TODO(nupurgarg): Support inputs as tensors in SpaceToBatchND. struct SpaceToBatchNDContext { SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); + block_shape = GetInput(context, node, 1); + paddings = GetInput(context, node, 2); output = GetOutput(context, node, 0); } - TfLiteSpaceToBatchNDParams* params; TfLiteTensor* input; + TfLiteTensor* block_shape; + TfLiteTensor* paddings; TfLiteTensor* output; }; @@ -51,32 +50,29 @@ struct SpaceToBatchNDContext { // The 4D array need to have exactly 2 spatial dimensions. // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND. const int kInputDimensionNum = 4; -const int kOutputDimensionNum = 4; +const int kBlockSizeDimensionNum = 1; const int kSpatialDimensionNum = 2; -const int kPaddingDimensionNum = 4; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + SpaceToBatchNDContext* op_context) { + TfLiteIntArray* input_size = op_context->input->dims; + const int32* block_shape = GetTensorData(op_context->block_shape); + const int32* paddings_data = GetTensorData(op_context->paddings); - SpaceToBatchNDContext op_context(context, node); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), - kInputDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions, + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), + kBlockSizeDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0], + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings), kSpatialDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - - const TfLiteIntArray* input_size = op_context.input->dims; - const int* block_shape = op_context.params->block_shape; - TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum); + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); // Ensures the input height and width (with padding) is a multiple of block // shape height and width. for (int dim = 0; dim < kSpatialDimensionNum; ++dim) { - int final_dim_size = - (input_size->data[dim + 1] + op_context.params->before_paddings[dim] + - op_context.params->after_paddings[dim]); + int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] + + paddings_data[dim * 2 + 1]); TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0); output_size->data[dim + 1] = final_dim_size / block_shape[dim]; } @@ -88,33 +84,43 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size->data[0] = output_batch_size; output_size->data[3] = output_channel_size; - return context->ResizeTensor(context, op_context.output, output_size); + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + SpaceToBatchNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + if (!IsConstantTensor(op_context.block_shape) || + !IsConstantTensor(op_context.paddings)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { SpaceToBatchNDContext op_context(context, node); - int block_shape_dims_array[1] = {kSpatialDimensionNum}; - Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1); - - // Initialize padding array in the format accepted by the kernel code. - // TODO(nupurgarg): Make kernel code accept padding array format that is - // consistent with Pad operation (i.e. before_paddings and after_paddings). - TfLiteIntArray* padding_data = TfLiteIntArrayCreate(kPaddingDimensionNum); - padding_data->data[0] = op_context.params->before_paddings[0]; - padding_data->data[1] = op_context.params->after_paddings[0]; - padding_data->data[2] = op_context.params->before_paddings[1]; - padding_data->data[3] = op_context.params->after_paddings[1]; - int padding_dims_array[1] = {kPaddingDimensionNum}; - Dims<4> padding_dims = GetTensorDims(padding_dims_array, 1); - -#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ - type::SpaceToBatchND(GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), \ - op_context.params->block_shape, block_shape_dims, \ - padding_data->data, padding_dims, \ - GetTensorData(op_context.output), \ + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ + type::SpaceToBatchND(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + GetTensorData(op_context.block_shape), \ + GetTensorDims(op_context.block_shape), \ + GetTensorData(op_context.paddings), \ + GetTensorDims(op_context.paddings), \ + GetTensorData(op_context.output), \ GetTensorDims(op_context.output)) switch (op_context.input->type) { // Already know in/out types are same. case kTfLiteFloat32: @@ -151,8 +157,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } #undef TF_LITE_SPACE_TO_BATCH_ND - - TfLiteIntArrayFree(padding_data); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc index 45a6aef73d05b57a7f9a7fc6f58c3971c6e03118..92a4a037d5873e608ee7bdbdfc5eaa5e9b62bc8c 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc @@ -26,41 +26,81 @@ using ::testing::ElementsAreArray; class SpaceToBatchNDOpModel : public SingleOpModel { public: - SpaceToBatchNDOpModel(std::initializer_list input_shape, - std::initializer_list block_shape, - std::initializer_list before_paddings, - std::initializer_list after_paddings) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, - BuiltinOptions_SpaceToBatchNDOptions, - CreateSpaceToBatchNDOptions( - builder_, builder_.CreateVector(block_shape), - builder_.CreateVector(before_paddings), - builder_.CreateVector(after_paddings)) - .Union()); - BuildInterpreter({input_shape}); - } - void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetBlockShape(std::initializer_list data) { + PopulateTensor(block_shape_, data); + } + + void SetPaddings(std::initializer_list data) { + PopulateTensor(paddings_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; + int block_shape_; + int paddings_; int output_; }; +// Tests case where block_shape and paddings are const tensors. +// +// Example usage is as follows: +// SpaceToBatchNDOpConstModel m(input_shape, block_shape, paddings); +// m.SetInput(input_data); +// m.Invoke(); +class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { + public: + SpaceToBatchNDOpConstModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list paddings) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOptions_SpaceToBatchNDOptions, + CreateSpaceToBatchNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Tests case where block_shape and paddings are non-const tensors. +// +// Example usage is as follows: +// SpaceToBatchNDOpDynamicModel m(input_shape); +// m.SetInput(input_data); +// m.SetBlockShape(block_shape); +// m.SetPaddings(paddings); +// m.Invoke(); +class SpaceToBatchNDOpDynamicModel : public SpaceToBatchNDOpModel { + public: + SpaceToBatchNDOpDynamicModel(std::initializer_list input_shape) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddInput(TensorType_INT32); + paddings_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOptions_SpaceToBatchNDOptions, + CreateSpaceToBatchNDOptions(builder_).Union()); + BuildInterpreter({input_shape, {2}, {2, 2}}); + } +}; + TEST(SpaceToBatchNDOpTest, InvalidShapeTest) { - EXPECT_DEATH(SpaceToBatchNDOpModel({1, 3, 3, 1}, {2, 2}, {0, 0}, {0, 0}), + EXPECT_DEATH(SpaceToBatchNDOpConstModel({1, 3, 3, 1}, {2, 2}, {0, 0, 0, 0}), "Cannot allocate tensors"); } -TEST(SpaceToBatchNDOpTest, SimpleTest) { - SpaceToBatchNDOpModel m({1, 4, 4, 1}, {2, 2}, {0, 0}, {0, 0}); +TEST(SpaceToBatchNDOpTest, SimpleConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); @@ -68,17 +108,39 @@ TEST(SpaceToBatchNDOpTest, SimpleTest) { 13, 15, 6, 8, 14, 16})); } -TEST(SpaceToBatchNDOpTest, MultipleInputBatches) { - SpaceToBatchNDOpModel m({2, 2, 4, 1}, {2, 2}, {0, 0}, {0, 0}); +TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) { + SpaceToBatchNDOpDynamicModel m({1, 4, 4, 1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetPaddings({0, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) { + SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) { + SpaceToBatchNDOpDynamicModel m({2, 2, 4, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetPaddings({0, 0, 0, 0}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16})); } -TEST(SpaceToBatchNDOpTest, SimplePadding) { - SpaceToBatchNDOpModel m({1, 5, 2, 1}, {3, 2}, {1, 2}, {0, 0}); +TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); @@ -88,9 +150,36 @@ TEST(SpaceToBatchNDOpTest, SimplePadding) { })); } -TEST(SpaceToBatchNDOpTest, ComplexPadding) { - SpaceToBatchNDOpModel m({1, 4, 2, 1}, {3, 2}, {1, 2}, {1, 4}); +TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({1, 5, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 0, 2, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, + 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10, + })); +} + +TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, + 0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0, + 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, + })); +} + +TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({1, 4, 2, 1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 1, 2, 4}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({ diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc new file mode 100644 index 0000000000000000000000000000000000000000..b524c79f8779b0119781679c0af9fe354e38ad4f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/split.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace split { + +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + axis = GetInput(context, node, 0); + input = GetInput(context, node, 1); + } + TfLiteSplitParams* params; + TfLiteTensor* axis; + TfLiteTensor* input; +}; + +TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) { + for (int i = 0; i < NumOutputs(node); ++i) { + SetTensorToDynamic(GetOutput(context, node, i)); + } + return kTfLiteOk; +} + +TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node, + TfLiteTensor* axis, TfLiteTensor* input, + int num_splits) { + int axis_value = GetTensorData(axis)[0]; + if (axis_value < 0) { + axis_value += NumDimensions(input); + } + + const int input_size = SizeOfDimension(input, axis_value); + TF_LITE_ENSURE_MSG(context, input_size % num_splits == 0, + "Not an even split"); + const int slice_size = input_size / num_splits; + + for (int i = 0; i < NumOutputs(node); ++i) { + TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims); + output_dims->data[axis_value] = slice_size; + TfLiteTensor* output = GetOutput(context, node, i); + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims)); + } + + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + + OpContext op_context(context, node); + + TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits); + + auto input_type = op_context.input->type; + TF_LITE_ENSURE(context, + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + for (int i = 0; i < NumOutputs(node); ++i) { + GetOutput(context, node, i)->type = input_type; + } + + // If we know the contents of the 'axis' tensor, resize all outputs. + // Otherwise, wait until Eval(). + if (IsConstantTensor(op_context.axis)) { + return ResizeOutputTensors(context, node, op_context.axis, op_context.input, + op_context.params->num_splits); + } else { + return UseDynamicOutputTensors(context, node); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + + // When the 'axis' tensor is non-const we can't resize output tensors in + // Prepare(), and we have to do it now. + if (!IsConstantTensor(op_context.axis)) { + TF_LITE_ENSURE_OK( + context, + ResizeOutputTensors(context, node, op_context.axis, op_context.input, + op_context.params->num_splits)); + } + + int axis_value = GetTensorData(op_context.axis)[0]; + if (axis_value < 0) { + axis_value += NumDimensions(op_context.input); + } + axis_value = RemapDim(NumDimensions(op_context.input), axis_value); + + // TODO(ahentz): Our usage of VectorOfTensors could be optimized by + // calculating it in Prepare, unless we defer shape calculation. + // TODO(ahentz): We can improve the optimized_ops version to handle other + // cases too. +#define TF_LITE_SPLIT(scalar) \ + VectorOfTensors all_outputs(*context, *node->outputs); \ + if (axis_value == NumDimensions(op_context.input)) { \ + optimized_ops::TensorFlowSplit( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), NumOutputs(node), all_outputs.data(), \ + all_outputs.dims()); \ + } else { \ + reference_ops::TensorFlowSplit( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), axis_value, NumOutputs(node), \ + all_outputs.data(), all_outputs.dims()); \ + } + switch (op_context.input->type) { + case kTfLiteFloat32: { + TF_LITE_SPLIT(float); + break; + } + case kTfLiteUInt8: { + TF_LITE_SPLIT(uint8_t); + break; + } + default: + context->ReportError(context, + "Only float32 and uint8 are currently supported."); + return kTfLiteError; + } +#undef TF_LITE_SPLIT + + return kTfLiteOk; +} + +} // namespace split + +TfLiteRegistration* Register_SPLIT() { + static TfLiteRegistration r = {nullptr, nullptr, split::Prepare, split::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/split_test.cc b/tensorflow/contrib/lite/kernels/split_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..61a0759c6475795c06a9b55d3586d2b818f298b2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/split_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +constexpr int kAxisIsATensor = -1000; + +class SplitOpModel : public SingleOpModel { + public: + SplitOpModel(const TensorData& input, int num_splits, + int axis = kAxisIsATensor) { + if (axis == kAxisIsATensor) { + axis_ = AddInput({TensorType_INT32, {1}}); + } else { + axis_ = AddConstInput(TensorType_INT32, {axis}, {1}); + } + input_ = AddInput(input); + for (int i = 0; i < num_splits; ++i) { + outputs_.push_back(AddOutput(input.type)); + } + SetBuiltinOp(BuiltinOperator_SPLIT, BuiltinOptions_SplitOptions, + CreateSplitOptions(builder_, num_splits).Union()); + if (axis == kAxisIsATensor) { + BuildInterpreter({GetShape(axis_), GetShape(input_)}); + } else { + BuildInterpreter({{}, GetShape(input_)}); + } + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } + + std::vector GetOutput(int i) { + return ExtractVector(outputs_[i]); + } + std::vector GetOutputShape(int i) { return GetTensorShape(outputs_[i]); } + + private: + int input_; + int axis_; + std::vector outputs_; +}; + +using TensorValues = std::initializer_list; + +void Check(int axis, int num_splits, std::initializer_list input_shape, + std::initializer_list output_shape, + const TensorValues& input_data, + const std::vector& output_data) { + auto debug = [&](int i) { + std::stringstream ss; + ss << "for output tensor " << i << " axis=" << axis + << " and num_splits=" << num_splits; + return ss.str(); + }; + SplitOpModel m({TensorType_FLOAT32, input_shape}, num_splits); + m.SetInput(input_data); + m.SetAxis(axis); + m.Invoke(); + for (int i = 0; i < num_splits; ++i) { + EXPECT_THAT(m.GetOutput(i), ElementsAreArray(output_data[i])) << debug(i); + EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shape)) + << debug(i); + } + + SplitOpModel const_m({TensorType_FLOAT32, input_shape}, num_splits, axis); + const_m.SetInput(input_data); + const_m.Invoke(); + for (int i = 0; i < num_splits; ++i) { + EXPECT_THAT(const_m.GetOutput(i), ElementsAreArray(output_data[i])) + << debug(i); + EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shape)) + << debug(i); + } +} + +TEST(SplitOpTest, FourDimensional) { + Check(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }); + Check(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 9, 10, 11, 12}, + {5, 6, 7, 8, 13, 14, 15, 16}, + }); + Check(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 5, 6, 9, 10, 13, 14}, + {3, 4, 7, 8, 11, 12, 15, 16}, + }); + Check(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 3, 5, 7, 9, 11, 13, 15}, + {2, 4, 6, 8, 10, 12, 14, 16}, + }); +} + +TEST(SplitOpTest, OneDimensional) { + Check(/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8}, + {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}); +} + +TEST(SplitOpTest, NegativeAxis) { + Check(/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index 91ba4a9b7851c35a5138f4ccea307c810a4731a1..fb1e11e0ca00abb36d7f29d562711a7bbcbeca1c 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -57,65 +57,6 @@ struct StridedSliceContext { int dims; }; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - - StridedSliceContext op_context(context, node); - - // Ensure validity of input tensor and its dimension - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); - TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - // Only INT32 begin/end/strides are supported - // TODO(soroosh) add support for INT64 - TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); - TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); - TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); - TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, - "StridedSlice op only supports 1D-4D input arrays."); - - // TODO(soroosh): add the following missing functionalities - TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, - "ellipsis_mask is not implemented yet."); - TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, - "new_axis_mask is not implemented yet."); - TF_LITE_ENSURE_MSG(context, op_context.params->shrink_axis_mask == 0, - "shrink_axis_mask is not implemented yet."); - - // TODO(soroosh): optimize for constant tensors to do allocation in Prepare - op_context.output->allocation_type = kTfLiteDynamic; - return kTfLiteOk; -} // namespace strided_slice - -// TODO(soroosh): consolidate with BytesRequired in interpreter.h -TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type, - const int* dims, int dims_size, size_t* bytes) { - // TODO(aselle): Check for overflow here using overflow.h in TensorFlow - // MultiplyWithoutOverflow. - TF_LITE_ENSURE(context, bytes != nullptr); - size_t count = 1; - for (int k = 0; k < dims_size; k++) count *= dims[k]; - switch (type) { - case kTfLiteFloat32: - *bytes = sizeof(float) * count; - break; - case kTfLiteInt32: - *bytes = sizeof(int32_t) * count; - break; - case kTfLiteUInt8: - *bytes = sizeof(uint8_t) * count; - break; - case kTfLiteInt64: - *bytes = sizeof(int64_t) * count; - break; - default: - return kTfLiteError; - } - return kTfLiteOk; -} - // Reverse order of bits in the mask to match the expected order in kernel inline int ReverseMaskBits(int mask, int num_dimensions) { int out = 0; @@ -146,40 +87,110 @@ inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) { std::min(std::max(index, -dim), dim - 1), dim)); } +inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) { + const int dim = op_context->input->dims->data[idx]; + const bool pos_stride = GetTensorData(op_context->strides)[idx] > 0; + return op_context->params->begin_mask & (1 << idx) + ? pos_stride ? 0 : dim - 1 + : ClampedIndex(GetTensorData(op_context->begin)[idx], dim, + pos_stride); +} + +inline int32_t GetEndValueAtIndex(StridedSliceContext* op_context, int idx) { + const int dim = op_context->input->dims->data[idx]; + const bool pos_stride = GetTensorData(op_context->strides)[idx] > 0; + return op_context->params->end_mask & (1 << idx) + ? pos_stride ? dim : -1 + : ClampedIndex(GetTensorData(op_context->end)[idx], dim, + pos_stride); +} + +// Processes the indexing tensors (begin, end and strides) to resize the +// output tensor. This function is callable from both Prepare() and Eval() as +// long as the caller ensures the indexing tensors are present. +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + StridedSliceContext* op_context) { + std::vector output_shape_vector; + + for (int idx = op_context->dims - 1; idx >= 0; --idx) { + int32_t stride = GetTensorData(op_context->strides)[idx]; + TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); + + int32_t begin = GetBeginValueAtIndex(op_context, idx); + int32_t end = GetEndValueAtIndex(op_context, idx); + + // This is valid for both positive and negative strides + int32_t dim_shape = ceil((end - begin) / static_cast(stride)); + dim_shape = dim_shape < 0 ? 0 : dim_shape; + if (!(op_context->params->shrink_axis_mask & (1 << idx))) { + output_shape_vector.push_back(dim_shape); + } + } + + TfLiteIntArray* output_shape = + TfLiteIntArrayCreate(output_shape_vector.size()); + + std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(), + output_shape->data); + + TF_LITE_ENSURE_STATUS( + context->ResizeTensor(context, op_context->output, output_shape)); + + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + StridedSliceContext op_context(context, node); + + // Ensure validity of input tensor and its dimension + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + // Only INT32 begin/end/strides are supported + // TODO(soroosh) add support for INT64 + TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); + TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, + "StridedSlice op only supports 1D-4D input arrays."); + + // TODO(soroosh): add the following missing functionalities + TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, + "ellipsis_mask is not implemented yet."); + TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, + "new_axis_mask is not implemented yet."); + + // Postpone allocation of output if any of the indexing tensors is not + // constant + if (!(IsConstantTensor(op_context.begin) && + IsConstantTensor(op_context.end) && + IsConstantTensor(op_context.strides))) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); +} + template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { StridedSliceContext op_context(context, node); - std::vector starts; - std::vector stops; - std::vector strides; + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } - // Determine size of output tensor and map indices - TfLiteIntArray* output_shape = TfLiteIntArrayCreate(op_context.dims); - for (int idx = op_context.dims - 1; idx >= 0; --idx) { - int dim = op_context.input->dims->data[idx]; - int32_t stride = GetTensorData(op_context.strides)[idx]; - TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); - bool pos_stride = stride > 0; - - int32_t begin = - op_context.params->begin_mask & (1 << idx) - ? pos_stride ? 0 : dim - 1 - : ClampedIndex(GetTensorData(op_context.begin)[idx], dim, - pos_stride); - int32_t end = - op_context.params->end_mask & (1 << idx) - ? pos_stride ? dim : -1 - : ClampedIndex(GetTensorData(op_context.end)[idx], dim, - pos_stride); + std::vector starts; + std::vector stops; + std::vector strides; - // This is valid for both positive and negative strides - output_shape->data[idx] = ceil((end - begin) / static_cast(stride)); - output_shape->data[idx] = - output_shape->data[idx] < 0 ? 0 : output_shape->data[idx]; - starts.emplace_back(begin); - stops.emplace_back(end); - strides.emplace_back(stride); + for (int idx = op_context.dims - 1; idx >= 0; --idx) { + starts.emplace_back(GetBeginValueAtIndex(&op_context, idx)); + stops.emplace_back(GetEndValueAtIndex(&op_context, idx)); + strides.emplace_back(GetTensorData(op_context.strides)[idx]); } for (int i = op_context.dims; i < kMaxDim; i++) { @@ -188,27 +199,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { strides.emplace_back(1); } - TF_LITE_ENSURE_STATUS( - context->ResizeTensor(context, op_context.output, output_shape)); - - size_t required_bytes; - TF_LITE_ENSURE_OK( - context, - BytesRequired(context, op_context.output->type, output_shape->data, - output_shape->size, &required_bytes)); - TfLiteTensorRealloc(required_bytes, op_context.output); - op_context.params->begin_mask = ReverseMaskBits(op_context.params->begin_mask, op_context.dims); op_context.params->end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims); - -#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ - kernel_type::StridedSlice( \ - GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), op_context.params->begin_mask, \ - op_context.params->end_mask, starts, stops, strides, \ - GetTensorData(op_context.output), \ + op_context.params->shrink_axis_mask = + ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); + +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), op_context.params->begin_mask, \ + op_context.params->end_mask, op_context.params->shrink_axis_mask, \ + starts, stops, strides, GetTensorData(op_context.output), \ GetTensorDims(op_context.output)) switch (op_context.input->type) { diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc index cd4a364682c0e66b2ceec92c0b34461945caf779..5cac04b38364958c5b0794c21742e8b592372ae9 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -21,6 +21,7 @@ limitations under the License. namespace tflite { namespace { +using ::int32; using ::testing::ElementsAreArray; class StridedSliceOpModel : public SingleOpModel { @@ -79,8 +80,6 @@ TEST(StridedSliceOpTest, UnssupportedArgs) { "ellipsis_mask is not implemented yet."); EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), "new_axis_mask is not implemented yet."); - EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 0, 1), - "shrink_axis_mask is not implemented yet."); } TEST(StridedSliceOpTest, In1D) { @@ -213,6 +212,7 @@ TEST(StridedSliceOpTest, In1D_EndMask) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); } + TEST(StridedSliceOpTest, In1D_NegStride) { StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3}); @@ -234,6 +234,7 @@ TEST(StridedSliceOpTest, In1D_EvenLenStride2) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); } + TEST(StridedSliceOpTest, In1D_OddLenStride2) { StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3}); @@ -255,6 +256,7 @@ TEST(StridedSliceOpTest, In2D_Identity) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } + TEST(StridedSliceOpTest, In2D) { StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); @@ -320,6 +322,7 @@ TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); } + TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6}); @@ -354,6 +357,7 @@ TEST(StridedSliceOpTest, In3D_NegStride) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})); } + TEST(StridedSliceOpTest, In3D_Strided2) { StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); @@ -365,6 +369,159 @@ TEST(StridedSliceOpTest, In3D_Strided2) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5})); } +TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({2}); + m.SetEnd({1}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-2}); + m.SetEnd({-3}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 7, 8})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 7})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 72f705fe4242b01c1516c99d3500484e8729fd9a..c69755447d5093e25d408eb6dea80750937465e7 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -15,8 +15,8 @@ limitations under the License. #include #include #include -#include #include +#include #include #include diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index 4de2ceaf053df31a4bc857fb250db416c071e80f..0f166dc69b95f3459388135b3a6c4d9b73a31cb4 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite SVDF op. -#include #include +#include #include #include diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index 3a58e7ec321f649a6cae4cc0969807c2c74c6529..373310bd87370a670a847cf5328633956028a850 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -172,11 +172,14 @@ void SingleOpModel::BuildInterpreter( auto* model = GetModel(builder_.GetBufferPointer()); - ops::builtin::BuiltinOpResolver builtins; - for (const auto& reg : custom_registrations_) { - builtins.AddCustom(reg.first.data(), reg.second()); + if (!resolver_) { + auto resolver = new ops::builtin::BuiltinOpResolver(); + for (const auto& reg : custom_registrations_) { + resolver->AddCustom(reg.first.data(), reg.second()); + } + resolver_ = std::unique_ptr(resolver); } - InterpreterBuilder(model, builtins)(&interpreter_); + InterpreterBuilder(model, *resolver_)(&interpreter_); CHECK(interpreter_ != nullptr); @@ -184,6 +187,7 @@ void SingleOpModel::BuildInterpreter( for (const auto& shape : input_shapes) { int input_idx = interpreter_->inputs()[i++]; if (input_idx == kOptionalTensor) continue; + if (shape.empty()) continue; CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk); } CHECK(interpreter_->AllocateTensors() == kTfLiteOk) diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index cc445299ff9f0b75610c7ff38f28facbbbe5587d..7d476ba1eaffbb24fb77390c0e71c32d60b6411e 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -85,6 +85,23 @@ struct TensorData { int32_t zero_point; }; +class SingleOpResolver : public OpResolver { + public: + SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration) + : op_(op), registration_(registration) {} + TfLiteRegistration* FindOp(BuiltinOperator op) const override { + if (op == op_) { + return registration_; + } + return nullptr; + } + TfLiteRegistration* FindOp(const char* op) const override { return nullptr; } + + private: + const BuiltinOperator op_; + TfLiteRegistration* registration_; +}; + class SingleOpModel { public: SingleOpModel() {} @@ -178,11 +195,16 @@ class SingleOpModel { return result; } + void SetResolver(std::unique_ptr resolver) { + resolver_ = std::move(resolver); + } + protected: int32_t GetTensorSize(int index) const; flatbuffers::FlatBufferBuilder builder_; std::unique_ptr interpreter_; + std::unique_ptr resolver_; private: int AddTensor(TensorData t, std::initializer_list data); @@ -197,6 +219,36 @@ class SingleOpModel { std::map> custom_registrations_; }; +// Base class for single op unit tests. +// The tests are parameterized to test multiple kernels for a single op. +// The parameters are strings like "optimized" and "reference" to have better +// readability in test reports. +// +// To use this class: +// * Define a constant map from strings to TfLiteRegistration. +// * Implement a test class that inherits SingleOpTest. +// * Instantiate the test cases with SingleOpTest::GetKernelTags helper +// function. +// * Call GetRegistration to get the TfLiteRegistration to be used before +// building the interpreter. +class SingleOpTest : public ::testing::TestWithParam { + public: + static std::vector GetKernelTags( + const std::map& kernel_map) { + std::vector tags; + for (auto it : kernel_map) { + tags.push_back(it.first); + } + return tags; + } + + protected: + virtual const std::map& GetKernelMap() = 0; + TfLiteRegistration* GetRegistration() { + return GetKernelMap().at(GetParam()); + } +}; + // Strings have a special implementation that is in test_util.cc template <> std::vector SingleOpModel::ExtractVector(int index); diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc new file mode 100644 index 0000000000000000000000000000000000000000..807e84609f8b23d25324d99d26086331d78a0684 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/topk_v2.cc @@ -0,0 +1,232 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +namespace tflite { +namespace ops { +namespace builtin { +namespace topk_v2 { +constexpr int kInputTensor = 0; +constexpr int kInputTopK = 1; +constexpr int kOutputIndexes = 0; +constexpr int kOutputValues = 1; + +namespace { +TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* top_k = GetInput(context, node, kInputTopK); + // INT32 number of top results is supported. + TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32); + // Check that the tensor contains only one value. + TF_LITE_ENSURE_EQ(context, NumDimensions(top_k), 1); + TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1); + const int32 k = top_k->data.i32[0]; + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + const int num_dimensions = NumDimensions(input); + // Check that input has one or more dimensions. + TF_LITE_ENSURE_MSG(context, input->dims->size >= 1, + "TopK k input must have 1 or more dimensions."); + // Check that k is less or equal the internal dimension. + TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1], + "TopK k is higher than the internal dimension."); + + TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions); + TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions); + for (int i = 0; i < num_dimensions - 1; ++i) { + output_indexes_shape->data[i] = input->dims->data[i]; + output_values_shape->data[i] = input->dims->data[i]; + } + output_indexes_shape->data[num_dimensions - 1] = k; + output_values_shape->data[num_dimensions - 1] = k; + TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); + TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size, + TfLiteIntArray* delete_on_error) { + TfLiteStatus status = context->ResizeTensor(context, tensor, new_size); + if (status != kTfLiteOk) { + TfLiteIntArrayFree(new_size); + if (delete_on_error != nullptr) { + TfLiteIntArrayFree(delete_on_error); + } + } + return status; + }; + TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape, + output_values_shape)); + TF_LITE_ENSURE_OK(context, + resize_tensor(output_values, output_values_shape, nullptr)); + return kTfLiteOk; +} + +// The class that collects top indexes of k values. Based on template +// tensorflow::gtl::TopN<> but, for optimization, +// it re-uses the same container. +template +class TopContainer { + public: + TopContainer() = delete; + TopContainer(int32 k, int32 row_size) : k_(k) { + container_.reserve(std::min(k, row_size) + 1); + } + + void start_collecting(const T* values) { + values_ = values; + container_.clear(); + } + void push(int32 a) { + auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); }; + if (container_.size() <= k_) { + container_.push_back(a); + if (container_.size() == k_ + 1) { + std::make_heap(container_.begin(), container_.end(), comparator); + std::pop_heap(container_.begin(), container_.end(), comparator); + } + } else if (comparator(a, container_.front())) { + container_.back() = a; + std::push_heap(container_.begin(), container_.end(), comparator); + std::pop_heap(container_.begin(), container_.end(), comparator); + } + } + + const std::vector& sorted_result() { + auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); }; + if (container_.size() <= k_) { + std::sort(container_.begin(), container_.end(), comparator); + } else { + std::sort_heap(container_.begin(), container_.end() - 1, comparator); + container_.resize(k_); + } + return container_; + } + + private: + int32 k_; + std::vector container_; + const T* values_ = nullptr; + + bool compare_fun(int32 a, int32 b) const { + if (values_[b] < values_[a]) { + return true; + } else if (values_[b] > values_[a]) { + return false; + } else { + return a < b; + } + } +}; + +// Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU. +template +void TopK(int32 row_size, int32 num_rows, const T* data, int32 k, + int32* output_indexes, T* output_values) { + TopContainer topc(k, row_size); + for (int row = 0; row < num_rows; ++row) { + const T* values_row = data + row * row_size; + topc.start_collecting(values_row); + for (int32 c = 0; c < row_size; ++c) { + topc.push(c); + } + + // Prepare output buffers. + int32* indexes_row = output_indexes + row * k; + T* output_row = output_values + row * k; + // We always assume that the output is sorted. + const auto& top_k = topc.sorted_result(); + std::copy(top_k.begin(), top_k.end(), indexes_row); + std::transform(top_k.begin(), top_k.end(), output_row, + [values_row](const int32 loc) { return values_row[loc]; }); + } +} + +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check that the inputs and outputs have the right sizes and types. + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + TF_LITE_ENSURE_EQ(context, input->type, output_values->type); + + TfLiteTensor* top_k = GetInput(context, node, kInputTopK); + TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32); + + // Set output dynamic if the input is not const. + if (IsConstantTensor(top_k)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } else { + TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); + TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + SetTensorToDynamic(output_indexes); + SetTensorToDynamic(output_values); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); + if (IsDynamicTensor(output_values)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } + TfLiteTensor* top_k = GetInput(context, node, kInputTopK); + const int32 k = top_k->data.i32[0]; + // The tensor can have more than 2 dimensions or even be a vector, the code + // anyway calls the internal dimension as row; + TfLiteTensor* input = GetInput(context, node, kInputTensor); + const int32 row_size = input->dims->data[input->dims->size - 1]; + int32 num_rows = 1; + for (int i = 0; i < input->dims->size - 1; ++i) { + num_rows *= input->dims->data[i]; + } + switch (output_values->type) { + case kTfLiteFloat32: + TopK(row_size, num_rows, input->data.f, k, output_indexes->data.i32, + output_values->data.f); + break; + case kTfLiteUInt8: + TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32, + output_values->data.uint8); + break; + case kTfLiteInt32: + TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32, + output_values->data.i32); + break; + case kTfLiteInt64: + TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32, + output_values->data.i64); + break; + default: + context->ReportError(context, "Type is currently not supported by TopK."); + return kTfLiteError; + } + + return kTfLiteOk; +} +} // namespace topk_v2 +TfLiteRegistration* Register_TOPK_V2() { + static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare, + topk_v2::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..29f2a057cd45e1cded3ff1aa0f0fdcad666ce2fa --- /dev/null +++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc @@ -0,0 +1,155 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class TopKV2OpModel : public SingleOpModel { + public: + TopKV2OpModel(std::initializer_list input_shape, TensorType input_type, + int top_k) { + input_ = AddInput(input_type); + top_k_ = AddInput(TensorType_INT32); + output_indexes_ = AddOutput(TensorType_INT32); + output_values_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0); + BuildInterpreter({input_shape, {1}}); + PopulateTensor(top_k_, {top_k}); + } + + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt64(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetIndexes() { + return ExtractVector(output_indexes_); + } + + std::vector GetValuesFloat() { + return ExtractVector(output_values_); + } + + std::vector GetValuesUInt8() { + return ExtractVector(output_values_); + } + + std::vector GetValuesInt32() { + return ExtractVector(output_values_); + } + + std::vector GetValuesInt64() { + return ExtractVector(output_values_); + } + + protected: + int input_; + int top_k_; + int output_indexes_; + int output_values_; +}; + +// The test where the tensor dimension is equal to top. +TEST(TopKV2OpTest, EqualFloat) { + TopKV2OpModel m({2, 2}, TensorType_FLOAT32, 2); + m.SetInputFloat({-2.0, 0.2, 0.8, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({1, 0, 0, 1})); + EXPECT_THAT(m.GetValuesFloat(), + ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1}))); +} + +// Test when internal dimension is k+1. +TEST(TopKV2OpTest, BorderFloat) { + TopKV2OpModel m({2, 3}, TensorType_FLOAT32, 2); + m.SetInputFloat({-2.0, -3.0, 0.2, 0.8, 0.1, -0.1}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 0, 0, 1})); + EXPECT_THAT(m.GetValuesFloat(), + ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1}))); +} +// Test when internal dimension is higher than k. +TEST(TopKV2OpTest, LargeFloat) { + TopKV2OpModel m({2, 4}, TensorType_FLOAT32, 2); + m.SetInputFloat({-2.0, -3.0, -4.0, 0.2, 0.8, 0.1, -0.1, -0.8}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({3, 0, 0, 1})); + EXPECT_THAT(m.GetValuesFloat(), + ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1}))); +} + +// Test 1D case. +TEST(TopKV2OpTest, VectorFloat) { + TopKV2OpModel m({8}, TensorType_FLOAT32, 2); + m.SetInputFloat({-2.0, -3.0, -4.0, 0.2, 0.8, 0.1, -0.1, -0.8}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({4, 3})); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(ArrayFloatNear({0.8, 0.2}))); +} + +// Check that uint8 works. +TEST(TopKV2OpTest, TypeUint8) { + TopKV2OpModel m({2, 3}, TensorType_UINT8, 2); + m.SetInputUInt8({1, 2, 3, 251, 250, 249}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1})); + EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250})); +} + +// Check that int32 works. +TEST(TopKV2OpTest, TypeInt32) { + TopKV2OpModel m({2, 3}, TensorType_INT32, 2); + m.SetInputInt32({1, 2, 3, 10251, 10250, 10249}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1})); + EXPECT_THAT(m.GetValuesInt32(), ElementsAreArray({3, 2, 10251, 10250})); +} + +// Check that int64 works. +TEST(TopKV2OpTest, TypeInt64) { + TopKV2OpModel m({2, 3}, TensorType_INT64, 2); + m.SetInputInt64({1, 2, 3, -1, -2, -3}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1})); + EXPECT_THAT(m.GetValuesInt64(), ElementsAreArray({3, 2, -1, -2})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc index 75d8136b6a26efd805d9fc8e9db26dce2cfcfcb1..d3c10a9bb7b07404ccd8cfe2636473a622b91787 100644 --- a/tensorflow/contrib/lite/kernels/transpose.cc +++ b/tensorflow/contrib/lite/kernels/transpose.cc @@ -31,60 +31,77 @@ enum KernelType { kReference, }; -// TODO(nupurgarg): Permutation arrays represented as a tensor are ignored. Only -// use the `perm` specified in `params`. struct TransposeContext { TransposeContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); + perm = GetInput(context, node, 1); output = GetOutput(context, node, 0); } - TfLiteTransposeParams* params; TfLiteTensor* input; + TfLiteTensor* perm; TfLiteTensor* output; }; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + TransposeContext* op_context) { + int dims = NumDimensions(op_context->input); + const int* perm_data = GetTensorData(op_context->perm); - TransposeContext op_context(context, node); - int dims = NumDimensions(op_context.input); - - // Ensure validity of input tensor and permutation array. - TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions); - TF_LITE_ENSURE_MSG(context, dims <= 4, - "Transpose op only supports 1D-4D input arrays."); + // Ensure validity of the permutations tensor as a 1D tensor. + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->perm), 1); + TF_LITE_ENSURE_EQ(context, op_context->perm->dims->data[0], dims); for (int idx = 0; idx < dims; ++idx) { - TF_LITE_ENSURE_MSG(context, - op_context.params->perm[idx] >= 0 && - op_context.params->perm[idx] < dims, + TF_LITE_ENSURE_MSG(context, (perm_data[idx] >= 0 && perm_data[idx] < dims), "Transpose op permutations array is out of bounds."); } // Determine size of output tensor. - const TfLiteIntArray* input_size = op_context.input->dims; - TfLiteIntArray* output_size = TfLiteIntArrayCreate(dims); + TfLiteIntArray* input_size = op_context->input->dims; + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); for (int idx = 0; idx < dims; ++idx) { - output_size->data[idx] = input_size->data[op_context.params->perm[idx]]; + output_size->data[idx] = input_size->data[perm_data[idx]]; } - return context->ResizeTensor(context, op_context.output, output_size); + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TransposeContext op_context(context, node); + + // Ensure validity of input tensor. + TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 4, + "Transpose op only supports 1D-4D input arrays."); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + if (!IsConstantTensor(op_context.perm)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TransposeContext op_context(context, node); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + // Reverse the permuted axes and convert to 4D due to the way Dims are // constructed in GetTensorDims. + const int* perm_data = GetTensorData(op_context.perm); + const int size = op_context.perm->dims->data[0]; const int kOutputDimensionNum = 4; int reversed_perm[kOutputDimensionNum]; - int size = op_context.params->num_dimensions; + for (int output_k = 0, input_k = size - 1; output_k < size; ++output_k, --input_k) { - reversed_perm[output_k] = size - op_context.params->perm[input_k] - 1; + reversed_perm[output_k] = size - perm_data[input_k] - 1; } for (int k = size; k < kOutputDimensionNum; ++k) { reversed_perm[k] = k; diff --git a/tensorflow/contrib/lite/kernels/transpose_test.cc b/tensorflow/contrib/lite/kernels/transpose_test.cc index 7f5832cd5fa3d502b52bf5554111b45136b588ae..337bc144b967392523bf784603cca4c1b968cdf2 100644 --- a/tensorflow/contrib/lite/kernels/transpose_test.cc +++ b/tensorflow/contrib/lite/kernels/transpose_test.cc @@ -127,61 +127,124 @@ TEST(TransposeTest, TestRefOps4D) { class TransposeOpModel : public SingleOpModel { public: - TransposeOpModel(std::initializer_list input_shape, - std::initializer_list perm) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions, - CreateTransposeOptions(builder_, builder_.CreateVector(perm)) - .Union()); - BuildInterpreter({input_shape}); - } - void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetPerm(std::initializer_list data) { + PopulateTensor(perm_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; + int perm_; int output_; }; +// Tests case where perm is a const tensor. +// +// Example usage is as follows: +// SpaceToBatchNDOpConstModel m(input_shape, perm_shape, perm_data); +// m.SetInput(input_data); +// m.Invoke(); +class TransposeOpConstModel : public TransposeOpModel { + public: + TransposeOpConstModel(std::initializer_list input_shape, + std::initializer_list perm_shape, + std::initializer_list perm) { + input_ = AddInput(TensorType_FLOAT32); + perm_ = AddConstInput(TensorType_INT32, perm, perm_shape); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions, + CreateTransposeOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Tests case where perm is a non-const tensor. +// +// Example usage is as follows: +// TransposeOpDynamicModel m(input_shape, perm_shape); +// m.SetInput(input_data); +// m.SetPerm(perm_data); +// m.Invoke(); +class TransposeOpDynamicModel : public TransposeOpModel { + public: + TransposeOpDynamicModel(std::initializer_list input_shape, + std::initializer_list perm_shape) { + input_ = AddInput(TensorType_FLOAT32); + perm_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions, + CreateTransposeOptions(builder_).Union()); + BuildInterpreter({input_shape, perm_shape}); + } +}; + TEST(TransposeTest, TestUnequalPermSize) { - EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {2, 2}), - "dims != op_context.params->num_dimensions"); + EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {2}, {2, 2}), "2 != 4"); } TEST(TransposeTest, TestPermOutOfBounds) { - EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {0, -1, -2, -3}), + EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {4}, {0, -1, -2, -3}), "Transpose op permutations array is out of bounds."); - EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {0, 1, 2, 4}), + EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {4}, {0, 1, 2, 4}), "Transpose op permutations array is out of bounds."); } -TEST(TransposeTest, Test1DInputTensor) { - TransposeOpModel m({3}, {0}); +TEST(TransposeTest, Test1DInputConstTensor) { + TransposeOpConstModel m({3}, {1}, {0}); m.SetInput({1, 2, 3}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); } -TEST(TransposeTest, Test2DInputTensor) { - TransposeOpModel m({3, 2}, {1, 0}); +TEST(TransposeTest, Test1DInputDynamicTensor) { + TransposeOpDynamicModel m({3}, {1}); + m.SetInput({1, 2, 3}); + m.SetPerm({0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(TransposeTest, Test2DInputConstTensor) { + TransposeOpConstModel m({3, 2}, {2}, {1, 0}); + m.SetInput({0, 1, 2, 3, 4, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5})); +} + +TEST(TransposeTest, Test2DInputDynamicTensor) { + TransposeOpDynamicModel m({3, 2}, {2}); m.SetInput({0, 1, 2, 3, 4, 5}); + m.SetPerm({1, 0}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5})); } -TEST(TransposeTest, Test3DInputTensor) { - TransposeOpModel m({2, 3, 4}, {2, 0, 1}); +TEST(TransposeTest, Test3DInputConstTensor) { + TransposeOpConstModel m({2, 3, 4}, {3}, {2, 0, 1}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23})); +} + +TEST(TransposeTest, Test3DInputDynamicTensor) { + TransposeOpDynamicModel m({2, 3, 4}, {3}); m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + m.SetPerm({2, 0, 1}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3})); EXPECT_THAT(m.GetOutput(), @@ -190,28 +253,64 @@ TEST(TransposeTest, Test3DInputTensor) { } TEST(TransposeTest, Test5DInputTensor) { - EXPECT_DEATH(TransposeOpModel({1, 2, 3, 4, 5}, {0, 1, 2, 3, 4}), + EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5}, {5}, {0, 1, 2, 3, 4}), "Transpose op only supports 1D-4D input arrays."); } -TEST(TransposeTest, SimpleTestNoReorder) { - TransposeOpModel m({1, 2, 3, 1}, {0, 1, 2, 3}); +TEST(TransposeTest, SimpleTestNoReorderConstTensor) { + TransposeOpConstModel m({1, 2, 3, 1}, {4}, {0, 1, 2, 3}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(TransposeTest, SimpleTestNoReorderDynamicTensor) { + TransposeOpDynamicModel m({1, 2, 3, 1}, {4}); m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetPerm({0, 1, 2, 3}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(TransposeTest, SimpleTestWithReorder) { - TransposeOpModel m({1, 2, 3, 1}, {2, 1, 3, 0}); +TEST(TransposeTest, SimpleTestWithReorderConstTensor) { + TransposeOpConstModel m({1, 2, 3, 1}, {4}, {2, 1, 3, 0}); m.SetInput({1, 2, 3, 4, 5, 6}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2, 1, 1})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); } -TEST(TransposeTest, ComplexTestWithReorder) { - TransposeOpModel m({2, 3, 4, 5}, {2, 0, 1, 3}); +TEST(TransposeTest, ComplexTestWithReorderConstTensor) { + TransposeOpConstModel m({2, 3, 4, 5}, {4}, {2, 0, 1, 3}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119}); + m.Invoke(); + + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5})); + auto result = ElementsAreArray( + {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44, + 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104, + 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49, + 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109, + 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54, + 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114, + 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59, + 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119}); + EXPECT_THAT(m.GetOutput(), result); +} + +TEST(TransposeTest, ComplexTestWithReorderDynamicTensor) { + TransposeOpDynamicModel m({2, 3, 4, 5}, {4}); m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, @@ -222,6 +321,7 @@ TEST(TransposeTest, ComplexTestWithReorder) { 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119}); + m.SetPerm({2, 0, 1, 3}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5})); diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index f5f1ec2cf3f45ae730b849b18e2b85fac50159c7..ac00c37b67dcbe77023a2495a698967ca555b1d5 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -15,14 +15,15 @@ limitations under the License. #include #include #include -#include #include +#include #include #include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -82,48 +83,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size_array->data[0] = (time_major) ? max_time : batch_size; output_size_array->data[1] = (time_major) ? batch_size : max_time; output_size_array->data[2] = num_units; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, - output_size_array)); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); return kTfLiteOk; } -namespace { -void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr, - const float* recurrent_weights_ptr, const float* bias_ptr, - int input_size, int num_units, int input_weights_stride, - int recurrent_weights_stride, TfLiteFusedActivation activation, - float* hidden_state_ptr_batch, float* output_ptr_batch) { - // Output = bias - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = bias_ptr[o]; - } - - // Output += input * input_weights - for (int o = 0; o < num_units; o++) { - for (int i = 0; i < input_size; i++) { - output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; - } - input_weights_ptr += input_weights_stride; - } - - // Output += recurrent_weights * hidden_state - for (int o = 0; o < num_units; o++) { - for (int h = 0; h < num_units; h++) { - output_ptr_batch[o] += - hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; - } - recurrent_weights_ptr += recurrent_weights_stride; - } - - // Output = activation(Output) and update hidden_state - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]); - hidden_state_ptr_batch[o] = output_ptr_batch[o]; - } -} -} // namespace - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -147,30 +112,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { (time_major) ? input->dims->data[0] : input->dims->data[1]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[2]; - const int input_weights_stride = input_weights->dims->data[1]; - const int recurrent_weights_stride = recurrent_weights->dims->data[1]; // Initialize input_weights and recurrent_weights. const float* input_weights_ptr = input_weights->data.f; const float* recurrent_weights_ptr = recurrent_weights->data.f; if (time_major) { - // Unroll the sequence + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Unroll the sequence and use batch batch operations for efficiency. for (int s = 0; s < max_time; s++) { - for (int b = 0; b < batch_size; b++) { - // Initialize the pointer to hidden state. - float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; - // Initialize the pointer to input and output. - const float* input_ptr_batch = - input->data.f + s * input_size * batch_size + b * input_size; - float* output_ptr_batch = - output->data.f + s * num_units * batch_size + b * num_units; - - RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, - bias_ptr, input_size, num_units, input_weights_stride, - recurrent_weights_stride, params->activation, - hidden_state_ptr_batch, output_ptr_batch); - } + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size; + float* output_ptr_batch = output->data.f + s * num_units * batch_size; + + kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, + recurrent_weights_ptr, bias_ptr, input_size, + num_units, batch_size, params->activation, + hidden_state_ptr_batch, output_ptr_batch); } } else { // For each batch @@ -184,10 +144,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_ptr_batch = output->data.f + b * num_units * max_time + s * num_units; - RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, - bias_ptr, input_size, num_units, input_weights_stride, - recurrent_weights_stride, params->activation, - hidden_state_ptr_batch, output_ptr_batch); + kernel_utils::RnnBatchStep( + input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr, + input_size, num_units, /*batch_size=*/1, params->activation, + hidden_state_ptr_batch, output_ptr_batch); } } } diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc index 82c680ec3d8656004d721c8498292677cb061b6b..7e32969763b59620dc3534708f965750680002d2 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite Sequential RNN op. -#include #include +#include #include #include @@ -120,8 +120,7 @@ static float rnn_golden_output[] = { 0.415153, 0.210318, 0, 0, 0, 0, 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, - 0.628881, 3.58099, 1.49974, 0 -}; + 0.628881, 3.58099, 1.49974, 0}; class UnidirectionalRNNOpModel : public SingleOpModel { public: diff --git a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh new file mode 100755 index 0000000000000000000000000000000000000000..b58ae266017caf8781c28331f49a8f5bc1550767 --- /dev/null +++ b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh @@ -0,0 +1,81 @@ +#!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +echo "Starting" +TFLITE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/.." + +TMP_DIR=$(mktemp -d) +echo "Package dir: " $TMP_DIR +FW_DIR=$TMP_DIR/tensorflow_lite_ios_frameworks +FW_DIR_TFLITE=$FW_DIR/tensorflow_lite.framework +FW_DIR_TFLITE_HDRS=$FW_DIR_TFLITE/Headers + +echo "Creating target Headers directories" +mkdir -p $FW_DIR_TFLITE_HDRS + +echo "Headers, populating: TensorFlow Lite" +cd $TFLITE_DIR/../../.. + +find tensorflow/contrib/lite -name '*.h' \ + -not -path 'tensorflow/contrib/lite/downloads/*' \ + -not -path 'tensorflow/contrib/lite/examples/*' \ + -not -path 'tensorflow/contrib/lite/gen/*' \ + -not -path 'tensorflow/contrib/lite/toco/*' \ + -not -path 'tensorflow/contrib/lite/nnapi/*' \ + -not -path 'tensorflow/contrib/lite/java/*' \ + | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T - +cd $FW_DIR_TFLITE_HDRS +tar xf tmp.tar +rm -f tmp.tar + +echo "Headers, populating: Flatbuffer" +cd $TFLITE_DIR/downloads/flatbuffers/include/ +find . -name '*.h' | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T - +cd $FW_DIR_TFLITE_HDRS +tar xf tmp.tar +rm -f tmp.tar + +cd $TFLITE_DIR/../../.. +echo "Generate master LICENSE file and copy to target" +bazel build //tensorflow/tools/lib_package:clicenses_generate +cp $TFLITE_DIR/../../../bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE \ + $FW_DIR_TFLITE + +echo "Copying static libraries" +cp $TFLITE_DIR/gen/lib/libtensorflow-lite.a \ + $FW_DIR_TFLITE/tensorflow_lite + +# This is required, otherwise they interfere with the documentation of the +# pod at cocoapods.org. +echo "Remove all README files" +cd $FW_DIR_TFLITE_HDRS +find . -type f -name README\* -exec rm -f {} \; +find . -type f -name readme\* -exec rm -f {} \; + +TARGET_GEN_LOCATION="$TFLITE_DIR/gen/ios_frameworks" +echo "Moving results to target: " $TARGET_GEN_LOCATION +cd $FW_DIR +zip -q -r tensorflow_lite.framework.zip tensorflow_lite.framework -x .DS_Store +rm -rf $TARGET_GEN_LOCATION +mkdir -p $TARGET_GEN_LOCATION +cp -r tensorflow_lite.framework.zip $TARGET_GEN_LOCATION + +echo "Cleaning up" +rm -rf $TMP_DIR + +echo "Finished" diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 303a10af03e582d5e4e641c15072e1c9d594e1f4..d6522fc077d03bb49fe54b7a04fa6341ccf4cf3a 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -30,17 +30,6 @@ limitations under the License. namespace tflite { -namespace { -inline const tflite::Model* VerifyAndGetModel(const void* buf, size_t len) { - ::flatbuffers::Verifier verifier(static_cast(buf), len); - if (VerifyModelBuffer(verifier)) { - return ::tflite::GetModel(buf); - } else { - return nullptr; - } -} -} // namespace - const char* kEmptyTensorName = ""; std::unique_ptr FlatBufferModel::BuildFromFile( @@ -82,7 +71,7 @@ FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, } if (!allocation_->valid() || !CheckModelIdentifier()) return; - model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); + model_ = ::tflite::GetModel(allocation_->base()); } bool FlatBufferModel::CheckModelIdentifier() const { @@ -103,7 +92,7 @@ FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter); if (!allocation_->valid()) return; - model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); + model_ = ::tflite::GetModel(allocation_->base()); } FlatBufferModel::FlatBufferModel(const Model* model, @@ -147,7 +136,7 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { } } else if (!opcode->custom_code()) { error_reporter_->Report( - "Operator with builtin_code==0 has no custom_code.\n"); + "Operator with CUSTOM builtin_code has no custom_code.\n"); status = kTfLiteError; } else { const char* name = opcode->custom_code()->c_str(); @@ -289,6 +278,8 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_RELU_N1_TO_1: case BuiltinOperator_RELU6: case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_EXP: + case BuiltinOperator_TOPK_V2: break; case BuiltinOperator_LSH_PROJECTION: { TfLiteLSHProjectionParams* params = @@ -339,6 +330,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { TfLiteSequenceRNNParams* params = MallocPOD(); if (auto* sequence_rnn_params = @@ -476,6 +468,12 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_RESIZE_BILINEAR: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_ResizeBilinearOptions()) { + params->align_corners = schema_params->align_corners(); + } + builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_PAD: { @@ -521,62 +519,26 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_SPACE_TO_BATCH_ND: { - auto* params = MallocPOD(); - if (auto* schema_params = - op->builtin_options_as_SpaceToBatchNDOptions()) { - const auto& block_shape = schema_params->block_shape(); - FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape, - params->block_shape, error_reporter); - const auto& before_paddings = schema_params->before_paddings(); - FlatBufferIntVectorToArray(sizeof(params->before_paddings), - before_paddings, params->before_paddings, - error_reporter); - const auto& after_paddings = schema_params->after_paddings(); - FlatBufferIntVectorToArray(sizeof(params->after_paddings), - after_paddings, params->after_paddings, - error_reporter); - params->num_spatial_dimensions = block_shape->Length(); - } - builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_BATCH_TO_SPACE_ND: { - auto* params = MallocPOD(); - if (auto* schema_params = - op->builtin_options_as_BatchToSpaceNDOptions()) { - const auto& block_shape = schema_params->block_shape(); - FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape, - params->block_shape, error_reporter); - const auto& before_crops = schema_params->before_crops(); - FlatBufferIntVectorToArray(sizeof(params->before_crops), before_crops, - params->before_crops, error_reporter); - const auto& after_crops = schema_params->after_crops(); - FlatBufferIntVectorToArray(sizeof(params->after_crops), after_crops, - params->after_crops, error_reporter); - params->num_spatial_dimensions = block_shape->Length(); - } - builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_TRANSPOSE: { - auto* params = MallocPOD(); - if (auto* schema_params = op->builtin_options_as_TransposeOptions()) { - const auto& perm = schema_params->perm(); - FlatBufferIntVectorToArray(sizeof(params->perm), perm, params->perm, - error_reporter); - params->num_dimensions = perm->Length(); - } - builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_MEAN: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_MeanOptions()) { - const auto& axis = schema_params->axis(); - FlatBufferIntVectorToArray(sizeof(params->axis), axis, params->axis, - error_reporter); params->keep_dims = schema_params->keep_dims(); - params->num_axis_dimensions = axis->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SPLIT: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_SplitOptions()) { + params->num_splits = schema_params->num_splits(); } builtin_data = reinterpret_cast(params); break; diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index 5330c8f594593655b2a8776cf6b399c0d16cdc19..66f22fd66a9ae0d35553a1f780ef73a5c5994c99 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include "tensorflow/contrib/lite/model.h" @@ -247,14 +246,6 @@ TEST(BasicFlatBufferModel, TestNullErrorReporter) { ASSERT_NE(interpreter->Invoke(), kTfLiteOk); } -// Test what happens if we cannot bind any of the ops. -TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { - std::string corrupted_data = "123"; - auto model = FlatBufferModel::BuildFromBuffer(corrupted_data.c_str(), - corrupted_data.length()); - ASSERT_FALSE(model); -} - // Test that loading model directly from a Model flatbuffer works. TEST(BasicFlatBufferModel, TestBuildFromModel) { TestErrorReporter reporter; diff --git a/tensorflow/contrib/lite/models/BUILD b/tensorflow/contrib/lite/models/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..6a1255b586ef04b80159156a78f0c4569a4661c5 --- /dev/null +++ b/tensorflow/contrib/lite/models/BUILD @@ -0,0 +1,26 @@ +# Model tests +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +exports_files(glob([ + "testdata/*", +])) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h index 7019c29959fc02f4f84d1e4c8cf280751e585de0..76032771af2c8e099aed498b2071816646f3b606 100644 --- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h +++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h @@ -1571,7 +1571,7 @@ inline int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model, } /** - * Specfifies which operands will be the model's inputs and outputs. + * Specifies which operands will be the model's inputs and outputs. * * An operand cannot be used for both input and output. Doing so will * return an error. diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index d5b9319407a461c571411c44ae702c137c914fa9..02e8499f61c6a3d5fceb978aa0e63a4ee90cf19a 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -319,6 +319,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SVDF: case tflite::BuiltinOperator_HASHTABLE_LOOKUP: case tflite::BuiltinOperator_RNN: + case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_EMBEDDING_LOOKUP: case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: @@ -334,12 +335,15 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_GATHER: case tflite::BuiltinOperator_SPACE_TO_BATCH_ND: case tflite::BuiltinOperator_BATCH_TO_SPACE_ND: + case tflite::BuiltinOperator_TOPK_V2: case tflite::BuiltinOperator_TRANSPOSE: case tflite::BuiltinOperator_MEAN: case tflite::BuiltinOperator_DIV: case tflite::BuiltinOperator_SUB: + case tflite::BuiltinOperator_SPLIT: case tflite::BuiltinOperator_SQUEEZE: case tflite::BuiltinOperator_STRIDED_SLICE: + case tflite::BuiltinOperator_EXP: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 3d6a3ec0fd4c673f601254b19452bbf8b9454e27..82feae0f0041997949212613c654a5695f468d56 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -13,6 +13,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":op_hint", "//tensorflow/contrib/lite/toco:model_flags_proto_py", "//tensorflow/contrib/lite/toco:toco_flags_proto_py", "//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco", @@ -20,6 +21,18 @@ py_library( ], ) +py_library( + name = "op_hint", + srcs = ["op_hint.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + ], +) + py_test( name = "lite_test", srcs = ["lite_test.py"], @@ -27,6 +40,7 @@ py_test( tags = ["no_oss"], deps = [ ":lite", + ":op_hint", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 3c369774beda57cca3bc1ea0ab9a9ad619841e7e..5d2f21653762a405a57288a7ba38323e5e42b3e1 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -18,16 +18,21 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. @@toco_convert @@toco_convert_protos +@@OpHint +@@convert_op_hints_to_stubs """ from __future__ import absolute_import from __future__ import division from __future__ import print_function - import os import subprocess import tempfile +# pylint: disable=unused-import +from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs +from tensorflow.contrib.lite.python.op_hint import OpHint +# pylint: enable=unused-import from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 7d55f3fe6fe41a5d9e4e57c7a8e664bba6887fc7..b8b4510188bee867b32ffde714b27f41a1df778a 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -18,10 +18,14 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python.op_hint import _tensor_name_base as _tensor_name_base from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util +from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes +from tensorflow.python.framework.graph_util_impl import _extract_graph_summary from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -35,7 +39,8 @@ class LiteTest(test_util.TensorFlowTestCase): # Try running on valid graph result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor]) self.assertTrue(result) - # TODO(aselle): remove tests that fail. + # TODO(aselle): remove tests that fail (we must get TOCO to not fatal + # all the time). # Try running on identity graph (known fail) # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"): # result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor]) @@ -51,5 +56,116 @@ class LiteTest(test_util.TensorFlowTestCase): quantized_input_stats=[(0., 1.)]) self.assertTrue(result) + +class LiteTestOpHint(test_util.TensorFlowTestCase): + """Test the hint to stub functionality.""" + + def _getGraphOpTypes(self, graphdef, output_nodes): + """Returns used op types in `graphdef` reachable from `output_nodes`. + + This is used to check that after the stub transformation the expected + nodes are there. Typically use this with self.assertCountEqual(...). + + NOTE: this is not a exact test that the graph is the correct output, but + it balances compact expressibility of test with sanity checking. + + Args: + graphdef: TensorFlow proto graphdef. + output_nodes: A list of output node names that we need to reach. + + Returns: + A set of node types reachable from `output_nodes`. + """ + name_to_input_name, name_to_node, _ = ( + _extract_graph_summary(graphdef)) + # Find all nodes that are needed by the outputs + used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) + return set([name_to_node[node_name].op for node_name in used_node_names]) + + def _countIdentities(self, nodes): + """Count the number of "Identity" op types in the list of proto nodes. + + Args: + nodes: NodeDefs of the graph. + + Returns: + The number of nodes with op type "Identity" found. + """ + return len([x for x in nodes if x.op == "Identity"]) + + def testSwishLiteHint(self): + """Makes a custom op swish and makes sure it gets converted as a unit.""" + image = array_ops.constant([1., 2., 3., 4.]) + swish_scale = array_ops.constant(1.0) + + def _swish(input_tensor, scale): + custom = lite.OpHint("cool_activation") + input_tensor, scale = custom.add_inputs(input_tensor, scale) + output = math_ops.sigmoid(input_tensor) * input_tensor * scale + output, = custom.add_outputs(output) + return output + output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput") + + with self.test_session() as sess: + # check if identities have been put into the graph (2 input, 1 output, + # and 1 final output). + self.assertEqual(self._countIdentities(sess.graph_def.node), 4) + + stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) + + self.assertCountEqual( + self._getGraphOpTypes( + stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), + ["cool_activation", "Const", "Identity"]) + + def testScaleAndBiasAndIdentity(self): + """This tests a scaled add which has 3 inputs and 2 outputs.""" + a = array_ops.constant(1.) + x = array_ops.constant([2., 3.]) + b = array_ops.constant([4., 5.]) + + def _scaled_and_bias_and_identity(a, x, b): + custom = lite.OpHint("scale_and_bias_and_identity") + a, x, b = custom.add_inputs(a, x, b) + return custom.add_outputs(a * x + b, x) + output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b), + name="ModelOutput") + + with self.test_session() as sess: + # make sure one identity for each input (3) and output (2) => 3 + 2 = 5 + # +1 for the final output + self.assertEqual(self._countIdentities(sess.graph_def.node), 6) + + stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) + + self.assertCountEqual( + self._getGraphOpTypes( + stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), + ["scale_and_bias_and_identity", "Const", "Identity", "Pack"]) + + def testTwoFunctions(self): + """Tests if two functions are converted correctly.""" + a = array_ops.constant([1.]) + b = array_ops.constant([1.]) + def _double_values(x): + custom = lite.OpHint("add_test") + x = custom.add_inputs(x) + output = math_ops.multiply(x, x) + output, = custom.add_outputs(output) + return output + output = array_ops.identity( + math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput") + + with self.test_session() as sess: + # make sure one identity for each input (2) and output (2) => 2 + 2 + # +1 for the final output + self.assertEqual(self._countIdentities(sess.graph_def.node), 5) + stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) + self.assertCountEqual( + self._getGraphOpTypes( + stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), + ["add_test", "Const", "Identity", "Add"]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3971228a683211e84b4c55d3a3e8d574b5ed94 --- /dev/null +++ b/tensorflow/contrib/lite/python/op_hint.py @@ -0,0 +1,306 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Define tflite op hints (intrinsic operations). + +This essentially allows defining a TensorFlow API for tflite operations in +Python with hints on how they are represented in TensorFlow Lite. This basically +is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution +graph and is useful for LSTMs and other complicated TensorFlow constructions +that are difficult to pattern match in TOCO, but are represented by a single +accelerated tflite op. + +Example: + def tflite_cool_activation(input): + # A cool activation function. + custom = tf.contrib.lite.OpHint("cool_activation") + input = custom.add_inputs(input) + output = tf.sigmoid(input) * input + custom.add_outputs(output) + return output + + image = tf.placeholder(tf.float32, (1, 16, 16, 1)) + output = tf.identity(tflite_cool_activation(image)) + + session = tf.Session() + + graphdef_to_convert = tf.contrib.lite.convert_op_hints_to_stubs(session) + tflite_graph = tf.contrib.lite.toco_convert(graphdef_to_convert, + [image], [output]) + [image], [output]) + with open("/tmp/graph.fb", "wb") as fp: + fp.write(tflite_graph) + +How does it work?: + +OpHint is a helper that you use when defining a vanilla python function. +It allows you to wrap arguments with tf.identities with some custom attributes. +These attributes allow you to find the original block of ops that was created. +For example, if you use cool_activation above you essentially get: + +a_input = tf.identity() +result = tf.multiply(tf.sigmoid(a_input), a_input) +output = tf.identity() + +a_input, output are identities that have parameters representing +what argument they are, what the name of the function they should turn into +in tf lite as well as a guid that uniquely identifies a particular invocation. + +Once you have built your whole tensorflow graph, you can run it and train it +as usual, but after you have done that, you need to convert the graph into +a form that replaces these subgraphs wrapped in identities to stub ops. These +ops don't actually exist in the normal TensorFlow runtime, but will be +understood by toco later. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections as _collections +import itertools as _itertools +import uuid as _uuid + +from tensorflow.contrib import framework as _framework +from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2 +from tensorflow.python.framework import ops as _ops +from tensorflow.python.ops import array_ops as _array_ops +from tensorflow.python.util.all_util import remove_undocumented + + +class OpHint(object): + """A class that helps build tflite function invocations. + + It allows you to take a bunch of TensorFlow ops and annotate the construction + such that toco knows how to convert it to tflite. This embeds a pseudo + function in a TensorFlow graph. This allows embedding high-level API usage + information in a lower level TensorFlow implementation so that an alternative + implementation can be substituted later. + + Essentially, any "input" into this pseudo op is fed into an identity, and + attributes are added to that input before being used by the constituent ops + that make up the pseudo op. A similar process is done to any output that + is to be exported from the current op. + + TODO(aselle): When TensorFlow functions functionality works for arbitrary + constructs, this mechanism can be retired and changed to use python defun's. + """ + + # Attr constants that are used for representation in the GraphDef + FUNCTION_NAME_ATTR = "_tflite_function_name" + FUNCTION_UUID_ATTR = "_tflite_function_uuid" + FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index" + FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index" + + def __init__(self, function_name, **kwargs): + """Create a OpHint. + + Args: + function_name: Name of the function (the custom op name in tflite) + **kwargs: Keyword arguments of any constant attributes for the function. + """ + self._function_name = function_name + self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough? + self._curr_input_index = 0 + self._curr_output_index = 0 + self._attrs_to_store_later = kwargs + self._stored_attrs = False + + def _setattr(self, dest_op, name, value): + tensor_value = _ops.convert_to_tensor(value) + dest_op.op.node_def.attr[name].tensor.CopyFrom( + tensor_value.op.node_def.attr["value"].tensor) + + def add_inputs(self, *args): + """Add a sequence of inputs to the function invocation. + + Args: + *args: List of inputs to be converted (should be Tf.Tensor). + Returns: + Wrapped inputs (identity standins that have additional metadata). These + are also are also tf.Tensor's. + """ + + def augmented_identity(arg): + identity_op = _array_ops.identity(arg) + # pylint: disable=protected-access + identity_op.op._set_attr( + OpHint.FUNCTION_NAME_ATTR, + _attr_value_pb2.AttrValue(s=self._function_name)) + identity_op.op._set_attr( + OpHint.FUNCTION_UUID_ATTR, + _attr_value_pb2.AttrValue(s=self._unique_function_id)) + identity_op.op._set_attr( + OpHint.FUNCTION_INPUT_INDEX_ATTR, + _attr_value_pb2.AttrValue(i=self._curr_input_index)) + # pylint: enable=protected-access + self._curr_input_index += 1 + return identity_op + + return [augmented_identity(arg) for arg in args] + + def add_outputs(self, *args): + """Add a sequence of outputs to the function invocation. + + Args: + *args: List of outputs to be converted (should be tf.Tensor). + Returns: + Wrapped outputs (identity standins that have additional metadata). These + are also tf.Tensor's. + """ + + def augmented_identity(arg): + identity_op = _array_ops.identity(arg) + # pylint: disable=protected-access + identity_op.op._set_attr( + OpHint.FUNCTION_NAME_ATTR, + _attr_value_pb2.AttrValue(s=self._function_name)) + identity_op.op._set_attr( + OpHint.FUNCTION_UUID_ATTR, + _attr_value_pb2.AttrValue(s=self._unique_function_id)) + identity_op.op._set_attr( + OpHint.FUNCTION_OUTPUT_INDEX_ATTR, + _attr_value_pb2.AttrValue(i=self._curr_output_index)) + # pylint: enable=protected-access + self._curr_output_index += 1 + return identity_op + + wrapped_outputs = [augmented_identity(arg) for arg in args] + + if not self._stored_attrs: + for key, value in self._attrs_to_store_later.iteritems(): + self._setattr(wrapped_outputs[0], "_tflite_attr_" + key, value) + self._stored_attrs = True + + return wrapped_outputs + + +class _LiteFuncCall(object): + """Represent a TensorFlow Lite custom function. + + This is uses to accumulate found hints in the graphdef into a single + conceptual unit. + + Properties: + self.inputs: inputs to the op (hash from index # to argument) + self.outputs: outputs to the op (hash from index # to argument) + self.function_name: the tflite custom op name to use + self.uuid: a unique call id for this particular call (i.e. + multiple function calls would have the same function_name but different + uuids. + self.params: A param name to key value for op constant data. I.e. for + axis on a reduction, strides on a convolution, etc. + """ + + def __init__(self): + self.inputs = {} + self.outputs = {} + self.function_name = None + self.uuid = None + self.params = {} + + def __str__(self): + return "tflite function %s call %s\n\tinputs: %r\n\toutputs: %r" % ( + self.function_name, self.uuid, self.inputs, self.outputs) + + +def _find_all_hints_in_graph_def(session): + """Look at the current default graph and return a list of LiteFuncCall objs. + + Args: + session: A TensorFlow session that contains the graph to convert. + Returns: + a list of `LifeFuncCall` objects in the form + + """ + func_calls = _collections.defaultdict(_LiteFuncCall) + seen_ops = set() + + for op in session.graph.get_operations(): + for operand in _itertools.chain(op.inputs, op.outputs): + if operand in seen_ops: + continue + seen_ops.add(operand) + attr = operand.op.node_def.attr + uuid = attr[OpHint.FUNCTION_UUID_ATTR].s + if OpHint.FUNCTION_UUID_ATTR not in attr: + continue + call_def = func_calls[uuid] + call_def.uuid = uuid + if OpHint.FUNCTION_UUID_ATTR in attr: + call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s + if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: + call_def.inputs[attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i] = operand + if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: + call_def.outputs[attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i] = operand + + for a in attr: + if a.startswith("_tflite_attr_"): + # TODO(aselle): Remember the attribute tensors so we can put them + # in collapse. + call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor + + return func_calls + + +def _tensor_name_base(full_tensor_name): + """Removes the device assignment code from a tensor. + + e.g. _tensor_name_base("foo:3") => "foo" + + Args: + full_tensor_name: A tensor name that is annotated with a device placement + (this is what tensor flow introspection gives). + Returns: + A name without any device assignment. + """ + return full_tensor_name.name.split(":")[0] + + +def convert_op_hints_to_stubs(session): + """Converts a graphdef with LiteOp hints into stub operations. + + This is used to prepare for toco conversion of complex intrinsic usages. + + Args: + session: A TensorFlow session that contains the graph to convert. + Returns: + A new graphdef with all ops contained in OpHints being replaced by + a single op call with the right parameters. + """ + hints = _find_all_hints_in_graph_def(session) + current_graph_def = session.graph_def + for call in hints.values(): + input_names = [None] * len(call.inputs) + output_names = [None] * len(call.outputs) + output_dtypes = [None] * len(call.outputs) + output_quantized = False + for input_index, tensor in call.inputs.items(): + input_names[input_index] = _tensor_name_base(tensor) + for output_index, tensor in call.outputs.items(): + output_names[output_index] = _tensor_name_base(tensor) + output_dtypes[output_index] = tensor.dtype.as_datatype_enum + # TODO(aselle): Support quantized flag properly + current_graph_def = _framework.fuse_op( + current_graph_def, input_names, output_names, output_dtypes, + output_quantized, call.uuid, call.function_name) + for node in current_graph_def.node: + if node.name == call.uuid: + for param, tensor in call.params.items(): + node.attr[param].tensor.CopyFrom(tensor) + return current_graph_def + + +_allowed_symbols = ["OpHint", "convert_op_hints_to_stubs"] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD b/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0148149a6adc141d67e82808f7e8c72ddb7e309a --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD @@ -0,0 +1,43 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "generator", + srcs = ["generator.cc"], + hdrs = ["generator.h"], + deps = [ + "//tensorflow/contrib/lite/schema:schema_fbs", + ], +) + +cc_binary( + name = "generate", + srcs = ["generate.cc"], + deps = [ + ":generator", + ], +) + +cc_test( + name = "generator_test", + srcs = ["generator_test.cc"], + deps = [ + ":generator", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "consistency_test", + srcs = ["consistency_test.cc"], + data = [ + "//tensorflow/contrib/lite:builtin_ops.h", + ], + deps = [ + ":generator", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/README.md b/tensorflow/contrib/lite/schema/builtin_ops_header/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f20d4f664e62fdd52e55339e45b9603307a2b671 --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/README.md @@ -0,0 +1,12 @@ +# Builtin Ops Header Generator. + +This directory contains a code generator to generate a pure C header for +builtin op definition. + +Whenever you add a new builtin op, please execute: + +```sh +bazel run \ + //tensorflow/contrib/lite/schema/builtin_ops_header:generate > \ + tensorflow/contrib/lite/builtin_ops.h +``` diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/consistency_test.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/consistency_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d55c125c117db3c1b8d67ab0b674abe2e7c39d94 --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/consistency_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" + +namespace { + +const char* kHeaderFileName = + "tensorflow/contrib/lite/builtin_ops.h"; + +// The test ensures that `builtin_ops.h` is consistent with the FlatBuffer +// schema definition. When the schema is modified, it's required to run the +// generator to re-generate the header. +// Please see README.md for more details. +TEST(BuiltinOpsHeaderTest, TestConsistency) { + std::ifstream input_stream(kHeaderFileName, std::ios::binary); + ASSERT_TRUE(input_stream); + std::string file_content((std::istreambuf_iterator(input_stream)), + std::istreambuf_iterator()); + + std::ostringstream output_stream; + tflite::builtin_ops_header::GenerateHeader(output_stream); + std::string generated_content = output_stream.str(); + + EXPECT_EQ(file_content, generated_content); +} + +} // anonymous namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generate.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generate.cc new file mode 100644 index 0000000000000000000000000000000000000000..72a28987b8d4863b0f03f7861177940177edd884 --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generate.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" + +// This executable is used to generate builtin_ops.h in TensorFlow Lite. +// Please see README.md for more details. +int main() { + if (!tflite::builtin_ops_header::GenerateHeader(std::cout)) { + std::cerr << "Failed to generate the header file.\n"; + } + return 0; +} diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc new file mode 100644 index 0000000000000000000000000000000000000000..b983d59d85955b241a22012b4e9adbeea346f80d --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc @@ -0,0 +1,132 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { +namespace builtin_ops_header { + +namespace { +const char* kFileHeader = + R"(/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ + +// DO NOT EDIT MANUALLY: This file is automatically generated by +// `schema_builtin_ops_header_generator.py`. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { +)"; + +const char* kFileFooter = + R"(} TfLiteBuiltinOperator; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +} +)"; +} // anonymous namespace + +bool IsValidInputEnumName(const std::string& name) { + const char* begin = name.c_str(); + const char* ch = begin; + while (*ch != '\0') { + // If it's not the first character, expect an underscore. + if (ch != begin) { + if (*ch != '_') { + return false; + } + ++ch; + } + + // Expecting a word with upper case letters or digits, like "CONV", + // "CONV2D", "2D"...etc. + bool empty = true; + while (isupper(*ch) || isdigit(*ch)) { + // It's not empty if at least one character is consumed. + empty = false; + ++ch; + } + if (empty) { + return false; + } + } + return true; +} + +std::string ConstantizeVariableName(const std::string& name) { + std::string result = "kTfLiteBuiltin"; + bool uppercase = true; + for (char input_char : name) { + if (input_char == '_') { + uppercase = true; + } else if (uppercase) { + result += toupper(input_char); + uppercase = false; + } else { + result += tolower(input_char); + } + } + + return result; +} + +bool GenerateHeader(std::ostream& os) { + auto enum_names = tflite::EnumNamesBuiltinOperator(); + + // Check if all the input enum names are valid. + for (auto enum_value : EnumValuesBuiltinOperator()) { + auto enum_name = enum_names[enum_value]; + if (!IsValidInputEnumName(enum_name)) { + std::cerr << "Invalid input enum name: " << enum_name << std::endl; + return false; + } + } + + os << kFileHeader; + for (auto enum_value : EnumValuesBuiltinOperator()) { + auto enum_name = enum_names[enum_value]; + os << " "; + os << ConstantizeVariableName(enum_name); + os << " = "; + os << enum_value; + os << ",\n"; + } + os << kFileFooter; + return true; +} + +} // namespace builtin_ops_header +} // namespace tflite diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.h b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.h new file mode 100644 index 0000000000000000000000000000000000000000..3241ff83d599ed8a476fc1d5a88c26143ebfbaf2 --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.h @@ -0,0 +1,38 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// An utility library to generate pure C header for builtin ops definition. +#ifndef TENSORFLOW_CONTRIB_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ +#define TENSORFLOW_CONTRIB_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ + +#include + +namespace tflite { +namespace builtin_ops_header { + +// Check if the input enum name (from the Flatbuffer definition) is valid. +bool IsValidInputEnumName(const std::string& name); + +// Convert the enum name from Flatbuffer convention to C enum name convention. +// E.g. `L2_POOL_2D` becomes `kTfLiteBuiltinL2Pool2d`. +std::string ConstantizeVariableName(const std::string& name); + +// The function generates a pure C header for builtin ops definition, and write +// it to the output stream. +bool GenerateHeader(std::ostream& os); + +} // namespace builtin_ops_header +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator_test.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a7dc8e1b0486eda6e09f38a209dca95c0317a1fb --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator_test.cc @@ -0,0 +1,63 @@ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" +#include +#include + +namespace { + +using tflite::builtin_ops_header::ConstantizeVariableName; +using tflite::builtin_ops_header::IsValidInputEnumName; + +TEST(TestIsValidInputEnumName, TestWithValidInputNames) { + EXPECT_TRUE(IsValidInputEnumName("ADD")); + EXPECT_TRUE(IsValidInputEnumName("CONV_2D")); + EXPECT_TRUE(IsValidInputEnumName("L2_POOL_2D")); +} + +TEST(TestIsValidInputEnumName, TestWithLeadingUnderscore) { + EXPECT_FALSE(IsValidInputEnumName("_ADD")); + EXPECT_FALSE(IsValidInputEnumName("_CONV_2D")); +} + +TEST(TestIsValidInputEnumName, TestWithLowerCase) { + EXPECT_FALSE(IsValidInputEnumName("_AdD")); + EXPECT_FALSE(IsValidInputEnumName("_COnV_2D")); +} + +TEST(TestIsValidInputEnumName, TestWithOtherCharacters) { + EXPECT_FALSE(IsValidInputEnumName("_AdD!2D")); + EXPECT_FALSE(IsValidInputEnumName("_COnV?2D")); +} + +TEST(TestIsValidInputEnumName, TestWithDoubleUnderscores) { + EXPECT_FALSE(IsValidInputEnumName("ADD__2D")); + EXPECT_FALSE(IsValidInputEnumName("CONV__2D")); +} + +TEST(TestConstantizeVariableName, TestWithValidInputNames) { + EXPECT_EQ(ConstantizeVariableName("ADD"), "kTfLiteBuiltinAdd"); + EXPECT_EQ(ConstantizeVariableName("CONV_2D"), "kTfLiteBuiltinConv2d"); + EXPECT_EQ(ConstantizeVariableName("L2_POOL_2D"), "kTfLiteBuiltinL2Pool2d"); +} + +} // anonymous namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index ec202cd4073f152e1b2f4d5efd443615e901afc6..75970b41267613058199c22a0fcb0c80a1c8f04f 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -119,6 +119,10 @@ enum BuiltinOperator : byte { SQUEEZE = 43, UNIDIRECTIONAL_SEQUENCE_LSTM = 44, STRIDED_SLICE = 45, + BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, } // Options for the builtin operators. @@ -155,6 +159,9 @@ union BuiltinOptions { SqueezeOptions, SequenceRNNOptions, StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, } enum Padding : byte { SAME, VALID } @@ -224,6 +231,12 @@ table SequenceRNNOptions { fused_activation_function:ActivationFunctionType; } +// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. +table BidirectionalSequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; +} + // An implementation of TensorFlow fully_connected (a.k.a Dense) layer. table FullyConnectedOptions { fused_activation_function:ActivationFunctionType; @@ -266,6 +279,9 @@ table LSTMOptions { } table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; } // A call operation options @@ -282,15 +298,9 @@ table ReshapeOptions { } table SpaceToBatchNDOptions { - block_shape:[int]; - before_paddings:[int]; - after_paddings:[int]; } table BatchToSpaceNDOptions { - block_shape:[int]; - before_crops:[int]; - after_crops:[int]; } table SkipGramOptions { @@ -311,6 +321,9 @@ table DivOptions { fused_activation_function:ActivationFunctionType; } +table TopKV2Options { +} + enum CombinerType : byte { SUM = 0, MEAN = 1, @@ -326,11 +339,12 @@ table GatherOptions { } table TransposeOptions { - perm:[int]; +} + +table ExpOptions { } table MeanOptions { - axis:[int]; keep_dims: bool; } @@ -338,6 +352,10 @@ table SqueezeOptions { squeeze_dims:[int]; } +table SplitOptions { + num_splits: int; +} + table StridedSliceOptions { begin_mask: int; end_mask: int; diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h old mode 100644 new mode 100755 index c04a73a2bf00807442967499cceaaee941e54278..06989c7b61dc9904bff380e7f1cdc11097cb340d --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ // automatically generated by the FlatBuffers compiler, do not modify + #ifndef FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ #define FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ @@ -51,6 +52,9 @@ struct RNNOptionsT; struct SequenceRNNOptions; struct SequenceRNNOptionsT; +struct BidirectionalSequenceRNNOptions; +struct BidirectionalSequenceRNNOptionsT; + struct FullyConnectedOptions; struct FullyConnectedOptionsT; @@ -105,6 +109,9 @@ struct SubOptionsT; struct DivOptions; struct DivOptionsT; +struct TopKV2Options; +struct TopKV2OptionsT; + struct EmbeddingLookupSparseOptions; struct EmbeddingLookupSparseOptionsT; @@ -114,12 +121,18 @@ struct GatherOptionsT; struct TransposeOptions; struct TransposeOptionsT; +struct ExpOptions; +struct ExpOptionsT; + struct MeanOptions; struct MeanOptionsT; struct SqueezeOptions; struct SqueezeOptionsT; +struct SplitOptions; +struct SplitOptionsT; + struct StridedSliceOptions; struct StridedSliceOptionsT; @@ -150,15 +163,27 @@ enum TensorType { }; inline TensorType (&EnumValuesTensorType())[6] { - static TensorType values[] = {TensorType_FLOAT32, TensorType_FLOAT16, - TensorType_INT32, TensorType_UINT8, - TensorType_INT64, TensorType_STRING}; + static TensorType values[] = { + TensorType_FLOAT32, + TensorType_FLOAT16, + TensorType_INT32, + TensorType_UINT8, + TensorType_INT64, + TensorType_STRING + }; return values; } inline const char **EnumNamesTensorType() { - static const char *names[] = {"FLOAT32", "FLOAT16", "INT32", "UINT8", - "INT64", "STRING", nullptr}; + static const char *names[] = { + "FLOAT32", + "FLOAT16", + "INT32", + "UINT8", + "INT64", + "STRING", + nullptr + }; return names; } @@ -211,106 +236,121 @@ enum BuiltinOperator { BuiltinOperator_SQUEEZE = 43, BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM = 44, BuiltinOperator_STRIDED_SLICE = 45, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN = 46, + BuiltinOperator_EXP = 47, + BuiltinOperator_TOPK_V2 = 48, + BuiltinOperator_SPLIT = 49, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_STRIDED_SLICE + BuiltinOperator_MAX = BuiltinOperator_SPLIT }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[43] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[47] { static BuiltinOperator values[] = { - BuiltinOperator_ADD, - BuiltinOperator_AVERAGE_POOL_2D, - BuiltinOperator_CONCATENATION, - BuiltinOperator_CONV_2D, - BuiltinOperator_DEPTHWISE_CONV_2D, - BuiltinOperator_EMBEDDING_LOOKUP, - BuiltinOperator_FULLY_CONNECTED, - BuiltinOperator_HASHTABLE_LOOKUP, - BuiltinOperator_L2_NORMALIZATION, - BuiltinOperator_L2_POOL_2D, - BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, - BuiltinOperator_LOGISTIC, - BuiltinOperator_LSH_PROJECTION, - BuiltinOperator_LSTM, - BuiltinOperator_MAX_POOL_2D, - BuiltinOperator_MUL, - BuiltinOperator_RELU, - BuiltinOperator_RELU_N1_TO_1, - BuiltinOperator_RELU6, - BuiltinOperator_RESHAPE, - BuiltinOperator_RESIZE_BILINEAR, - BuiltinOperator_RNN, - BuiltinOperator_SOFTMAX, - BuiltinOperator_SPACE_TO_DEPTH, - BuiltinOperator_SVDF, - BuiltinOperator_TANH, - BuiltinOperator_CONCAT_EMBEDDINGS, - BuiltinOperator_SKIP_GRAM, - BuiltinOperator_CALL, - BuiltinOperator_CUSTOM, - BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, - BuiltinOperator_PAD, - BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, - BuiltinOperator_GATHER, - BuiltinOperator_BATCH_TO_SPACE_ND, - BuiltinOperator_SPACE_TO_BATCH_ND, - BuiltinOperator_TRANSPOSE, - BuiltinOperator_MEAN, - BuiltinOperator_SUB, - BuiltinOperator_DIV, - BuiltinOperator_SQUEEZE, - BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, - BuiltinOperator_STRIDED_SLICE}; + BuiltinOperator_ADD, + BuiltinOperator_AVERAGE_POOL_2D, + BuiltinOperator_CONCATENATION, + BuiltinOperator_CONV_2D, + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOperator_EMBEDDING_LOOKUP, + BuiltinOperator_FULLY_CONNECTED, + BuiltinOperator_HASHTABLE_LOOKUP, + BuiltinOperator_L2_NORMALIZATION, + BuiltinOperator_L2_POOL_2D, + BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOperator_LOGISTIC, + BuiltinOperator_LSH_PROJECTION, + BuiltinOperator_LSTM, + BuiltinOperator_MAX_POOL_2D, + BuiltinOperator_MUL, + BuiltinOperator_RELU, + BuiltinOperator_RELU_N1_TO_1, + BuiltinOperator_RELU6, + BuiltinOperator_RESHAPE, + BuiltinOperator_RESIZE_BILINEAR, + BuiltinOperator_RNN, + BuiltinOperator_SOFTMAX, + BuiltinOperator_SPACE_TO_DEPTH, + BuiltinOperator_SVDF, + BuiltinOperator_TANH, + BuiltinOperator_CONCAT_EMBEDDINGS, + BuiltinOperator_SKIP_GRAM, + BuiltinOperator_CALL, + BuiltinOperator_CUSTOM, + BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + BuiltinOperator_PAD, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOperator_GATHER, + BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOperator_TRANSPOSE, + BuiltinOperator_MEAN, + BuiltinOperator_SUB, + BuiltinOperator_DIV, + BuiltinOperator_SQUEEZE, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOperator_STRIDED_SLICE, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOperator_EXP, + BuiltinOperator_TOPK_V2, + BuiltinOperator_SPLIT + }; return values; } inline const char **EnumNamesBuiltinOperator() { - static const char *names[] = {"ADD", - "AVERAGE_POOL_2D", - "CONCATENATION", - "CONV_2D", - "DEPTHWISE_CONV_2D", - "", - "", - "EMBEDDING_LOOKUP", - "", - "FULLY_CONNECTED", - "HASHTABLE_LOOKUP", - "L2_NORMALIZATION", - "L2_POOL_2D", - "LOCAL_RESPONSE_NORMALIZATION", - "LOGISTIC", - "LSH_PROJECTION", - "LSTM", - "MAX_POOL_2D", - "MUL", - "RELU", - "RELU_N1_TO_1", - "RELU6", - "RESHAPE", - "RESIZE_BILINEAR", - "RNN", - "SOFTMAX", - "SPACE_TO_DEPTH", - "SVDF", - "TANH", - "CONCAT_EMBEDDINGS", - "SKIP_GRAM", - "CALL", - "CUSTOM", - "EMBEDDING_LOOKUP_SPARSE", - "PAD", - "UNIDIRECTIONAL_SEQUENCE_RNN", - "GATHER", - "BATCH_TO_SPACE_ND", - "SPACE_TO_BATCH_ND", - "TRANSPOSE", - "MEAN", - "SUB", - "DIV", - "SQUEEZE", - "UNIDIRECTIONAL_SEQUENCE_LSTM", - "STRIDED_SLICE", - nullptr}; + static const char *names[] = { + "ADD", + "AVERAGE_POOL_2D", + "CONCATENATION", + "CONV_2D", + "DEPTHWISE_CONV_2D", + "", + "", + "EMBEDDING_LOOKUP", + "", + "FULLY_CONNECTED", + "HASHTABLE_LOOKUP", + "L2_NORMALIZATION", + "L2_POOL_2D", + "LOCAL_RESPONSE_NORMALIZATION", + "LOGISTIC", + "LSH_PROJECTION", + "LSTM", + "MAX_POOL_2D", + "MUL", + "RELU", + "RELU_N1_TO_1", + "RELU6", + "RESHAPE", + "RESIZE_BILINEAR", + "RNN", + "SOFTMAX", + "SPACE_TO_DEPTH", + "SVDF", + "TANH", + "CONCAT_EMBEDDINGS", + "SKIP_GRAM", + "CALL", + "CUSTOM", + "EMBEDDING_LOOKUP_SPARSE", + "PAD", + "UNIDIRECTIONAL_SEQUENCE_RNN", + "GATHER", + "BATCH_TO_SPACE_ND", + "SPACE_TO_BATCH_ND", + "TRANSPOSE", + "MEAN", + "SUB", + "DIV", + "SQUEEZE", + "UNIDIRECTIONAL_SEQUENCE_LSTM", + "STRIDED_SLICE", + "BIDIRECTIONAL_SEQUENCE_RNN", + "EXP", + "TOPK_V2", + "SPLIT", + nullptr + }; return names; } @@ -353,83 +393,95 @@ enum BuiltinOptions { BuiltinOptions_SqueezeOptions = 30, BuiltinOptions_SequenceRNNOptions = 31, BuiltinOptions_StridedSliceOptions = 32, + BuiltinOptions_ExpOptions = 33, + BuiltinOptions_TopKV2Options = 34, + BuiltinOptions_SplitOptions = 35, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_StridedSliceOptions + BuiltinOptions_MAX = BuiltinOptions_SplitOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[33] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[36] { static BuiltinOptions values[] = { - BuiltinOptions_NONE, - BuiltinOptions_Conv2DOptions, - BuiltinOptions_DepthwiseConv2DOptions, - BuiltinOptions_ConcatEmbeddingsOptions, - BuiltinOptions_LSHProjectionOptions, - BuiltinOptions_Pool2DOptions, - BuiltinOptions_SVDFOptions, - BuiltinOptions_RNNOptions, - BuiltinOptions_FullyConnectedOptions, - BuiltinOptions_SoftmaxOptions, - BuiltinOptions_ConcatenationOptions, - BuiltinOptions_AddOptions, - BuiltinOptions_L2NormOptions, - BuiltinOptions_LocalResponseNormalizationOptions, - BuiltinOptions_LSTMOptions, - BuiltinOptions_ResizeBilinearOptions, - BuiltinOptions_CallOptions, - BuiltinOptions_ReshapeOptions, - BuiltinOptions_SkipGramOptions, - BuiltinOptions_SpaceToDepthOptions, - BuiltinOptions_EmbeddingLookupSparseOptions, - BuiltinOptions_MulOptions, - BuiltinOptions_PadOptions, - BuiltinOptions_GatherOptions, - BuiltinOptions_BatchToSpaceNDOptions, - BuiltinOptions_SpaceToBatchNDOptions, - BuiltinOptions_TransposeOptions, - BuiltinOptions_MeanOptions, - BuiltinOptions_SubOptions, - BuiltinOptions_DivOptions, - BuiltinOptions_SqueezeOptions, - BuiltinOptions_SequenceRNNOptions, - BuiltinOptions_StridedSliceOptions}; + BuiltinOptions_NONE, + BuiltinOptions_Conv2DOptions, + BuiltinOptions_DepthwiseConv2DOptions, + BuiltinOptions_ConcatEmbeddingsOptions, + BuiltinOptions_LSHProjectionOptions, + BuiltinOptions_Pool2DOptions, + BuiltinOptions_SVDFOptions, + BuiltinOptions_RNNOptions, + BuiltinOptions_FullyConnectedOptions, + BuiltinOptions_SoftmaxOptions, + BuiltinOptions_ConcatenationOptions, + BuiltinOptions_AddOptions, + BuiltinOptions_L2NormOptions, + BuiltinOptions_LocalResponseNormalizationOptions, + BuiltinOptions_LSTMOptions, + BuiltinOptions_ResizeBilinearOptions, + BuiltinOptions_CallOptions, + BuiltinOptions_ReshapeOptions, + BuiltinOptions_SkipGramOptions, + BuiltinOptions_SpaceToDepthOptions, + BuiltinOptions_EmbeddingLookupSparseOptions, + BuiltinOptions_MulOptions, + BuiltinOptions_PadOptions, + BuiltinOptions_GatherOptions, + BuiltinOptions_BatchToSpaceNDOptions, + BuiltinOptions_SpaceToBatchNDOptions, + BuiltinOptions_TransposeOptions, + BuiltinOptions_MeanOptions, + BuiltinOptions_SubOptions, + BuiltinOptions_DivOptions, + BuiltinOptions_SqueezeOptions, + BuiltinOptions_SequenceRNNOptions, + BuiltinOptions_StridedSliceOptions, + BuiltinOptions_ExpOptions, + BuiltinOptions_TopKV2Options, + BuiltinOptions_SplitOptions + }; return values; } inline const char **EnumNamesBuiltinOptions() { - static const char *names[] = {"NONE", - "Conv2DOptions", - "DepthwiseConv2DOptions", - "ConcatEmbeddingsOptions", - "LSHProjectionOptions", - "Pool2DOptions", - "SVDFOptions", - "RNNOptions", - "FullyConnectedOptions", - "SoftmaxOptions", - "ConcatenationOptions", - "AddOptions", - "L2NormOptions", - "LocalResponseNormalizationOptions", - "LSTMOptions", - "ResizeBilinearOptions", - "CallOptions", - "ReshapeOptions", - "SkipGramOptions", - "SpaceToDepthOptions", - "EmbeddingLookupSparseOptions", - "MulOptions", - "PadOptions", - "GatherOptions", - "BatchToSpaceNDOptions", - "SpaceToBatchNDOptions", - "TransposeOptions", - "MeanOptions", - "SubOptions", - "DivOptions", - "SqueezeOptions", - "SequenceRNNOptions", - "StridedSliceOptions", - nullptr}; + static const char *names[] = { + "NONE", + "Conv2DOptions", + "DepthwiseConv2DOptions", + "ConcatEmbeddingsOptions", + "LSHProjectionOptions", + "Pool2DOptions", + "SVDFOptions", + "RNNOptions", + "FullyConnectedOptions", + "SoftmaxOptions", + "ConcatenationOptions", + "AddOptions", + "L2NormOptions", + "LocalResponseNormalizationOptions", + "LSTMOptions", + "ResizeBilinearOptions", + "CallOptions", + "ReshapeOptions", + "SkipGramOptions", + "SpaceToDepthOptions", + "EmbeddingLookupSparseOptions", + "MulOptions", + "PadOptions", + "GatherOptions", + "BatchToSpaceNDOptions", + "SpaceToBatchNDOptions", + "TransposeOptions", + "MeanOptions", + "SubOptions", + "DivOptions", + "SqueezeOptions", + "SequenceRNNOptions", + "StridedSliceOptions", + "ExpOptions", + "TopKV2Options", + "SplitOptions", + nullptr + }; return names; } @@ -438,206 +490,170 @@ inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { return EnumNamesBuiltinOptions()[index]; } -template -struct BuiltinOptionsTraits { +template struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_NONE; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_Conv2DOptions; }; -template <> -struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = - BuiltinOptions_DepthwiseConv2DOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DepthwiseConv2DOptions; }; -template <> -struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = - BuiltinOptions_ConcatEmbeddingsOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ConcatEmbeddingsOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LSHProjectionOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_Pool2DOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SVDFOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_RNNOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_FullyConnectedOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SoftmaxOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ConcatenationOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_AddOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_L2NormOptions; }; -template <> -struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = - BuiltinOptions_LocalResponseNormalizationOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LocalResponseNormalizationOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LSTMOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ResizeBilinearOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_CallOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ReshapeOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SkipGramOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SpaceToDepthOptions; }; -template <> -struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = - BuiltinOptions_EmbeddingLookupSparseOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EmbeddingLookupSparseOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_MulOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_PadOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_GatherOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_BatchToSpaceNDOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SpaceToBatchNDOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_MeanOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SubOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_DivOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ExpOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TopKV2Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SplitOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; BuiltinOptionsUnion() : type(BuiltinOptions_NONE), value(nullptr) {} - BuiltinOptionsUnion(BuiltinOptionsUnion &&u) FLATBUFFERS_NOEXCEPT - : type(BuiltinOptions_NONE), - value(nullptr) { - std::swap(type, u.type); - std::swap(value, u.value); - } + BuiltinOptionsUnion(BuiltinOptionsUnion&& u) FLATBUFFERS_NOEXCEPT : + type(BuiltinOptions_NONE), value(nullptr) + { std::swap(type, u.type); std::swap(value, u.value); } BuiltinOptionsUnion(const BuiltinOptionsUnion &) FLATBUFFERS_NOEXCEPT; - BuiltinOptionsUnion &operator=(const BuiltinOptionsUnion &u) - FLATBUFFERS_NOEXCEPT { - BuiltinOptionsUnion t(u); - std::swap(type, t.type); - std::swap(value, t.value); - return *this; - } - BuiltinOptionsUnion &operator=(BuiltinOptionsUnion &&u) FLATBUFFERS_NOEXCEPT { - std::swap(type, u.type); - std::swap(value, u.value); - return *this; - } + BuiltinOptionsUnion &operator=(const BuiltinOptionsUnion &u) FLATBUFFERS_NOEXCEPT + { BuiltinOptionsUnion t(u); std::swap(type, t.type); std::swap(value, t.value); return *this; } + BuiltinOptionsUnion &operator=(BuiltinOptionsUnion &&u) FLATBUFFERS_NOEXCEPT + { std::swap(type, u.type); std::swap(value, u.value); return *this; } ~BuiltinOptionsUnion() { Reset(); } void Reset(); #ifndef FLATBUFFERS_CPP98_STL template - void Set(T &&val) { + void Set(T&& val) { Reset(); type = BuiltinOptionsTraits::enum_value; if (type != BuiltinOptions_NONE) { @@ -646,342 +662,293 @@ struct BuiltinOptionsUnion { } #endif // FLATBUFFERS_CPP98_STL - static void *UnPack(const void *obj, BuiltinOptions type, - const flatbuffers::resolver_function_t *resolver); - flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const flatbuffers::rehasher_function_t *_rehasher = nullptr) const; + static void *UnPack(const void *obj, BuiltinOptions type, const flatbuffers::resolver_function_t *resolver); + flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher = nullptr) const; Conv2DOptionsT *AsConv2DOptions() { - return type == BuiltinOptions_Conv2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_Conv2DOptions ? + reinterpret_cast(value) : nullptr; } const Conv2DOptionsT *AsConv2DOptions() const { - return type == BuiltinOptions_Conv2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_Conv2DOptions ? + reinterpret_cast(value) : nullptr; } DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() { - return type == BuiltinOptions_DepthwiseConv2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_DepthwiseConv2DOptions ? + reinterpret_cast(value) : nullptr; } const DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() const { - return type == BuiltinOptions_DepthwiseConv2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_DepthwiseConv2DOptions ? + reinterpret_cast(value) : nullptr; } ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() { - return type == BuiltinOptions_ConcatEmbeddingsOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ConcatEmbeddingsOptions ? + reinterpret_cast(value) : nullptr; } const ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() const { - return type == BuiltinOptions_ConcatEmbeddingsOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ConcatEmbeddingsOptions ? + reinterpret_cast(value) : nullptr; } LSHProjectionOptionsT *AsLSHProjectionOptions() { - return type == BuiltinOptions_LSHProjectionOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_LSHProjectionOptions ? + reinterpret_cast(value) : nullptr; } const LSHProjectionOptionsT *AsLSHProjectionOptions() const { - return type == BuiltinOptions_LSHProjectionOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_LSHProjectionOptions ? + reinterpret_cast(value) : nullptr; } Pool2DOptionsT *AsPool2DOptions() { - return type == BuiltinOptions_Pool2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_Pool2DOptions ? + reinterpret_cast(value) : nullptr; } const Pool2DOptionsT *AsPool2DOptions() const { - return type == BuiltinOptions_Pool2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_Pool2DOptions ? + reinterpret_cast(value) : nullptr; } SVDFOptionsT *AsSVDFOptions() { - return type == BuiltinOptions_SVDFOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SVDFOptions ? + reinterpret_cast(value) : nullptr; } const SVDFOptionsT *AsSVDFOptions() const { - return type == BuiltinOptions_SVDFOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SVDFOptions ? + reinterpret_cast(value) : nullptr; } RNNOptionsT *AsRNNOptions() { - return type == BuiltinOptions_RNNOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_RNNOptions ? + reinterpret_cast(value) : nullptr; } const RNNOptionsT *AsRNNOptions() const { - return type == BuiltinOptions_RNNOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_RNNOptions ? + reinterpret_cast(value) : nullptr; } FullyConnectedOptionsT *AsFullyConnectedOptions() { - return type == BuiltinOptions_FullyConnectedOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_FullyConnectedOptions ? + reinterpret_cast(value) : nullptr; } const FullyConnectedOptionsT *AsFullyConnectedOptions() const { - return type == BuiltinOptions_FullyConnectedOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_FullyConnectedOptions ? + reinterpret_cast(value) : nullptr; } SoftmaxOptionsT *AsSoftmaxOptions() { - return type == BuiltinOptions_SoftmaxOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SoftmaxOptions ? + reinterpret_cast(value) : nullptr; } const SoftmaxOptionsT *AsSoftmaxOptions() const { - return type == BuiltinOptions_SoftmaxOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SoftmaxOptions ? + reinterpret_cast(value) : nullptr; } ConcatenationOptionsT *AsConcatenationOptions() { - return type == BuiltinOptions_ConcatenationOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ConcatenationOptions ? + reinterpret_cast(value) : nullptr; } const ConcatenationOptionsT *AsConcatenationOptions() const { - return type == BuiltinOptions_ConcatenationOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ConcatenationOptions ? + reinterpret_cast(value) : nullptr; } AddOptionsT *AsAddOptions() { - return type == BuiltinOptions_AddOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_AddOptions ? + reinterpret_cast(value) : nullptr; } const AddOptionsT *AsAddOptions() const { - return type == BuiltinOptions_AddOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_AddOptions ? + reinterpret_cast(value) : nullptr; } L2NormOptionsT *AsL2NormOptions() { - return type == BuiltinOptions_L2NormOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_L2NormOptions ? + reinterpret_cast(value) : nullptr; } const L2NormOptionsT *AsL2NormOptions() const { - return type == BuiltinOptions_L2NormOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_L2NormOptions ? + reinterpret_cast(value) : nullptr; } LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() { - return type == BuiltinOptions_LocalResponseNormalizationOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_LocalResponseNormalizationOptions ? + reinterpret_cast(value) : nullptr; } - const LocalResponseNormalizationOptionsT * - AsLocalResponseNormalizationOptions() const { - return type == BuiltinOptions_LocalResponseNormalizationOptions - ? reinterpret_cast( - value) - : nullptr; + const LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() const { + return type == BuiltinOptions_LocalResponseNormalizationOptions ? + reinterpret_cast(value) : nullptr; } LSTMOptionsT *AsLSTMOptions() { - return type == BuiltinOptions_LSTMOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_LSTMOptions ? + reinterpret_cast(value) : nullptr; } const LSTMOptionsT *AsLSTMOptions() const { - return type == BuiltinOptions_LSTMOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_LSTMOptions ? + reinterpret_cast(value) : nullptr; } ResizeBilinearOptionsT *AsResizeBilinearOptions() { - return type == BuiltinOptions_ResizeBilinearOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ResizeBilinearOptions ? + reinterpret_cast(value) : nullptr; } const ResizeBilinearOptionsT *AsResizeBilinearOptions() const { - return type == BuiltinOptions_ResizeBilinearOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ResizeBilinearOptions ? + reinterpret_cast(value) : nullptr; } CallOptionsT *AsCallOptions() { - return type == BuiltinOptions_CallOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_CallOptions ? + reinterpret_cast(value) : nullptr; } const CallOptionsT *AsCallOptions() const { - return type == BuiltinOptions_CallOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_CallOptions ? + reinterpret_cast(value) : nullptr; } ReshapeOptionsT *AsReshapeOptions() { - return type == BuiltinOptions_ReshapeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ReshapeOptions ? + reinterpret_cast(value) : nullptr; } const ReshapeOptionsT *AsReshapeOptions() const { - return type == BuiltinOptions_ReshapeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ReshapeOptions ? + reinterpret_cast(value) : nullptr; } SkipGramOptionsT *AsSkipGramOptions() { - return type == BuiltinOptions_SkipGramOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SkipGramOptions ? + reinterpret_cast(value) : nullptr; } const SkipGramOptionsT *AsSkipGramOptions() const { - return type == BuiltinOptions_SkipGramOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SkipGramOptions ? + reinterpret_cast(value) : nullptr; } SpaceToDepthOptionsT *AsSpaceToDepthOptions() { - return type == BuiltinOptions_SpaceToDepthOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SpaceToDepthOptions ? + reinterpret_cast(value) : nullptr; } const SpaceToDepthOptionsT *AsSpaceToDepthOptions() const { - return type == BuiltinOptions_SpaceToDepthOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SpaceToDepthOptions ? + reinterpret_cast(value) : nullptr; } EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() { - return type == BuiltinOptions_EmbeddingLookupSparseOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_EmbeddingLookupSparseOptions ? + reinterpret_cast(value) : nullptr; } const EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() const { - return type == BuiltinOptions_EmbeddingLookupSparseOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_EmbeddingLookupSparseOptions ? + reinterpret_cast(value) : nullptr; } MulOptionsT *AsMulOptions() { - return type == BuiltinOptions_MulOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_MulOptions ? + reinterpret_cast(value) : nullptr; } const MulOptionsT *AsMulOptions() const { - return type == BuiltinOptions_MulOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_MulOptions ? + reinterpret_cast(value) : nullptr; } PadOptionsT *AsPadOptions() { - return type == BuiltinOptions_PadOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_PadOptions ? + reinterpret_cast(value) : nullptr; } const PadOptionsT *AsPadOptions() const { - return type == BuiltinOptions_PadOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_PadOptions ? + reinterpret_cast(value) : nullptr; } GatherOptionsT *AsGatherOptions() { - return type == BuiltinOptions_GatherOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_GatherOptions ? + reinterpret_cast(value) : nullptr; } const GatherOptionsT *AsGatherOptions() const { - return type == BuiltinOptions_GatherOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_GatherOptions ? + reinterpret_cast(value) : nullptr; } BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() { - return type == BuiltinOptions_BatchToSpaceNDOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_BatchToSpaceNDOptions ? + reinterpret_cast(value) : nullptr; } const BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() const { - return type == BuiltinOptions_BatchToSpaceNDOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_BatchToSpaceNDOptions ? + reinterpret_cast(value) : nullptr; } SpaceToBatchNDOptionsT *AsSpaceToBatchNDOptions() { - return type == BuiltinOptions_SpaceToBatchNDOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SpaceToBatchNDOptions ? + reinterpret_cast(value) : nullptr; } const SpaceToBatchNDOptionsT *AsSpaceToBatchNDOptions() const { - return type == BuiltinOptions_SpaceToBatchNDOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SpaceToBatchNDOptions ? + reinterpret_cast(value) : nullptr; } TransposeOptionsT *AsTransposeOptions() { - return type == BuiltinOptions_TransposeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_TransposeOptions ? + reinterpret_cast(value) : nullptr; } const TransposeOptionsT *AsTransposeOptions() const { - return type == BuiltinOptions_TransposeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_TransposeOptions ? + reinterpret_cast(value) : nullptr; } MeanOptionsT *AsMeanOptions() { - return type == BuiltinOptions_MeanOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_MeanOptions ? + reinterpret_cast(value) : nullptr; } const MeanOptionsT *AsMeanOptions() const { - return type == BuiltinOptions_MeanOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_MeanOptions ? + reinterpret_cast(value) : nullptr; } SubOptionsT *AsSubOptions() { - return type == BuiltinOptions_SubOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SubOptions ? + reinterpret_cast(value) : nullptr; } const SubOptionsT *AsSubOptions() const { - return type == BuiltinOptions_SubOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SubOptions ? + reinterpret_cast(value) : nullptr; } DivOptionsT *AsDivOptions() { - return type == BuiltinOptions_DivOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_DivOptions ? + reinterpret_cast(value) : nullptr; } const DivOptionsT *AsDivOptions() const { - return type == BuiltinOptions_DivOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_DivOptions ? + reinterpret_cast(value) : nullptr; } SqueezeOptionsT *AsSqueezeOptions() { - return type == BuiltinOptions_SqueezeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SqueezeOptions ? + reinterpret_cast(value) : nullptr; } const SqueezeOptionsT *AsSqueezeOptions() const { - return type == BuiltinOptions_SqueezeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SqueezeOptions ? + reinterpret_cast(value) : nullptr; } SequenceRNNOptionsT *AsSequenceRNNOptions() { - return type == BuiltinOptions_SequenceRNNOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SequenceRNNOptions ? + reinterpret_cast(value) : nullptr; } const SequenceRNNOptionsT *AsSequenceRNNOptions() const { - return type == BuiltinOptions_SequenceRNNOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SequenceRNNOptions ? + reinterpret_cast(value) : nullptr; } StridedSliceOptionsT *AsStridedSliceOptions() { - return type == BuiltinOptions_StridedSliceOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_StridedSliceOptions ? + reinterpret_cast(value) : nullptr; } const StridedSliceOptionsT *AsStridedSliceOptions() const { - return type == BuiltinOptions_StridedSliceOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_StridedSliceOptions ? + reinterpret_cast(value) : nullptr; + } + ExpOptionsT *AsExpOptions() { + return type == BuiltinOptions_ExpOptions ? + reinterpret_cast(value) : nullptr; + } + const ExpOptionsT *AsExpOptions() const { + return type == BuiltinOptions_ExpOptions ? + reinterpret_cast(value) : nullptr; + } + TopKV2OptionsT *AsTopKV2Options() { + return type == BuiltinOptions_TopKV2Options ? + reinterpret_cast(value) : nullptr; + } + const TopKV2OptionsT *AsTopKV2Options() const { + return type == BuiltinOptions_TopKV2Options ? + reinterpret_cast(value) : nullptr; + } + SplitOptionsT *AsSplitOptions() { + return type == BuiltinOptions_SplitOptions ? + reinterpret_cast(value) : nullptr; + } + const SplitOptionsT *AsSplitOptions() const { + return type == BuiltinOptions_SplitOptions ? + reinterpret_cast(value) : nullptr; } }; -bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, - BuiltinOptions type); -bool VerifyBuiltinOptionsVector( - flatbuffers::Verifier &verifier, - const flatbuffers::Vector> *values, - const flatbuffers::Vector *types); +bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); +bool VerifyBuiltinOptionsVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); enum Padding { Padding_SAME = 0, @@ -991,12 +958,19 @@ enum Padding { }; inline Padding (&EnumValuesPadding())[2] { - static Padding values[] = {Padding_SAME, Padding_VALID}; + static Padding values[] = { + Padding_SAME, + Padding_VALID + }; return values; } inline const char **EnumNamesPadding() { - static const char *names[] = {"SAME", "VALID", nullptr}; + static const char *names[] = { + "SAME", + "VALID", + nullptr + }; return names; } @@ -1018,15 +992,26 @@ enum ActivationFunctionType { inline ActivationFunctionType (&EnumValuesActivationFunctionType())[6] { static ActivationFunctionType values[] = { - ActivationFunctionType_NONE, ActivationFunctionType_RELU, - ActivationFunctionType_RELU_N1_TO_1, ActivationFunctionType_RELU6, - ActivationFunctionType_TANH, ActivationFunctionType_SIGN_BIT}; + ActivationFunctionType_NONE, + ActivationFunctionType_RELU, + ActivationFunctionType_RELU_N1_TO_1, + ActivationFunctionType_RELU6, + ActivationFunctionType_TANH, + ActivationFunctionType_SIGN_BIT + }; return values; } inline const char **EnumNamesActivationFunctionType() { - static const char *names[] = {"NONE", "RELU", "RELU_N1_TO_1", "RELU6", - "TANH", "SIGN_BIT", nullptr}; + static const char *names[] = { + "NONE", + "RELU", + "RELU_N1_TO_1", + "RELU6", + "TANH", + "SIGN_BIT", + nullptr + }; return names; } @@ -1044,14 +1029,21 @@ enum LSHProjectionType { }; inline LSHProjectionType (&EnumValuesLSHProjectionType())[3] { - static LSHProjectionType values[] = {LSHProjectionType_UNKNOWN, - LSHProjectionType_SPARSE, - LSHProjectionType_DENSE}; + static LSHProjectionType values[] = { + LSHProjectionType_UNKNOWN, + LSHProjectionType_SPARSE, + LSHProjectionType_DENSE + }; return values; } inline const char **EnumNamesLSHProjectionType() { - static const char *names[] = {"UNKNOWN", "SPARSE", "DENSE", nullptr}; + static const char *names[] = { + "UNKNOWN", + "SPARSE", + "DENSE", + nullptr + }; return names; } @@ -1069,13 +1061,21 @@ enum CombinerType { }; inline CombinerType (&EnumValuesCombinerType())[3] { - static CombinerType values[] = {CombinerType_SUM, CombinerType_MEAN, - CombinerType_SQRTN}; + static CombinerType values[] = { + CombinerType_SUM, + CombinerType_MEAN, + CombinerType_SQRTN + }; return values; } inline const char **EnumNamesCombinerType() { - static const char *names[] = {"SUM", "MEAN", "SQRTN", nullptr}; + static const char *names[] = { + "SUM", + "MEAN", + "SQRTN", + nullptr + }; return names; } @@ -1091,12 +1091,17 @@ enum CustomOptionsFormat { }; inline CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] { - static CustomOptionsFormat values[] = {CustomOptionsFormat_FLEXBUFFERS}; + static CustomOptionsFormat values[] = { + CustomOptionsFormat_FLEXBUFFERS + }; return values; } inline const char **EnumNamesCustomOptionsFormat() { - static const char *names[] = {"FLEXBUFFERS", nullptr}; + static const char *names[] = { + "FLEXBUFFERS", + nullptr + }; return names; } @@ -1111,13 +1116,18 @@ struct QuantizationParametersT : public flatbuffers::NativeTable { std::vector max; std::vector scale; std::vector zero_point; - QuantizationParametersT() {} + QuantizationParametersT() { + } }; -struct QuantizationParameters FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef QuantizationParametersT NativeTableType; - enum { VT_MIN = 4, VT_MAX = 6, VT_SCALE = 8, VT_ZERO_POINT = 10 }; + enum { + VT_MIN = 4, + VT_MAX = 6, + VT_SCALE = 8, + VT_ZERO_POINT = 10 + }; const flatbuffers::Vector *min() const { return GetPointer *>(VT_MIN); } @@ -1131,20 +1141,20 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS return GetPointer *>(VT_ZERO_POINT); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_MIN) && - verifier.Verify(min()) && VerifyOffset(verifier, VT_MAX) && - verifier.Verify(max()) && VerifyOffset(verifier, VT_SCALE) && - verifier.Verify(scale()) && VerifyOffset(verifier, VT_ZERO_POINT) && - verifier.Verify(zero_point()) && verifier.EndTable(); - } - QuantizationParametersT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - QuantizationParametersT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_MIN) && + verifier.Verify(min()) && + VerifyOffset(verifier, VT_MAX) && + verifier.Verify(max()) && + VerifyOffset(verifier, VT_SCALE) && + verifier.Verify(scale()) && + VerifyOffset(verifier, VT_ZERO_POINT) && + verifier.Verify(zero_point()) && + verifier.EndTable(); + } + QuantizationParametersT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(QuantizationParametersT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct QuantizationParametersBuilder { @@ -1159,16 +1169,14 @@ struct QuantizationParametersBuilder { void add_scale(flatbuffers::Offset> scale) { fbb_.AddOffset(QuantizationParameters::VT_SCALE, scale); } - void add_zero_point( - flatbuffers::Offset> zero_point) { + void add_zero_point(flatbuffers::Offset> zero_point) { fbb_.AddOffset(QuantizationParameters::VT_ZERO_POINT, zero_point); } explicit QuantizationParametersBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - QuantizationParametersBuilder &operator=( - const QuantizationParametersBuilder &); + QuantizationParametersBuilder &operator=(const QuantizationParametersBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1190,23 +1198,21 @@ inline flatbuffers::Offset CreateQuantizationParameters( return builder_.Finish(); } -inline flatbuffers::Offset -CreateQuantizationParametersDirect( +inline flatbuffers::Offset CreateQuantizationParametersDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *min = nullptr, const std::vector *max = nullptr, const std::vector *scale = nullptr, const std::vector *zero_point = nullptr) { return tflite::CreateQuantizationParameters( - _fbb, min ? _fbb.CreateVector(*min) : 0, + _fbb, + min ? _fbb.CreateVector(*min) : 0, max ? _fbb.CreateVector(*max) : 0, scale ? _fbb.CreateVector(*scale) : 0, zero_point ? _fbb.CreateVector(*zero_point) : 0); } -flatbuffers::Offset CreateQuantizationParameters( - flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateQuantizationParameters(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct TensorT : public flatbuffers::NativeTable { typedef Tensor TableType; @@ -1215,7 +1221,10 @@ struct TensorT : public flatbuffers::NativeTable { uint32_t buffer; std::string name; std::unique_ptr quantization; - TensorT() : type(TensorType_FLOAT32), buffer(0) {} + TensorT() + : type(TensorType_FLOAT32), + buffer(0) { + } }; struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -1233,7 +1242,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { TensorType type() const { return static_cast(GetField(VT_TYPE, 0)); } - uint32_t buffer() const { return GetField(VT_BUFFER, 0); } + uint32_t buffer() const { + return GetField(VT_BUFFER, 0); + } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } @@ -1241,20 +1252,20 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer(VT_QUANTIZATION); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && - verifier.Verify(shape()) && VerifyField(verifier, VT_TYPE) && + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SHAPE) && + verifier.Verify(shape()) && + VerifyField(verifier, VT_TYPE) && VerifyField(verifier, VT_BUFFER) && - VerifyOffset(verifier, VT_NAME) && verifier.Verify(name()) && + VerifyOffset(verifier, VT_NAME) && + verifier.Verify(name()) && VerifyOffset(verifier, VT_QUANTIZATION) && - verifier.VerifyTable(quantization()) && verifier.EndTable(); + verifier.VerifyTable(quantization()) && + verifier.EndTable(); } - TensorT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t *_resolver = - nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct TensorBuilder { @@ -1272,11 +1283,11 @@ struct TensorBuilder { void add_name(flatbuffers::Offset name) { fbb_.AddOffset(Tensor::VT_NAME, name); } - void add_quantization( - flatbuffers::Offset quantization) { + void add_quantization(flatbuffers::Offset quantization) { fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); } - explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } TensorBuilder &operator=(const TensorBuilder &); @@ -1290,7 +1301,8 @@ struct TensorBuilder { inline flatbuffers::Offset CreateTensor( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> shape = 0, - TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, + TensorType type = TensorType_FLOAT32, + uint32_t buffer = 0, flatbuffers::Offset name = 0, flatbuffers::Offset quantization = 0) { TensorBuilder builder_(_fbb); @@ -1305,17 +1317,20 @@ inline flatbuffers::Offset CreateTensor( inline flatbuffers::Offset CreateTensorDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *shape = nullptr, - TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, + TensorType type = TensorType_FLOAT32, + uint32_t buffer = 0, const char *name = nullptr, flatbuffers::Offset quantization = 0) { return tflite::CreateTensor( - _fbb, shape ? _fbb.CreateVector(*shape) : 0, type, buffer, - name ? _fbb.CreateString(name) : 0, quantization); + _fbb, + shape ? _fbb.CreateVector(*shape) : 0, + type, + buffer, + name ? _fbb.CreateString(name) : 0, + quantization); } -flatbuffers::Offset CreateTensor( - flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct Conv2DOptionsT : public flatbuffers::NativeTable { typedef Conv2DOptions TableType; @@ -1327,7 +1342,8 @@ struct Conv2DOptionsT : public flatbuffers::NativeTable { : padding(Padding_SAME), stride_w(0), stride_h(0), - fused_activation_function(ActivationFunctionType_NONE) {} + fused_activation_function(ActivationFunctionType_NONE) { + } }; struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -1341,11 +1357,14 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { Padding padding() const { return static_cast(GetField(VT_PADDING, 0)); } - int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); } - int32_t stride_h() const { return GetField(VT_STRIDE_H, 0); } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1355,22 +1374,16 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - Conv2DOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - Conv2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + Conv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Conv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct Conv2DOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_padding(Padding padding) { - fbb_.AddElement(Conv2DOptions::VT_PADDING, - static_cast(padding), 0); + fbb_.AddElement(Conv2DOptions::VT_PADDING, static_cast(padding), 0); } void add_stride_w(int32_t stride_w) { fbb_.AddElement(Conv2DOptions::VT_STRIDE_W, stride_w, 0); @@ -1378,13 +1391,11 @@ struct Conv2DOptionsBuilder { void add_stride_h(int32_t stride_h) { fbb_.AddElement(Conv2DOptions::VT_STRIDE_H, stride_h, 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit Conv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } Conv2DOptionsBuilder &operator=(const Conv2DOptionsBuilder &); @@ -1396,10 +1407,11 @@ struct Conv2DOptionsBuilder { }; inline flatbuffers::Offset CreateConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, Padding padding = Padding_SAME, - int32_t stride_w = 0, int32_t stride_h = 0, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + flatbuffers::FlatBufferBuilder &_fbb, + Padding padding = Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { Conv2DOptionsBuilder builder_(_fbb); builder_.add_stride_h(stride_h); builder_.add_stride_w(stride_w); @@ -1408,9 +1420,7 @@ inline flatbuffers::Offset CreateConv2DOptions( return builder_.Finish(); } -flatbuffers::Offset CreateConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct Pool2DOptionsT : public flatbuffers::NativeTable { typedef Pool2DOptions TableType; @@ -1426,7 +1436,8 @@ struct Pool2DOptionsT : public flatbuffers::NativeTable { stride_h(0), filter_width(0), filter_height(0), - fused_activation_function(ActivationFunctionType_NONE) {} + fused_activation_function(ActivationFunctionType_NONE) { + } }; struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -1442,15 +1453,20 @@ struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { Padding padding() const { return static_cast(GetField(VT_PADDING, 0)); } - int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); } - int32_t stride_h() const { return GetField(VT_STRIDE_H, 0); } - int32_t filter_width() const { return GetField(VT_FILTER_WIDTH, 0); } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + int32_t filter_width() const { + return GetField(VT_FILTER_WIDTH, 0); + } int32_t filter_height() const { return GetField(VT_FILTER_HEIGHT, 0); } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1462,22 +1478,16 @@ struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - Pool2DOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - Pool2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + Pool2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Pool2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct Pool2DOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_padding(Padding padding) { - fbb_.AddElement(Pool2DOptions::VT_PADDING, - static_cast(padding), 0); + fbb_.AddElement(Pool2DOptions::VT_PADDING, static_cast(padding), 0); } void add_stride_w(int32_t stride_w) { fbb_.AddElement(Pool2DOptions::VT_STRIDE_W, stride_w, 0); @@ -1491,13 +1501,11 @@ struct Pool2DOptionsBuilder { void add_filter_height(int32_t filter_height) { fbb_.AddElement(Pool2DOptions::VT_FILTER_HEIGHT, filter_height, 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(Pool2DOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Pool2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit Pool2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } Pool2DOptionsBuilder &operator=(const Pool2DOptionsBuilder &); @@ -1509,11 +1517,13 @@ struct Pool2DOptionsBuilder { }; inline flatbuffers::Offset CreatePool2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, Padding padding = Padding_SAME, - int32_t stride_w = 0, int32_t stride_h = 0, int32_t filter_width = 0, + flatbuffers::FlatBufferBuilder &_fbb, + Padding padding = Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + int32_t filter_width = 0, int32_t filter_height = 0, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { Pool2DOptionsBuilder builder_(_fbb); builder_.add_filter_height(filter_height); builder_.add_filter_width(filter_width); @@ -1524,9 +1534,7 @@ inline flatbuffers::Offset CreatePool2DOptions( return builder_.Finish(); } -flatbuffers::Offset CreatePool2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreatePool2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable { typedef DepthwiseConv2DOptions TableType; @@ -1540,11 +1548,11 @@ struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable { stride_w(0), stride_h(0), depth_multiplier(0), - fused_activation_function(ActivationFunctionType_NONE) {} + fused_activation_function(ActivationFunctionType_NONE) { + } }; -struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef DepthwiseConv2DOptionsT NativeTableType; enum { VT_PADDING = 4, @@ -1556,14 +1564,17 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS Padding padding() const { return static_cast(GetField(VT_PADDING, 0)); } - int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); } - int32_t stride_h() const { return GetField(VT_STRIDE_H, 0); } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } int32_t depth_multiplier() const { return GetField(VT_DEPTH_MULTIPLIER, 0); } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1574,22 +1585,16 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - DepthwiseConv2DOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - DepthwiseConv2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + DepthwiseConv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DepthwiseConv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct DepthwiseConv2DOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_padding(Padding padding) { - fbb_.AddElement(DepthwiseConv2DOptions::VT_PADDING, - static_cast(padding), 0); + fbb_.AddElement(DepthwiseConv2DOptions::VT_PADDING, static_cast(padding), 0); } void add_stride_w(int32_t stride_w) { fbb_.AddElement(DepthwiseConv2DOptions::VT_STRIDE_W, stride_w, 0); @@ -1598,21 +1603,16 @@ struct DepthwiseConv2DOptionsBuilder { fbb_.AddElement(DepthwiseConv2DOptions::VT_STRIDE_H, stride_h, 0); } void add_depth_multiplier(int32_t depth_multiplier) { - fbb_.AddElement(DepthwiseConv2DOptions::VT_DEPTH_MULTIPLIER, - depth_multiplier, 0); + fbb_.AddElement(DepthwiseConv2DOptions::VT_DEPTH_MULTIPLIER, depth_multiplier, 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement( - DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit DepthwiseConv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - DepthwiseConv2DOptionsBuilder &operator=( - const DepthwiseConv2DOptionsBuilder &); + DepthwiseConv2DOptionsBuilder &operator=(const DepthwiseConv2DOptionsBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1621,10 +1621,12 @@ struct DepthwiseConv2DOptionsBuilder { }; inline flatbuffers::Offset CreateDepthwiseConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, Padding padding = Padding_SAME, - int32_t stride_w = 0, int32_t stride_h = 0, int32_t depth_multiplier = 0, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + flatbuffers::FlatBufferBuilder &_fbb, + Padding padding = Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + int32_t depth_multiplier = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { DepthwiseConv2DOptionsBuilder builder_(_fbb); builder_.add_depth_multiplier(depth_multiplier); builder_.add_stride_h(stride_h); @@ -1634,34 +1636,33 @@ inline flatbuffers::Offset CreateDepthwiseConv2DOptions( return builder_.Finish(); } -flatbuffers::Offset CreateDepthwiseConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateDepthwiseConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct ConcatEmbeddingsOptionsT : public flatbuffers::NativeTable { typedef ConcatEmbeddingsOptions TableType; int32_t num_channels; std::vector num_columns_per_channel; std::vector embedding_dim_per_channel; - ConcatEmbeddingsOptionsT() : num_channels(0) {} + ConcatEmbeddingsOptionsT() + : num_channels(0) { + } }; -struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ConcatEmbeddingsOptionsT NativeTableType; enum { VT_NUM_CHANNELS = 4, VT_NUM_COLUMNS_PER_CHANNEL = 6, VT_EMBEDDING_DIM_PER_CHANNEL = 8 }; - int32_t num_channels() const { return GetField(VT_NUM_CHANNELS, 0); } + int32_t num_channels() const { + return GetField(VT_NUM_CHANNELS, 0); + } const flatbuffers::Vector *num_columns_per_channel() const { - return GetPointer *>( - VT_NUM_COLUMNS_PER_CHANNEL); + return GetPointer *>(VT_NUM_COLUMNS_PER_CHANNEL); } const flatbuffers::Vector *embedding_dim_per_channel() const { - return GetPointer *>( - VT_EMBEDDING_DIM_PER_CHANNEL); + return GetPointer *>(VT_EMBEDDING_DIM_PER_CHANNEL); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1669,43 +1670,31 @@ struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS VerifyOffset(verifier, VT_NUM_COLUMNS_PER_CHANNEL) && verifier.Verify(num_columns_per_channel()) && VerifyOffset(verifier, VT_EMBEDDING_DIM_PER_CHANNEL) && - verifier.Verify(embedding_dim_per_channel()) && verifier.EndTable(); + verifier.Verify(embedding_dim_per_channel()) && + verifier.EndTable(); } - ConcatEmbeddingsOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - ConcatEmbeddingsOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ConcatEmbeddingsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConcatEmbeddingsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ConcatEmbeddingsOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_num_channels(int32_t num_channels) { - fbb_.AddElement(ConcatEmbeddingsOptions::VT_NUM_CHANNELS, - num_channels, 0); + fbb_.AddElement(ConcatEmbeddingsOptions::VT_NUM_CHANNELS, num_channels, 0); } - void add_num_columns_per_channel( - flatbuffers::Offset> - num_columns_per_channel) { - fbb_.AddOffset(ConcatEmbeddingsOptions::VT_NUM_COLUMNS_PER_CHANNEL, - num_columns_per_channel); + void add_num_columns_per_channel(flatbuffers::Offset> num_columns_per_channel) { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_NUM_COLUMNS_PER_CHANNEL, num_columns_per_channel); } - void add_embedding_dim_per_channel( - flatbuffers::Offset> - embedding_dim_per_channel) { - fbb_.AddOffset(ConcatEmbeddingsOptions::VT_EMBEDDING_DIM_PER_CHANNEL, - embedding_dim_per_channel); + void add_embedding_dim_per_channel(flatbuffers::Offset> embedding_dim_per_channel) { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_EMBEDDING_DIM_PER_CHANNEL, embedding_dim_per_channel); } explicit ConcatEmbeddingsOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - ConcatEmbeddingsOptionsBuilder &operator=( - const ConcatEmbeddingsOptionsBuilder &); + ConcatEmbeddingsOptionsBuilder &operator=(const ConcatEmbeddingsOptionsBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1713,13 +1702,11 @@ struct ConcatEmbeddingsOptionsBuilder { } }; -inline flatbuffers::Offset -CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, - int32_t num_channels = 0, - flatbuffers::Offset> - num_columns_per_channel = 0, - flatbuffers::Offset> - embedding_dim_per_channel = 0) { +inline flatbuffers::Offset CreateConcatEmbeddingsOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_channels = 0, + flatbuffers::Offset> num_columns_per_channel = 0, + flatbuffers::Offset> embedding_dim_per_channel = 0) { ConcatEmbeddingsOptionsBuilder builder_(_fbb); builder_.add_embedding_dim_per_channel(embedding_dim_per_channel); builder_.add_num_columns_per_channel(num_columns_per_channel); @@ -1727,61 +1714,54 @@ CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, return builder_.Finish(); } -inline flatbuffers::Offset -CreateConcatEmbeddingsOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, int32_t num_channels = 0, +inline flatbuffers::Offset CreateConcatEmbeddingsOptionsDirect( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_channels = 0, const std::vector *num_columns_per_channel = nullptr, const std::vector *embedding_dim_per_channel = nullptr) { return tflite::CreateConcatEmbeddingsOptions( - _fbb, num_channels, - num_columns_per_channel - ? _fbb.CreateVector(*num_columns_per_channel) - : 0, - embedding_dim_per_channel - ? _fbb.CreateVector(*embedding_dim_per_channel) - : 0); + _fbb, + num_channels, + num_columns_per_channel ? _fbb.CreateVector(*num_columns_per_channel) : 0, + embedding_dim_per_channel ? _fbb.CreateVector(*embedding_dim_per_channel) : 0); } -flatbuffers::Offset CreateConcatEmbeddingsOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct LSHProjectionOptionsT : public flatbuffers::NativeTable { typedef LSHProjectionOptions TableType; LSHProjectionType type; - LSHProjectionOptionsT() : type(LSHProjectionType_UNKNOWN) {} + LSHProjectionOptionsT() + : type(LSHProjectionType_UNKNOWN) { + } }; -struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef LSHProjectionOptionsT NativeTableType; - enum { VT_TYPE = 4 }; + enum { + VT_TYPE = 4 + }; LSHProjectionType type() const { return static_cast(GetField(VT_TYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_TYPE) && verifier.EndTable(); + VerifyField(verifier, VT_TYPE) && + verifier.EndTable(); } - LSHProjectionOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - LSHProjectionOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + LSHProjectionOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LSHProjectionOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct LSHProjectionOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_type(LSHProjectionType type) { - fbb_.AddElement(LSHProjectionOptions::VT_TYPE, - static_cast(type), 0); + fbb_.AddElement(LSHProjectionOptions::VT_TYPE, static_cast(type), 0); } explicit LSHProjectionOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } LSHProjectionOptionsBuilder &operator=(const LSHProjectionOptionsBuilder &); @@ -1800,25 +1780,29 @@ inline flatbuffers::Offset CreateLSHProjectionOptions( return builder_.Finish(); } -flatbuffers::Offset CreateLSHProjectionOptions( - flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateLSHProjectionOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SVDFOptionsT : public flatbuffers::NativeTable { typedef SVDFOptions TableType; int32_t rank; ActivationFunctionType fused_activation_function; SVDFOptionsT() - : rank(0), fused_activation_function(ActivationFunctionType_NONE) {} + : rank(0), + fused_activation_function(ActivationFunctionType_NONE) { + } }; struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SVDFOptionsT NativeTableType; - enum { VT_RANK = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 }; - int32_t rank() const { return GetField(VT_RANK, 0); } + enum { + VT_RANK = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + int32_t rank() const { + return GetField(VT_RANK, 0); + } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1826,14 +1810,9 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - SVDFOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SVDFOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SVDFOptionsBuilder { @@ -1842,13 +1821,11 @@ struct SVDFOptionsBuilder { void add_rank(int32_t rank) { fbb_.AddElement(SVDFOptions::VT_RANK, rank, 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SVDFOptionsBuilder &operator=(const SVDFOptionsBuilder &); @@ -1860,57 +1837,51 @@ struct SVDFOptionsBuilder { }; inline flatbuffers::Offset CreateSVDFOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t rank = 0, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + flatbuffers::FlatBufferBuilder &_fbb, + int32_t rank = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { SVDFOptionsBuilder builder_(_fbb); builder_.add_rank(rank); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateSVDFOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct RNNOptionsT : public flatbuffers::NativeTable { typedef RNNOptions TableType; ActivationFunctionType fused_activation_function; - RNNOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + RNNOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef RNNOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - RNNOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - RNNOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct RNNOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } RNNOptionsBuilder &operator=(const RNNOptionsBuilder &); @@ -1923,16 +1894,13 @@ struct RNNOptionsBuilder { inline flatbuffers::Offset CreateRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { RNNOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SequenceRNNOptionsT : public flatbuffers::NativeTable { typedef SequenceRNNOptions TableType; @@ -1940,16 +1908,21 @@ struct SequenceRNNOptionsT : public flatbuffers::NativeTable { ActivationFunctionType fused_activation_function; SequenceRNNOptionsT() : time_major(false), - fused_activation_function(ActivationFunctionType_NONE) {} + fused_activation_function(ActivationFunctionType_NONE) { + } }; struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SequenceRNNOptionsT NativeTableType; - enum { VT_TIME_MAJOR = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 }; - bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; } + enum { + VT_TIME_MAJOR = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + bool time_major() const { + return GetField(VT_TIME_MAJOR, 0) != 0; + } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1957,30 +1930,22 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - SequenceRNNOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SequenceRNNOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SequenceRNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SequenceRNNOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_time_major(bool time_major) { - fbb_.AddElement(SequenceRNNOptions::VT_TIME_MAJOR, - static_cast(time_major), 0); + fbb_.AddElement(SequenceRNNOptions::VT_TIME_MAJOR, static_cast(time_major), 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SequenceRNNOptionsBuilder &operator=(const SequenceRNNOptionsBuilder &); @@ -1992,59 +1957,117 @@ struct SequenceRNNOptionsBuilder { }; inline flatbuffers::Offset CreateSequenceRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + flatbuffers::FlatBufferBuilder &_fbb, + bool time_major = false, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { SequenceRNNOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); builder_.add_time_major(time_major); return builder_.Finish(); } -flatbuffers::Offset CreateSequenceRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSequenceRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable { + typedef BidirectionalSequenceRNNOptions TableType; + bool time_major; + ActivationFunctionType fused_activation_function; + BidirectionalSequenceRNNOptionsT() + : time_major(false), + fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef BidirectionalSequenceRNNOptionsT NativeTableType; + enum { + VT_TIME_MAJOR = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + bool time_major() const { + return GetField(VT_TIME_MAJOR, 0) != 0; + } + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_TIME_MAJOR) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BidirectionalSequenceRNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BidirectionalSequenceRNNOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_time_major(bool time_major) { + fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_TIME_MAJOR, static_cast(time_major), 0); + } + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + BidirectionalSequenceRNNOptionsBuilder &operator=(const BidirectionalSequenceRNNOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateBidirectionalSequenceRNNOptions( + flatbuffers::FlatBufferBuilder &_fbb, + bool time_major = false, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + BidirectionalSequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_time_major(time_major); + return builder_.Finish(); +} + +flatbuffers::Offset CreateBidirectionalSequenceRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct FullyConnectedOptionsT : public flatbuffers::NativeTable { typedef FullyConnectedOptions TableType; ActivationFunctionType fused_activation_function; FullyConnectedOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) {} + : fused_activation_function(ActivationFunctionType_NONE) { + } }; -struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef FullyConnectedOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - FullyConnectedOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - FullyConnectedOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FullyConnectedOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct FullyConnectedOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FullyConnectedOptionsBuilder &operator=(const FullyConnectedOptionsBuilder &); @@ -2057,39 +2080,38 @@ struct FullyConnectedOptionsBuilder { inline flatbuffers::Offset CreateFullyConnectedOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { FullyConnectedOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateFullyConnectedOptions( - flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateFullyConnectedOptions(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SoftmaxOptionsT : public flatbuffers::NativeTable { typedef SoftmaxOptions TableType; float beta; - SoftmaxOptionsT() : beta(0.0f) {} + SoftmaxOptionsT() + : beta(0.0f) { + } }; struct SoftmaxOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SoftmaxOptionsT NativeTableType; - enum { VT_BETA = 4 }; - float beta() const { return GetField(VT_BETA, 0.0f); } + enum { + VT_BETA = 4 + }; + float beta() const { + return GetField(VT_BETA, 0.0f); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_BETA) && verifier.EndTable(); + VerifyField(verifier, VT_BETA) && + verifier.EndTable(); } - SoftmaxOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SoftmaxOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SoftmaxOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SoftmaxOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SoftmaxOptionsBuilder { @@ -2099,7 +2121,7 @@ struct SoftmaxOptionsBuilder { fbb_.AddElement(SoftmaxOptions::VT_BETA, beta, 0.0f); } explicit SoftmaxOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SoftmaxOptionsBuilder &operator=(const SoftmaxOptionsBuilder &); @@ -2111,32 +2133,36 @@ struct SoftmaxOptionsBuilder { }; inline flatbuffers::Offset CreateSoftmaxOptions( - flatbuffers::FlatBufferBuilder &_fbb, float beta = 0.0f) { + flatbuffers::FlatBufferBuilder &_fbb, + float beta = 0.0f) { SoftmaxOptionsBuilder builder_(_fbb); builder_.add_beta(beta); return builder_.Finish(); } -flatbuffers::Offset CreateSoftmaxOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSoftmaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct ConcatenationOptionsT : public flatbuffers::NativeTable { typedef ConcatenationOptions TableType; int32_t axis; ActivationFunctionType fused_activation_function; ConcatenationOptionsT() - : axis(0), fused_activation_function(ActivationFunctionType_NONE) {} + : axis(0), + fused_activation_function(ActivationFunctionType_NONE) { + } }; -struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ConcatenationOptionsT NativeTableType; - enum { VT_AXIS = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 }; - int32_t axis() const { return GetField(VT_AXIS, 0); } + enum { + VT_AXIS = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -2144,14 +2170,9 @@ struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - ConcatenationOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - ConcatenationOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ConcatenationOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConcatenationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ConcatenationOptionsBuilder { @@ -2160,13 +2181,11 @@ struct ConcatenationOptionsBuilder { void add_axis(int32_t axis) { fbb_.AddElement(ConcatenationOptions::VT_AXIS, axis, 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(ConcatenationOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(ConcatenationOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit ConcatenationOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ConcatenationOptionsBuilder &operator=(const ConcatenationOptionsBuilder &); @@ -2178,57 +2197,51 @@ struct ConcatenationOptionsBuilder { }; inline flatbuffers::Offset CreateConcatenationOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { ConcatenationOptionsBuilder builder_(_fbb); builder_.add_axis(axis); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateConcatenationOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateConcatenationOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct AddOptionsT : public flatbuffers::NativeTable { typedef AddOptions TableType; ActivationFunctionType fused_activation_function; - AddOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + AddOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct AddOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef AddOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - AddOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - AddOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + AddOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(AddOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct AddOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(AddOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(AddOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit AddOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } AddOptionsBuilder &operator=(const AddOptionsBuilder &); @@ -2241,55 +2254,48 @@ struct AddOptionsBuilder { inline flatbuffers::Offset CreateAddOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { AddOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateAddOptions( - flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateAddOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct MulOptionsT : public flatbuffers::NativeTable { typedef MulOptions TableType; ActivationFunctionType fused_activation_function; - MulOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + MulOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct MulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef MulOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - MulOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - MulOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + MulOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct MulOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(MulOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(MulOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit MulOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } MulOptionsBuilder &operator=(const MulOptionsBuilder &); @@ -2302,55 +2308,48 @@ struct MulOptionsBuilder { inline flatbuffers::Offset CreateMulOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { MulOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateMulOptions( - flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct L2NormOptionsT : public flatbuffers::NativeTable { typedef L2NormOptions TableType; ActivationFunctionType fused_activation_function; - L2NormOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + L2NormOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct L2NormOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef L2NormOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - L2NormOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - L2NormOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + L2NormOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(L2NormOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct L2NormOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(L2NormOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(L2NormOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit L2NormOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } L2NormOptionsBuilder &operator=(const L2NormOptionsBuilder &); @@ -2363,16 +2362,13 @@ struct L2NormOptionsBuilder { inline flatbuffers::Offset CreateL2NormOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { L2NormOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateL2NormOptions( - flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateL2NormOptions(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct LocalResponseNormalizationOptionsT : public flatbuffers::NativeTable { typedef LocalResponseNormalizationOptions TableType; @@ -2381,61 +2377,66 @@ struct LocalResponseNormalizationOptionsT : public flatbuffers::NativeTable { float alpha; float beta; LocalResponseNormalizationOptionsT() - : radius(0), bias(0.0f), alpha(0.0f), beta(0.0f) {} + : radius(0), + bias(0.0f), + alpha(0.0f), + beta(0.0f) { + } }; -struct LocalResponseNormalizationOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct LocalResponseNormalizationOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef LocalResponseNormalizationOptionsT NativeTableType; - enum { VT_RADIUS = 4, VT_BIAS = 6, VT_ALPHA = 8, VT_BETA = 10 }; - int32_t radius() const { return GetField(VT_RADIUS, 0); } - float bias() const { return GetField(VT_BIAS, 0.0f); } - float alpha() const { return GetField(VT_ALPHA, 0.0f); } - float beta() const { return GetField(VT_BETA, 0.0f); } + enum { + VT_RADIUS = 4, + VT_BIAS = 6, + VT_ALPHA = 8, + VT_BETA = 10 + }; + int32_t radius() const { + return GetField(VT_RADIUS, 0); + } + float bias() const { + return GetField(VT_BIAS, 0.0f); + } + float alpha() const { + return GetField(VT_ALPHA, 0.0f); + } + float beta() const { + return GetField(VT_BETA, 0.0f); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_RADIUS) && VerifyField(verifier, VT_BIAS) && VerifyField(verifier, VT_ALPHA) && - VerifyField(verifier, VT_BETA) && verifier.EndTable(); + VerifyField(verifier, VT_BETA) && + verifier.EndTable(); } - LocalResponseNormalizationOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - LocalResponseNormalizationOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const LocalResponseNormalizationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + LocalResponseNormalizationOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LocalResponseNormalizationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct LocalResponseNormalizationOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_radius(int32_t radius) { - fbb_.AddElement(LocalResponseNormalizationOptions::VT_RADIUS, - radius, 0); + fbb_.AddElement(LocalResponseNormalizationOptions::VT_RADIUS, radius, 0); } void add_bias(float bias) { - fbb_.AddElement(LocalResponseNormalizationOptions::VT_BIAS, bias, - 0.0f); + fbb_.AddElement(LocalResponseNormalizationOptions::VT_BIAS, bias, 0.0f); } void add_alpha(float alpha) { - fbb_.AddElement(LocalResponseNormalizationOptions::VT_ALPHA, alpha, - 0.0f); + fbb_.AddElement(LocalResponseNormalizationOptions::VT_ALPHA, alpha, 0.0f); } void add_beta(float beta) { - fbb_.AddElement(LocalResponseNormalizationOptions::VT_BETA, beta, - 0.0f); + fbb_.AddElement(LocalResponseNormalizationOptions::VT_BETA, beta, 0.0f); } - explicit LocalResponseNormalizationOptionsBuilder( - flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit LocalResponseNormalizationOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - LocalResponseNormalizationOptionsBuilder &operator=( - const LocalResponseNormalizationOptionsBuilder &); + LocalResponseNormalizationOptionsBuilder &operator=(const LocalResponseNormalizationOptionsBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -2443,10 +2444,12 @@ struct LocalResponseNormalizationOptionsBuilder { } }; -inline flatbuffers::Offset -CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, - int32_t radius = 0, float bias = 0.0f, - float alpha = 0.0f, float beta = 0.0f) { +inline flatbuffers::Offset CreateLocalResponseNormalizationOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t radius = 0, + float bias = 0.0f, + float alpha = 0.0f, + float beta = 0.0f) { LocalResponseNormalizationOptionsBuilder builder_(_fbb); builder_.add_beta(beta); builder_.add_alpha(alpha); @@ -2455,11 +2458,7 @@ CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, return builder_.Finish(); } -flatbuffers::Offset -CreateLocalResponseNormalizationOptions( - flatbuffers::FlatBufferBuilder &_fbb, - const LocalResponseNormalizationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct LSTMOptionsT : public flatbuffers::NativeTable { typedef LSTMOptions TableType; @@ -2469,41 +2468,43 @@ struct LSTMOptionsT : public flatbuffers::NativeTable { LSTMOptionsT() : fused_activation_function(ActivationFunctionType_NONE), cell_clip(0.0f), - proj_clip(0.0f) {} + proj_clip(0.0f) { + } }; struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef LSTMOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, VT_PROJ_CLIP = 8 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { + return GetField(VT_CELL_CLIP, 0.0f); + } + float proj_clip() const { + return GetField(VT_PROJ_CLIP, 0.0f); } - float cell_clip() const { return GetField(VT_CELL_CLIP, 0.0f); } - float proj_clip() const { return GetField(VT_PROJ_CLIP, 0.0f); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_CELL_CLIP) && - VerifyField(verifier, VT_PROJ_CLIP) && verifier.EndTable(); + VerifyField(verifier, VT_PROJ_CLIP) && + verifier.EndTable(); } - LSTMOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - LSTMOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct LSTMOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(LSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(LSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } void add_cell_clip(float cell_clip) { fbb_.AddElement(LSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); @@ -2512,7 +2513,7 @@ struct LSTMOptionsBuilder { fbb_.AddElement(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); } explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } LSTMOptionsBuilder &operator=(const LSTMOptionsBuilder &); @@ -2525,9 +2526,9 @@ struct LSTMOptionsBuilder { inline flatbuffers::Offset CreateLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE, - float cell_clip = 0.0f, float proj_clip = 0.0f) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + float cell_clip = 0.0f, + float proj_clip = 0.0f) { LSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); @@ -2535,50 +2536,42 @@ inline flatbuffers::Offset CreateLSTMOptions( return builder_.Finish(); } -flatbuffers::Offset CreateLSTMOptions( - flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct ResizeBilinearOptionsT : public flatbuffers::NativeTable { typedef ResizeBilinearOptions TableType; - int32_t new_height; - int32_t new_width; - ResizeBilinearOptionsT() : new_height(0), new_width(0) {} + bool align_corners; + ResizeBilinearOptionsT() + : align_corners(false) { + } }; -struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ResizeBilinearOptionsT NativeTableType; - enum { VT_NEW_HEIGHT = 4, VT_NEW_WIDTH = 6 }; - int32_t new_height() const { return GetField(VT_NEW_HEIGHT, 0); } - int32_t new_width() const { return GetField(VT_NEW_WIDTH, 0); } + enum { + VT_ALIGN_CORNERS = 8 + }; + bool align_corners() const { + return GetField(VT_ALIGN_CORNERS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_NEW_HEIGHT) && - VerifyField(verifier, VT_NEW_WIDTH) && verifier.EndTable(); + VerifyField(verifier, VT_ALIGN_CORNERS) && + verifier.EndTable(); } - ResizeBilinearOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - ResizeBilinearOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ResizeBilinearOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ResizeBilinearOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_new_height(int32_t new_height) { - fbb_.AddElement(ResizeBilinearOptions::VT_NEW_HEIGHT, new_height, - 0); - } - void add_new_width(int32_t new_width) { - fbb_.AddElement(ResizeBilinearOptions::VT_NEW_WIDTH, new_width, 0); + void add_align_corners(bool align_corners) { + fbb_.AddElement(ResizeBilinearOptions::VT_ALIGN_CORNERS, static_cast(align_corners), 0); } explicit ResizeBilinearOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ResizeBilinearOptionsBuilder &operator=(const ResizeBilinearOptionsBuilder &); @@ -2590,40 +2583,39 @@ struct ResizeBilinearOptionsBuilder { }; inline flatbuffers::Offset CreateResizeBilinearOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t new_height = 0, - int32_t new_width = 0) { + flatbuffers::FlatBufferBuilder &_fbb, + bool align_corners = false) { ResizeBilinearOptionsBuilder builder_(_fbb); - builder_.add_new_width(new_width); - builder_.add_new_height(new_height); + builder_.add_align_corners(align_corners); return builder_.Finish(); } -flatbuffers::Offset CreateResizeBilinearOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct CallOptionsT : public flatbuffers::NativeTable { typedef CallOptions TableType; uint32_t subgraph; - CallOptionsT() : subgraph(0) {} + CallOptionsT() + : subgraph(0) { + } }; struct CallOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef CallOptionsT NativeTableType; - enum { VT_SUBGRAPH = 4 }; - uint32_t subgraph() const { return GetField(VT_SUBGRAPH, 0); } + enum { + VT_SUBGRAPH = 4 + }; + uint32_t subgraph() const { + return GetField(VT_SUBGRAPH, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_SUBGRAPH) && verifier.EndTable(); + VerifyField(verifier, VT_SUBGRAPH) && + verifier.EndTable(); } - CallOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - CallOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + CallOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CallOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct CallOptionsBuilder { @@ -2633,7 +2625,7 @@ struct CallOptionsBuilder { fbb_.AddElement(CallOptions::VT_SUBGRAPH, subgraph, 0); } explicit CallOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } CallOptionsBuilder &operator=(const CallOptionsBuilder &); @@ -2645,41 +2637,37 @@ struct CallOptionsBuilder { }; inline flatbuffers::Offset CreateCallOptions( - flatbuffers::FlatBufferBuilder &_fbb, uint32_t subgraph = 0) { + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t subgraph = 0) { CallOptionsBuilder builder_(_fbb); builder_.add_subgraph(subgraph); return builder_.Finish(); } -flatbuffers::Offset CreateCallOptions( - flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateCallOptions(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct PadOptionsT : public flatbuffers::NativeTable { typedef PadOptions TableType; - PadOptionsT() {} + PadOptionsT() { + } }; struct PadOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef PadOptionsT NativeTableType; bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && verifier.EndTable(); + return VerifyTableStart(verifier) && + verifier.EndTable(); } - PadOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - PadOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + PadOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct PadOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; explicit PadOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } PadOptionsBuilder &operator=(const PadOptionsBuilder &); @@ -2696,45 +2684,42 @@ inline flatbuffers::Offset CreatePadOptions( return builder_.Finish(); } -flatbuffers::Offset CreatePadOptions( - flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreatePadOptions(flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct ReshapeOptionsT : public flatbuffers::NativeTable { typedef ReshapeOptions TableType; std::vector new_shape; - ReshapeOptionsT() {} + ReshapeOptionsT() { + } }; struct ReshapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ReshapeOptionsT NativeTableType; - enum { VT_NEW_SHAPE = 4 }; + enum { + VT_NEW_SHAPE = 4 + }; const flatbuffers::Vector *new_shape() const { return GetPointer *>(VT_NEW_SHAPE); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NEW_SHAPE) && - verifier.Verify(new_shape()) && verifier.EndTable(); + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NEW_SHAPE) && + verifier.Verify(new_shape()) && + verifier.EndTable(); } - ReshapeOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - ReshapeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ReshapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReshapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ReshapeOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_new_shape( - flatbuffers::Offset> new_shape) { + void add_new_shape(flatbuffers::Offset> new_shape) { fbb_.AddOffset(ReshapeOptions::VT_NEW_SHAPE, new_shape); } explicit ReshapeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ReshapeOptionsBuilder &operator=(const ReshapeOptionsBuilder &); @@ -2757,70 +2742,34 @@ inline flatbuffers::Offset CreateReshapeOptionsDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *new_shape = nullptr) { return tflite::CreateReshapeOptions( - _fbb, new_shape ? _fbb.CreateVector(*new_shape) : 0); + _fbb, + new_shape ? _fbb.CreateVector(*new_shape) : 0); } -flatbuffers::Offset CreateReshapeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateReshapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SpaceToBatchNDOptionsT : public flatbuffers::NativeTable { typedef SpaceToBatchNDOptions TableType; - std::vector block_shape; - std::vector before_paddings; - std::vector after_paddings; - SpaceToBatchNDOptionsT() {} + SpaceToBatchNDOptionsT() { + } }; -struct SpaceToBatchNDOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct SpaceToBatchNDOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SpaceToBatchNDOptionsT NativeTableType; - enum { VT_BLOCK_SHAPE = 4, VT_BEFORE_PADDINGS = 6, VT_AFTER_PADDINGS = 8 }; - const flatbuffers::Vector *block_shape() const { - return GetPointer *>(VT_BLOCK_SHAPE); - } - const flatbuffers::Vector *before_paddings() const { - return GetPointer *>(VT_BEFORE_PADDINGS); - } - const flatbuffers::Vector *after_paddings() const { - return GetPointer *>(VT_AFTER_PADDINGS); - } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_BLOCK_SHAPE) && - verifier.Verify(block_shape()) && - VerifyOffset(verifier, VT_BEFORE_PADDINGS) && - verifier.Verify(before_paddings()) && - VerifyOffset(verifier, VT_AFTER_PADDINGS) && - verifier.Verify(after_paddings()) && verifier.EndTable(); - } - SpaceToBatchNDOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SpaceToBatchNDOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + verifier.EndTable(); + } + SpaceToBatchNDOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SpaceToBatchNDOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SpaceToBatchNDOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_block_shape( - flatbuffers::Offset> block_shape) { - fbb_.AddOffset(SpaceToBatchNDOptions::VT_BLOCK_SHAPE, block_shape); - } - void add_before_paddings( - flatbuffers::Offset> before_paddings) { - fbb_.AddOffset(SpaceToBatchNDOptions::VT_BEFORE_PADDINGS, before_paddings); - } - void add_after_paddings( - flatbuffers::Offset> after_paddings) { - fbb_.AddOffset(SpaceToBatchNDOptions::VT_AFTER_PADDINGS, after_paddings); - } explicit SpaceToBatchNDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SpaceToBatchNDOptionsBuilder &operator=(const SpaceToBatchNDOptionsBuilder &); @@ -2832,90 +2781,35 @@ struct SpaceToBatchNDOptionsBuilder { }; inline flatbuffers::Offset CreateSpaceToBatchNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset> block_shape = 0, - flatbuffers::Offset> before_paddings = 0, - flatbuffers::Offset> after_paddings = 0) { + flatbuffers::FlatBufferBuilder &_fbb) { SpaceToBatchNDOptionsBuilder builder_(_fbb); - builder_.add_after_paddings(after_paddings); - builder_.add_before_paddings(before_paddings); - builder_.add_block_shape(block_shape); return builder_.Finish(); } -inline flatbuffers::Offset -CreateSpaceToBatchNDOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *block_shape = nullptr, - const std::vector *before_paddings = nullptr, - const std::vector *after_paddings = nullptr) { - return tflite::CreateSpaceToBatchNDOptions( - _fbb, block_shape ? _fbb.CreateVector(*block_shape) : 0, - before_paddings ? _fbb.CreateVector(*before_paddings) : 0, - after_paddings ? _fbb.CreateVector(*after_paddings) : 0); -} - -flatbuffers::Offset CreateSpaceToBatchNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSpaceToBatchNDOptions(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct BatchToSpaceNDOptionsT : public flatbuffers::NativeTable { typedef BatchToSpaceNDOptions TableType; - std::vector block_shape; - std::vector before_crops; - std::vector after_crops; - BatchToSpaceNDOptionsT() {} + BatchToSpaceNDOptionsT() { + } }; -struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef BatchToSpaceNDOptionsT NativeTableType; - enum { VT_BLOCK_SHAPE = 4, VT_BEFORE_CROPS = 6, VT_AFTER_CROPS = 8 }; - const flatbuffers::Vector *block_shape() const { - return GetPointer *>(VT_BLOCK_SHAPE); - } - const flatbuffers::Vector *before_crops() const { - return GetPointer *>(VT_BEFORE_CROPS); - } - const flatbuffers::Vector *after_crops() const { - return GetPointer *>(VT_AFTER_CROPS); - } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_BLOCK_SHAPE) && - verifier.Verify(block_shape()) && - VerifyOffset(verifier, VT_BEFORE_CROPS) && - verifier.Verify(before_crops()) && - VerifyOffset(verifier, VT_AFTER_CROPS) && - verifier.Verify(after_crops()) && verifier.EndTable(); - } - BatchToSpaceNDOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - BatchToSpaceNDOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + verifier.EndTable(); + } + BatchToSpaceNDOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BatchToSpaceNDOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct BatchToSpaceNDOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_block_shape( - flatbuffers::Offset> block_shape) { - fbb_.AddOffset(BatchToSpaceNDOptions::VT_BLOCK_SHAPE, block_shape); - } - void add_before_crops( - flatbuffers::Offset> before_crops) { - fbb_.AddOffset(BatchToSpaceNDOptions::VT_BEFORE_CROPS, before_crops); - } - void add_after_crops( - flatbuffers::Offset> after_crops) { - fbb_.AddOffset(BatchToSpaceNDOptions::VT_AFTER_CROPS, after_crops); - } explicit BatchToSpaceNDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } BatchToSpaceNDOptionsBuilder &operator=(const BatchToSpaceNDOptionsBuilder &); @@ -2927,32 +2821,12 @@ struct BatchToSpaceNDOptionsBuilder { }; inline flatbuffers::Offset CreateBatchToSpaceNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset> block_shape = 0, - flatbuffers::Offset> before_crops = 0, - flatbuffers::Offset> after_crops = 0) { + flatbuffers::FlatBufferBuilder &_fbb) { BatchToSpaceNDOptionsBuilder builder_(_fbb); - builder_.add_after_crops(after_crops); - builder_.add_before_crops(before_crops); - builder_.add_block_shape(block_shape); return builder_.Finish(); } -inline flatbuffers::Offset -CreateBatchToSpaceNDOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *block_shape = nullptr, - const std::vector *before_crops = nullptr, - const std::vector *after_crops = nullptr) { - return tflite::CreateBatchToSpaceNDOptions( - _fbb, block_shape ? _fbb.CreateVector(*block_shape) : 0, - before_crops ? _fbb.CreateVector(*before_crops) : 0, - after_crops ? _fbb.CreateVector(*after_crops) : 0); -} - -flatbuffers::Offset CreateBatchToSpaceNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateBatchToSpaceNDOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SkipGramOptionsT : public flatbuffers::NativeTable { typedef SkipGramOptions TableType; @@ -2960,13 +2834,22 @@ struct SkipGramOptionsT : public flatbuffers::NativeTable { int32_t max_skip_size; bool include_all_ngrams; SkipGramOptionsT() - : ngram_size(0), max_skip_size(0), include_all_ngrams(false) {} + : ngram_size(0), + max_skip_size(0), + include_all_ngrams(false) { + } }; struct SkipGramOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SkipGramOptionsT NativeTableType; - enum { VT_NGRAM_SIZE = 4, VT_MAX_SKIP_SIZE = 6, VT_INCLUDE_ALL_NGRAMS = 8 }; - int32_t ngram_size() const { return GetField(VT_NGRAM_SIZE, 0); } + enum { + VT_NGRAM_SIZE = 4, + VT_MAX_SKIP_SIZE = 6, + VT_INCLUDE_ALL_NGRAMS = 8 + }; + int32_t ngram_size() const { + return GetField(VT_NGRAM_SIZE, 0); + } int32_t max_skip_size() const { return GetField(VT_MAX_SKIP_SIZE, 0); } @@ -2980,14 +2863,9 @@ struct SkipGramOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_INCLUDE_ALL_NGRAMS) && verifier.EndTable(); } - SkipGramOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SkipGramOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SkipGramOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SkipGramOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SkipGramOptionsBuilder { @@ -2997,15 +2875,13 @@ struct SkipGramOptionsBuilder { fbb_.AddElement(SkipGramOptions::VT_NGRAM_SIZE, ngram_size, 0); } void add_max_skip_size(int32_t max_skip_size) { - fbb_.AddElement(SkipGramOptions::VT_MAX_SKIP_SIZE, max_skip_size, - 0); + fbb_.AddElement(SkipGramOptions::VT_MAX_SKIP_SIZE, max_skip_size, 0); } void add_include_all_ngrams(bool include_all_ngrams) { - fbb_.AddElement(SkipGramOptions::VT_INCLUDE_ALL_NGRAMS, - static_cast(include_all_ngrams), 0); + fbb_.AddElement(SkipGramOptions::VT_INCLUDE_ALL_NGRAMS, static_cast(include_all_ngrams), 0); } explicit SkipGramOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SkipGramOptionsBuilder &operator=(const SkipGramOptionsBuilder &); @@ -3017,8 +2893,10 @@ struct SkipGramOptionsBuilder { }; inline flatbuffers::Offset CreateSkipGramOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t ngram_size = 0, - int32_t max_skip_size = 0, bool include_all_ngrams = false) { + flatbuffers::FlatBufferBuilder &_fbb, + int32_t ngram_size = 0, + int32_t max_skip_size = 0, + bool include_all_ngrams = false) { SkipGramOptionsBuilder builder_(_fbb); builder_.add_max_skip_size(max_skip_size); builder_.add_ngram_size(ngram_size); @@ -3026,33 +2904,32 @@ inline flatbuffers::Offset CreateSkipGramOptions( return builder_.Finish(); } -flatbuffers::Offset CreateSkipGramOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSkipGramOptions(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SpaceToDepthOptionsT : public flatbuffers::NativeTable { typedef SpaceToDepthOptions TableType; int32_t block_size; - SpaceToDepthOptionsT() : block_size(0) {} + SpaceToDepthOptionsT() + : block_size(0) { + } }; -struct SpaceToDepthOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct SpaceToDepthOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SpaceToDepthOptionsT NativeTableType; - enum { VT_BLOCK_SIZE = 4 }; - int32_t block_size() const { return GetField(VT_BLOCK_SIZE, 0); } + enum { + VT_BLOCK_SIZE = 4 + }; + int32_t block_size() const { + return GetField(VT_BLOCK_SIZE, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_BLOCK_SIZE) && verifier.EndTable(); + VerifyField(verifier, VT_BLOCK_SIZE) && + verifier.EndTable(); } - SpaceToDepthOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SpaceToDepthOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SpaceToDepthOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SpaceToDepthOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SpaceToDepthOptionsBuilder { @@ -3062,7 +2939,7 @@ struct SpaceToDepthOptionsBuilder { fbb_.AddElement(SpaceToDepthOptions::VT_BLOCK_SIZE, block_size, 0); } explicit SpaceToDepthOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SpaceToDepthOptionsBuilder &operator=(const SpaceToDepthOptionsBuilder &); @@ -3074,54 +2951,49 @@ struct SpaceToDepthOptionsBuilder { }; inline flatbuffers::Offset CreateSpaceToDepthOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t block_size = 0) { + flatbuffers::FlatBufferBuilder &_fbb, + int32_t block_size = 0) { SpaceToDepthOptionsBuilder builder_(_fbb); builder_.add_block_size(block_size); return builder_.Finish(); } -flatbuffers::Offset CreateSpaceToDepthOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSpaceToDepthOptions(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SubOptionsT : public flatbuffers::NativeTable { typedef SubOptions TableType; ActivationFunctionType fused_activation_function; - SubOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + SubOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct SubOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SubOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - SubOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SubOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SubOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SubOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SubOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(SubOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SubOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit SubOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SubOptionsBuilder &operator=(const SubOptionsBuilder &); @@ -3134,55 +3006,48 @@ struct SubOptionsBuilder { inline flatbuffers::Offset CreateSubOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { SubOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateSubOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSubOptions(flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct DivOptionsT : public flatbuffers::NativeTable { typedef DivOptions TableType; ActivationFunctionType fused_activation_function; - DivOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + DivOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct DivOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef DivOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - DivOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - DivOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + DivOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct DivOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(DivOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(DivOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit DivOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } DivOptionsBuilder &operator=(const DivOptionsBuilder &); @@ -3195,59 +3060,91 @@ struct DivOptionsBuilder { inline flatbuffers::Offset CreateDivOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { DivOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateDivOptions( - flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TopKV2OptionsT : public flatbuffers::NativeTable { + typedef TopKV2Options TableType; + TopKV2OptionsT() { + } +}; + +struct TopKV2Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TopKV2OptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + TopKV2OptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TopKV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TopKV2OptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit TopKV2OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TopKV2OptionsBuilder &operator=(const TopKV2OptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTopKV2Options( + flatbuffers::FlatBufferBuilder &_fbb) { + TopKV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateTopKV2Options(flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct EmbeddingLookupSparseOptionsT : public flatbuffers::NativeTable { typedef EmbeddingLookupSparseOptions TableType; CombinerType combiner; - EmbeddingLookupSparseOptionsT() : combiner(CombinerType_SUM) {} + EmbeddingLookupSparseOptionsT() + : combiner(CombinerType_SUM) { + } }; -struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef EmbeddingLookupSparseOptionsT NativeTableType; - enum { VT_COMBINER = 4 }; + enum { + VT_COMBINER = 4 + }; CombinerType combiner() const { return static_cast(GetField(VT_COMBINER, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_COMBINER) && verifier.EndTable(); + VerifyField(verifier, VT_COMBINER) && + verifier.EndTable(); } - EmbeddingLookupSparseOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + EmbeddingLookupSparseOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EmbeddingLookupSparseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct EmbeddingLookupSparseOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_combiner(CombinerType combiner) { - fbb_.AddElement(EmbeddingLookupSparseOptions::VT_COMBINER, - static_cast(combiner), 0); + fbb_.AddElement(EmbeddingLookupSparseOptions::VT_COMBINER, static_cast(combiner), 0); } - explicit EmbeddingLookupSparseOptionsBuilder( - flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit EmbeddingLookupSparseOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - EmbeddingLookupSparseOptionsBuilder &operator=( - const EmbeddingLookupSparseOptionsBuilder &); + EmbeddingLookupSparseOptionsBuilder &operator=(const EmbeddingLookupSparseOptionsBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -3255,42 +3152,40 @@ struct EmbeddingLookupSparseOptionsBuilder { } }; -inline flatbuffers::Offset -CreateEmbeddingLookupSparseOptions(flatbuffers::FlatBufferBuilder &_fbb, - CombinerType combiner = CombinerType_SUM) { +inline flatbuffers::Offset CreateEmbeddingLookupSparseOptions( + flatbuffers::FlatBufferBuilder &_fbb, + CombinerType combiner = CombinerType_SUM) { EmbeddingLookupSparseOptionsBuilder builder_(_fbb); builder_.add_combiner(combiner); return builder_.Finish(); } -flatbuffers::Offset -CreateEmbeddingLookupSparseOptions( - flatbuffers::FlatBufferBuilder &_fbb, - const EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateEmbeddingLookupSparseOptions(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct GatherOptionsT : public flatbuffers::NativeTable { typedef GatherOptions TableType; int32_t axis; - GatherOptionsT() : axis(0) {} + GatherOptionsT() + : axis(0) { + } }; struct GatherOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef GatherOptionsT NativeTableType; - enum { VT_AXIS = 4 }; - int32_t axis() const { return GetField(VT_AXIS, 0); } + enum { + VT_AXIS = 4 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_AXIS) && verifier.EndTable(); + VerifyField(verifier, VT_AXIS) && + verifier.EndTable(); } - GatherOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - GatherOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + GatherOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GatherOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct GatherOptionsBuilder { @@ -3300,7 +3195,7 @@ struct GatherOptionsBuilder { fbb_.AddElement(GatherOptions::VT_AXIS, axis, 0); } explicit GatherOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } GatherOptionsBuilder &operator=(const GatherOptionsBuilder &); @@ -3312,50 +3207,37 @@ struct GatherOptionsBuilder { }; inline flatbuffers::Offset CreateGatherOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0) { + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0) { GatherOptionsBuilder builder_(_fbb); builder_.add_axis(axis); return builder_.Finish(); } -flatbuffers::Offset CreateGatherOptions( - flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateGatherOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct TransposeOptionsT : public flatbuffers::NativeTable { typedef TransposeOptions TableType; - std::vector perm; - TransposeOptionsT() {} + TransposeOptionsT() { + } }; struct TransposeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef TransposeOptionsT NativeTableType; - enum { VT_PERM = 4 }; - const flatbuffers::Vector *perm() const { - return GetPointer *>(VT_PERM); - } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PERM) && - verifier.Verify(perm()) && verifier.EndTable(); + return VerifyTableStart(verifier) && + verifier.EndTable(); } - TransposeOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - TransposeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + TransposeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TransposeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct TransposeOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_perm(flatbuffers::Offset> perm) { - fbb_.AddOffset(TransposeOptions::VT_PERM, perm); - } explicit TransposeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } TransposeOptionsBuilder &operator=(const TransposeOptionsBuilder &); @@ -3367,65 +3249,87 @@ struct TransposeOptionsBuilder { }; inline flatbuffers::Offset CreateTransposeOptions( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset> perm = 0) { + flatbuffers::FlatBufferBuilder &_fbb) { TransposeOptionsBuilder builder_(_fbb); - builder_.add_perm(perm); return builder_.Finish(); } -inline flatbuffers::Offset CreateTransposeOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *perm = nullptr) { - return tflite::CreateTransposeOptions( - _fbb, perm ? _fbb.CreateVector(*perm) : 0); +flatbuffers::Offset CreateTransposeOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ExpOptionsT : public flatbuffers::NativeTable { + typedef ExpOptions TableType; + ExpOptionsT() { + } +}; + +struct ExpOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ExpOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ExpOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExpOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExpOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit ExpOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ExpOptionsBuilder &operator=(const ExpOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateExpOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + ExpOptionsBuilder builder_(_fbb); + return builder_.Finish(); } -flatbuffers::Offset CreateTransposeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct MeanOptionsT : public flatbuffers::NativeTable { typedef MeanOptions TableType; - std::vector axis; bool keep_dims; - MeanOptionsT() : keep_dims(false) {} + MeanOptionsT() + : keep_dims(false) { + } }; struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef MeanOptionsT NativeTableType; - enum { VT_AXIS = 4, VT_KEEP_DIMS = 6 }; - const flatbuffers::Vector *axis() const { - return GetPointer *>(VT_AXIS); + enum { + VT_KEEP_DIMS = 4 + }; + bool keep_dims() const { + return GetField(VT_KEEP_DIMS, 0) != 0; } - bool keep_dims() const { return GetField(VT_KEEP_DIMS, 0) != 0; } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_AXIS) && - verifier.Verify(axis()) && - VerifyField(verifier, VT_KEEP_DIMS) && verifier.EndTable(); + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_KEEP_DIMS) && + verifier.EndTable(); } - MeanOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - MeanOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + MeanOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct MeanOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_axis(flatbuffers::Offset> axis) { - fbb_.AddOffset(MeanOptions::VT_AXIS, axis); - } void add_keep_dims(bool keep_dims) { - fbb_.AddElement(MeanOptions::VT_KEEP_DIMS, - static_cast(keep_dims), 0); + fbb_.AddElement(MeanOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); } explicit MeanOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } MeanOptionsBuilder &operator=(const MeanOptionsBuilder &); @@ -3438,61 +3342,48 @@ struct MeanOptionsBuilder { inline flatbuffers::Offset CreateMeanOptions( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset> axis = 0, bool keep_dims = false) { MeanOptionsBuilder builder_(_fbb); - builder_.add_axis(axis); builder_.add_keep_dims(keep_dims); return builder_.Finish(); } -inline flatbuffers::Offset CreateMeanOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *axis = nullptr, bool keep_dims = false) { - return tflite::CreateMeanOptions( - _fbb, axis ? _fbb.CreateVector(*axis) : 0, keep_dims); -} - -flatbuffers::Offset CreateMeanOptions( - flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SqueezeOptionsT : public flatbuffers::NativeTable { typedef SqueezeOptions TableType; std::vector squeeze_dims; - SqueezeOptionsT() {} + SqueezeOptionsT() { + } }; struct SqueezeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SqueezeOptionsT NativeTableType; - enum { VT_SQUEEZE_DIMS = 4 }; + enum { + VT_SQUEEZE_DIMS = 4 + }; const flatbuffers::Vector *squeeze_dims() const { return GetPointer *>(VT_SQUEEZE_DIMS); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SQUEEZE_DIMS) && - verifier.Verify(squeeze_dims()) && verifier.EndTable(); + verifier.Verify(squeeze_dims()) && + verifier.EndTable(); } - SqueezeOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SqueezeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SqueezeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SqueezeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SqueezeOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_squeeze_dims( - flatbuffers::Offset> squeeze_dims) { + void add_squeeze_dims(flatbuffers::Offset> squeeze_dims) { fbb_.AddOffset(SqueezeOptions::VT_SQUEEZE_DIMS, squeeze_dims); } explicit SqueezeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SqueezeOptionsBuilder &operator=(const SqueezeOptionsBuilder &); @@ -3515,12 +3406,65 @@ inline flatbuffers::Offset CreateSqueezeOptionsDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *squeeze_dims = nullptr) { return tflite::CreateSqueezeOptions( - _fbb, squeeze_dims ? _fbb.CreateVector(*squeeze_dims) : 0); + _fbb, + squeeze_dims ? _fbb.CreateVector(*squeeze_dims) : 0); +} + +flatbuffers::Offset CreateSqueezeOptions(flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SplitOptionsT : public flatbuffers::NativeTable { + typedef SplitOptions TableType; + int32_t num_splits; + SplitOptionsT() + : num_splits(0) { + } +}; + +struct SplitOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SplitOptionsT NativeTableType; + enum { + VT_NUM_SPLITS = 4 + }; + int32_t num_splits() const { + return GetField(VT_NUM_SPLITS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NUM_SPLITS) && + verifier.EndTable(); + } + SplitOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SplitOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SplitOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_num_splits(int32_t num_splits) { + fbb_.AddElement(SplitOptions::VT_NUM_SPLITS, num_splits, 0); + } + explicit SplitOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SplitOptionsBuilder &operator=(const SplitOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSplitOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_splits = 0) { + SplitOptionsBuilder builder_(_fbb); + builder_.add_num_splits(num_splits); + return builder_.Finish(); } -flatbuffers::Offset CreateSqueezeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSplitOptions(flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct StridedSliceOptionsT : public flatbuffers::NativeTable { typedef StridedSliceOptions TableType; @@ -3534,11 +3478,11 @@ struct StridedSliceOptionsT : public flatbuffers::NativeTable { end_mask(0), ellipsis_mask(0), new_axis_mask(0), - shrink_axis_mask(0) {} + shrink_axis_mask(0) { + } }; -struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef StridedSliceOptionsT NativeTableType; enum { VT_BEGIN_MASK = 4, @@ -3547,8 +3491,12 @@ struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS VT_NEW_AXIS_MASK = 10, VT_SHRINK_AXIS_MASK = 12 }; - int32_t begin_mask() const { return GetField(VT_BEGIN_MASK, 0); } - int32_t end_mask() const { return GetField(VT_END_MASK, 0); } + int32_t begin_mask() const { + return GetField(VT_BEGIN_MASK, 0); + } + int32_t end_mask() const { + return GetField(VT_END_MASK, 0); + } int32_t ellipsis_mask() const { return GetField(VT_ELLIPSIS_MASK, 0); } @@ -3567,14 +3515,9 @@ struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS VerifyField(verifier, VT_SHRINK_AXIS_MASK) && verifier.EndTable(); } - StridedSliceOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - StridedSliceOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + StridedSliceOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StridedSliceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct StridedSliceOptionsBuilder { @@ -3587,19 +3530,16 @@ struct StridedSliceOptionsBuilder { fbb_.AddElement(StridedSliceOptions::VT_END_MASK, end_mask, 0); } void add_ellipsis_mask(int32_t ellipsis_mask) { - fbb_.AddElement(StridedSliceOptions::VT_ELLIPSIS_MASK, - ellipsis_mask, 0); + fbb_.AddElement(StridedSliceOptions::VT_ELLIPSIS_MASK, ellipsis_mask, 0); } void add_new_axis_mask(int32_t new_axis_mask) { - fbb_.AddElement(StridedSliceOptions::VT_NEW_AXIS_MASK, - new_axis_mask, 0); + fbb_.AddElement(StridedSliceOptions::VT_NEW_AXIS_MASK, new_axis_mask, 0); } void add_shrink_axis_mask(int32_t shrink_axis_mask) { - fbb_.AddElement(StridedSliceOptions::VT_SHRINK_AXIS_MASK, - shrink_axis_mask, 0); + fbb_.AddElement(StridedSliceOptions::VT_SHRINK_AXIS_MASK, shrink_axis_mask, 0); } explicit StridedSliceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } StridedSliceOptionsBuilder &operator=(const StridedSliceOptionsBuilder &); @@ -3611,8 +3551,11 @@ struct StridedSliceOptionsBuilder { }; inline flatbuffers::Offset CreateStridedSliceOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t begin_mask = 0, - int32_t end_mask = 0, int32_t ellipsis_mask = 0, int32_t new_axis_mask = 0, + flatbuffers::FlatBufferBuilder &_fbb, + int32_t begin_mask = 0, + int32_t end_mask = 0, + int32_t ellipsis_mask = 0, + int32_t new_axis_mask = 0, int32_t shrink_axis_mask = 0) { StridedSliceOptionsBuilder builder_(_fbb); builder_.add_shrink_axis_mask(shrink_axis_mask); @@ -3623,20 +3566,23 @@ inline flatbuffers::Offset CreateStridedSliceOptions( return builder_.Finish(); } -flatbuffers::Offset CreateStridedSliceOptions( - flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateStridedSliceOptions(flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; std::string custom_code; - OperatorCodeT() : builtin_code(BuiltinOperator_ADD) {} + OperatorCodeT() + : builtin_code(BuiltinOperator_ADD) { + } }; struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef OperatorCodeT NativeTableType; - enum { VT_BUILTIN_CODE = 4, VT_CUSTOM_CODE = 6 }; + enum { + VT_BUILTIN_CODE = 4, + VT_CUSTOM_CODE = 6 + }; BuiltinOperator builtin_code() const { return static_cast(GetField(VT_BUILTIN_CODE, 0)); } @@ -3647,30 +3593,25 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyField(verifier, VT_BUILTIN_CODE) && VerifyOffset(verifier, VT_CUSTOM_CODE) && - verifier.Verify(custom_code()) && verifier.EndTable(); + verifier.Verify(custom_code()) && + verifier.EndTable(); } - OperatorCodeT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - OperatorCodeT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + OperatorCodeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OperatorCodeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct OperatorCodeBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_builtin_code(BuiltinOperator builtin_code) { - fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, - static_cast(builtin_code), 0); + fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, static_cast(builtin_code), 0); } void add_custom_code(flatbuffers::Offset custom_code) { fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code); } explicit OperatorCodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } OperatorCodeBuilder &operator=(const OperatorCodeBuilder &); @@ -3696,12 +3637,12 @@ inline flatbuffers::Offset CreateOperatorCodeDirect( BuiltinOperator builtin_code = BuiltinOperator_ADD, const char *custom_code = nullptr) { return tflite::CreateOperatorCode( - _fbb, builtin_code, custom_code ? _fbb.CreateString(custom_code) : 0); + _fbb, + builtin_code, + custom_code ? _fbb.CreateString(custom_code) : 0); } -flatbuffers::Offset CreateOperatorCode( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct OperatorT : public flatbuffers::NativeTable { typedef Operator TableType; @@ -3713,7 +3654,8 @@ struct OperatorT : public flatbuffers::NativeTable { CustomOptionsFormat custom_options_format; OperatorT() : opcode_index(0), - custom_options_format(CustomOptionsFormat_FLEXBUFFERS) {} + custom_options_format(CustomOptionsFormat_FLEXBUFFERS) { + } }; struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -3737,398 +3679,283 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer *>(VT_OUTPUTS); } BuiltinOptions builtin_options_type() const { - return static_cast( - GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); + return static_cast(GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); } const void *builtin_options() const { return GetPointer(VT_BUILTIN_OPTIONS); } - template - const T *builtin_options_as() const; + template const T *builtin_options_as() const; const Conv2DOptions *builtin_options_as_Conv2DOptions() const { - return builtin_options_type() == BuiltinOptions_Conv2DOptions - ? static_cast(builtin_options()) - : nullptr; - } - const DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() - const { - return builtin_options_type() == BuiltinOptions_DepthwiseConv2DOptions - ? static_cast(builtin_options()) - : nullptr; - } - const ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() - const { - return builtin_options_type() == BuiltinOptions_ConcatEmbeddingsOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_Conv2DOptions ? static_cast(builtin_options()) : nullptr; + } + const DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() const { + return builtin_options_type() == BuiltinOptions_DepthwiseConv2DOptions ? static_cast(builtin_options()) : nullptr; + } + const ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() const { + return builtin_options_type() == BuiltinOptions_ConcatEmbeddingsOptions ? static_cast(builtin_options()) : nullptr; } const LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const { - return builtin_options_type() == BuiltinOptions_LSHProjectionOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_LSHProjectionOptions ? static_cast(builtin_options()) : nullptr; } const Pool2DOptions *builtin_options_as_Pool2DOptions() const { - return builtin_options_type() == BuiltinOptions_Pool2DOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_Pool2DOptions ? static_cast(builtin_options()) : nullptr; } const SVDFOptions *builtin_options_as_SVDFOptions() const { - return builtin_options_type() == BuiltinOptions_SVDFOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SVDFOptions ? static_cast(builtin_options()) : nullptr; } const RNNOptions *builtin_options_as_RNNOptions() const { - return builtin_options_type() == BuiltinOptions_RNNOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_RNNOptions ? static_cast(builtin_options()) : nullptr; } - const FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() - const { - return builtin_options_type() == BuiltinOptions_FullyConnectedOptions - ? static_cast(builtin_options()) - : nullptr; + const FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() const { + return builtin_options_type() == BuiltinOptions_FullyConnectedOptions ? static_cast(builtin_options()) : nullptr; } const SoftmaxOptions *builtin_options_as_SoftmaxOptions() const { - return builtin_options_type() == BuiltinOptions_SoftmaxOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SoftmaxOptions ? static_cast(builtin_options()) : nullptr; } const ConcatenationOptions *builtin_options_as_ConcatenationOptions() const { - return builtin_options_type() == BuiltinOptions_ConcatenationOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_ConcatenationOptions ? static_cast(builtin_options()) : nullptr; } const AddOptions *builtin_options_as_AddOptions() const { - return builtin_options_type() == BuiltinOptions_AddOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_AddOptions ? static_cast(builtin_options()) : nullptr; } const L2NormOptions *builtin_options_as_L2NormOptions() const { - return builtin_options_type() == BuiltinOptions_L2NormOptions - ? static_cast(builtin_options()) - : nullptr; - } - const LocalResponseNormalizationOptions * - builtin_options_as_LocalResponseNormalizationOptions() const { - return builtin_options_type() == - BuiltinOptions_LocalResponseNormalizationOptions - ? static_cast( - builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_L2NormOptions ? static_cast(builtin_options()) : nullptr; + } + const LocalResponseNormalizationOptions *builtin_options_as_LocalResponseNormalizationOptions() const { + return builtin_options_type() == BuiltinOptions_LocalResponseNormalizationOptions ? static_cast(builtin_options()) : nullptr; } const LSTMOptions *builtin_options_as_LSTMOptions() const { - return builtin_options_type() == BuiltinOptions_LSTMOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_LSTMOptions ? static_cast(builtin_options()) : nullptr; } - const ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() - const { - return builtin_options_type() == BuiltinOptions_ResizeBilinearOptions - ? static_cast(builtin_options()) - : nullptr; + const ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() const { + return builtin_options_type() == BuiltinOptions_ResizeBilinearOptions ? static_cast(builtin_options()) : nullptr; } const CallOptions *builtin_options_as_CallOptions() const { - return builtin_options_type() == BuiltinOptions_CallOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_CallOptions ? static_cast(builtin_options()) : nullptr; } const ReshapeOptions *builtin_options_as_ReshapeOptions() const { - return builtin_options_type() == BuiltinOptions_ReshapeOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_ReshapeOptions ? static_cast(builtin_options()) : nullptr; } const SkipGramOptions *builtin_options_as_SkipGramOptions() const { - return builtin_options_type() == BuiltinOptions_SkipGramOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SkipGramOptions ? static_cast(builtin_options()) : nullptr; } const SpaceToDepthOptions *builtin_options_as_SpaceToDepthOptions() const { - return builtin_options_type() == BuiltinOptions_SpaceToDepthOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SpaceToDepthOptions ? static_cast(builtin_options()) : nullptr; } - const EmbeddingLookupSparseOptions * - builtin_options_as_EmbeddingLookupSparseOptions() const { - return builtin_options_type() == BuiltinOptions_EmbeddingLookupSparseOptions - ? static_cast( - builtin_options()) - : nullptr; + const EmbeddingLookupSparseOptions *builtin_options_as_EmbeddingLookupSparseOptions() const { + return builtin_options_type() == BuiltinOptions_EmbeddingLookupSparseOptions ? static_cast(builtin_options()) : nullptr; } const MulOptions *builtin_options_as_MulOptions() const { - return builtin_options_type() == BuiltinOptions_MulOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_MulOptions ? static_cast(builtin_options()) : nullptr; } const PadOptions *builtin_options_as_PadOptions() const { - return builtin_options_type() == BuiltinOptions_PadOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_PadOptions ? static_cast(builtin_options()) : nullptr; } const GatherOptions *builtin_options_as_GatherOptions() const { - return builtin_options_type() == BuiltinOptions_GatherOptions - ? static_cast(builtin_options()) - : nullptr; - } - const BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions() - const { - return builtin_options_type() == BuiltinOptions_BatchToSpaceNDOptions - ? static_cast(builtin_options()) - : nullptr; - } - const SpaceToBatchNDOptions *builtin_options_as_SpaceToBatchNDOptions() - const { - return builtin_options_type() == BuiltinOptions_SpaceToBatchNDOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_GatherOptions ? static_cast(builtin_options()) : nullptr; + } + const BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions() const { + return builtin_options_type() == BuiltinOptions_BatchToSpaceNDOptions ? static_cast(builtin_options()) : nullptr; + } + const SpaceToBatchNDOptions *builtin_options_as_SpaceToBatchNDOptions() const { + return builtin_options_type() == BuiltinOptions_SpaceToBatchNDOptions ? static_cast(builtin_options()) : nullptr; } const TransposeOptions *builtin_options_as_TransposeOptions() const { - return builtin_options_type() == BuiltinOptions_TransposeOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_TransposeOptions ? static_cast(builtin_options()) : nullptr; } const MeanOptions *builtin_options_as_MeanOptions() const { - return builtin_options_type() == BuiltinOptions_MeanOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_MeanOptions ? static_cast(builtin_options()) : nullptr; } const SubOptions *builtin_options_as_SubOptions() const { - return builtin_options_type() == BuiltinOptions_SubOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SubOptions ? static_cast(builtin_options()) : nullptr; } const DivOptions *builtin_options_as_DivOptions() const { - return builtin_options_type() == BuiltinOptions_DivOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_DivOptions ? static_cast(builtin_options()) : nullptr; } const SqueezeOptions *builtin_options_as_SqueezeOptions() const { - return builtin_options_type() == BuiltinOptions_SqueezeOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SqueezeOptions ? static_cast(builtin_options()) : nullptr; } const SequenceRNNOptions *builtin_options_as_SequenceRNNOptions() const { - return builtin_options_type() == BuiltinOptions_SequenceRNNOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SequenceRNNOptions ? static_cast(builtin_options()) : nullptr; } const StridedSliceOptions *builtin_options_as_StridedSliceOptions() const { - return builtin_options_type() == BuiltinOptions_StridedSliceOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_StridedSliceOptions ? static_cast(builtin_options()) : nullptr; + } + const ExpOptions *builtin_options_as_ExpOptions() const { + return builtin_options_type() == BuiltinOptions_ExpOptions ? static_cast(builtin_options()) : nullptr; + } + const TopKV2Options *builtin_options_as_TopKV2Options() const { + return builtin_options_type() == BuiltinOptions_TopKV2Options ? static_cast(builtin_options()) : nullptr; + } + const SplitOptions *builtin_options_as_SplitOptions() const { + return builtin_options_type() == BuiltinOptions_SplitOptions ? static_cast(builtin_options()) : nullptr; } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } CustomOptionsFormat custom_options_format() const { - return static_cast( - GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); + return static_cast(GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_OPCODE_INDEX) && - VerifyOffset(verifier, VT_INPUTS) && verifier.Verify(inputs()) && - VerifyOffset(verifier, VT_OUTPUTS) && verifier.Verify(outputs()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.Verify(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.Verify(outputs()) && VerifyField(verifier, VT_BUILTIN_OPTIONS_TYPE) && VerifyOffset(verifier, VT_BUILTIN_OPTIONS) && - VerifyBuiltinOptions(verifier, builtin_options(), - builtin_options_type()) && + VerifyBuiltinOptions(verifier, builtin_options(), builtin_options_type()) && VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && verifier.Verify(custom_options()) && VerifyField(verifier, VT_CUSTOM_OPTIONS_FORMAT) && verifier.EndTable(); } - OperatorT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - OperatorT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OperatorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -template <> -inline const Conv2DOptions *Operator::builtin_options_as() - const { +template<> inline const Conv2DOptions *Operator::builtin_options_as() const { return builtin_options_as_Conv2DOptions(); } -template <> -inline const DepthwiseConv2DOptions * -Operator::builtin_options_as() const { +template<> inline const DepthwiseConv2DOptions *Operator::builtin_options_as() const { return builtin_options_as_DepthwiseConv2DOptions(); } -template <> -inline const ConcatEmbeddingsOptions * -Operator::builtin_options_as() const { +template<> inline const ConcatEmbeddingsOptions *Operator::builtin_options_as() const { return builtin_options_as_ConcatEmbeddingsOptions(); } -template <> -inline const LSHProjectionOptions * -Operator::builtin_options_as() const { +template<> inline const LSHProjectionOptions *Operator::builtin_options_as() const { return builtin_options_as_LSHProjectionOptions(); } -template <> -inline const Pool2DOptions *Operator::builtin_options_as() - const { +template<> inline const Pool2DOptions *Operator::builtin_options_as() const { return builtin_options_as_Pool2DOptions(); } -template <> -inline const SVDFOptions *Operator::builtin_options_as() const { +template<> inline const SVDFOptions *Operator::builtin_options_as() const { return builtin_options_as_SVDFOptions(); } -template <> -inline const RNNOptions *Operator::builtin_options_as() const { +template<> inline const RNNOptions *Operator::builtin_options_as() const { return builtin_options_as_RNNOptions(); } -template <> -inline const FullyConnectedOptions * -Operator::builtin_options_as() const { +template<> inline const FullyConnectedOptions *Operator::builtin_options_as() const { return builtin_options_as_FullyConnectedOptions(); } -template <> -inline const SoftmaxOptions *Operator::builtin_options_as() - const { +template<> inline const SoftmaxOptions *Operator::builtin_options_as() const { return builtin_options_as_SoftmaxOptions(); } -template <> -inline const ConcatenationOptions * -Operator::builtin_options_as() const { +template<> inline const ConcatenationOptions *Operator::builtin_options_as() const { return builtin_options_as_ConcatenationOptions(); } -template <> -inline const AddOptions *Operator::builtin_options_as() const { +template<> inline const AddOptions *Operator::builtin_options_as() const { return builtin_options_as_AddOptions(); } -template <> -inline const L2NormOptions *Operator::builtin_options_as() - const { +template<> inline const L2NormOptions *Operator::builtin_options_as() const { return builtin_options_as_L2NormOptions(); } -template <> -inline const LocalResponseNormalizationOptions * -Operator::builtin_options_as() const { +template<> inline const LocalResponseNormalizationOptions *Operator::builtin_options_as() const { return builtin_options_as_LocalResponseNormalizationOptions(); } -template <> -inline const LSTMOptions *Operator::builtin_options_as() const { +template<> inline const LSTMOptions *Operator::builtin_options_as() const { return builtin_options_as_LSTMOptions(); } -template <> -inline const ResizeBilinearOptions * -Operator::builtin_options_as() const { +template<> inline const ResizeBilinearOptions *Operator::builtin_options_as() const { return builtin_options_as_ResizeBilinearOptions(); } -template <> -inline const CallOptions *Operator::builtin_options_as() const { +template<> inline const CallOptions *Operator::builtin_options_as() const { return builtin_options_as_CallOptions(); } -template <> -inline const ReshapeOptions *Operator::builtin_options_as() - const { +template<> inline const ReshapeOptions *Operator::builtin_options_as() const { return builtin_options_as_ReshapeOptions(); } -template <> -inline const SkipGramOptions *Operator::builtin_options_as() - const { +template<> inline const SkipGramOptions *Operator::builtin_options_as() const { return builtin_options_as_SkipGramOptions(); } -template <> -inline const SpaceToDepthOptions * -Operator::builtin_options_as() const { +template<> inline const SpaceToDepthOptions *Operator::builtin_options_as() const { return builtin_options_as_SpaceToDepthOptions(); } -template <> -inline const EmbeddingLookupSparseOptions * -Operator::builtin_options_as() const { +template<> inline const EmbeddingLookupSparseOptions *Operator::builtin_options_as() const { return builtin_options_as_EmbeddingLookupSparseOptions(); } -template <> -inline const MulOptions *Operator::builtin_options_as() const { +template<> inline const MulOptions *Operator::builtin_options_as() const { return builtin_options_as_MulOptions(); } -template <> -inline const PadOptions *Operator::builtin_options_as() const { +template<> inline const PadOptions *Operator::builtin_options_as() const { return builtin_options_as_PadOptions(); } -template <> -inline const GatherOptions *Operator::builtin_options_as() - const { +template<> inline const GatherOptions *Operator::builtin_options_as() const { return builtin_options_as_GatherOptions(); } -template <> -inline const BatchToSpaceNDOptions * -Operator::builtin_options_as() const { +template<> inline const BatchToSpaceNDOptions *Operator::builtin_options_as() const { return builtin_options_as_BatchToSpaceNDOptions(); } -template <> -inline const SpaceToBatchNDOptions * -Operator::builtin_options_as() const { +template<> inline const SpaceToBatchNDOptions *Operator::builtin_options_as() const { return builtin_options_as_SpaceToBatchNDOptions(); } -template <> -inline const TransposeOptions *Operator::builtin_options_as() - const { +template<> inline const TransposeOptions *Operator::builtin_options_as() const { return builtin_options_as_TransposeOptions(); } -template <> -inline const MeanOptions *Operator::builtin_options_as() const { +template<> inline const MeanOptions *Operator::builtin_options_as() const { return builtin_options_as_MeanOptions(); } -template <> -inline const SubOptions *Operator::builtin_options_as() const { +template<> inline const SubOptions *Operator::builtin_options_as() const { return builtin_options_as_SubOptions(); } -template <> -inline const DivOptions *Operator::builtin_options_as() const { +template<> inline const DivOptions *Operator::builtin_options_as() const { return builtin_options_as_DivOptions(); } -template <> -inline const SqueezeOptions *Operator::builtin_options_as() - const { +template<> inline const SqueezeOptions *Operator::builtin_options_as() const { return builtin_options_as_SqueezeOptions(); } -template <> -inline const SequenceRNNOptions * -Operator::builtin_options_as() const { +template<> inline const SequenceRNNOptions *Operator::builtin_options_as() const { return builtin_options_as_SequenceRNNOptions(); } -template <> -inline const StridedSliceOptions * -Operator::builtin_options_as() const { +template<> inline const StridedSliceOptions *Operator::builtin_options_as() const { return builtin_options_as_StridedSliceOptions(); } +template<> inline const ExpOptions *Operator::builtin_options_as() const { + return builtin_options_as_ExpOptions(); +} + +template<> inline const TopKV2Options *Operator::builtin_options_as() const { + return builtin_options_as_TopKV2Options(); +} + +template<> inline const SplitOptions *Operator::builtin_options_as() const { + return builtin_options_as_SplitOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -4142,21 +3969,19 @@ struct OperatorBuilder { fbb_.AddOffset(Operator::VT_OUTPUTS, outputs); } void add_builtin_options_type(BuiltinOptions builtin_options_type) { - fbb_.AddElement(Operator::VT_BUILTIN_OPTIONS_TYPE, - static_cast(builtin_options_type), 0); + fbb_.AddElement(Operator::VT_BUILTIN_OPTIONS_TYPE, static_cast(builtin_options_type), 0); } void add_builtin_options(flatbuffers::Offset builtin_options) { fbb_.AddOffset(Operator::VT_BUILTIN_OPTIONS, builtin_options); } - void add_custom_options( - flatbuffers::Offset> custom_options) { + void add_custom_options(flatbuffers::Offset> custom_options) { fbb_.AddOffset(Operator::VT_CUSTOM_OPTIONS, custom_options); } void add_custom_options_format(CustomOptionsFormat custom_options_format) { - fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, - static_cast(custom_options_format), 0); + fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast(custom_options_format), 0); } - explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } OperatorBuilder &operator=(const OperatorBuilder &); @@ -4168,14 +3993,14 @@ struct OperatorBuilder { }; inline flatbuffers::Offset CreateOperator( - flatbuffers::FlatBufferBuilder &_fbb, uint32_t opcode_index = 0, + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t opcode_index = 0, flatbuffers::Offset> inputs = 0, flatbuffers::Offset> outputs = 0, BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, flatbuffers::Offset> custom_options = 0, - CustomOptionsFormat custom_options_format = - CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { OperatorBuilder builder_(_fbb); builder_.add_custom_options(custom_options); builder_.add_builtin_options(builtin_options); @@ -4188,25 +4013,26 @@ inline flatbuffers::Offset CreateOperator( } inline flatbuffers::Offset CreateOperatorDirect( - flatbuffers::FlatBufferBuilder &_fbb, uint32_t opcode_index = 0, + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t opcode_index = 0, const std::vector *inputs = nullptr, const std::vector *outputs = nullptr, BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, const std::vector *custom_options = nullptr, - CustomOptionsFormat custom_options_format = - CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { return tflite::CreateOperator( - _fbb, opcode_index, inputs ? _fbb.CreateVector(*inputs) : 0, - outputs ? _fbb.CreateVector(*outputs) : 0, builtin_options_type, + _fbb, + opcode_index, + inputs ? _fbb.CreateVector(*inputs) : 0, + outputs ? _fbb.CreateVector(*outputs) : 0, + builtin_options_type, builtin_options, custom_options ? _fbb.CreateVector(*custom_options) : 0, custom_options_format); } -flatbuffers::Offset CreateOperator( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SubGraphT : public flatbuffers::NativeTable { typedef SubGraph TableType; @@ -4215,7 +4041,8 @@ struct SubGraphT : public flatbuffers::NativeTable { std::vector outputs; std::vector> operators; std::string name; - SubGraphT() {} + SubGraphT() { + } }; struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -4228,8 +4055,7 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_NAME = 12 }; const flatbuffers::Vector> *tensors() const { - return GetPointer> *>( - VT_TENSORS); + return GetPointer> *>(VT_TENSORS); } const flatbuffers::Vector *inputs() const { return GetPointer *>(VT_INPUTS); @@ -4238,41 +4064,36 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer *>(VT_OUTPUTS); } const flatbuffers::Vector> *operators() const { - return GetPointer< - const flatbuffers::Vector> *>( - VT_OPERATORS); + return GetPointer> *>(VT_OPERATORS); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_TENSORS) && + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TENSORS) && verifier.Verify(tensors()) && verifier.VerifyVectorOfTables(tensors()) && - VerifyOffset(verifier, VT_INPUTS) && verifier.Verify(inputs()) && - VerifyOffset(verifier, VT_OUTPUTS) && verifier.Verify(outputs()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.Verify(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.Verify(outputs()) && VerifyOffset(verifier, VT_OPERATORS) && verifier.Verify(operators()) && verifier.VerifyVectorOfTables(operators()) && - VerifyOffset(verifier, VT_NAME) && verifier.Verify(name()) && + VerifyOffset(verifier, VT_NAME) && + verifier.Verify(name()) && verifier.EndTable(); } - SubGraphT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SubGraphT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SubGraphT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SubGraphT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SubGraphBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_tensors( - flatbuffers::Offset>> - tensors) { + void add_tensors(flatbuffers::Offset>> tensors) { fbb_.AddOffset(SubGraph::VT_TENSORS, tensors); } void add_inputs(flatbuffers::Offset> inputs) { @@ -4281,15 +4102,14 @@ struct SubGraphBuilder { void add_outputs(flatbuffers::Offset> outputs) { fbb_.AddOffset(SubGraph::VT_OUTPUTS, outputs); } - void add_operators( - flatbuffers::Offset>> - operators) { + void add_operators(flatbuffers::Offset>> operators) { fbb_.AddOffset(SubGraph::VT_OPERATORS, operators); } void add_name(flatbuffers::Offset name) { fbb_.AddOffset(SubGraph::VT_NAME, name); } - explicit SubGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + explicit SubGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SubGraphBuilder &operator=(const SubGraphBuilder &); @@ -4302,12 +4122,10 @@ struct SubGraphBuilder { inline flatbuffers::Offset CreateSubGraph( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset>> - tensors = 0, + flatbuffers::Offset>> tensors = 0, flatbuffers::Offset> inputs = 0, flatbuffers::Offset> outputs = 0, - flatbuffers::Offset>> - operators = 0, + flatbuffers::Offset>> operators = 0, flatbuffers::Offset name = 0) { SubGraphBuilder builder_(_fbb); builder_.add_name(name); @@ -4330,38 +4148,36 @@ inline flatbuffers::Offset CreateSubGraphDirect( tensors ? _fbb.CreateVector>(*tensors) : 0, inputs ? _fbb.CreateVector(*inputs) : 0, outputs ? _fbb.CreateVector(*outputs) : 0, - operators ? _fbb.CreateVector>(*operators) - : 0, + operators ? _fbb.CreateVector>(*operators) : 0, name ? _fbb.CreateString(name) : 0); } -flatbuffers::Offset CreateSubGraph( - flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSubGraph(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct BufferT : public flatbuffers::NativeTable { typedef Buffer TableType; std::vector data; - BufferT() {} + BufferT() { + } }; struct Buffer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef BufferT NativeTableType; - enum { VT_DATA = 4 }; + enum { + VT_DATA = 4 + }; const flatbuffers::Vector *data() const { return GetPointer *>(VT_DATA); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DATA) && - verifier.Verify(data()) && verifier.EndTable(); + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.Verify(data()) && + verifier.EndTable(); } - BufferT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver = - nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + BufferT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct BufferBuilder { @@ -4370,7 +4186,8 @@ struct BufferBuilder { void add_data(flatbuffers::Offset> data) { fbb_.AddOffset(Buffer::VT_DATA, data); } - explicit BufferBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + explicit BufferBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } BufferBuilder &operator=(const BufferBuilder &); @@ -4392,13 +4209,12 @@ inline flatbuffers::Offset CreateBuffer( inline flatbuffers::Offset CreateBufferDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *data = nullptr) { - return tflite::CreateBuffer(_fbb, - data ? _fbb.CreateVector(*data) : 0); + return tflite::CreateBuffer( + _fbb, + data ? _fbb.CreateVector(*data) : 0); } -flatbuffers::Offset CreateBuffer( - flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateBuffer(flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct ModelT : public flatbuffers::NativeTable { typedef Model TableType; @@ -4407,7 +4223,9 @@ struct ModelT : public flatbuffers::NativeTable { std::vector> subgraphs; std::string description; std::vector> buffers; - ModelT() : version(0) {} + ModelT() + : version(0) { + } }; struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -4419,24 +4237,20 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_DESCRIPTION = 10, VT_BUFFERS = 12 }; - uint32_t version() const { return GetField(VT_VERSION, 0); } - const flatbuffers::Vector> *operator_codes() - const { - return GetPointer< - const flatbuffers::Vector> *>( - VT_OPERATOR_CODES); + uint32_t version() const { + return GetField(VT_VERSION, 0); + } + const flatbuffers::Vector> *operator_codes() const { + return GetPointer> *>(VT_OPERATOR_CODES); } const flatbuffers::Vector> *subgraphs() const { - return GetPointer< - const flatbuffers::Vector> *>( - VT_SUBGRAPHS); + return GetPointer> *>(VT_SUBGRAPHS); } const flatbuffers::String *description() const { return GetPointer(VT_DESCRIPTION); } const flatbuffers::Vector> *buffers() const { - return GetPointer> *>( - VT_BUFFERS); + return GetPointer> *>(VT_BUFFERS); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4449,16 +4263,14 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVectorOfTables(subgraphs()) && VerifyOffset(verifier, VT_DESCRIPTION) && verifier.Verify(description()) && - VerifyOffset(verifier, VT_BUFFERS) && verifier.Verify(buffers()) && - verifier.VerifyVectorOfTables(buffers()) && verifier.EndTable(); + VerifyOffset(verifier, VT_BUFFERS) && + verifier.Verify(buffers()) && + verifier.VerifyVectorOfTables(buffers()) && + verifier.EndTable(); } - ModelT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver = - nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ModelBuilder { @@ -4467,26 +4279,20 @@ struct ModelBuilder { void add_version(uint32_t version) { fbb_.AddElement(Model::VT_VERSION, version, 0); } - void add_operator_codes( - flatbuffers::Offset< - flatbuffers::Vector>> - operator_codes) { + void add_operator_codes(flatbuffers::Offset>> operator_codes) { fbb_.AddOffset(Model::VT_OPERATOR_CODES, operator_codes); } - void add_subgraphs( - flatbuffers::Offset>> - subgraphs) { + void add_subgraphs(flatbuffers::Offset>> subgraphs) { fbb_.AddOffset(Model::VT_SUBGRAPHS, subgraphs); } void add_description(flatbuffers::Offset description) { fbb_.AddOffset(Model::VT_DESCRIPTION, description); } - void add_buffers( - flatbuffers::Offset>> - buffers) { + void add_buffers(flatbuffers::Offset>> buffers) { fbb_.AddOffset(Model::VT_BUFFERS, buffers); } - explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ModelBuilder &operator=(const ModelBuilder &); @@ -4498,14 +4304,12 @@ struct ModelBuilder { }; inline flatbuffers::Offset CreateModel( - flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, - flatbuffers::Offset>> - operator_codes = 0, - flatbuffers::Offset>> - subgraphs = 0, + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t version = 0, + flatbuffers::Offset>> operator_codes = 0, + flatbuffers::Offset>> subgraphs = 0, flatbuffers::Offset description = 0, - flatbuffers::Offset>> - buffers = 0) { + flatbuffers::Offset>> buffers = 0) { ModelBuilder builder_(_fbb); builder_.add_buffers(buffers); builder_.add_description(description); @@ -4516,2058 +4320,1277 @@ inline flatbuffers::Offset CreateModel( } inline flatbuffers::Offset CreateModelDirect( - flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, - const std::vector> *operator_codes = - nullptr, + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t version = 0, + const std::vector> *operator_codes = nullptr, const std::vector> *subgraphs = nullptr, const char *description = nullptr, const std::vector> *buffers = nullptr) { return tflite::CreateModel( - _fbb, version, - operator_codes ? _fbb.CreateVector>( - *operator_codes) - : 0, - subgraphs ? _fbb.CreateVector>(*subgraphs) - : 0, + _fbb, + version, + operator_codes ? _fbb.CreateVector>(*operator_codes) : 0, + subgraphs ? _fbb.CreateVector>(*subgraphs) : 0, description ? _fbb.CreateString(description) : 0, buffers ? _fbb.CreateVector>(*buffers) : 0); } -flatbuffers::Offset CreateModel( - flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -inline QuantizationParametersT *QuantizationParameters::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline QuantizationParametersT *QuantizationParameters::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new QuantizationParametersT(); UnPackTo(_o, _resolver); return _o; } -inline void QuantizationParameters::UnPackTo( - QuantizationParametersT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void QuantizationParameters::UnPackTo(QuantizationParametersT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = min(); - if (_e) { - _o->min.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->min[_i] = _e->Get(_i); - } - } - }; - { - auto _e = max(); - if (_e) { - _o->max.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->max[_i] = _e->Get(_i); - } - } - }; - { - auto _e = scale(); - if (_e) { - _o->scale.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->scale[_i] = _e->Get(_i); - } - } - }; - { - auto _e = zero_point(); - if (_e) { - _o->zero_point.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->zero_point[_i] = _e->Get(_i); - } - } - }; + { auto _e = min(); if (_e) { _o->min.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->min[_i] = _e->Get(_i); } } }; + { auto _e = max(); if (_e) { _o->max.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->max[_i] = _e->Get(_i); } } }; + { auto _e = scale(); if (_e) { _o->scale.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scale[_i] = _e->Get(_i); } } }; + { auto _e = zero_point(); if (_e) { _o->zero_point.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zero_point[_i] = _e->Get(_i); } } }; } -inline flatbuffers::Offset QuantizationParameters::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset QuantizationParameters::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateQuantizationParameters(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateQuantizationParameters( - flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateQuantizationParameters(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const QuantizationParametersT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const QuantizationParametersT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _min = _o->min.size() ? _fbb.CreateVector(_o->min) : 0; auto _max = _o->max.size() ? _fbb.CreateVector(_o->max) : 0; auto _scale = _o->scale.size() ? _fbb.CreateVector(_o->scale) : 0; - auto _zero_point = - _o->zero_point.size() ? _fbb.CreateVector(_o->zero_point) : 0; - return tflite::CreateQuantizationParameters(_fbb, _min, _max, _scale, - _zero_point); + auto _zero_point = _o->zero_point.size() ? _fbb.CreateVector(_o->zero_point) : 0; + return tflite::CreateQuantizationParameters( + _fbb, + _min, + _max, + _scale, + _zero_point); } -inline TensorT *Tensor::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline TensorT *Tensor::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new TensorT(); UnPackTo(_o, _resolver); return _o; } -inline void Tensor::UnPackTo( - TensorT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = shape(); - if (_e) { - _o->shape.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->shape[_i] = _e->Get(_i); - } - } - }; - { - auto _e = type(); - _o->type = _e; - }; - { - auto _e = buffer(); - _o->buffer = _e; - }; - { - auto _e = name(); - if (_e) _o->name = _e->str(); - }; - { - auto _e = quantization(); - if (_e) - _o->quantization = - std::unique_ptr(_e->UnPack(_resolver)); - }; + { auto _e = shape(); if (_e) { _o->shape.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape[_i] = _e->Get(_i); } } }; + { auto _e = type(); _o->type = _e; }; + { auto _e = buffer(); _o->buffer = _e; }; + { auto _e = name(); if (_e) _o->name = _e->str(); }; + { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr(_e->UnPack(_resolver)); }; } -inline flatbuffers::Offset Tensor::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateTensor(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateTensor( - flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const TensorT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TensorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _shape = _o->shape.size() ? _fbb.CreateVector(_o->shape) : 0; auto _type = _o->type; auto _buffer = _o->buffer; auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); - auto _quantization = _o->quantization - ? CreateQuantizationParameters( - _fbb, _o->quantization.get(), _rehasher) - : 0; - return tflite::CreateTensor(_fbb, _shape, _type, _buffer, _name, - _quantization); + auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0; + return tflite::CreateTensor( + _fbb, + _shape, + _type, + _buffer, + _name, + _quantization); } -inline Conv2DOptionsT *Conv2DOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new Conv2DOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void Conv2DOptions::UnPackTo( - Conv2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void Conv2DOptions::UnPackTo(Conv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = padding(); - _o->padding = _e; - }; - { - auto _e = stride_w(); - _o->stride_w = _e; - }; - { - auto _e = stride_h(); - _o->stride_h = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = padding(); _o->padding = _e; }; + { auto _e = stride_w(); _o->stride_w = _e; }; + { auto _e = stride_h(); _o->stride_h = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset Conv2DOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Conv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateConv2DOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const Conv2DOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Conv2DOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _padding = _o->padding; auto _stride_w = _o->stride_w; auto _stride_h = _o->stride_h; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateConv2DOptions(_fbb, _padding, _stride_w, _stride_h, - _fused_activation_function); + return tflite::CreateConv2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _fused_activation_function); } -inline Pool2DOptionsT *Pool2DOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline Pool2DOptionsT *Pool2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new Pool2DOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void Pool2DOptions::UnPackTo( - Pool2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void Pool2DOptions::UnPackTo(Pool2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = padding(); - _o->padding = _e; - }; - { - auto _e = stride_w(); - _o->stride_w = _e; - }; - { - auto _e = stride_h(); - _o->stride_h = _e; - }; - { - auto _e = filter_width(); - _o->filter_width = _e; - }; - { - auto _e = filter_height(); - _o->filter_height = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = padding(); _o->padding = _e; }; + { auto _e = stride_w(); _o->stride_w = _e; }; + { auto _e = stride_h(); _o->stride_h = _e; }; + { auto _e = filter_width(); _o->filter_width = _e; }; + { auto _e = filter_height(); _o->filter_height = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset Pool2DOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Pool2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreatePool2DOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreatePool2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreatePool2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const Pool2DOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Pool2DOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _padding = _o->padding; auto _stride_w = _o->stride_w; auto _stride_h = _o->stride_h; auto _filter_width = _o->filter_width; auto _filter_height = _o->filter_height; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreatePool2DOptions(_fbb, _padding, _stride_w, _stride_h, - _filter_width, _filter_height, - _fused_activation_function); + return tflite::CreatePool2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _filter_width, + _filter_height, + _fused_activation_function); } -inline DepthwiseConv2DOptionsT *DepthwiseConv2DOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline DepthwiseConv2DOptionsT *DepthwiseConv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new DepthwiseConv2DOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void DepthwiseConv2DOptions::UnPackTo( - DepthwiseConv2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void DepthwiseConv2DOptions::UnPackTo(DepthwiseConv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = padding(); - _o->padding = _e; - }; - { - auto _e = stride_w(); - _o->stride_w = _e; - }; - { - auto _e = stride_h(); - _o->stride_h = _e; - }; - { - auto _e = depth_multiplier(); - _o->depth_multiplier = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = padding(); _o->padding = _e; }; + { auto _e = stride_w(); _o->stride_w = _e; }; + { auto _e = stride_h(); _o->stride_h = _e; }; + { auto _e = depth_multiplier(); _o->depth_multiplier = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset DepthwiseConv2DOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset DepthwiseConv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateDepthwiseConv2DOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateDepthwiseConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateDepthwiseConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const DepthwiseConv2DOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DepthwiseConv2DOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _padding = _o->padding; auto _stride_w = _o->stride_w; auto _stride_h = _o->stride_h; auto _depth_multiplier = _o->depth_multiplier; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateDepthwiseConv2DOptions(_fbb, _padding, _stride_w, - _stride_h, _depth_multiplier, - _fused_activation_function); + return tflite::CreateDepthwiseConv2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _depth_multiplier, + _fused_activation_function); } -inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ConcatEmbeddingsOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void ConcatEmbeddingsOptions::UnPackTo( - ConcatEmbeddingsOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void ConcatEmbeddingsOptions::UnPackTo(ConcatEmbeddingsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = num_channels(); - _o->num_channels = _e; - }; - { - auto _e = num_columns_per_channel(); - if (_e) { - _o->num_columns_per_channel.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->num_columns_per_channel[_i] = _e->Get(_i); - } - } - }; - { - auto _e = embedding_dim_per_channel(); - if (_e) { - _o->embedding_dim_per_channel.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->embedding_dim_per_channel[_i] = _e->Get(_i); - } - } - }; + { auto _e = num_channels(); _o->num_channels = _e; }; + { auto _e = num_columns_per_channel(); if (_e) { _o->num_columns_per_channel.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->num_columns_per_channel[_i] = _e->Get(_i); } } }; + { auto _e = embedding_dim_per_channel(); if (_e) { _o->embedding_dim_per_channel.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_dim_per_channel[_i] = _e->Get(_i); } } }; } -inline flatbuffers::Offset -ConcatEmbeddingsOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset ConcatEmbeddingsOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateConcatEmbeddingsOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset -CreateConcatEmbeddingsOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const ConcatEmbeddingsOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ConcatEmbeddingsOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _num_channels = _o->num_channels; - auto _num_columns_per_channel = - _o->num_columns_per_channel.size() - ? _fbb.CreateVector(_o->num_columns_per_channel) - : 0; - auto _embedding_dim_per_channel = - _o->embedding_dim_per_channel.size() - ? _fbb.CreateVector(_o->embedding_dim_per_channel) - : 0; - return tflite::CreateConcatEmbeddingsOptions(_fbb, _num_channels, - _num_columns_per_channel, - _embedding_dim_per_channel); -} - -inline LSHProjectionOptionsT *LSHProjectionOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { + auto _num_columns_per_channel = _o->num_columns_per_channel.size() ? _fbb.CreateVector(_o->num_columns_per_channel) : 0; + auto _embedding_dim_per_channel = _o->embedding_dim_per_channel.size() ? _fbb.CreateVector(_o->embedding_dim_per_channel) : 0; + return tflite::CreateConcatEmbeddingsOptions( + _fbb, + _num_channels, + _num_columns_per_channel, + _embedding_dim_per_channel); +} + +inline LSHProjectionOptionsT *LSHProjectionOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new LSHProjectionOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void LSHProjectionOptions::UnPackTo( - LSHProjectionOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void LSHProjectionOptions::UnPackTo(LSHProjectionOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = type(); - _o->type = _e; - }; + { auto _e = type(); _o->type = _e; }; } -inline flatbuffers::Offset LSHProjectionOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset LSHProjectionOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateLSHProjectionOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateLSHProjectionOptions( - flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateLSHProjectionOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const LSHProjectionOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LSHProjectionOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _type = _o->type; - return tflite::CreateLSHProjectionOptions(_fbb, _type); + return tflite::CreateLSHProjectionOptions( + _fbb, + _type); } -inline SVDFOptionsT *SVDFOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SVDFOptionsT *SVDFOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SVDFOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SVDFOptions::UnPackTo( - SVDFOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = rank(); - _o->rank = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = rank(); _o->rank = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset SVDFOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSVDFOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSVDFOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SVDFOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _rank = _o->rank; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateSVDFOptions(_fbb, _rank, _fused_activation_function); + return tflite::CreateSVDFOptions( + _fbb, + _rank, + _fused_activation_function); } -inline RNNOptionsT *RNNOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new RNNOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void RNNOptions::UnPackTo( - RNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset RNNOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateRNNOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const RNNOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateRNNOptions(_fbb, _fused_activation_function); + return tflite::CreateRNNOptions( + _fbb, + _fused_activation_function); } -inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SequenceRNNOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SequenceRNNOptions::UnPackTo( - SequenceRNNOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SequenceRNNOptions::UnPackTo(SequenceRNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = time_major(); - _o->time_major = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = time_major(); _o->time_major = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset SequenceRNNOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSequenceRNNOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSequenceRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSequenceRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _time_major = _o->time_major; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateSequenceRNNOptions( + _fbb, + _time_major, + _fused_activation_function); +} + +inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new BidirectionalSequenceRNNOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = time_major(); _o->time_major = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateBidirectionalSequenceRNNOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateBidirectionalSequenceRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SequenceRNNOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BidirectionalSequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _time_major = _o->time_major; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateSequenceRNNOptions(_fbb, _time_major, - _fused_activation_function); + return tflite::CreateBidirectionalSequenceRNNOptions( + _fbb, + _time_major, + _fused_activation_function); } -inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new FullyConnectedOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void FullyConnectedOptions::UnPackTo( - FullyConnectedOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset FullyConnectedOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateFullyConnectedOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateFullyConnectedOptions( - flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateFullyConnectedOptions(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const FullyConnectedOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FullyConnectedOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateFullyConnectedOptions(_fbb, _fused_activation_function); + return tflite::CreateFullyConnectedOptions( + _fbb, + _fused_activation_function); } -inline SoftmaxOptionsT *SoftmaxOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SoftmaxOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SoftmaxOptions::UnPackTo( - SoftmaxOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SoftmaxOptions::UnPackTo(SoftmaxOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = beta(); - _o->beta = _e; - }; + { auto _e = beta(); _o->beta = _e; }; } -inline flatbuffers::Offset SoftmaxOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SoftmaxOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSoftmaxOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSoftmaxOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSoftmaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SoftmaxOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SoftmaxOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _beta = _o->beta; - return tflite::CreateSoftmaxOptions(_fbb, _beta); + return tflite::CreateSoftmaxOptions( + _fbb, + _beta); } -inline ConcatenationOptionsT *ConcatenationOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ConcatenationOptionsT *ConcatenationOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ConcatenationOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void ConcatenationOptions::UnPackTo( - ConcatenationOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void ConcatenationOptions::UnPackTo(ConcatenationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = axis(); - _o->axis = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = axis(); _o->axis = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset ConcatenationOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset ConcatenationOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateConcatenationOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateConcatenationOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateConcatenationOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const ConcatenationOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ConcatenationOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _axis = _o->axis; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateConcatenationOptions(_fbb, _axis, - _fused_activation_function); + return tflite::CreateConcatenationOptions( + _fbb, + _axis, + _fused_activation_function); } -inline AddOptionsT *AddOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline AddOptionsT *AddOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new AddOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void AddOptions::UnPackTo( - AddOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void AddOptions::UnPackTo(AddOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset AddOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset AddOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateAddOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateAddOptions( - flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateAddOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const AddOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const AddOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateAddOptions(_fbb, _fused_activation_function); + return tflite::CreateAddOptions( + _fbb, + _fused_activation_function); } -inline MulOptionsT *MulOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline MulOptionsT *MulOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new MulOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void MulOptions::UnPackTo( - MulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void MulOptions::UnPackTo(MulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset MulOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset MulOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateMulOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateMulOptions( - flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const MulOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MulOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateMulOptions(_fbb, _fused_activation_function); + return tflite::CreateMulOptions( + _fbb, + _fused_activation_function); } -inline L2NormOptionsT *L2NormOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline L2NormOptionsT *L2NormOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new L2NormOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void L2NormOptions::UnPackTo( - L2NormOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void L2NormOptions::UnPackTo(L2NormOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset L2NormOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset L2NormOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateL2NormOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateL2NormOptions( - flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateL2NormOptions(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const L2NormOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const L2NormOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateL2NormOptions(_fbb, _fused_activation_function); + return tflite::CreateL2NormOptions( + _fbb, + _fused_activation_function); } -inline LocalResponseNormalizationOptionsT * -LocalResponseNormalizationOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline LocalResponseNormalizationOptionsT *LocalResponseNormalizationOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new LocalResponseNormalizationOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void LocalResponseNormalizationOptions::UnPackTo( - LocalResponseNormalizationOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void LocalResponseNormalizationOptions::UnPackTo(LocalResponseNormalizationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = radius(); - _o->radius = _e; - }; - { - auto _e = bias(); - _o->bias = _e; - }; - { - auto _e = alpha(); - _o->alpha = _e; - }; - { - auto _e = beta(); - _o->beta = _e; - }; + { auto _e = radius(); _o->radius = _e; }; + { auto _e = bias(); _o->bias = _e; }; + { auto _e = alpha(); _o->alpha = _e; }; + { auto _e = beta(); _o->beta = _e; }; } -inline flatbuffers::Offset -LocalResponseNormalizationOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const LocalResponseNormalizationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset LocalResponseNormalizationOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateLocalResponseNormalizationOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset -CreateLocalResponseNormalizationOptions( - flatbuffers::FlatBufferBuilder &_fbb, - const LocalResponseNormalizationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const LocalResponseNormalizationOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LocalResponseNormalizationOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _radius = _o->radius; auto _bias = _o->bias; auto _alpha = _o->alpha; auto _beta = _o->beta; - return tflite::CreateLocalResponseNormalizationOptions(_fbb, _radius, _bias, - _alpha, _beta); + return tflite::CreateLocalResponseNormalizationOptions( + _fbb, + _radius, + _bias, + _alpha, + _beta); } -inline LSTMOptionsT *LSTMOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline LSTMOptionsT *LSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new LSTMOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void LSTMOptions::UnPackTo( - LSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; - { - auto _e = cell_clip(); - _o->cell_clip = _e; - }; - { - auto _e = proj_clip(); - _o->proj_clip = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = cell_clip(); _o->cell_clip = _e; }; + { auto _e = proj_clip(); _o->proj_clip = _e; }; } -inline flatbuffers::Offset LSTMOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateLSTMOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateLSTMOptions( - flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const LSTMOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LSTMOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; - return tflite::CreateLSTMOptions(_fbb, _fused_activation_function, _cell_clip, - _proj_clip); + return tflite::CreateLSTMOptions( + _fbb, + _fused_activation_function, + _cell_clip, + _proj_clip); } -inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ResizeBilinearOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void ResizeBilinearOptions::UnPackTo( - ResizeBilinearOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void ResizeBilinearOptions::UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = new_height(); - _o->new_height = _e; - }; - { - auto _e = new_width(); - _o->new_width = _e; - }; + { auto _e = align_corners(); _o->align_corners = _e; }; } -inline flatbuffers::Offset ResizeBilinearOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset ResizeBilinearOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateResizeBilinearOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateResizeBilinearOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const ResizeBilinearOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - auto _new_height = _o->new_height; - auto _new_width = _o->new_width; - return tflite::CreateResizeBilinearOptions(_fbb, _new_height, _new_width); -} - -inline CallOptionsT *CallOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ResizeBilinearOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _align_corners = _o->align_corners; + return tflite::CreateResizeBilinearOptions( + _fbb, + _align_corners); +} + +inline CallOptionsT *CallOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new CallOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void CallOptions::UnPackTo( - CallOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void CallOptions::UnPackTo(CallOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = subgraph(); - _o->subgraph = _e; - }; + { auto _e = subgraph(); _o->subgraph = _e; }; } -inline flatbuffers::Offset CallOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CallOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateCallOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateCallOptions( - flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateCallOptions(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const CallOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CallOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _subgraph = _o->subgraph; - return tflite::CreateCallOptions(_fbb, _subgraph); + return tflite::CreateCallOptions( + _fbb, + _subgraph); } -inline PadOptionsT *PadOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline PadOptionsT *PadOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new PadOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void PadOptions::UnPackTo( - PadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void PadOptions::UnPackTo(PadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; } -inline flatbuffers::Offset PadOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset PadOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreatePadOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreatePadOptions( - flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreatePadOptions(flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const PadOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - return tflite::CreatePadOptions(_fbb); + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PadOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreatePadOptions( + _fbb); } -inline ReshapeOptionsT *ReshapeOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ReshapeOptionsT *ReshapeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ReshapeOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void ReshapeOptions::UnPackTo( - ReshapeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void ReshapeOptions::UnPackTo(ReshapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = new_shape(); - if (_e) { - _o->new_shape.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->new_shape[_i] = _e->Get(_i); - } - } - }; + { auto _e = new_shape(); if (_e) { _o->new_shape.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->new_shape[_i] = _e->Get(_i); } } }; } -inline flatbuffers::Offset ReshapeOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset ReshapeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateReshapeOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateReshapeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateReshapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const ReshapeOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReshapeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _new_shape = _o->new_shape.size() ? _fbb.CreateVector(_o->new_shape) : 0; - return tflite::CreateReshapeOptions(_fbb, _new_shape); + return tflite::CreateReshapeOptions( + _fbb, + _new_shape); } -inline SpaceToBatchNDOptionsT *SpaceToBatchNDOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SpaceToBatchNDOptionsT *SpaceToBatchNDOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SpaceToBatchNDOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SpaceToBatchNDOptions::UnPackTo( - SpaceToBatchNDOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SpaceToBatchNDOptions::UnPackTo(SpaceToBatchNDOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = block_shape(); - if (_e) { - _o->block_shape.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->block_shape[_i] = _e->Get(_i); - } - } - }; - { - auto _e = before_paddings(); - if (_e) { - _o->before_paddings.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->before_paddings[_i] = _e->Get(_i); - } - } - }; - { - auto _e = after_paddings(); - if (_e) { - _o->after_paddings.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->after_paddings[_i] = _e->Get(_i); - } - } - }; } -inline flatbuffers::Offset SpaceToBatchNDOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SpaceToBatchNDOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSpaceToBatchNDOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSpaceToBatchNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSpaceToBatchNDOptions(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SpaceToBatchNDOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - auto _block_shape = - _o->block_shape.size() ? _fbb.CreateVector(_o->block_shape) : 0; - auto _before_paddings = - _o->before_paddings.size() ? _fbb.CreateVector(_o->before_paddings) : 0; - auto _after_paddings = - _o->after_paddings.size() ? _fbb.CreateVector(_o->after_paddings) : 0; - return tflite::CreateSpaceToBatchNDOptions(_fbb, _block_shape, - _before_paddings, _after_paddings); -} - -inline BatchToSpaceNDOptionsT *BatchToSpaceNDOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SpaceToBatchNDOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSpaceToBatchNDOptions( + _fbb); +} + +inline BatchToSpaceNDOptionsT *BatchToSpaceNDOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new BatchToSpaceNDOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void BatchToSpaceNDOptions::UnPackTo( - BatchToSpaceNDOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void BatchToSpaceNDOptions::UnPackTo(BatchToSpaceNDOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = block_shape(); - if (_e) { - _o->block_shape.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->block_shape[_i] = _e->Get(_i); - } - } - }; - { - auto _e = before_crops(); - if (_e) { - _o->before_crops.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->before_crops[_i] = _e->Get(_i); - } - } - }; - { - auto _e = after_crops(); - if (_e) { - _o->after_crops.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->after_crops[_i] = _e->Get(_i); - } - } - }; } -inline flatbuffers::Offset BatchToSpaceNDOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset BatchToSpaceNDOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateBatchToSpaceNDOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateBatchToSpaceNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateBatchToSpaceNDOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const BatchToSpaceNDOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - auto _block_shape = - _o->block_shape.size() ? _fbb.CreateVector(_o->block_shape) : 0; - auto _before_crops = - _o->before_crops.size() ? _fbb.CreateVector(_o->before_crops) : 0; - auto _after_crops = - _o->after_crops.size() ? _fbb.CreateVector(_o->after_crops) : 0; - return tflite::CreateBatchToSpaceNDOptions(_fbb, _block_shape, _before_crops, - _after_crops); -} - -inline SkipGramOptionsT *SkipGramOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BatchToSpaceNDOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBatchToSpaceNDOptions( + _fbb); +} + +inline SkipGramOptionsT *SkipGramOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SkipGramOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SkipGramOptions::UnPackTo( - SkipGramOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SkipGramOptions::UnPackTo(SkipGramOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = ngram_size(); - _o->ngram_size = _e; - }; - { - auto _e = max_skip_size(); - _o->max_skip_size = _e; - }; - { - auto _e = include_all_ngrams(); - _o->include_all_ngrams = _e; - }; + { auto _e = ngram_size(); _o->ngram_size = _e; }; + { auto _e = max_skip_size(); _o->max_skip_size = _e; }; + { auto _e = include_all_ngrams(); _o->include_all_ngrams = _e; }; } -inline flatbuffers::Offset SkipGramOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SkipGramOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSkipGramOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSkipGramOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSkipGramOptions(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SkipGramOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SkipGramOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _ngram_size = _o->ngram_size; auto _max_skip_size = _o->max_skip_size; auto _include_all_ngrams = _o->include_all_ngrams; - return tflite::CreateSkipGramOptions(_fbb, _ngram_size, _max_skip_size, - _include_all_ngrams); + return tflite::CreateSkipGramOptions( + _fbb, + _ngram_size, + _max_skip_size, + _include_all_ngrams); } -inline SpaceToDepthOptionsT *SpaceToDepthOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SpaceToDepthOptionsT *SpaceToDepthOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SpaceToDepthOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SpaceToDepthOptions::UnPackTo( - SpaceToDepthOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SpaceToDepthOptions::UnPackTo(SpaceToDepthOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = block_size(); - _o->block_size = _e; - }; + { auto _e = block_size(); _o->block_size = _e; }; } -inline flatbuffers::Offset SpaceToDepthOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SpaceToDepthOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSpaceToDepthOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSpaceToDepthOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSpaceToDepthOptions(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SpaceToDepthOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SpaceToDepthOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _block_size = _o->block_size; - return tflite::CreateSpaceToDepthOptions(_fbb, _block_size); + return tflite::CreateSpaceToDepthOptions( + _fbb, + _block_size); } -inline SubOptionsT *SubOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SubOptionsT *SubOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SubOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SubOptions::UnPackTo( - SubOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void SubOptions::UnPackTo(SubOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset SubOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SubOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSubOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSubOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSubOptions(flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SubOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SubOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateSubOptions(_fbb, _fused_activation_function); + return tflite::CreateSubOptions( + _fbb, + _fused_activation_function); } -inline DivOptionsT *DivOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline DivOptionsT *DivOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new DivOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void DivOptions::UnPackTo( - DivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void DivOptions::UnPackTo(DivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset DivOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset DivOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateDivOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateDivOptions( - flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const DivOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DivOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateDivOptions(_fbb, _fused_activation_function); + return tflite::CreateDivOptions( + _fbb, + _fused_activation_function); +} + +inline TopKV2OptionsT *TopKV2Options::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TopKV2OptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void TopKV2Options::UnPackTo(TopKV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset TopKV2Options::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTopKV2Options(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateTopKV2Options(flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TopKV2OptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTopKV2Options( + _fbb); } -inline EmbeddingLookupSparseOptionsT *EmbeddingLookupSparseOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline EmbeddingLookupSparseOptionsT *EmbeddingLookupSparseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new EmbeddingLookupSparseOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void EmbeddingLookupSparseOptions::UnPackTo( - EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void EmbeddingLookupSparseOptions::UnPackTo(EmbeddingLookupSparseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = combiner(); - _o->combiner = _e; - }; + { auto _e = combiner(); _o->combiner = _e; }; } -inline flatbuffers::Offset -EmbeddingLookupSparseOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset EmbeddingLookupSparseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateEmbeddingLookupSparseOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset -CreateEmbeddingLookupSparseOptions( - flatbuffers::FlatBufferBuilder &_fbb, - const EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateEmbeddingLookupSparseOptions(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const EmbeddingLookupSparseOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EmbeddingLookupSparseOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _combiner = _o->combiner; - return tflite::CreateEmbeddingLookupSparseOptions(_fbb, _combiner); + return tflite::CreateEmbeddingLookupSparseOptions( + _fbb, + _combiner); } -inline GatherOptionsT *GatherOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline GatherOptionsT *GatherOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new GatherOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void GatherOptions::UnPackTo( - GatherOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void GatherOptions::UnPackTo(GatherOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = axis(); - _o->axis = _e; - }; + { auto _e = axis(); _o->axis = _e; }; } -inline flatbuffers::Offset GatherOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset GatherOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateGatherOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateGatherOptions( - flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateGatherOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const GatherOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GatherOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _axis = _o->axis; - return tflite::CreateGatherOptions(_fbb, _axis); + return tflite::CreateGatherOptions( + _fbb, + _axis); } -inline TransposeOptionsT *TransposeOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline TransposeOptionsT *TransposeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new TransposeOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void TransposeOptions::UnPackTo( - TransposeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void TransposeOptions::UnPackTo(TransposeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = perm(); - if (_e) { - _o->perm.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->perm[_i] = _e->Get(_i); - } - } - }; } -inline flatbuffers::Offset TransposeOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset TransposeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateTransposeOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateTransposeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateTransposeOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TransposeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTransposeOptions( + _fbb); +} + +inline ExpOptionsT *ExpOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ExpOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ExpOptions::UnPackTo(ExpOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset ExpOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateExpOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const TransposeOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - auto _perm = _o->perm.size() ? _fbb.CreateVector(_o->perm) : 0; - return tflite::CreateTransposeOptions(_fbb, _perm); -} - -inline MeanOptionsT *MeanOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ExpOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateExpOptions( + _fbb); +} + +inline MeanOptionsT *MeanOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new MeanOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void MeanOptions::UnPackTo( - MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void MeanOptions::UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = axis(); - if (_e) { - _o->axis.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->axis[_i] = _e->Get(_i); - } - } - }; - { - auto _e = keep_dims(); - _o->keep_dims = _e; - }; + { auto _e = keep_dims(); _o->keep_dims = _e; }; } -inline flatbuffers::Offset MeanOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset MeanOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateMeanOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateMeanOptions( - flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const MeanOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - auto _axis = _o->axis.size() ? _fbb.CreateVector(_o->axis) : 0; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MeanOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _keep_dims = _o->keep_dims; - return tflite::CreateMeanOptions(_fbb, _axis, _keep_dims); + return tflite::CreateMeanOptions( + _fbb, + _keep_dims); } -inline SqueezeOptionsT *SqueezeOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SqueezeOptionsT *SqueezeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SqueezeOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SqueezeOptions::UnPackTo( - SqueezeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SqueezeOptions::UnPackTo(SqueezeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = squeeze_dims(); - if (_e) { - _o->squeeze_dims.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->squeeze_dims[_i] = _e->Get(_i); - } - } - }; + { auto _e = squeeze_dims(); if (_e) { _o->squeeze_dims.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->squeeze_dims[_i] = _e->Get(_i); } } }; } -inline flatbuffers::Offset SqueezeOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SqueezeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSqueezeOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSqueezeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSqueezeOptions(flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SqueezeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _squeeze_dims = _o->squeeze_dims.size() ? _fbb.CreateVector(_o->squeeze_dims) : 0; + return tflite::CreateSqueezeOptions( + _fbb, + _squeeze_dims); +} + +inline SplitOptionsT *SplitOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SplitOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SplitOptions::UnPackTo(SplitOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = num_splits(); _o->num_splits = _e; }; +} + +inline flatbuffers::Offset SplitOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSplitOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSplitOptions(flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SqueezeOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - auto _squeeze_dims = - _o->squeeze_dims.size() ? _fbb.CreateVector(_o->squeeze_dims) : 0; - return tflite::CreateSqueezeOptions(_fbb, _squeeze_dims); -} - -inline StridedSliceOptionsT *StridedSliceOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SplitOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _num_splits = _o->num_splits; + return tflite::CreateSplitOptions( + _fbb, + _num_splits); +} + +inline StridedSliceOptionsT *StridedSliceOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new StridedSliceOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void StridedSliceOptions::UnPackTo( - StridedSliceOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void StridedSliceOptions::UnPackTo(StridedSliceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = begin_mask(); - _o->begin_mask = _e; - }; - { - auto _e = end_mask(); - _o->end_mask = _e; - }; - { - auto _e = ellipsis_mask(); - _o->ellipsis_mask = _e; - }; - { - auto _e = new_axis_mask(); - _o->new_axis_mask = _e; - }; - { - auto _e = shrink_axis_mask(); - _o->shrink_axis_mask = _e; - }; + { auto _e = begin_mask(); _o->begin_mask = _e; }; + { auto _e = end_mask(); _o->end_mask = _e; }; + { auto _e = ellipsis_mask(); _o->ellipsis_mask = _e; }; + { auto _e = new_axis_mask(); _o->new_axis_mask = _e; }; + { auto _e = shrink_axis_mask(); _o->shrink_axis_mask = _e; }; } -inline flatbuffers::Offset StridedSliceOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset StridedSliceOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateStridedSliceOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateStridedSliceOptions( - flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateStridedSliceOptions(flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const StridedSliceOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const StridedSliceOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _begin_mask = _o->begin_mask; auto _end_mask = _o->end_mask; auto _ellipsis_mask = _o->ellipsis_mask; auto _new_axis_mask = _o->new_axis_mask; auto _shrink_axis_mask = _o->shrink_axis_mask; - return tflite::CreateStridedSliceOptions(_fbb, _begin_mask, _end_mask, - _ellipsis_mask, _new_axis_mask, - _shrink_axis_mask); + return tflite::CreateStridedSliceOptions( + _fbb, + _begin_mask, + _end_mask, + _ellipsis_mask, + _new_axis_mask, + _shrink_axis_mask); } -inline OperatorCodeT *OperatorCode::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); return _o; } -inline void OperatorCode::UnPackTo( - OperatorCodeT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void OperatorCode::UnPackTo(OperatorCodeT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = builtin_code(); - _o->builtin_code = _e; - }; - { - auto _e = custom_code(); - if (_e) _o->custom_code = _e->str(); - }; + { auto _e = builtin_code(); _o->builtin_code = _e; }; + { auto _e = custom_code(); if (_e) _o->custom_code = _e->str(); }; } -inline flatbuffers::Offset OperatorCode::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset OperatorCode::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateOperatorCode(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateOperatorCode( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const OperatorCodeT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OperatorCodeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _builtin_code = _o->builtin_code; - auto _custom_code = - _o->custom_code.empty() ? 0 : _fbb.CreateString(_o->custom_code); - return tflite::CreateOperatorCode(_fbb, _builtin_code, _custom_code); + auto _custom_code = _o->custom_code.empty() ? 0 : _fbb.CreateString(_o->custom_code); + return tflite::CreateOperatorCode( + _fbb, + _builtin_code, + _custom_code); } -inline OperatorT *Operator::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline OperatorT *Operator::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorT(); UnPackTo(_o, _resolver); return _o; } -inline void Operator::UnPackTo( - OperatorT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = opcode_index(); - _o->opcode_index = _e; - }; - { - auto _e = inputs(); - if (_e) { - _o->inputs.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->inputs[_i] = _e->Get(_i); - } - } - }; - { - auto _e = outputs(); - if (_e) { - _o->outputs.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->outputs[_i] = _e->Get(_i); - } - } - }; - { - auto _e = builtin_options_type(); - _o->builtin_options.type = _e; - }; - { - auto _e = builtin_options(); - if (_e) - _o->builtin_options.value = - BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); - }; - { - auto _e = custom_options(); - if (_e) { - _o->custom_options.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->custom_options[_i] = _e->Get(_i); - } - } - }; - { - auto _e = custom_options_format(); - _o->custom_options_format = _e; - }; + { auto _e = opcode_index(); _o->opcode_index = _e; }; + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } }; + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } }; + { auto _e = builtin_options_type(); _o->builtin_options.type = _e; }; + { auto _e = builtin_options(); if (_e) _o->builtin_options.value = BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); }; + { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } }; + { auto _e = custom_options_format(); _o->custom_options_format = _e; }; } -inline flatbuffers::Offset Operator::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Operator::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateOperator(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateOperator( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const OperatorT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OperatorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _opcode_index = _o->opcode_index; auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; auto _builtin_options_type = _o->builtin_options.type; auto _builtin_options = _o->builtin_options.Pack(_fbb); - auto _custom_options = - _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; + auto _custom_options = _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; auto _custom_options_format = _o->custom_options_format; - return tflite::CreateOperator(_fbb, _opcode_index, _inputs, _outputs, - _builtin_options_type, _builtin_options, - _custom_options, _custom_options_format); + return tflite::CreateOperator( + _fbb, + _opcode_index, + _inputs, + _outputs, + _builtin_options_type, + _builtin_options, + _custom_options, + _custom_options_format); } -inline SubGraphT *SubGraph::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SubGraphT *SubGraph::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SubGraphT(); UnPackTo(_o, _resolver); return _o; } -inline void SubGraph::UnPackTo( - SubGraphT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void SubGraph::UnPackTo(SubGraphT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = tensors(); - if (_e) { - _o->tensors.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->tensors[_i] = - std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); - } - } - }; - { - auto _e = inputs(); - if (_e) { - _o->inputs.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->inputs[_i] = _e->Get(_i); - } - } - }; - { - auto _e = outputs(); - if (_e) { - _o->outputs.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->outputs[_i] = _e->Get(_i); - } - } - }; - { - auto _e = operators(); - if (_e) { - _o->operators.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->operators[_i] = - std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); - } - } - }; - { - auto _e = name(); - if (_e) _o->name = _e->str(); - }; + { auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } }; + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } }; + { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = name(); if (_e) _o->name = _e->str(); }; } -inline flatbuffers::Offset SubGraph::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SubGraph::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSubGraph(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSubGraph( - flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSubGraph(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SubGraphT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - auto _tensors = - _o->tensors.size() - ? _fbb.CreateVector>( - _o->tensors.size(), - [](size_t i, _VectorArgs *__va) { - return CreateTensor(*__va->__fbb, __va->__o->tensors[i].get(), - __va->__rehasher); - }, - &_va) - : 0; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SubGraphT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _tensors = _o->tensors.size() ? _fbb.CreateVector> (_o->tensors.size(), [](size_t i, _VectorArgs *__va) { return CreateTensor(*__va->__fbb, __va->__o->tensors[i].get(), __va->__rehasher); }, &_va ) : 0; auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; - auto _operators = _o->operators.size() - ? _fbb.CreateVector>( - _o->operators.size(), - [](size_t i, _VectorArgs *__va) { - return CreateOperator( - *__va->__fbb, __va->__o->operators[i].get(), - __va->__rehasher); - }, - &_va) - : 0; + auto _operators = _o->operators.size() ? _fbb.CreateVector> (_o->operators.size(), [](size_t i, _VectorArgs *__va) { return CreateOperator(*__va->__fbb, __va->__o->operators[i].get(), __va->__rehasher); }, &_va ) : 0; auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); - return tflite::CreateSubGraph(_fbb, _tensors, _inputs, _outputs, _operators, - _name); + return tflite::CreateSubGraph( + _fbb, + _tensors, + _inputs, + _outputs, + _operators, + _name); } -inline BufferT *Buffer::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline BufferT *Buffer::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new BufferT(); UnPackTo(_o, _resolver); return _o; } -inline void Buffer::UnPackTo( - BufferT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void Buffer::UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = data(); - if (_e) { - _o->data.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->data[_i] = _e->Get(_i); - } - } - }; + { auto _e = data(); if (_e) { _o->data.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->data[_i] = _e->Get(_i); } } }; } -inline flatbuffers::Offset Buffer::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Buffer::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateBuffer(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateBuffer( - flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateBuffer(flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const BufferT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BufferT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _data = _o->data.size() ? _fbb.CreateVector(_o->data) : 0; - return tflite::CreateBuffer(_fbb, _data); + return tflite::CreateBuffer( + _fbb, + _data); } -inline ModelT *Model::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ModelT(); UnPackTo(_o, _resolver); return _o; } -inline void Model::UnPackTo( - ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void Model::UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = version(); - _o->version = _e; - }; - { - auto _e = operator_codes(); - if (_e) { - _o->operator_codes.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->operator_codes[_i] = - std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); - } - } - }; - { - auto _e = subgraphs(); - if (_e) { - _o->subgraphs.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->subgraphs[_i] = - std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); - } - } - }; - { - auto _e = description(); - if (_e) _o->description = _e->str(); - }; - { - auto _e = buffers(); - if (_e) { - _o->buffers.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->buffers[_i] = - std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); - } - } - }; + { auto _e = version(); _o->version = _e; }; + { auto _e = operator_codes(); if (_e) { _o->operator_codes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->operator_codes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->subgraphs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = description(); if (_e) _o->description = _e->str(); }; + { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; } -inline flatbuffers::Offset Model::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateModel(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateModel( - flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const ModelT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _version = _o->version; - auto _operator_codes = - _o->operator_codes.size() - ? _fbb.CreateVector>( - _o->operator_codes.size(), - [](size_t i, _VectorArgs *__va) { - return CreateOperatorCode(*__va->__fbb, - __va->__o->operator_codes[i].get(), - __va->__rehasher); - }, - &_va) - : 0; - auto _subgraphs = _o->subgraphs.size() - ? _fbb.CreateVector>( - _o->subgraphs.size(), - [](size_t i, _VectorArgs *__va) { - return CreateSubGraph( - *__va->__fbb, __va->__o->subgraphs[i].get(), - __va->__rehasher); - }, - &_va) - : 0; - auto _description = - _o->description.empty() ? 0 : _fbb.CreateString(_o->description); - auto _buffers = - _o->buffers.size() - ? _fbb.CreateVector>( - _o->buffers.size(), - [](size_t i, _VectorArgs *__va) { - return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), - __va->__rehasher); - }, - &_va) - : 0; - return tflite::CreateModel(_fbb, _version, _operator_codes, _subgraphs, - _description, _buffers); -} - -inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, - const void *obj, BuiltinOptions type) { + auto _operator_codes = _o->operator_codes.size() ? _fbb.CreateVector> (_o->operator_codes.size(), [](size_t i, _VectorArgs *__va) { return CreateOperatorCode(*__va->__fbb, __va->__o->operator_codes[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _subgraphs = _o->subgraphs.size() ? _fbb.CreateVector> (_o->subgraphs.size(), [](size_t i, _VectorArgs *__va) { return CreateSubGraph(*__va->__fbb, __va->__o->subgraphs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _description = _o->description.empty() ? 0 : _fbb.CreateString(_o->description); + auto _buffers = _o->buffers.size() ? _fbb.CreateVector> (_o->buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), __va->__rehasher); }, &_va ) : 0; + return tflite::CreateModel( + _fbb, + _version, + _operator_codes, + _subgraphs, + _description, + _buffers); +} + +inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type) { switch (type) { case BuiltinOptions_NONE: { return true; @@ -6621,8 +5644,7 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, return verifier.VerifyTable(ptr); } case BuiltinOptions_LocalResponseNormalizationOptions: { - auto ptr = - reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LSTMOptions: { @@ -6701,28 +5723,35 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - default: - return false; + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return false; } } -inline bool VerifyBuiltinOptionsVector( - flatbuffers::Verifier &verifier, - const flatbuffers::Vector> *values, - const flatbuffers::Vector *types) { +inline bool VerifyBuiltinOptionsVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; if (values->size() != types->size()) return false; for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { - if (!VerifyBuiltinOptions(verifier, values->Get(i), - types->GetEnum(i))) { + if (!VerifyBuiltinOptions( + verifier, values->Get(i), types->GetEnum(i))) { return false; } } return true; } -inline void *BuiltinOptionsUnion::UnPack( - const void *obj, BuiltinOptions type, - const flatbuffers::resolver_function_t *resolver) { +inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, const flatbuffers::resolver_function_t *resolver) { switch (type) { case BuiltinOptions_Conv2DOptions: { auto ptr = reinterpret_cast(obj); @@ -6773,8 +5802,7 @@ inline void *BuiltinOptionsUnion::UnPack( return ptr->UnPack(resolver); } case BuiltinOptions_LocalResponseNormalizationOptions: { - auto ptr = - reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LSTMOptions: { @@ -6853,14 +5881,23 @@ inline void *BuiltinOptionsUnion::UnPack( auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } - default: - return nullptr; + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + default: return nullptr; } } -inline flatbuffers::Offset BuiltinOptionsUnion::Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const flatbuffers::rehasher_function_t *_rehasher) const { +inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher) const { switch (type) { case BuiltinOptions_Conv2DOptions: { auto ptr = reinterpret_cast(value); @@ -6911,10 +5948,8 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack( return CreateL2NormOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LocalResponseNormalizationOptions: { - auto ptr = - reinterpret_cast(value); - return CreateLocalResponseNormalizationOptions(_fbb, ptr, _rehasher) - .Union(); + auto ptr = reinterpret_cast(value); + return CreateLocalResponseNormalizationOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LSTMOptions: { auto ptr = reinterpret_cast(value); @@ -6992,32 +6027,38 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack( auto ptr = reinterpret_cast(value); return CreateStridedSliceOptions(_fbb, ptr, _rehasher).Union(); } - default: - return 0; + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(value); + return CreateExpOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(value); + return CreateTopKV2Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(value); + return CreateSplitOptions(_fbb, ptr, _rehasher).Union(); + } + default: return 0; } } -inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) - FLATBUFFERS_NOEXCEPT : type(u.type), - value(nullptr) { +inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FLATBUFFERS_NOEXCEPT : type(u.type), value(nullptr) { switch (type) { case BuiltinOptions_Conv2DOptions: { value = new Conv2DOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_DepthwiseConv2DOptions: { - value = new DepthwiseConv2DOptionsT( - *reinterpret_cast(u.value)); + value = new DepthwiseConv2DOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ConcatEmbeddingsOptions: { - value = new ConcatEmbeddingsOptionsT( - *reinterpret_cast(u.value)); + value = new ConcatEmbeddingsOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LSHProjectionOptions: { - value = new LSHProjectionOptionsT( - *reinterpret_cast(u.value)); + value = new LSHProjectionOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_Pool2DOptions: { @@ -7033,18 +6074,15 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_FullyConnectedOptions: { - value = new FullyConnectedOptionsT( - *reinterpret_cast(u.value)); + value = new FullyConnectedOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SoftmaxOptions: { - value = - new SoftmaxOptionsT(*reinterpret_cast(u.value)); + value = new SoftmaxOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ConcatenationOptions: { - value = new ConcatenationOptionsT( - *reinterpret_cast(u.value)); + value = new ConcatenationOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_AddOptions: { @@ -7056,8 +6094,7 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_LocalResponseNormalizationOptions: { - value = new LocalResponseNormalizationOptionsT( - *reinterpret_cast(u.value)); + value = new LocalResponseNormalizationOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LSTMOptions: { @@ -7065,8 +6102,7 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_ResizeBilinearOptions: { - value = new ResizeBilinearOptionsT( - *reinterpret_cast(u.value)); + value = new ResizeBilinearOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_CallOptions: { @@ -7074,23 +6110,19 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_ReshapeOptions: { - value = - new ReshapeOptionsT(*reinterpret_cast(u.value)); + value = new ReshapeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SkipGramOptions: { - value = - new SkipGramOptionsT(*reinterpret_cast(u.value)); + value = new SkipGramOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SpaceToDepthOptions: { - value = new SpaceToDepthOptionsT( - *reinterpret_cast(u.value)); + value = new SpaceToDepthOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_EmbeddingLookupSparseOptions: { - value = new EmbeddingLookupSparseOptionsT( - *reinterpret_cast(u.value)); + value = new EmbeddingLookupSparseOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_MulOptions: { @@ -7106,18 +6138,15 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_BatchToSpaceNDOptions: { - value = new BatchToSpaceNDOptionsT( - *reinterpret_cast(u.value)); + value = new BatchToSpaceNDOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SpaceToBatchNDOptions: { - value = new SpaceToBatchNDOptionsT( - *reinterpret_cast(u.value)); + value = new SpaceToBatchNDOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_TransposeOptions: { - value = new TransposeOptionsT( - *reinterpret_cast(u.value)); + value = new TransposeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_MeanOptions: { @@ -7133,18 +6162,27 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_SqueezeOptions: { - value = - new SqueezeOptionsT(*reinterpret_cast(u.value)); + value = new SqueezeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SequenceRNNOptions: { - value = new SequenceRNNOptionsT( - *reinterpret_cast(u.value)); + value = new SequenceRNNOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_StridedSliceOptions: { - value = new StridedSliceOptionsT( - *reinterpret_cast(u.value)); + value = new StridedSliceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ExpOptions: { + value = new ExpOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_TopKV2Options: { + value = new TopKV2OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SplitOptions: { + value = new SplitOptionsT(*reinterpret_cast(u.value)); break; } default: @@ -7314,8 +6352,22 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } - default: + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(value); + delete ptr; break; + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + default: break; } value = nullptr; type = BuiltinOptions_NONE; @@ -7325,25 +6377,33 @@ inline const tflite::Model *GetModel(const void *buf) { return flatbuffers::GetRoot(buf); } -inline const char *ModelIdentifier() { return "TFL3"; } +inline const char *ModelIdentifier() { + return "TFL3"; +} inline bool ModelBufferHasIdentifier(const void *buf) { - return flatbuffers::BufferHasIdentifier(buf, ModelIdentifier()); + return flatbuffers::BufferHasIdentifier( + buf, ModelIdentifier()); } -inline bool VerifyModelBuffer(flatbuffers::Verifier &verifier) { +inline bool VerifyModelBuffer( + flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(ModelIdentifier()); } -inline const char *ModelExtension() { return "tflite"; } +inline const char *ModelExtension() { + return "tflite"; +} -inline void FinishModelBuffer(flatbuffers::FlatBufferBuilder &fbb, - flatbuffers::Offset root) { +inline void FinishModelBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { fbb.Finish(root, ModelIdentifier()); } inline std::unique_ptr UnPackModel( - const void *buf, const flatbuffers::resolver_function_t *res = nullptr) { + const void *buf, + const flatbuffers::resolver_function_t *res = nullptr) { return std::unique_ptr(GetModel(buf)->UnPack(res)); } diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h index 0c5e00a1f2e6a3303556ec54d8e50e8398644bf5..0535522374c63459d029c252ebe94628cf3122d5 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.h +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -36,9 +36,9 @@ struct ArenaAlloc { } }; -// This small class is responsible for allocating, dealocating and reusing +// This small class is responsible for allocating, deallocating and reusing // dynamic memory from a common underlying buffer. The arena can be used in -// scenarios when the pattern of memory allocations and dealocations is +// scenarios when the pattern of memory allocations and deallocations is // repetitive, e.g. running NN inference in multiple iterations. class SimpleMemoryArena { public: diff --git a/tensorflow/contrib/lite/testdata/multi_add.pb b/tensorflow/contrib/lite/testdata/multi_add.pb new file mode 100644 index 0000000000000000000000000000000000000000..e95a20841fb2b320bd77994d9dda157d79311dd6 --- /dev/null +++ b/tensorflow/contrib/lite/testdata/multi_add.pb @@ -0,0 +1,26 @@ + +I +a Placeholder" /device:CPU:0* +shape:* +dtype0 +I +b Placeholder" /device:CPU:0* +dtype0* +shape: +I +c Placeholder" /device:CPU:0* +dtype0* +shape: +I +d Placeholder" /device:CPU:0* +dtype0* +shape: +& +iAddbc" /device:CPU:0* +T0 +& +xAddai" /device:CPU:0* +T0 +& +yAdddi" /device:CPU:0* +T0" \ No newline at end of file diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 50e8ca75f8efd600d4773b83cd2c8de11c9d13ca..06570ae9aa3d10c3cb73ab362e30244ec0b78a35 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -25,6 +25,7 @@ gen_zipped_test_files( "conv.zip", "depthwiseconv.zip", "div.zip", + "exp.zip", "fully_connected.zip", "fused_batch_norm.zip", "gather.zip", @@ -45,9 +46,11 @@ gen_zipped_test_files( "softmax.zip", "space_to_batch_nd.zip", "space_to_depth.zip", + "split.zip", "squeeze.zip", "strided_slice.zip", "sub.zip", + "topk.zip", "transpose.zip", ], ) @@ -121,6 +124,21 @@ cc_test( ], ) +cc_library( + name = "join", + hdrs = ["join.h"], +) + +cc_test( + name = "join_test", + size = "small", + srcs = ["join_test.cc"], + deps = [ + ":join", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "tflite_driver", srcs = ["tflite_driver.cc"], @@ -195,9 +213,36 @@ cc_binary( ], ) +cc_library( + name = "tf_driver", + srcs = ["tf_driver.cc"], + hdrs = ["tf_driver.h"], + deps = [ + ":join", + ":split", + ":test_runner", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +cc_test( + name = "tf_driver_test", + size = "small", + srcs = ["tf_driver_test.cc"], + data = ["//tensorflow/contrib/lite:testdata/multi_add.pb"], + deps = [ + ":tf_driver", + "@com_google_googletest//:gtest_main", + ], +) + tf_cc_test( name = "generated_examples_zip_test", - size = "medium", + size = "large", srcs = ["generated_examples_zip_test.cc"], args = [ "--zip_files_dir=tensorflow/contrib/lite/testing/optest", @@ -206,7 +251,7 @@ tf_cc_test( "--unzip_binary_path=/usr/bin/unzip", ], data = [":optest"], - shard_count = 10, + shard_count = 20, tags = ["no_oss"], deps = [ ":parse_testdata_lib", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index a639351657835a1e7d17466e70277e8bf40bc0f9..b6c09306d6adb8e54d5108dac850f0249ffcb838 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -94,12 +94,17 @@ KNOWN_BUGS = { r"softmax.*input_shape=\[1,3,4,3\]": "67749831", # SpaceToDepth only supports float32. r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", - # BatchToSpaceND doesn't support cropping. + # BatchToSpaceND doesn't support cropping. This catches test cases with + # const tensors as crops. r"batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\]": "70594634", # BatchToSpaceND only supports 4D tensors. r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", - # Div will use floordiv - r"div.*int32": "72051395" + # Div will use floordiv. + r"div.*int32": "72051395", + # TOCO require matching dimensions in strided_slice. + r"strided_slice.*begin=\[0\].*end=\[1\].*": "73170889", + # No support for SplitV + r"split.*num_or_size_splits=\[2,2\]": "73377559", } @@ -240,7 +245,7 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): if dtype in (tf.float32, tf.float16): value = (max_value-min_value)*np.random.random_sample(shape)+min_value elif dtype in (tf.int32, tf.uint8, tf.int64): - value = np.random.random_integers(min_value, max_value, shape) + value = np.random.randint(min_value, max_value+1, shape) return value.astype(dtype) @@ -326,6 +331,12 @@ def toco_convert(graph_def_str, input_tensors, output_tensors, return (None if exit_code != 0 else output_file.read()), log +def normalize_output_name(output_name): + """Remove :0 suffix from tensor names.""" + return output_name.split(":")[0] if output_name.endswith( + ":0") else output_name + + def make_zip_of_tests(zip_path, test_parameters, make_graph, @@ -414,8 +425,8 @@ def make_zip_of_tests(zip_path, sess.graph_def.SerializeToString(), [(input_tensor.name.split(":")[0], input_tensor.get_shape(), input_tensor.dtype) for input_tensor in inputs], - [out.name.split(":")[0] - for out in outputs], drop_control_dependency) + [normalize_output_name(out.name) for out in outputs], + drop_control_dependency) report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None else report_lib.FAILED) report["toco_log"] = toco_log @@ -618,7 +629,7 @@ def make_constant_tests(zip_path): def build_graph(parameters): # Since Toco & Tflite can't have a single constant op in the entire graph, - # this test adds a zero tesnor with a constant op tensor. + # this test adds a zero tensor with a constant op tensor. input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", shape=parameters["input_shape"]) out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1 @@ -694,6 +705,7 @@ def make_mean_tests(zip_path): [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] ], + "const_axis": [True, False], "keep_dims": [True, False], }, { "input_dtype": [tf.float32, tf.int32, tf.int64], @@ -704,6 +716,7 @@ def make_mean_tests(zip_path): -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], [2, 2, 3], [-3, -3, -4], [-3, 2, 1] ], + "const_axis": [True, False], "keep_dims": [True, False], }] @@ -713,17 +726,59 @@ def make_mean_tests(zip_path): dtype=parameters["input_dtype"], name="input", shape=parameters["input_shape"]) + + # Get axis as either a placeholder or constants. + if parameters["const_axis"]: + axis = parameters["axis"] + input_tensors = [input_tensor] + else: + if isinstance(parameters["axis"], list): + shape = [len(parameters["axis"])] + else: + shape = [0] # shape for None or integers. + axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) + input_tensors = [input_tensor, axis] + out = tf.reduce_mean( - input_tensor, - axis=parameters["axis"], - keep_dims=parameters["keep_dims"]) + input_tensor, axis=axis, keep_dims=parameters["keep_dims"]) + return input_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) + ] + if not parameters["const_axis"]: + if parameters["axis"]: + values.append(np.array(parameters["axis"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_exp_tests(zip_path): + """Make a set of tests to do exp.""" + + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + }] + + def build_graph(parameters): + """Build the exp op testing graph.""" + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + + out = tf.exp(input_tensor) return [input_tensor], [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["input_dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["input_dtype"], parameters["input_shape"], + min_value=-100, max_value=9) + ] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -978,19 +1033,44 @@ def make_depthwiseconv_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_split_tests(zip_path): + """Make a set of tests to do tf.split.""" + + test_parameters = [{ + "input_shape": [[1, 3, 4, 6], [2, 4, 1], [6, 4], [8]], + "num_or_size_splits": [1, 2, 3, 4, 5, [2, 2]], + "axis": [0, 1, 2, 3, -4, -3, -2, -1], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.split( + input_tensor, parameters["num_or_size_splits"], parameters["axis"]) + return [input_tensor], out + + def build_inputs(parameters, sess, inputs, outputs): + values = [create_tensor_data(np.float32, parameters["input_shape"])] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_concatenation_tests(zip_path): - """Make a set of tests to do concatenatinon.""" + """Make a set of tests to do concatenation.""" test_parameters = [{ "base_shape": [[1, 3, 4, 3], [3, 4]], "num_tensors": [1, 2, 3, 4, 5, 6], - "axis": [0, 1, 2, 3], + "axis": [0, 1, 2, 3, -3, -2, -1], }] def get_shape(parameters, delta): """Return a tweaked version of 'base_shape'.""" axis = parameters["axis"] shape = parameters["base_shape"][:] + if axis < 0: + axis += len(shape) if axis < len(shape): shape[axis] += delta return shape @@ -1318,12 +1398,16 @@ def make_space_to_batch_nd_tests(zip_path): "input_shape": [[1, 2, 2, 3], [2, 2, 4, 1]], "block_shape": [[1, 3], [2, 2]], "paddings": [[[0, 0], [0, 0]], [[0, 0], [2, 0]], [[1, 1], [1, 1]]], + "constant_block_shape": [True, False], + "constant_paddings": [True, False], }, { "dtype": [tf.float32], "input_shape": [[2, 3, 7, 3]], "block_shape": [[1, 3], [2, 2]], "paddings": [[[0, 0], [2, 0]], [[1, 0], [1, 0]]], + "constant_block_shape": [True, False], + "constant_paddings": [True, False], }, # Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others. { @@ -1331,23 +1415,47 @@ def make_space_to_batch_nd_tests(zip_path): "input_shape": [[1, 4, 4, 4, 1, 1]], "block_shape": [[2, 2, 2]], "paddings": [[[0, 0], [0, 0], [0, 0]]], + "constant_block_shape": [True, False], + "constant_paddings": [True, False], }, ] def build_graph(parameters): + """Build a space_to_batch graph given `parameters`.""" input_tensor = tf.placeholder( dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - out = tf.space_to_batch_nd(input_tensor, parameters["block_shape"], - parameters["paddings"]) - return [input_tensor], [out] + input_tensors = [input_tensor] + + # Get block_shape either as a const or as a placeholder (tensor). + if parameters["constant_block_shape"]: + block_shape = parameters["block_shape"] + else: + shape = [len(parameters["block_shape"])] + block_shape = tf.placeholder(dtype=tf.int32, name="shape", shape=shape) + input_tensors.append(block_shape) + + # Get paddings either as a const or as a placeholder (tensor). + if parameters["constant_paddings"]: + paddings = parameters["paddings"] + else: + shape = [len(parameters["paddings"]), 2] + paddings = tf.placeholder(dtype=tf.int32, name="paddings", shape=shape) + input_tensors.append(paddings) + + out = tf.space_to_batch_nd(input_tensor, block_shape, paddings) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_block_shape"]: + values.append(np.array(parameters["block_shape"])) + if not parameters["constant_paddings"]: + values.append(np.array(parameters["paddings"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -1361,6 +1469,8 @@ def make_batch_to_space_nd_tests(zip_path): "input_shape": [[12, 2, 2, 1]], "block_shape": [[1, 4], [2, 2], [3, 4]], "crops": [[[0, 0], [0, 0]], [[1, 1], [1, 1]]], + "constant_block_shape": [True, False], + "constant_crops": [True, False], }, # Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others. { @@ -1368,23 +1478,47 @@ def make_batch_to_space_nd_tests(zip_path): "input_shape": [[8, 2, 2, 2, 1, 1]], "block_shape": [[2, 2, 2]], "crops": [[[0, 0], [0, 0], [0, 0]]], + "constant_block_shape": [True, False], + "constant_crops": [True, False], }, ] def build_graph(parameters): + """Build a batch_to_space graph given `parameters`.""" input_tensor = tf.placeholder( dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - out = tf.batch_to_space_nd(input_tensor, parameters["block_shape"], - parameters["crops"]) - return [input_tensor], [out] + input_tensors = [input_tensor] + + # Get block_shape either as a const or as a placeholder (tensor). + if parameters["constant_block_shape"]: + block_shape = parameters["block_shape"] + else: + shape = [len(parameters["block_shape"])] + block_shape = tf.placeholder(dtype=tf.int32, name="shape", shape=shape) + input_tensors.append(block_shape) + + # Get crops either as a const or as a placeholder (tensor). + if parameters["constant_crops"]: + crops = parameters["crops"] + else: + shape = [len(parameters["crops"]), 2] + crops = tf.placeholder(dtype=tf.int32, name="crops", shape=shape) + input_tensors.append(crops) + + out = tf.batch_to_space_nd(input_tensor, block_shape, crops) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_block_shape"]: + values.append(np.array(parameters["block_shape"])) + if not parameters["constant_crops"]: + values.append(np.array(parameters["crops"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -1397,29 +1531,44 @@ def make_transpose_tests(zip_path): "dtype": [tf.int32, tf.int64, tf.float32], "input_shape": [[2, 2, 3]], "perm": [[0, 1, 2], [0, 2, 1]], + "constant_perm": [True, False], }, { "dtype": [tf.float32], "input_shape": [[1, 2, 3, 4]], "perm": [[0, 1, 2, 3], [3, 0, 1, 2]], + "constant_perm": [True, False], }, { "dtype": [tf.float32], "input_shape": [[1, 2, 3, 4, 5]], "perm": [[0, 1, 2, 3, 4]], + "constant_perm": [True, False], }] def build_graph(parameters): + """Build a transpose graph given `parameters`.""" input_tensor = tf.placeholder( dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - out = tf.transpose(input_tensor, perm=parameters["perm"]) - return [input_tensor], [out] + + if parameters["constant_perm"]: + perm = parameters["perm"] + input_tensors = [input_tensor] + else: + shape = [len(parameters["perm"]), 2] + perm = tf.placeholder(dtype=tf.int32, name="perm", shape=shape) + input_tensors = [input_tensor, perm] + + out = tf.transpose(input_tensor, perm=perm) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_perm"]: + values.append(np.array(parameters["perm"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -1474,9 +1623,24 @@ def make_strided_slice_tests(zip_path): "input_shape": [[12, 2, 2, 5]], "begin": [[0, 0, 0, 0], [1, 0, 1, 0]], "end": [[8, 2, 2, 3], [12, 2, 2, 5]], - "strides": [None, [1, 1, 1, 1], [2, 1, 3, 1]], - "begin_mask": [None, 0, 1, 2, 8], - "end_mask": [None, 0, 1, 2, 8], + "strides": [None, [2, 1, 3, 1]], + "begin_mask": [None, 1, 8], + "end_mask": [None, 1, 8], + "shrink_axis_mask": [None, 1, 8, 11, 15, -1], + "constant_indices": [False, True], + }, + # + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0]], + "end": [[1]], + "strides": [[1]], + "begin_mask": [0], + "end_mask": [0], + "shrink_axis_mask": [1], + "constant_indices": [True], }, # 2-D { @@ -1485,20 +1649,24 @@ def make_strided_slice_tests(zip_path): "input_shape": [[2, 3]], "begin": [[0, 0], [1, 0]], "end": [[2, 3], [2, 2]], - "strides": [None, [1, 1], [2, 2]], - "begin_mask": [None, 0, 1, 2], - "end_mask": [None, 0, 1, 2], + "strides": [None, [2, 2]], + "begin_mask": [None, 1, 2], + "end_mask": [None, 1, 2], + "shrink_axis_mask": [None, 1, 2, 3, -1], + "constant_indices": [False, True], }, # Negative strides { - "dtype": [tf.float32, tf.int32, tf.int64], + "dtype": [tf.float32], "index_type": [tf.int32], "input_shape": [[2, 3]], "begin": [[0, -1]], "end": [[2, -3]], "strides": [[1, -1]], - "begin_mask": [None, 0, 1, 2], - "end_mask": [None, 0, 1, 2], + "begin_mask": [None, 1, 2], + "end_mask": [None, 1, 2], + "shrink_axis_mask": [None, 1, 2, 3, -1], + "constant_indices": [False], }, ] @@ -1508,23 +1676,29 @@ def make_strided_slice_tests(zip_path): dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - begin = tf.placeholder( - dtype=parameters["index_type"], - name="begin", - shape=[len(parameters["input_shape"])]) - end = tf.placeholder( - dtype=parameters["index_type"], - name="end", - shape=[len(parameters["input_shape"])]) - strides = ( - tf.placeholder( - dtype=parameters["index_type"], - name="strides", - shape=[len(parameters["input_shape"])]) - if parameters["strides"] is not None else None) - tensors = [input_tensor, begin, end] - if strides is not None: - tensors.append(strides) + if parameters["constant_indices"]: + begin = parameters["begin"] + end = parameters["end"] + strides = parameters["strides"] + tensors = [input_tensor] + else: + begin = tf.placeholder( + dtype=parameters["index_type"], + name="begin", + shape=[len(parameters["input_shape"])]) + end = tf.placeholder( + dtype=parameters["index_type"], + name="end", + shape=[len(parameters["input_shape"])]) + strides = ( + tf.placeholder( + dtype=parameters["index_type"], + name="strides", + shape=[len(parameters["input_shape"])]) + if parameters["strides"] is not None else None) + tensors = [input_tensor, begin, end] + if strides is not None: + tensors.append(strides) out = tf.strided_slice( input_tensor, begin, @@ -1539,14 +1713,17 @@ def make_strided_slice_tests(zip_path): input_values = create_tensor_data(parameters["dtype"], parameters["input_shape"]) index_type = _TF_TYPE_INFO[parameters["index_type"]][0] - begin_values = np.array(parameters["begin"]).astype(index_type) - end_values = np.array(parameters["end"]).astype(index_type) - stride_values = ( - np.array(parameters["strides"]).astype(index_type) - if parameters["strides"] is not None else None) - values = [input_values, begin_values, end_values] - if stride_values is not None: - values.append(stride_values) + values = [input_values] + if not parameters["constant_indices"]: + begin_values = np.array(parameters["begin"]).astype(index_type) + end_values = np.array(parameters["end"]).astype(index_type) + stride_values = ( + np.array(parameters["strides"]).astype(index_type) + if parameters["strides"] is not None else None) + values.append(begin_values) + values.append(end_values) + if stride_values is not None: + values.append(stride_values) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) @@ -1560,6 +1737,32 @@ def make_l2_pool(input_tensor, ksize, strides, padding, data_format): padding=padding, data_format=data_format)) +def make_topk_tests(zip_path): + """Make a set of tests to do gather.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[10], [5, 20]], + }] + + def build_graph(parameters): + """Build the gather op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + k = tf.constant(3, name="k") + out = tf.nn.top_k(input_value, k) + return [input_value], [out[1]] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + # Toco binary path provided by the generate rule. bin_path = None @@ -1608,10 +1811,13 @@ def main(unused_args): "sigmoid.zip": make_sigmoid_tests, "softmax.zip": make_softmax_tests, "space_to_depth.zip": make_space_to_depth_tests, + "topk.zip": make_topk_tests, + "split.zip": make_split_tests, "transpose.zip": make_transpose_tests, "mean.zip": make_mean_tests, "squeeze.zip": make_squeeze_tests, "strided_slice.zip": make_strided_slice_tests, + "exp.zip": make_exp_tests, } out = FLAGS.zip_to_output bin_path = FLAGS.toco diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 41652a07d21fbf022cb66a4022706cfee02d2c09..49766cedac8d1acd96f9b38665119e99f8bb9ac0 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -47,9 +47,7 @@ tensorflow::Env* env = tensorflow::Env::Default(); // Key is a substring of the test name and value is a bug number. // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { - // Add doesn't support broadcasting. - {R"(^\/adda.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, - {R"(^\/mula.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + // Sub and Div don't support broadcasting. {R"(^\/diva.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, {R"(^\/suba.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, @@ -67,7 +65,11 @@ std::map kBrokenTests = { // L2Norm only supports tensors with 4D or fewer. {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, - // SpaceToBatch only supports 4D tensors. + // BatchToSpaceND doesn't support cropping. This catches test cases with + // non-const tensors as crops. + {R"(^\/batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\])", "70594634"}, + + // SpaceToBatchND only supports 4D tensors. {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, // L2Norm only works for dim=-1. @@ -87,12 +89,9 @@ std::map kBrokenTests = { // ResizeBilinear looks completely incompatible with Tensorflow {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"}, - {R"(^\/resize_bilinearalign_corners=True,.*,size=\[2,2\])", "72401483"}, - {R"(^\/resize_bilinearalign_corners=True,.*,size=\[4,3\])", "72401483"}, - {R"(^\/resize_bilinearalign_corners=True,.*,size=\[5,6\])", "72401483"}, // Transpose only supports 1D-4D input tensors. - {R"(^\/transposedtype=.*,input_shape=\[.,.,.,.,.\],perm=.*)", "71545879"}, + {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, }; // Allows test data to be unzipped into a temporary directory and makes @@ -239,11 +238,11 @@ INSTANTIATE_TESTS(avg_pool) INSTANTIATE_TESTS(space_to_batch_nd) INSTANTIATE_TESTS(batch_to_space_nd) INSTANTIATE_TESTS(concat) -// TODO(b/71642435) re-enable this test -// INSTANTIATE_TESTS(constant) +INSTANTIATE_TESTS(constant) INSTANTIATE_TESTS(control_dep) INSTANTIATE_TESTS(conv) INSTANTIATE_TESTS(depthwiseconv) +INSTANTIATE_TESTS(exp) INSTANTIATE_TESTS(fully_connected) INSTANTIATE_TESTS(fused_batch_norm) INSTANTIATE_TESTS(gather) @@ -263,6 +262,7 @@ INSTANTIATE_TESTS(sigmoid) INSTANTIATE_TESTS(softmax) INSTANTIATE_TESTS(space_to_depth) INSTANTIATE_TESTS(sub) +INSTANTIATE_TESTS(split) INSTANTIATE_TESTS(div) INSTANTIATE_TESTS(transpose) INSTANTIATE_TESTS(mean) diff --git a/tensorflow/contrib/lite/testing/join.h b/tensorflow/contrib/lite/testing/join.h new file mode 100644 index 0000000000000000000000000000000000000000..ce8c072a21c6e61e8ab8ae12ba52418e6144009a --- /dev/null +++ b/tensorflow/contrib/lite/testing/join.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ + +#include +#include +#include + +namespace tflite { +namespace testing { + +// Join a list of data separated by delimieter. +template +string Join(T* data, size_t len, const string& delimiter) { + if (len == 0 || data == nullptr) { + return ""; + } + std::stringstream result; + result << data[0]; + for (int i = 1; i < len; i++) { + result << delimiter << data[i]; + } + return result.str(); +} + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ diff --git a/tensorflow/contrib/lite/testing/join_test.cc b/tensorflow/contrib/lite/testing/join_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd04528381f6d31164728a5cabbf8753e9b8d2b8 --- /dev/null +++ b/tensorflow/contrib/lite/testing/join_test.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/join.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +TEST(JoinTest, JoinInt) { + std::vector data = {1, 2, 3}; + EXPECT_EQ(Join(data.data(), data.size(), ","), "1,2,3"); +} + +TEST(JoinTest, JoinFloat) { + float data[] = {1.0, -3, 2.3, 1e-5}; + EXPECT_EQ(Join(data, 4, " "), "1 -3 2.3 1e-05"); +} + +TEST(JoinTest, JoinNullData) { EXPECT_THAT(Join(nullptr, 3, ","), ""); } + +TEST(JoinTest, JoinZeroData) { + std::vector data; + EXPECT_THAT(Join(data.data(), 0, ","), ""); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h index 60eaafa474a01887bee12b031b1f59cc5c91f173..05770beee23275ebe210606dbfd2b33eea17612d 100644 --- a/tensorflow/contrib/lite/testing/test_runner.h +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -68,6 +68,10 @@ class TestRunner { // satisfied. virtual bool CheckResults() = 0; + // Read contents of tensor into csv format. + // The given 'id' is guaranteed to be one of the ids returned by GetOutputs(). + virtual string ReadOutput(int id) = 0; + // Set the base path for loading models. void SetModelBaseDir(const string& path) { model_base_dir_ = path; diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/contrib/lite/testing/test_runner_test.cc index f712a5347a042990ae5adb9d44325dd683193168..3f04aa20bd7de813f0acd3f5897d5ab2df6c0fd7 100644 --- a/tensorflow/contrib/lite/testing/test_runner_test.cc +++ b/tensorflow/contrib/lite/testing/test_runner_test.cc @@ -31,6 +31,7 @@ class ConcreteTestRunner : public TestRunner { void ResetTensor(int id) override {} void SetInput(int id, const string& csv_values) override {} void SetExpectation(int id, const string& csv_values) override {} + string ReadOutput(int id) override { return ""; } void Invoke() override {} bool CheckResults() override { return true; } bool CheckFloatSizes(size_t bytes, size_t values) { diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c253bb1983e5ddc5bc12858c929585d1bcee710 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tf_driver.cc @@ -0,0 +1,182 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tf_driver.h" + +#include +#include + +#include "tensorflow/contrib/lite/testing/join.h" +#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tflite { +namespace testing { + +namespace { + +tensorflow::Tensor CreateTensor(const tensorflow::DataType type, + const std::vector& dim) { + tensorflow::TensorShape shape{gtl::ArraySlice{ + reinterpret_cast(dim.data()), dim.size()}}; + return {type, shape}; +} + +template +void FillTensorWithData(tensorflow::Tensor* tensor, const string& csv_values) { + auto data = tensor->flat(); + + const auto& values = testing::Split(csv_values, ","); + for (int i = 0; i < values.size(); i++) { + data(i) = values[i]; + } +} + +template +void FillTensorWithZeros(tensorflow::Tensor* tensor) { + auto data = tensor->flat(); + for (int i = 0; i < tensor->NumElements(); i++) { + data(i) = 0; + } +} + +template +string TensorDataToCsvString(const tensorflow::Tensor& tensor) { + const auto& data = tensor.flat(); + return Join(data.data(), data.size(), ","); +} + +} // namespace + +TfDriver::TfDriver(const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer) + : input_names_(input_layer), output_names_(output_layer) { + CHECK_EQ(input_layer.size(), input_layer_type.size()); + CHECK_EQ(input_layer.size(), input_layer_shape.size()); + + input_ids_.resize(input_layer.size()); + input_tensors_.reserve(input_layer.size()); + input_types_.resize(input_layer.size()); + input_shapes_.resize(input_layer.size()); + for (int i = 0; i < input_layer.size(); i++) { + input_ids_[i] = i; + input_tensors_[input_layer[i]] = {}; + CHECK(DataTypeFromString(input_layer_type[i], &input_types_[i])); + input_shapes_[i] = Split(input_layer_shape[i], ","); + } + + output_ids_.resize(output_layer.size()); + output_tensors_.reserve(output_layer.size()); + for (int i = 0; i < output_layer.size(); i++) { + output_ids_[i] = i; + } +} + +void TfDriver::LoadModel(const string& bin_file_path) { + if (!IsValid()) return; + std::cout << std::endl << "Loading model: " << bin_file_path << std::endl; + std::ifstream model(bin_file_path); + if (model.fail()) { + Invalidate("Failed to find the model"); + return; + } + + tensorflow::GraphDef graphdef; + if (!graphdef.ParseFromIstream(&model)) { + Invalidate("Failed to parse tensorflow graphdef"); + return; + } + + tensorflow::SessionOptions options; + session_.reset(tensorflow::NewSession(options)); + auto status = session_->Create(graphdef); + if (!status.ok()) { + Invalidate("Failed to create session"); + } +} + +void TfDriver::SetInput(int id, const string& csv_values) { + if (!IsValid()) return; + + auto tensor = CreateTensor(input_types_[id], input_shapes_[id]); + switch (input_types_[id]) { + case tensorflow::DT_FLOAT: { + FillTensorWithData(&tensor, csv_values); + break; + } + case tensorflow::DT_INT32: { + FillTensorWithData(&tensor, csv_values); + break; + } + default: + fprintf(stderr, "Unsupported type %d in SetInput\n", input_types_[id]); + Invalidate("Unsupported tensor data type"); + return; + } + input_tensors_[input_names_[id]] = tensor; +} + +void TfDriver::ResetTensor(int id) { + if (!IsValid()) return; + auto tensor = input_tensors_[input_names_[id]]; + switch (input_types_[id]) { + case tensorflow::DT_FLOAT: { + FillTensorWithZeros(&tensor); + break; + } + case tensorflow::DT_INT32: { + FillTensorWithZeros(&tensor); + break; + } + default: + fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]); + Invalidate("Unsupported tensor data type"); + return; + } +} + +void TfDriver::ReshapeTensor(int id, const string& csv_values) { + input_shapes_[id] = Split(csv_values, ","); + input_tensors_[input_names_[id]] = + CreateTensor(input_types_[id], input_shapes_[id]); + ResetTensor(id); +} + +string TfDriver::ReadOutput(int id) { + if (!IsValid()) return ""; + switch (output_tensors_[id].dtype()) { + case tensorflow::DT_FLOAT: + return TensorDataToCsvString(output_tensors_[id]); + case tensorflow::DT_INT32: + return TensorDataToCsvString(output_tensors_[id]); + default: + fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]); + Invalidate("Unsupported tensor data type"); + return ""; + } +} + +void TfDriver::Invoke() { + if (!IsValid()) return; + auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()}, + output_names_, {}, &output_tensors_); + if (!status.ok()) { + Invalidate("Failed to invoke interpreter"); + } +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tf_driver.h b/tensorflow/contrib/lite/testing/tf_driver.h new file mode 100644 index 0000000000000000000000000000000000000000..b766f85c4ddee9fb7b1513c264d4159e694770ca --- /dev/null +++ b/tensorflow/contrib/lite/testing/tf_driver.h @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ + +#include +#include + +#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/contrib/lite/testing/test_runner.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session.h" + +namespace tflite { +namespace testing { + +// A test runner that feeds inputs into Tensorflow and generates outputs. +class TfDriver : public TestRunner { + public: + explicit TfDriver(const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer); + ~TfDriver() override {} + + void LoadModel(const string& bin_file_path) override; + void SetInput(int id, const string& csv_values) override; + void Invoke() override; + string ReadOutput(int id) override; + + const std::vector& GetInputs() override { return input_ids_; } + const std::vector& GetOutputs() override { return output_ids_; } + void ReshapeTensor(int id, const string& csv_values) override; + // Note: ResetTensor only works for input tensor. + void ResetTensor(int id) override; + + // no-op. SetInput will overwrite existing data . + void AllocateTensors() override {} + // no-op. Tf driver is not supposed to check the results. + void SetExpectation(int id, const string& csv_values) override {} + // tf driver is not supposed to check the results. + bool CheckResults() override { return false; } + + private: + std::unique_ptr session_; + std::vector input_ids_; + std::vector input_names_; + std::vector> input_shapes_; + std::vector input_types_; + std::unordered_map input_tensors_; + + std::vector output_ids_; + std::vector output_names_; + std::vector<::tensorflow::Tensor> output_tensors_; +}; + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ diff --git a/tensorflow/contrib/lite/testing/tf_driver_test.cc b/tensorflow/contrib/lite/testing/tf_driver_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0faa4676adc3e846ad398bb203b77b99a2ba360 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tf_driver_test.cc @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tf_driver.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; + +TEST(TfDriverTest, SimpleTest) { + std::unique_ptr runner( + new TfDriver({"a", "b", "c", "d"}, {"float", "float", "float", "float"}, + {"1,8,8,3", "1,8,8,3", "1,8,8,3", "1,8,8,3"}, {"x", "y"})); + + runner->LoadModel( + "third_party/tensorflow/contrib/lite/testdata/multi_add.pb"); + EXPECT_TRUE(runner->IsValid()) << runner->GetErrorMessage(); + + ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3)); + ASSERT_THAT(runner->GetOutputs(), ElementsAre(0, 1)); + + for (int i : {0, 1, 2, 3}) { + runner->ReshapeTensor(i, "1,2,2,1"); + } + ASSERT_TRUE(runner->IsValid()); + + runner->SetInput(0, "0.1,0.2,0.3,0.4"); + runner->SetInput(1, "0.001,0.002,0.003,0.004"); + runner->SetInput(2, "0.001,0.002,0.003,0.004"); + runner->SetInput(3, "0.01,0.02,0.03,0.04"); + runner->ResetTensor(2); + runner->Invoke(); + + ASSERT_EQ(runner->ReadOutput(0), "0.101,0.202,0.303,0.404"); + ASSERT_EQ(runner->ReadOutput(1), "0.011,0.022,0.033,0.044"); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index bae639ea95318a16c963269de5e55afcb681d4c5..613223f3d4ff212cb8672494243b2d7a1d06b3db 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -106,8 +106,8 @@ class TfLiteDriver::Expectation { if (error_is_large) { good_output = false; if (verbose) { - std::cerr << " index " << i << ": " << reference - << " != " << computed << std::endl; + std::cerr << " index " << i << ": got " << computed + << ", but expected " << reference << std::endl; } } } @@ -203,6 +203,10 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) { void TfLiteDriver::SetExpectation(int id, const string& csv_values) { if (!IsValid()) return; auto* tensor = interpreter_->tensor(id); + if (expected_output_.count(id) != 0) { + fprintf(stderr, "Overriden expectation for tensor %d\n", id); + Invalidate("Overriden expectation"); + } expected_output_[id].reset(new Expectation); switch (tensor->type) { case kTfLiteFloat32: diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h index 25689a9fb42c06fa3f8f2f92064cf59e8c331637..02b7de1534e648734d7bc53154afa42f2ef256b4 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.h +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -45,6 +45,7 @@ class TfLiteDriver : public TestRunner { void SetExpectation(int id, const string& csv_values) override; void Invoke() override; bool CheckResults() override; + string ReadOutput(int id) override { return "no-op"; } private: class Expectation; diff --git a/tensorflow/contrib/lite/tflite_static.bp b/tensorflow/contrib/lite/tflite_static.bp index 3db884de660475e995040da40b8cb3ee67de0fba..9b7d0adbbfa691e326abd73542c1c422cbb30ec1 100644 --- a/tensorflow/contrib/lite/tflite_static.bp +++ b/tensorflow/contrib/lite/tflite_static.bp @@ -22,6 +22,7 @@ cc_library_static { "arena_planner.cc", "context.c", "error_reporter.cc", + "graph_info.cc", "interpreter.cc", "model.cc", "nnapi_delegate.cc", @@ -32,10 +33,12 @@ cc_library_static { "kernels/add.cc", "kernels/basic_rnn.cc", "kernels/batch_to_space_nd.cc", + "kernels/bidirectional_sequence_rnn.cc", "kernels/concatenation.cc", "kernels/conv.cc", "kernels/depthwise_conv.cc", "kernels/div.cc", + "kernels/exp.cc", "kernels/embedding_lookup.cc", "kernels/embedding_lookup_sparse.cc", "kernels/fully_connected.cc", @@ -57,13 +60,16 @@ cc_library_static { "kernels/skip_gram.cc", "kernels/space_to_batch_nd.cc", "kernels/space_to_depth.cc", + "kernels/split.cc", "kernels/squeeze.cc", "kernels/strided_slice.cc", "kernels/sub.cc", "kernels/svdf.cc", + "kernels/topk_v2.cc", "kernels/transpose.cc", "kernels/unidirectional_sequence_lstm.cc", "kernels/unidirectional_sequence_rnn.cc", + "kernels/internal/kernel_utils.cc", "kernels/internal/tensor_utils.cc", "kernels/internal/quantization_util.cc", "kernels/internal/reference/portable_tensor_utils.cc", @@ -90,6 +96,7 @@ cc_library_static { "-Wno-mismatched-tags", "-Wno-missing-field-initializers", "-Wno-sign-compare", + "-Wno-typedef-redefinition", "-Wno-unused-lambda-capture", "-Wno-unused-parameter", "-Wno-unused-variable", diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 041e2487903c63572a7acda17f2f3ebc701be0c7..e2879fad327d965799d84da5c9092a12a36aa65b 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -160,6 +160,7 @@ cc_library( ], deps = [ # Placeholder for internal file dependency. + "@protobuf_archive//:protobuf_headers", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -173,6 +174,7 @@ cc_library( "graph_transformations/convert_pure_conv_to_depthwise.cc", "graph_transformations/convert_reorder_axes.cc", "graph_transformations/convert_trivial_addn_to_add.cc", + "graph_transformations/convert_trivial_stack_to_reshape.cc", "graph_transformations/convert_trivial_transpose_to_reshape.cc", "graph_transformations/create_im2col_arrays.cc", "graph_transformations/dequantize.cc", @@ -187,7 +189,10 @@ cc_library( "graph_transformations/identify_l2_normalization.cc", "graph_transformations/identify_l2_pool.cc", "graph_transformations/identify_lstm.cc", + "graph_transformations/identify_lstm_merge_inputs.cc", + "graph_transformations/identify_lstm_split_inputs.cc", "graph_transformations/identify_relu1.cc", + "graph_transformations/lstm_utils.cc", "graph_transformations/make_initial_dequantize_operator.cc", "graph_transformations/propagate_array_data_types.cc", "graph_transformations/propagate_fixed_sizes.cc", @@ -203,7 +208,9 @@ cc_library( "graph_transformations/remove_trivial_passthrough.h", "graph_transformations/remove_trivial_quantized_activation_func.cc", "graph_transformations/remove_trivial_reshape.cc", + "graph_transformations/remove_trivial_slice.cc", "graph_transformations/remove_unused_op.cc", + "graph_transformations/reorder_activation_functions.cc", "graph_transformations/resolve_batch_normalization.cc", "graph_transformations/resolve_batch_to_space_nd_attributes.cc", "graph_transformations/resolve_constant_binary.cc", @@ -214,8 +221,10 @@ cc_library( "graph_transformations/resolve_constant_shape_or_rank.cc", "graph_transformations/resolve_constant_stack.cc", "graph_transformations/resolve_constant_strided_slice.cc", + "graph_transformations/resolve_constant_transpose.cc", "graph_transformations/resolve_constant_unary.cc", "graph_transformations/resolve_mean_attributes.cc", + "graph_transformations/resolve_multiply_by_zero.cc", "graph_transformations/resolve_pad_attributes.cc", "graph_transformations/resolve_reorder_axes.cc", "graph_transformations/resolve_reshape_attributes.cc", @@ -230,9 +239,11 @@ cc_library( "graph_transformations/resolve_tensorflow_tile.cc", "graph_transformations/resolve_transpose_attributes.cc", "graph_transformations/unfuse_activation_functions.cc", + "graph_transformations/unroll_batch_matmul.cc", ], hdrs = [ "graph_transformations/graph_transformations.h", + "graph_transformations/lstm_utils.h", ], visibility = ["//visibility:public"], deps = [ @@ -243,6 +254,7 @@ cc_library( ":tooling_util", ":types_proto_cc", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc index 5961d30bf5403df7fa6228e05124479d118dd279..49cc1fc2aa365925cde86ceb658ff2b354d06911 100644 --- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc @@ -158,9 +158,7 @@ std::size_t TransientArraySize(const Model& model, const string& array_name, LOG(FATAL) << "A RNN state array, " << array_name << ", still does not " << "have a known data type after all graph transformations have " - << "run. That's mostly a toco bug --- sorry. For now, you can " - << "work around this issue by adding manually_create:true in the " - << "--rnn_state description of this RNN state."; + << "run."; } } LOG(FATAL) << "An array, " << array_name << ", still does not " diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 8004a1a37ae48468e9bf22785ec02f8de54bf236..b97a4720a7c4e69f8b69574475d19e0522cfe86d 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -208,6 +208,7 @@ struct ParsedModelFlags { Arg dump_graphviz_video = Arg(false); Arg allow_nonexistent_arrays = Arg(false); Arg allow_nonascii_arrays = Arg(false); + Arg arrays_extra_info_file; }; // Flags that describe the operation you would like to do (what conversion diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 529df3cd2e56f1888f3d431ddcd7dc7051a98355..570cc7943b2926136f6fdb21d20b8aa6acf8cd26 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -46,6 +46,32 @@ using tensorflow::TensorProto; namespace toco { namespace { +tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type) { + switch (data_type) { + case ArrayDataType::kBool: + return tensorflow::DT_BOOL; + case ArrayDataType::kFloat: + return tensorflow::DT_FLOAT; + case ArrayDataType::kUint8: + return tensorflow::DT_UINT8; + case ArrayDataType::kInt32: + return tensorflow::DT_INT32; + case ArrayDataType::kInt64: + return tensorflow::DT_INT64; + case ArrayDataType::kString: + return tensorflow::DT_STRING; + default: + case ArrayDataType::kNone: + LOG(FATAL) << "Unsupported data type: " << static_cast(data_type); + return tensorflow::DT_INVALID; + } +} + +tensorflow::DataType GetTensorFlowDataType(const Model& model, + const string& array_name) { + return GetTensorFlowDataType(model.GetArray(array_name).data_type); +} + // TensorFlow sometimes forbids what it calls "legacy scalars", // which are 1-D shapes where the unique shape size is 1. // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars. @@ -212,6 +238,24 @@ void ConvertIntTensorConst(const Model& model, const string& name, } } +void CreateIntTensorConst(const string& name, const std::vector& data, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + for (auto index : data) { + tensor->add_int_val(index); + } + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(data.size()); +} + void CreateMatrixShapeTensorConst(const string& name, int rows, int cols, GraphDef* tensorflow_graph) { if (HasAlreadyExportedConst(name, *tensorflow_graph)) { @@ -445,14 +489,23 @@ void ConvertSpaceToDepthOperator(const Model& model, void ConvertFullyConnectedOperator(const Model& model, const FullyConnectedOperator& src_op, GraphDef* tensorflow_graph) { - const string reshape_output = src_op.outputs[0] + "/reshape"; - const string reshape_shape = src_op.outputs[0] + "/reshape/shape"; + // Reshape input activations to have the shape expected by the MatMul. + const string reshape_output = + AvailableArrayName(model, src_op.outputs[0] + "/reshape"); + const string reshape_shape = + AvailableArrayName(model, reshape_output + "/shape"); + const auto& fc_weights_array = model.GetArray(src_op.inputs[1]); + const auto& fc_weights_shape = fc_weights_array.shape(); + CHECK_EQ(fc_weights_shape.dimensions_count(), 2); + CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, + tensorflow_graph); auto* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); reshape_op->add_input(src_op.inputs[0]); reshape_op->add_input(reshape_shape); - (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*reshape_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); const bool has_bias = src_op.inputs.size() >= 3; string matmul_output = src_op.outputs[0]; @@ -460,38 +513,43 @@ void ConvertFullyConnectedOperator(const Model& model, matmul_output += "/matmul"; } + // Transpose the RHS input from column-major to row-major to match TensorFlow + // expectations. This is the inverse of the transpose we do during + // ResolveTensorFlowMatMul. + const string transpose_output = + AvailableArrayName(model, matmul_output + "/transpose_weights"); + const string transpose_perm = + AvailableArrayName(model, transpose_output + "/perm"); + CreateIntTensorConst(transpose_perm, {1, 0}, tensorflow_graph); + auto transpose_op = tensorflow_graph->add_node(); + transpose_op->set_op("Transpose"); + transpose_op->set_name(transpose_output); + *transpose_op->add_input() = src_op.inputs[1]; + *transpose_op->add_input() = transpose_perm; + (*transpose_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); + (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32); + auto* matmul_op = tensorflow_graph->add_node(); matmul_op->set_op("MatMul"); - matmul_op->set_name(matmul_output); *matmul_op->add_input() = reshape_output; - *matmul_op->add_input() = src_op.inputs[1]; - (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT); + *matmul_op->add_input() = transpose_op->name(); + (*matmul_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); (*matmul_op->mutable_attr())["transpose_a"].set_b(false); (*matmul_op->mutable_attr())["transpose_b"].set_b(false); CHECK(model.HasArray(src_op.inputs[1])); - const string& fc_weights_name = - WalkUpToConstantArray(model, src_op.inputs[1]); - const auto& fc_weights_array = model.GetArray(fc_weights_name); - const auto& fc_weights_shape = fc_weights_array.shape(); - CHECK_EQ(fc_weights_shape.dimensions_count(), 2); - CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, - tensorflow_graph); - - CHECK(fc_weights_array.buffer); - CHECK(fc_weights_array.buffer->type == ArrayDataType::kFloat); - const float* fc_weights_data = - fc_weights_array.GetBuffer().data.data(); - ConvertFloatTensorConst(fc_weights_name, fc_weights_shape, fc_weights_data, - AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph); + // Add the bias, if it exists. if (has_bias) { auto* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(matmul_output); biasadd_op->add_input(src_op.inputs[2]); - (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*biasadd_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); CHECK(model.HasArray(src_op.inputs[2])); const auto& bias_array = model.GetArray(src_op.inputs[2]); // TODO(b/62904716) Bias arrays should be 1-D, and used directly. @@ -621,7 +679,8 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && + providing_op->type == OperatorType::kTensorFlowReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -656,6 +715,45 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT); } +void ConvertLogSoftmaxOperator(const Model& model, + const LogSoftmaxOperator& src_op, + GraphDef* tensorflow_graph) { + string softmax_input; + Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); + if (providing_op->type == OperatorType::kTensorFlowReshape) { + softmax_input = src_op.inputs[0]; + } else { + // Insert a reshape operator that reduces the dimensions down to the 2 that + // are required for TensorFlow Logits. + const string reshape_output = + src_op.outputs[0] + "/log_softmax_insert_reshape"; + const string softmax_size = src_op.outputs[0] + "/log_softmax_insert_size"; + softmax_input = reshape_output; + + auto* reshape_op = tensorflow_graph->add_node(); + reshape_op->set_op("Reshape"); + reshape_op->set_name(reshape_output); + *reshape_op->add_input() = src_op.inputs[0]; + *reshape_op->add_input() = softmax_size; + (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const auto& input_shape = model.GetArray(src_op.inputs[0]).shape(); + int32 flattened_size = 1; + for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) { + flattened_size *= input_shape.dims(i); + } + const std::vector shape_data = { + flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)}; + CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); + } + + auto* log_softmax_op = tensorflow_graph->add_node(); + log_softmax_op->set_op("LogSoftmax"); + log_softmax_op->set_name(src_op.outputs[0]); + *log_softmax_op->add_input() = softmax_input; + (*log_softmax_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, GraphDef* tensorflow_graph) { const string square_output = src_op.outputs[0] + "/square"; @@ -798,7 +896,8 @@ void ConvertConcatenationOperator(const Model& model, *dc_op->add_input() = input; } *dc_op->add_input() = dummy_axis; - (*dc_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*dc_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32); (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size()); } @@ -812,7 +911,8 @@ void ConvertTensorFlowReshapeOperator(const Model& model, CHECK_EQ(src_op.inputs.size(), 2); *reshape_op->add_input() = src_op.inputs[0]; *reshape_op->add_input() = src_op.inputs[1]; - (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*reshape_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); const auto& shape_array = model.GetArray(src_op.inputs[1]); QCHECK(shape_array.data_type == ArrayDataType::kInt32) << "Only int32 shape is supported."; @@ -909,24 +1009,6 @@ void ConvertSplitOperator(const Model& model, tensorflow_graph); } -tensorflow::DataType GetTensorFlowDataType(const Model& model, - const string& array_name) { - auto& dtype = model.GetArray(array_name).data_type; - CHECK(dtype == ArrayDataType::kFloat || dtype == ArrayDataType::kInt32 || - dtype == ArrayDataType::kUint8 || dtype == ArrayDataType::kInt64); - if (dtype == ArrayDataType::kFloat) { - return tensorflow::DT_FLOAT; - } else if (dtype == ArrayDataType::kInt32) { - return tensorflow::DT_INT32; - } else if (dtype == ArrayDataType::kUint8) { - return tensorflow::DT_UINT8; - } else if (dtype == ArrayDataType::kInt64) { - return tensorflow::DT_INT64; - } else { - LOG(FATAL) << "Wrong data type"; - } -} - void ConvertCastOperator(const Model& model, const CastOperator& src_op, GraphDef* tensorflow_graph) { auto* cast_op = tensorflow_graph->add_node(); @@ -981,6 +1063,113 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, GetTensorFlowDataType(model, src_op.outputs[0])); } +void ConvertTransposeOperator(const Model& model, + const TransposeOperator& src_op, + GraphDef* tensorflow_graph) { + auto* transpose_op = tensorflow_graph->add_node(); + transpose_op->set_op("Transpose"); + transpose_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *transpose_op->add_input() = src_op.inputs[0]; + *transpose_op->add_input() = src_op.inputs[1]; + (*transpose_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*transpose_op->mutable_attr())["Tperm"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); +} + +void ConvertTensorFlowShapeOperator(const Model& model, + const TensorFlowShapeOperator& src_op, + GraphDef* tensorflow_graph) { + auto* shape_op = tensorflow_graph->add_node(); + shape_op->set_op("Shape"); + shape_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *shape_op->add_input() = src_op.inputs[0]; + (*shape_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*shape_op->mutable_attr())["out_type"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); +} + +void ConvertRankOperator(const Model& model, const RankOperator& src_op, + GraphDef* tensorflow_graph) { + auto* rank_op = tensorflow_graph->add_node(); + rank_op->set_op("Rank"); + rank_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *rank_op->add_input() = src_op.inputs[0]; + (*rank_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); +} + +void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, + GraphDef* tensorflow_graph) { + auto* range_op = tensorflow_graph->add_node(); + range_op->set_op("Range"); + range_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 3); + *range_op->add_input() = src_op.inputs[0]; + *range_op->add_input() = src_op.inputs[1]; + *range_op->add_input() = src_op.inputs[2]; + (*range_op->mutable_attr())["Tidx"].set_type( + GetTensorFlowDataType(src_op.dtype)); +} + +void ConvertStackOperator(const Model& model, const StackOperator& src_op, + GraphDef* tensorflow_graph) { + auto* stack_op = tensorflow_graph->add_node(); + stack_op->set_op("Stack"); + stack_op->set_name(src_op.outputs[0]); + for (const auto& input : src_op.inputs) { + *stack_op->add_input() = input; + } + (*stack_op->mutable_attr())["elem_type"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); + (*stack_op->mutable_attr())["axis"].set_i(src_op.axis); +} + +void ConvertFillOperator(const Model& model, const FillOperator& src_op, + GraphDef* tensorflow_graph) { + auto* fill_op = tensorflow_graph->add_node(); + fill_op->set_op("Fill"); + fill_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *fill_op->add_input() = src_op.inputs[0]; + *fill_op->add_input() = src_op.inputs[1]; + (*fill_op->mutable_attr())["index_type"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*fill_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); +} + +void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op, + GraphDef* tensorflow_graph) { + auto* floor_div_op = tensorflow_graph->add_node(); + floor_div_op->set_op("FloorDiv"); + floor_div_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *floor_div_op->add_input() = src_op.inputs[0]; + *floor_div_op->add_input() = src_op.inputs[1]; + (*floor_div_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); +} + +void ConvertExpandDimsOperator(const Model& model, + const ExpandDimsOperator& src_op, + GraphDef* tensorflow_graph) { + auto* expand_dims_op = tensorflow_graph->add_node(); + expand_dims_op->set_op("ExpandDims"); + expand_dims_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *expand_dims_op->add_input() = src_op.inputs[0]; + *expand_dims_op->add_input() = src_op.inputs[1]; + (*expand_dims_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*expand_dims_op->mutable_attr())["Tdim"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); +} + void ConvertResizeBilinearOperator(const Model& model, const ResizeBilinearOperator& src_op, GraphDef* tensorflow_graph) { @@ -991,6 +1180,7 @@ void ConvertResizeBilinearOperator(const Model& model, *resize_op->add_input() = src_op.inputs[0]; *resize_op->add_input() = src_op.inputs[1]; (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners); } namespace { @@ -1046,8 +1236,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Write weights const string weights_output = base + "weights"; CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT])); - const auto& weights_array = - model.GetArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); + const string weights_name = WalkUpToConstantArray( + model, src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); + const auto& weights_array = model.GetArray(weights_name); // Convert 4D FullyConnected weights into 2D matrix const auto& weights_shape = weights_array.shape(); CHECK_EQ(weights_shape.dimensions_count(), 2); @@ -1072,8 +1263,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Write biases const string biases_output = base + "biases"; CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT])); - const auto& bias_array = - model.GetArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]); + const string bias_name = WalkUpToConstantArray( + model, src_op.inputs[LstmCellOperator::BIASES_INPUT]); + const auto& bias_array = model.GetArray(bias_name); // TODO(b/62904716) Bias arrays should be 1-D, and used directly. Shape bias_shape_1d = bias_array.shape(); UnextendShape(&bias_shape_1d, 1); @@ -1389,6 +1581,17 @@ void ConvertTensorFlowMaximumOperator(const Model& model, (*sub_op->mutable_attr())["T"].set_type(data_type); } +void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, + GraphDef* tensorflow_graph) { + auto* topk_op = tensorflow_graph->add_node(); + topk_op->set_op("TOPKV2"); + topk_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *topk_op->add_input() = src_op.inputs[0]; + *topk_op->add_input() = src_op.inputs[1]; + (*topk_op->mutable_attr())["sorted"].set_b(true); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1445,6 +1648,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kSoftmax) { ConvertSoftmaxOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kLogSoftmax) { + ConvertLogSoftmaxOperator(model, + static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kLocalResponseNormalization) { ConvertLocalResponseNormalizationOperator( static_cast(src_op), @@ -1533,6 +1740,35 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kArgMax) { ConvertArgMaxOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTopK_V2) { + ConvertTopKV2Operator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTranspose) { + ConvertTransposeOperator( + model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowShape) { + ConvertTensorFlowShapeOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kRank) { + ConvertRankOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kRange) { + ConvertRangeOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kStack) { + ConvertStackOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kFill) { + ConvertFillOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kFloorDiv) { + ConvertFloorDivOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kExpandDims) { + ConvertExpandDimsOperator(model, + static_cast(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } @@ -1622,6 +1858,30 @@ void ExportTensorFlowGraphDefImplementation(const Model& model, } } // namespace +void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model) { + for (const auto& array_kv : model->GetArrayMap()) { + const string& array_name = array_kv.first; + Array& array = *array_kv.second; + if (!array.buffer || !array.minmax) { + continue; + } + const string& wrapped_array_name = + AvailableArrayName(*model, array_name + "/data"); + Array& wrapped_array = model->GetOrCreateArray(wrapped_array_name); + wrapped_array.data_type = array.data_type; + wrapped_array.copy_shape(array.shape()); + wrapped_array.buffer = std::move(array.buffer); + FakeQuantOperator* fakequant_op = new FakeQuantOperator; + fakequant_op->inputs = {wrapped_array_name}; + fakequant_op->outputs = {array_name}; + fakequant_op->minmax.reset(new MinMax); + *fakequant_op->minmax = *array.minmax; + const auto& it = FindOpWithInput(*model, array_name); + model->operators.emplace(it, fakequant_op); + } + CheckInvariants(*model); +} + void ExportTensorFlowGraphDef(const Model& model, string* output_file_contents) { CHECK(output_file_contents->empty()); diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.h b/tensorflow/contrib/lite/toco/export_tensorflow.h index 79682153a8fd143c4934095567764b886bd776af..d7310bb75f258cde25236da2a9269f18234784e4 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.h +++ b/tensorflow/contrib/lite/toco/export_tensorflow.h @@ -22,6 +22,8 @@ namespace toco { void ExportTensorFlowGraphDef(const Model& model, string* output_file_contents); +void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model); + } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md index 7e152f5ba887088c98055596f8245b82fbc86eaa..372c52558973f4aacc180ac44b9e95a5e9b199ef 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md @@ -23,7 +23,7 @@ curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_ bazel run --config=opt \ //tensorflow/contrib/lite/toco:toco -- \ --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ - --output_file=/tmp/foo.lite \ + --output_file=/tmp/foo.tflite \ --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --inference_type=FLOAT \ @@ -101,7 +101,7 @@ direction, let us just give an example of that: ``` bazel run --config=opt \ //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/foo.lite \ + --input_file=/tmp/foo.tflite \ --output_file=/tmp/foo.pb \ --input_format=TFLITE \ --output_format=TENSORFLOW_GRAPHDEF \ @@ -130,7 +130,7 @@ flatbuffer is done like this: bazel run --config=opt \ //tensorflow/contrib/lite/toco:toco -- \ --input_file=/tmp/some_quantized_graph.pb \ - --output_file=/tmp/foo.lite \ + --output_file=/tmp/foo.tflite \ --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --inference_type=QUANTIZED_UINT8 \ @@ -207,7 +207,7 @@ curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_ bazel run --config=opt \ //tensorflow/contrib/lite/toco:toco -- \ --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ - --output_file=/tmp/foo.lite \ + --output_file=/tmp/foo.tflite \ --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --inference_type=FLOAT \ @@ -235,7 +235,7 @@ curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_ bazel run --config=opt \ //tensorflow/contrib/lite/toco:toco -- \ --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ - --output_file=/tmp/foo.lite \ + --output_file=/tmp/foo.tflite \ --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --inference_type=FLOAT \ @@ -308,7 +308,7 @@ curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_ bazel run --config=opt \ //tensorflow/contrib/lite/toco:toco -- \ --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ - --output_file=/tmp/foo.lite \ + --output_file=/tmp/foo.tflite \ --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --inference_type=FLOAT \ @@ -415,7 +415,7 @@ curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_ bazel run --config=opt \ //tensorflow/contrib/lite/toco:toco -- \ --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ - --output_file=/tmp/foo.lite \ + --output_file=/tmp/foo.tflite \ --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --inference_type=FLOAT \ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..0615b5e6c6db910ee847188427b416fd812aa141 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ConvertTrivialStackToReshape::Run(Model* model, std::size_t op_index) { + auto stack_it = model->operators.begin() + op_index; + if (stack_it->get()->type != OperatorType::kStack) { + return false; + } + auto* stack_op = static_cast(stack_it->get()); + if (stack_op->inputs.size() > 1) { + // Not trivial. + return false; + } + CHECK_EQ(stack_op->outputs.size(), 1); + + const auto& input_array = model->GetArray(stack_op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return false; + } + if (input_array.shape().dimensions_count() == 0) { + // Input array cannot be 0-D. + // (Unsure if this is TF behavior, but was required to get a test to pass.) + return false; + } + + AddMessageF("Converting trivial %s to a reshape", LogName(*stack_op)); + + // Note that we could convert to ExpandDims but toco prefers reshapes. + auto* reshape_op = new TensorFlowReshapeOperator; + reshape_op->inputs = {stack_op->inputs[0]}; + reshape_op->outputs = stack_op->outputs; + + // Create shape param. + string shape_array_name = + AvailableArrayName(*model, stack_op->outputs[0] + "_shape"); + Array& shape_array = model->GetOrCreateArray(shape_array_name); + *(shape_array.mutable_shape()->mutable_dims()) = { + 1 + input_array.shape().dimensions_count()}; + reshape_op->inputs.push_back(shape_array_name); + shape_array.data_type = ArrayDataType::kInt32; + auto& shape_buffer = shape_array.GetMutableBuffer(); + shape_buffer.data.push_back(1); + for (int dim : input_array.shape().dims()) { + shape_buffer.data.push_back(dim); + } + + // Replace the operator in the graph. + const auto reshape_it = model->operators.emplace(stack_it, reshape_op); + stack_it = reshape_it + 1; + CHECK_EQ(stack_it->get(), stack_op); + model->operators.erase(stack_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc index 88e59664ec427841df6f20686238feacef6a47e9..ab943f72d1dd87ae9ff4bd53a807cd4923a88c38 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -68,12 +68,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { return false; } - // TODO(b/72172404): Great many ops don't support activation function - // fusing. Switch to a categorizing function instead. - if (op->type == OperatorType::kConcatenation || - op->type == OperatorType::kSlice || - op->type == OperatorType::kTensorFlowReshape || - op->type == OperatorType::kTensorFlowSplit) { + if (!OperatorSupportsFusedActivation(op->type)) { AddMessageF( "Not fusing activation function because the %s op doesn't support it", LogName(*op)); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index e11bebcd4e0f66faf63290e3af0c72c39811cebe..616bdac268c41d29135368d685729c961f44132b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -115,6 +115,7 @@ void RunGraphTransformations(Model* model, const string& message, DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd) +DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) @@ -124,6 +125,8 @@ DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) +DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs) +DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs) DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) @@ -136,6 +139,7 @@ DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialSlice) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc) DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp) DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization) @@ -144,6 +148,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator) DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays) DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays) DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax) +DECLARE_GRAPH_TRANSFORMATION(ReorderActivationFunctions) DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) @@ -153,8 +158,10 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose) DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant) DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions) +DECLARE_GRAPH_TRANSFORMATION(UnrollBatchMatMul) DECLARE_GRAPH_TRANSFORMATION(ResolveSpaceToBatchNDAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes) @@ -167,6 +174,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) +DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero) DECLARE_GRAPH_TRANSFORMATION(Dequantize) class ResolveReshapeAttributes : public GraphTransformation { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 9689b205cd137904504d87906cb691d0ed8235bf..1b0be858107b54f5a6ecd2a1cb87c9dbde1c06bb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -177,6 +177,106 @@ bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min, output_minmax.max = max; return true; } + +// Propagates MinMax from any of the listed arrays, to all others. +// If multiple of these arrays have MinMax, then these are required +// to agree with each other. +bool PropagateMinMaxAmongArrays(Model* model, + const std::vector array_names) { + string reference_array_name; + MinMax* reference_minmax = nullptr; + for (const string& array_name : array_names) { + if (model->GetArray(array_name).minmax) { + reference_array_name = array_name; + reference_minmax = model->GetArray(array_name).minmax.get(); + break; + } + } + // No MinMax info is available to propagate. + if (!reference_minmax) { + return false; + } + bool changed = false; + for (const string& array_name : array_names) { + auto& array = model->GetArray(array_name); + if (array.minmax) { + CHECK(*array.minmax == *reference_minmax) + << "Both the following arrays have minmax, and they disagree: " + << reference_array_name << " and " << array_name + << ". Expected that either only one of them would have minmax, or at " + "least that they would agree."; + } else { + array.GetOrCreateMinMax() = *reference_minmax; + changed = true; + } + } + return changed; +} + +bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) { + CHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS); + CHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); + + bool changed = false; + changed |= PropagateMinMaxAmongArrays( + model, {op->inputs[LstmCellOperator::PREV_STATE_INPUT], + op->outputs[LstmCellOperator::STATE_OUTPUT]}); + + auto& input_activations = + model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]); + if (!input_activations.minmax) { + auto& minmax = input_activations.GetOrCreateMinMax(); + minmax.min = -1; + minmax.max = 127. / 128.; + changed = true; + } + + auto& prev_output_activations = + model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]); + if (!prev_output_activations.minmax) { + auto& minmax = prev_output_activations.GetOrCreateMinMax(); + minmax.min = -1; + minmax.max = 127. / 128.; + changed = true; + } + + auto& output_concat_temp = + model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP]); + if (!output_concat_temp.minmax) { + auto& minmax = output_concat_temp.GetOrCreateMinMax(); + minmax.min = -1; + minmax.max = 127. / 128.; + changed = true; + } + + auto& output_activations = + model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT]); + if (!output_activations.minmax) { + auto& minmax = output_activations.GetOrCreateMinMax(); + minmax.min = -1; + minmax.max = 127. / 128.; + changed = true; + } + + // (This comment should morph into proper documentation for + // quantization of LSTM models. It isn't just a local implementation detail, + // the training code for LSTM models needs to be adjusted to that.) + // + // Finally, output_activations_temp holds the output of the fully-connected + // node inside the LSTM cell. For it, we hardcode a minmax of [-8, 8]. + // The rationale for that is given in a lengthy comment on the LstmCell + // quantized runtime implementation in reference_ops.h. + auto& output_activations_temp = + model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP]); + if (!output_activations_temp.minmax) { + auto& minmax = output_activations_temp.GetOrCreateMinMax(); + minmax.min = -8; + minmax.max = 8 * 32767. / 32768.; + changed = true; + } + + return changed; +} } // namespace bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { @@ -219,6 +319,16 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.); break; + case OperatorType::kTanh: + // We hardcode quantization_params to: zero_point=127, scale=1/128. + // This choice of minmax is the one that is equivalent to that. + changed = HardcodeMinMaxForOutput(model, op, -127. / 128., 1.0); + break; + + case OperatorType::kLstmCell: + changed = HardcodeMinMaxForLstmCell(model, op); + break; + default: break; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc index 082820fddcf137238867239bbc4d4eed8158e307..c363b93394f0af7bcfc37c1e8be5f98aca6667ae 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "absl/strings/string_view.h" #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" @@ -202,23 +201,6 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, return true; } -absl::string_view FindLongestCommonPrefix(absl::string_view a, - absl::string_view b) { - if (a.empty() || b.empty()) return absl::string_view(); - - const char* pa = a.data(); - const char* pb = b.data(); - size_t count = 0; - const ssize_t limit = std::min(a.size(), b.size()); - while (count < limit && *pa == *pb) { - ++pa; - ++pb; - ++count; - } - - return absl::string_view(a.data(), count); -} - } // namespace bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc new file mode 100644 index 0000000000000000000000000000000000000000..45335fd78c99a577d535770d78acf4fcd6c04531 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { + // Find lstm cell. + auto op_it = model->operators.begin() + op_index; + auto src_op = op_it->get(); + if (src_op->type != OperatorType::kLstmCell) { + return false; + } + + // Already a compact LstmCell with LstmCellOperator::NUM_INPUTS of inputs, + // do not need to merge cell inputs. + if (src_op->inputs.size() == LstmCellOperator::NUM_INPUTS) { + return false; + } + + // Identify prev_activ_input, prev_state_input as required Op inputs, + // using the rnn_states in the model flag. + string prev_activ_input; + if (!GetMatchingRnnArray(model, src_op->outputs[kOutputTensor], + &prev_activ_input)) { + return false; + } + string prev_state_input; + if (!GetMatchingRnnArray(model, src_op->outputs[kCellStateTensor], + &prev_state_input)) { + return false; + } + + // Get LstmCell's cell, input, output size. + int num_cell = model->GetArray(src_op->inputs[kInputToInputWeightsTensor]) + .shape() + .dims(0); + int num_input = model->GetArray(src_op->inputs[kInputToInputWeightsTensor]) + .shape() + .dims(1); + int num_output = + model->GetArray(src_op->inputs[kRecurrentToInputWeightsTensor]) + .shape() + .dims(1); + + // Make sure n_cell and n_output are equal as there is no projection. + CHECK_EQ(num_cell, num_output); + + // Create tensorflow_graphdef style's one big weight tensor. + const string base_name(FindLongestCommonPrefix( + src_op->outputs[kOutputTensor], src_op->outputs[kCellStateTensor])); + string merged_weights = AvailableArrayName(*model, base_name + "weights"); + auto& array = model->GetOrCreateArray(merged_weights); + array.data_type = ArrayDataType::kFloat; + int weights_dim1 = 4 * num_cell; + int weights_dim2 = num_input + num_output; + Shape shape = Shape({weights_dim1, weights_dim2}); + array.copy_shape(shape); + auto& buffer = array.GetMutableBuffer(); + buffer.data.resize(weights_dim1 * weights_dim2); + + // Merge 8 small weight tensors to 1 weight tensor. + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kInputToInputWeightsTensor]), 0, 0); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kInputToCellWeightsTensor]), num_cell, 0); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kInputToForgetWeightsTensor]), + num_cell * 2, 0); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kInputToOutputWeightsTensor]), + num_cell * 3, 0); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kRecurrentToInputWeightsTensor]), 0, + num_input); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kRecurrentToCellWeightsTensor]), num_cell, + num_input); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kRecurrentToForgetWeightsTensor]), + num_cell * 2, num_input); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kRecurrentToOutputWeightsTensor]), + num_cell * 3, num_input); + + // Create tensorflow_graphdef style's one big bias tensor. + string merged_biases = AvailableArrayName(*model, base_name + "biases"); + auto& bias_array = model->GetOrCreateArray(merged_biases); + bias_array.data_type = ArrayDataType::kFloat; + bias_array.copy_shape(Shape({weights_dim1})); + auto& bias_buffer = bias_array.GetMutableBuffer(); + bias_buffer.data.resize(weights_dim1); + + // Merge 4 small bias tensors into a big one. + CopyArrayToSubArray(bias_buffer, weights_dim2, + model->GetArray(src_op->inputs[kInputGateBiasTensor]), 0, + 0); + CopyArrayToSubArray(bias_buffer, weights_dim2, + model->GetArray(src_op->inputs[kCellGateBiasTensor]), + num_cell, 0); + CopyArrayToSubArray(bias_buffer, weights_dim2, + model->GetArray(src_op->inputs[kForgetGateBiasTensor]), + num_cell * 2, 0); + CopyArrayToSubArray(bias_buffer, weights_dim2, + model->GetArray(src_op->inputs[kOutputGateBiasTensor]), + num_cell * 3, 0); + + // Emplace a new LSTM cell operator (use basic 5 inputs kernel). + auto lstm_cell_op = absl::make_unique(); + + // Compact LstmCell's 5 inputs. + lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS); + lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] = + src_op->inputs[kInputTensor]; + lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] = merged_weights; + lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] = merged_biases; + lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = prev_activ_input; + lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state_input; + + // Reorder LstmCell's 4 outputs. + lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS); + lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] = + src_op->outputs[kOutputTensor]; + lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] = + src_op->outputs[kCellStateTensor]; + lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = + src_op->outputs[kScratchBufferTensor]; + lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = + src_op->outputs[kOutputStateTensor]; + + // Add the op into model. + model->operators.emplace(op_it, std::move(lstm_cell_op)); + AddMessageF("Creating compact LstmCell replacing previous lstm cell"); + + // Delete arrays and operators replaced by the LSTM cell operator. Order is + // important - DeleteArrayIfUnused() only succeeds if dependent operators + // have been removed first. Start at the output and work towards the input. + // Erase curr lstm op being replaced. + DeleteArrayIfUnused(src_op->inputs[kInputToInputWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kInputToForgetWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kInputToCellWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kInputToOutputWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kRecurrentToInputWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kRecurrentToForgetWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kRecurrentToCellWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kRecurrentToOutputWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kInputGateBiasTensor], model); + DeleteArrayIfUnused(src_op->inputs[kForgetGateBiasTensor], model); + DeleteArrayIfUnused(src_op->inputs[kCellGateBiasTensor], model); + DeleteArrayIfUnused(src_op->inputs[kOutputGateBiasTensor], model); + model->operators.erase(FindOp(*model, src_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc new file mode 100644 index 0000000000000000000000000000000000000000..eca717680af281018b919c27068ba5d9f5699d69 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -0,0 +1,171 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { + // Find lstm cell. + auto op_it = model->operators.begin() + op_index; + auto curr_op = op_it->get(); + if (curr_op->type != OperatorType::kLstmCell) { + return false; + } + + // Already an extended LstmCell with kExtendedLstmInputCount of inputs, + // do not need to split cell inputs. + if (curr_op->inputs.size() == kExtendedLstmInputCount) { + return false; + } + + // Make sure the WEIGHTS_INPUT and BIASES_INPUT are constant arrays, + // that are able to be split into smaller weight and bias tensors. + if (!IsConstantParameterArray( + *model, curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]) || + !IsConstantParameterArray( + *model, curr_op->inputs[LstmCellOperator::BIASES_INPUT])) { + return false; + } + + // Make sure propagate_fixed_sizes has defined the size of the output. + if (!model->GetArray(curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT]) + .has_shape()) { + return false; + } + + // Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc). + auto lstm_cell_op = absl::make_unique(); + lstm_cell_op->inputs.resize(kExtendedLstmInputCount); + int num_input = model->GetArray(curr_op->inputs[LstmCellOperator::DATA_INPUT]) + .shape() + .dims(1); + + // n_cell and n_output have the same size when there is no projection. + int num_cell = + model->GetArray(curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT]) + .shape() + .dims(1); + int num_output = num_cell; + + // Data input. + lstm_cell_op->inputs[kInputTensor] = + curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT]; + + // Get original weight tensor and decompose 1 tensor to 8 sub tensors. + Array& kernel = + model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]); + const string base_name(FindLongestCommonPrefix( + curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT], + curr_op->outputs[LstmCellOperator::STATE_OUTPUT])); + + // Input weight tensors of size {n_cell, n_input}. + CopySubArrayToArray( + model, &(lstm_cell_op->inputs[kInputToInputWeightsTensor]), + base_name + "weight_i_i", num_cell, num_input, kernel, 0, 0); + CopySubArrayToArray(model, &(lstm_cell_op->inputs[kInputToCellWeightsTensor]), + base_name + "weight_c_i", num_cell, num_input, kernel, + num_cell, 0); + CopySubArrayToArray( + model, &(lstm_cell_op->inputs[kInputToForgetWeightsTensor]), + base_name + "weight_f_i", num_cell, num_input, kernel, num_cell * 2, 0); + CopySubArrayToArray( + model, &(lstm_cell_op->inputs[kInputToOutputWeightsTensor]), + base_name + "weight_o_i", num_cell, num_input, kernel, num_cell * 3, 0); + + // Recurrent weight tensors of size {n_cell, n_output}. + CopySubArrayToArray( + model, &(lstm_cell_op->inputs[kRecurrentToInputWeightsTensor]), + base_name + "weight_i_r", num_cell, num_output, kernel, 0, num_input); + CopySubArrayToArray(model, + &(lstm_cell_op->inputs[kRecurrentToCellWeightsTensor]), + base_name + "weight_c_r", num_cell, num_output, kernel, + num_cell, num_input); + CopySubArrayToArray(model, + &(lstm_cell_op->inputs[kRecurrentToForgetWeightsTensor]), + base_name + "weight_f_r", num_cell, num_output, kernel, + num_cell * 2, num_input); + CopySubArrayToArray(model, + &(lstm_cell_op->inputs[kRecurrentToOutputWeightsTensor]), + base_name + "weight_o_r", num_cell, num_output, kernel, + num_cell * 3, num_input); + + // Peephole (optional). + CreateOptionalArray(model, &(lstm_cell_op->inputs[kCellToInputWeightsTensor]), + base_name + "peephole_c_i"); + CreateOptionalArray(model, + &(lstm_cell_op->inputs[kCellToForgetWeightsTensor]), + base_name + "peephole_c_f"); + CreateOptionalArray(model, + &(lstm_cell_op->inputs[kCellToOutputWeightsTensor]), + base_name + "peephole_c_o"); + + // Get original bias tensor and decompose 1 tensor to 4 sub tensors + Array& bias = + model->GetArray(curr_op->inputs[LstmCellOperator::BIASES_INPUT]); + CopySubArrayToArray(model, &(lstm_cell_op->inputs[kInputGateBiasTensor]), + base_name + "bias_i", num_cell, 1, bias, 0, 0); + CopySubArrayToArray(model, &(lstm_cell_op->inputs[kCellGateBiasTensor]), + base_name + "bias_c", num_cell, 1, bias, num_cell, 0); + CopySubArrayToArray(model, &(lstm_cell_op->inputs[kForgetGateBiasTensor]), + base_name + "bias_f", num_cell, 1, bias, num_cell * 2, 0); + CopySubArrayToArray(model, &(lstm_cell_op->inputs[kOutputGateBiasTensor]), + base_name + "bias_o", num_cell, 1, bias, num_cell * 3, 0); + + // Projection (optional). + CreateOptionalArray(model, &(lstm_cell_op->inputs[kProjectionWeightsTensor]), + base_name + "proj_weight"); + CreateOptionalArray(model, &(lstm_cell_op->inputs[kProjectionBiasTensor]), + base_name + "proj_bias"); + + // Reorder LstmCell's outputs. + lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS); + lstm_cell_op->outputs[kScratchBufferTensor] = + curr_op->outputs[LstmCellOperator::CONCAT_TEMP]; + lstm_cell_op->outputs[kOutputStateTensor] = + curr_op->outputs[LstmCellOperator::ACTIV_TEMP]; + lstm_cell_op->outputs[kCellStateTensor] = + curr_op->outputs[LstmCellOperator::STATE_OUTPUT]; + lstm_cell_op->outputs[kOutputTensor] = + curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT]; + + // Add the op into model. + model->operators.emplace(op_it, std::move(lstm_cell_op)); + AddMessageF("Creating extended LstmCell replacing previous lstm cell"); + + // Delete arrays and operators replaced by the LSTM cell operator. Order is + // important - DeleteArrayIfUnused() only succeeds if dependent operators + // have been removed first. Start at the output and work towards the input. + // Erase curr lstm op being replaced. + DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model); + DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model); + DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT], + model); + DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT], + model); + model->operators.erase(FindOp(*model, curr_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc index d36e95060937d6af0789766bcb29ae70cef4569d..de6d8889fb4ccdb56e9639ab0dd7d093bfa4b908 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -57,45 +57,60 @@ int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op, } // namespace bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { - const auto maximum_it = model->operators.begin() + op_index; - const auto* maximum_op = maximum_it->get(); - if (maximum_op->type != OperatorType::kTensorFlowMaximum) { + // Follow sequences of min+max and max+min. First get the leading op. + const auto op_it = model->operators.begin() + op_index; + const auto* op_0 = op_it->get(); + if (op_0->type != OperatorType::kTensorFlowMinimum && + op_0->type != OperatorType::kTensorFlowMaximum) { return false; } - CHECK_EQ(maximum_op->inputs.size(), 2); - if (maximum_op->outputs.size() != 1) { - return false; - } - int scalar_input_index = - GetSingleScalarInputIndexOfBinaryOp(model, maximum_op, -1.0f); - if (scalar_input_index == -1) { + + // Get the paired op and ensure it's the counter to the first. + const auto* op_1 = GetOpWithInput(*model, op_0->outputs[0]); + if (!op_1 || + (op_1->type != OperatorType::kTensorFlowMinimum && + op_1->type != OperatorType::kTensorFlowMaximum) || + op_0->type == op_1->type) { return false; } - const auto* minimum_op = GetOpWithInput(*model, maximum_op->outputs[0]); - if (!minimum_op || minimum_op->type != OperatorType::kTensorFlowMinimum) { + + const auto* min_op = + op_0->type == OperatorType::kTensorFlowMinimum ? op_0 : op_1; + const auto* max_op = + op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1; + + CHECK_EQ(min_op->inputs.size(), 2); + CHECK_EQ(max_op->inputs.size(), 2); + if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) { return false; } - if (GetSingleScalarInputIndexOfBinaryOp(model, minimum_op, 1.0f) == -1) { + + // Get the original input to the min+max pair. + int min_scalar_input_index = + GetSingleScalarInputIndexOfBinaryOp(model, min_op, 1.0f); + int max_scalar_input_index = + GetSingleScalarInputIndexOfBinaryOp(model, max_op, -1.0f); + if (min_scalar_input_index == -1 || max_scalar_input_index == -1) { return false; } - CHECK_EQ(minimum_op->inputs.size(), 2); + int op_0_scalar_input_index = + op_0 == min_op ? min_scalar_input_index : max_scalar_input_index; - // Create and emplace Relu1 node + // Create and emplace Relu1 node. auto* relu1_op = new Relu1Operator; - relu1_op->inputs = {maximum_op->inputs[!scalar_input_index]}; - relu1_op->outputs = minimum_op->outputs; - model->operators.emplace(maximum_it, relu1_op); + relu1_op->inputs = {op_0->inputs[!op_0_scalar_input_index]}; + relu1_op->outputs = op_1->outputs; + model->operators.emplace(op_it, relu1_op); AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op)); - // Erase Maximum scalar input & operator - model->EraseArray(maximum_op->inputs[scalar_input_index]); - model->operators.erase(FindOperator(model, maximum_op)); - - // Erase Minimum inputs & operator - model->EraseArray(minimum_op->inputs[0]); - model->EraseArray(minimum_op->inputs[1]); - model->operators.erase(FindOperator(model, minimum_op)); + // Erase op scalar inputs & operators. Note that we preserve the non-scalar + // input to the first op as that's been redirected to the relu1_op. + DeleteArrayIfUsedOnce(op_0->inputs[op_0_scalar_input_index], model); + DeleteArrayIfUsedOnce(op_1->inputs[0], model); + DeleteArrayIfUsedOnce(op_1->inputs[1], model); + model->operators.erase(FindOperator(model, op_0)); + model->operators.erase(FindOperator(model, op_1)); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..910a96058979887972b41f27b2e570e8cb5b4f4c --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" + +namespace toco { + +void CreateOptionalArray(Model* model, string* input_array_buffer, + const string& array_name) { + *input_array_buffer = array_name; + model->CreateOptionalArray(array_name); +} + +void CopyArrayData(const Buffer& src_buffer, + int src_stride, int src_start_idx1, int src_start_idx2, + Buffer* dst_buffer, int dst_stride, + int dst_start_idx1, int dst_start_idx2, int dim1_copy_size, + int dim2_copy_size) { + int src_offset = src_start_idx1 * src_stride + src_start_idx2; + int dst_offset = dst_start_idx1 * dst_stride + dst_start_idx2; + for (int i = 0; i < dim1_copy_size; i++) { + for (int j = 0; j < dim2_copy_size; j++) { + int idx_src = src_offset + i * src_stride + j; + int idx_dst = dst_offset + i * dst_stride + j; + dst_buffer->data[idx_dst] = src_buffer.data[idx_src]; + } + } +} + +Buffer* CreateFloatArrayBuffer(Model* model, + string* array_name, + const Shape& shape) { + *array_name = AvailableArrayName(*model, *array_name); + auto& array = model->GetOrCreateArray(*array_name); + array.data_type = ArrayDataType::kFloat; + array.copy_shape(shape); + Buffer* buffer = + &(array.GetMutableBuffer()); + buffer->data.resize(RequiredBufferSizeForShape(shape)); + return buffer; +} + +void CopySubArrayToArray(Model* model, string* array_name, + const string& tensor_name, int dim1_size, + int dim2_size, const Array& original_array, + int start_idx1, int start_idx2) { + // Determine whether it's bias or not, create shape, buffer. + bool is_bias = dim2_size == 1; + Shape shape = is_bias ? Shape({dim1_size}) : Shape({dim1_size, dim2_size}); + Buffer* buffer = + CreateFloatArrayBuffer(model, array_name, shape); + auto& orig_buffer = original_array.GetBuffer(); + + // Copy data from big tensor. + CopyArrayData(orig_buffer, is_bias ? 1 : original_array.shape().dims(1), + start_idx1, start_idx2, buffer, dim2_size, 0, 0, dim1_size, + dim2_size); +} + +void CopyArrayToSubArray(Buffer& tensor_buffer, + int tensor_stride, const Array& sub_array, + int start_idx1, int start_idx2) { + // Get tensor data. + bool is_bias = sub_array.shape().dims().size() == 1; + int dim1_copy_size = sub_array.shape().dims()[0]; + int dim2_copy_size = is_bias ? 1 : sub_array.shape().dims(1); + auto& sub_buffer = sub_array.GetBuffer(); + + // Copy data from sub tensor. + CopyArrayData(sub_buffer, dim2_copy_size, 0, 0, &tensor_buffer, + is_bias ? 1 : tensor_stride, start_idx1, start_idx2, + dim1_copy_size, dim2_copy_size); +} + +bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array, + string* rnn_array) { + for (const auto& rnn_state : model->flags.rnn_states()) { + if (rnn_state.back_edge_source_array() == back_edge_source_array) { + *rnn_array = rnn_state.state_array(); + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..881c2d4dc892625d4640cac867a2f49c24b638f5 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +// For consistency with the parameters defined in extended LstmCell's kernel +// (tensorflow/contrib/lite/kernels/lstm.cc), +// use lowercase for these constants. + +enum ExtendedLstmCellInputs { + kInputTensor = 0, + kInputToInputWeightsTensor = 1, // Optional + kInputToForgetWeightsTensor = 2, + kInputToCellWeightsTensor = 3, + kInputToOutputWeightsTensor = 4, + kRecurrentToInputWeightsTensor = 5, // Optional + kRecurrentToForgetWeightsTensor = 6, + kRecurrentToCellWeightsTensor = 7, + kRecurrentToOutputWeightsTensor = 8, + kCellToInputWeightsTensor = 9, // Optional + kCellToForgetWeightsTensor = 10, // Optional + kCellToOutputWeightsTensor = 11, // Optional + kInputGateBiasTensor = 12, // Optional + kForgetGateBiasTensor = 13, + kCellGateBiasTensor = 14, + kOutputGateBiasTensor = 15, + kProjectionWeightsTensor = 16, // Optional + kProjectionBiasTensor = 17, // Optional + kExtendedLstmInputCount = 18 +}; + +enum ExtendedLstmCellOutputs { + kScratchBufferTensor = 0, + kOutputStateTensor = 1, + kCellStateTensor = 2, + kOutputTensor = 3 +}; + +// Create optional array used for optional tensor in ExtendedLstmCell inputs. +void CreateOptionalArray(Model* model, string* input_array_buffer, + const string& array_name); + +// Create float array and get its buffer. +Buffer* CreateFloatArrayBuffer(Model* model, + string* array_name, + const Shape& shape); + +// Copy data from one array to the other one (supports 1D and 2D array), +// for 1D array, the 2nd dim's size is 1. +// Arguments: +// src_buffer: the source buffer +// src_stride: the stride of source buffer, i.e., 2nd dim's size +// src_start_idx1: the 1st dim index of start point in src matrix +// src_start_idx2: the 2nd dim index of start point in src matrix +// dst_buffer: the destination buffer +// dst_stride: the stride of destination buffer, i.e., 2nd dim's size +// dst_start_idx1: the 1st dim index of start point in dst matrix +// dst_start_idx2: the 2nd dim index of start point in dst matrix +// dim1_copy_size: 1st dim size of copy data +// dim2_copy_size: 2nd dim size of copy data +void CopyArrayData(const Buffer& src_buffer, + int src_stride, int src_start_idx1, int src_start_idx2, + Buffer* dst_buffer, int dst_stride, + int dst_start_idx1, int dst_start_idx2, int dim1_copy_size, + int dim2_copy_size); + +// Copy a subset of array data and create a smaller array, +// mostly used for spliting weights and bias for Lstm cell. +void CopySubArrayToArray(Model* model, string* array_name, + const string& tensor_name, int dim1_size, + int dim2_size, const Array& original_array, + int start_idx1, int start_idx2); + +// Copy array data to a large array's submatrix, +// mostly used for merging weights and bias for Lstm cell. +void CopyArrayToSubArray(Buffer& tensor_buffer, + int tensor_stride, const Array& sub_array, + int start_idx1, int start_idx2); + +// Get mating rnn array inputs using rnn_states flag. +bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array, + string* rnn_array); + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 4fb3b6ae7a5fc5bfc2719b978331c67ae799eb54..0cf0994b43bb048616bab1abe79db1aae2223d37 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -61,23 +61,42 @@ void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth, output_shape->ReplaceDims({batch, output_height, output_width, output_depth}); } -void ComputeBinaryOperatorOutputSize(const Shape& input_shape1, - const Shape& input_shape2, +void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x, + const Shape& input_shape_y, Array* output_array) { - const int size1 = RequiredBufferSizeForShape(input_shape1); - const int size2 = RequiredBufferSizeForShape(input_shape2); - if (size1 > size2) { - output_array->copy_shape(input_shape1); - } else if (size2 > size1) { - output_array->copy_shape(input_shape2); - } else { - CHECK_EQ(size1, size2); - const int dims1 = input_shape1.dimensions_count(); - const int dims2 = input_shape2.dimensions_count(); - if (dims1 >= dims2) { - output_array->copy_shape(input_shape1); + // This matches the code in BroadcastBinaryOpShapeFn from tensorflow. + // It zips together the two input shapes and pads with 1 to make them the + // same length. For each dimension we broadcast if either dimension is 1 and + // otherwise expect them to match. + int rank_x = input_shape_x.dimensions_count(); + int rank_y = input_shape_y.dimensions_count(); + int rank_out = std::max(rank_x, rank_y); + std::vector* dims_out = output_array->mutable_shape()->mutable_dims(); + dims_out->clear(); + dims_out->reserve(rank_out); + for (int i = 0; i < rank_out; ++i) { + int dim_x = i < (rank_out - rank_x) + ? 1 + : input_shape_x.dims(i - (rank_out - rank_x)); + bool dim_y_is_one = i < (rank_out - rank_y); + int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y)); + if (dim_x == -1 || dim_y == -1) { + // One or both dimensions is unknown. + QCHECK(false) << "Shapes must be specified"; + } else if (dim_x == 1 || dim_y == 1) { + // Broadcast one dimension to the other that is 1. + if (dim_x == 1 && !dim_y_is_one) { + // Broadcast dim_y to dim_x (1). + dims_out->push_back(dim_y); + } else { + // Broadcast dim_x to dim_y (1). + DCHECK_EQ(dim_y, 1); + dims_out->push_back(dim_x); + } } else { - output_array->copy_shape(input_shape2); + // Expect the dimensions to match. + CHECK_EQ(dim_x, dim_y) << "Dimensions must match"; + dims_out->push_back(dim_x); } } CHECK(output_array->has_shape()); @@ -546,6 +565,9 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { // Use 0 input as basis for output dimensions. const auto& first_input_array = model->GetArray(op->inputs[0]); output_array.copy_shape(first_input_array.shape()); + // Negative axis means the count starts at the back of the dims(). + int axis = op->axis; + if (axis < 0) axis += first_input_array.shape().dims().size(); // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; @@ -558,14 +580,14 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { CHECK_EQ(input_array.shape().dimensions_count(), output_array.shape().dimensions_count()); const std::vector& input_dims = input_array.shape().dims(); - CHECK_LT(op->axis, input_dims.size()); - concat_size += input_dims[op->axis]; + CHECK_LT(axis, input_dims.size()); + concat_size += input_dims[axis]; } // Write out the concat_size on the output array shape. auto& output_shape = *output_array.mutable_shape(); auto& output_dims = *output_shape.mutable_dims(); - CHECK_LT(op->axis, output_shape.dimensions_count()); - output_dims[op->axis] = concat_size; + CHECK_LT(axis, output_shape.dimensions_count()); + output_dims[axis] = concat_size; } void ProcessRangeOperator(Model* model, RangeOperator* op) { @@ -628,15 +650,34 @@ void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { } const Shape& input_shape = input_array.shape(); - // This code is slightly suspect. The TensorFlow docs say that the axis - // selection defaults to 0, but we are splitting across the final axis. - const int input_dims_count = input_shape.dimensions_count(); - const int input_depth = input_shape.dims(input_dims_count - 1); - CHECK_EQ(input_depth % op->num_split, 0); - const int split_depth = input_depth / op->num_split; + // Yield until axis is constant. + if (!IsConstantParameterArray(*model, op->inputs[0])) { + return; + } + + const auto& axis_array = model->GetArray(op->inputs[0]); + + // Yield until axis dims have been resolved. + if (!axis_array.has_shape()) { + return; + } + + CHECK(axis_array.data_type == ArrayDataType::kInt32) + << "Axis array must be int32."; + CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1) + << "Axis array must be scalar."; + + int axis = axis_array.GetBuffer().data[0]; + if (axis < 0) { + axis += input_shape.dimensions_count(); + } + + const int split_dim = input_shape.dims(axis); + CHECK_EQ(split_dim % op->num_split, 0); + const int split_depth = split_dim / op->num_split; Shape output_shape = input_shape; - (*output_shape.mutable_dims())[input_dims_count - 1] = split_depth; + (*output_shape.mutable_dims())[axis] = split_depth; CHECK_EQ(op->outputs.size(), op->num_split); for (const auto& output : op->outputs) { @@ -725,9 +766,8 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) { } void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { - // I/O arrays should be allocated on creation of op. - QCHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS); - QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); + // Only required for compact LstmCell with default NUM_INPUTS of inputs. + if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return; const auto& input_array = model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]); @@ -942,6 +982,43 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { } } +void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) { + const auto& input_values = model->GetArray(op->inputs[0]); + const auto& input_k = model->GetArray(op->inputs[1]); + auto& output_indexes = model->GetArray(op->outputs[0]); + auto& output_values = model->GetArray(op->outputs[1]); + + // Bail if we already know the output shape. + if (output_indexes.has_shape()) { + QCHECK(output_values.has_shape()); + return; + } + + // Yield until input dims have been resolved. + if (!input_values.has_shape()) { + return; + } + + const auto& input_values_shape = input_values.shape(); + auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); + auto output_values_dims = output_values.mutable_shape()->mutable_dims(); + for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { + output_indexes_dims->push_back(input_values_shape.dims(dim)); + output_values_dims->push_back(input_values_shape.dims(dim)); + } + // If the value is initialized, we can specify the last dimension, otherwise + // unknown. + if (input_k.buffer) { + const int32_t k_value = input_k.GetBuffer().data[0]; + output_indexes_dims->push_back(k_value); + output_values_dims->push_back(k_value); + + } else { + output_indexes_dims->push_back(0); + output_values_dims->push_back(0); + } +} + void ProcessPadOperator(Model* model, PadOperator* op) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); @@ -1120,7 +1197,8 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { stop += input_array.shape().dims(i); } - int dim_size = (stop - start) / op->strides[i]; + int dim_size = ceil((stop - start) / static_cast(op->strides[i])); + dim_size = dim_size < 0 ? 0 : dim_size; if (op->shrink_axis_mask & mask) { CHECK_EQ(dim_size, 1) << "Output size for an axis must compute to 1 when " "shrinking that axis"; @@ -1214,7 +1292,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) { std::vector const& perm = perm_array.GetBuffer().data; CHECK_EQ(perm.size(), input_shape.dimensions_count()) - << "Transpose permutation input must be same length as input dimensions"; + << "Transpose permutation input " << op->inputs[0] + << " must be same length as input dimensions"; std::vector* output_dims = output_array.mutable_shape()->mutable_dims(); for (int i = 0; i < perm.size(); i++) { int axis = perm[i]; @@ -1271,6 +1350,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kRelu1: case OperatorType::kRelu6: case OperatorType::kSoftmax: + case OperatorType::kLogSoftmax: case OperatorType::kLogistic: case OperatorType::kTanh: case OperatorType::kLocalResponseNormalization: @@ -1284,12 +1364,15 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTensorFlowAssert: case OperatorType::kCast: case OperatorType::kFloor: + case OperatorType::kExp: ProcessSimpleOperator(model, op); break; case OperatorType::kGather: ProcessGatherOperator(model, static_cast(op)); break; - + case OperatorType::kTopK_V2: + ProcessTopkV2Operator(model, static_cast(op)); + break; case OperatorType::kAdd: case OperatorType::kSub: case OperatorType::kMul: @@ -1420,6 +1503,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kLstmCell: ProcessLstmCellOperator(model, static_cast(op)); break; + case OperatorType::kBatchMatMul: case OperatorType::kTensorFlowMatMul: // MatMul operators are converted to FullyConnected, after which their // shapes are propagated. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index b973b2b813147cc580d2e87cea7d395f180f5aa1..d7f804ee432598cafe6b6c05d03219aa7d2783fa 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -41,11 +41,15 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kConcatenation || type == OperatorType::kL2Normalization || type == OperatorType::kAdd || type == OperatorType::kAveragePool || type == OperatorType::kMaxPool || + type == OperatorType::kTensorFlowMinimum || + type == OperatorType::kTensorFlowMaximum || type == OperatorType::kLogistic || type == OperatorType::kSoftmax || + type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub || type == OperatorType::kSqueeze || type == OperatorType::kPad || type == OperatorType::kTensorFlowReshape || - type == OperatorType::kMul || type == OperatorType::kSpaceToDepth || - type == OperatorType::kDepthToSpace; + type == OperatorType::kTanh || type == OperatorType::kMul || + type == OperatorType::kSpaceToDepth || + type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell; } template @@ -100,6 +104,9 @@ void QuantizeArray(GraphTransformation* transformation, Model* model, case ArrayDataType::kUint8: return QuantizeArray(transformation, model, name, quantization_params); + case ArrayDataType::kInt16: + return QuantizeArray(transformation, model, name, + quantization_params); case ArrayDataType::kInt32: return QuantizeArray(transformation, model, name, quantization_params); @@ -168,36 +175,62 @@ bool ChooseQuantizationForOperatorInput( if (array.data_type != ArrayDataType::kFloat) { return false; } + + // Quantization of bias vectors + bool is_bias_vector = false; + int activations_input_index; + int weights_input_index; if (op.type == OperatorType::kConv || op.type == OperatorType::kDepthwiseConv || op.type == OperatorType::kFullyConnected) { if (input_index == 2) { - // Quantization of bias vector. - // We need both of the mandatory inputs (input activations and weights) to - // have - // been already quantized. - const auto& input_activations = model->GetArray(op.inputs[0]); - const auto& input_weights = model->GetArray(op.inputs[1]); - if (!input_activations.quantization_params || - !input_weights.quantization_params) { - return false; - } - const auto input_activations_scale = - input_activations.quantization_params->scale; - const auto input_weights_scale = input_weights.quantization_params->scale; - quantization_params->scale = - input_activations_scale * input_weights_scale; - quantization_params->zero_point = 0; - *quantized_data_type = ArrayDataType::kInt32; - transformation->AddMessageF( - "Input array %s is a bias vector. Choosing quantization params " - "accordingly.", - input); - return true; + is_bias_vector = true; + activations_input_index = 0; + weights_input_index = 1; } } + if (op.type == OperatorType::kLstmCell) { + if (input_index == LstmCellOperator::BIASES_INPUT) { + is_bias_vector = true; + activations_input_index = LstmCellOperator::DATA_INPUT; + weights_input_index = LstmCellOperator::WEIGHTS_INPUT; + } + } + if (is_bias_vector) { + // Quantization of bias vector. + // We need both of the mandatory inputs (input activations and weights) to + // have been already quantized. + const auto& input_activations = + model->GetArray(op.inputs[activations_input_index]); + const auto& input_weights = model->GetArray(op.inputs[weights_input_index]); + if (!input_activations.quantization_params || + !input_weights.quantization_params) { + return false; + } + const auto input_activations_scale = + input_activations.quantization_params->scale; + const auto input_weights_scale = input_weights.quantization_params->scale; + quantization_params->scale = input_activations_scale * input_weights_scale; + quantization_params->zero_point = 0; + *quantized_data_type = ArrayDataType::kInt32; + transformation->AddMessageF( + "Input array %s is a bias vector. Choosing quantization params " + "accordingly.", + input); + return true; + } const MinMax& minmax = GetOrComputeMinMax(model, input); + + if (op.type == OperatorType::kLstmCell) { + if (input_index == LstmCellOperator::PREV_STATE_INPUT) { + GetQuantizationParamsFromMinMax( + model->flags, minmax, quantization_params); + *quantized_data_type = ArrayDataType::kInt16; + return true; + } + } + GetQuantizationParamsFromMinMax(model->flags, minmax, quantization_params); transformation->AddMessageF( @@ -258,6 +291,17 @@ bool ChooseHardcodedQuantizationForOperatorOutput( *quantization_params)); return true; } + if (op.type == OperatorType::kTanh) { + // Tanh has the range: [-1, 1]. + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = 128; + quantization_params->scale = 1. / 128.; + // 0 should be exactly representable, as values will typically be centered + // around 0, with many values near 0. + CHECK( + IsExactlyRepresentable(0., *quantized_data_type, *quantization_params)); + return true; + } return false; } @@ -295,6 +339,15 @@ bool ChooseQuantizationForOperatorOutput( return true; } const MinMax& minmax = GetOrComputeMinMax(model, output); + if (op.type == OperatorType::kLstmCell) { + if (output_index == LstmCellOperator::STATE_OUTPUT || + output_index == LstmCellOperator::ACTIV_TEMP) { + GetQuantizationParamsFromMinMax( + model->flags, minmax, quantization_params); + *quantized_data_type = ArrayDataType::kInt16; + return true; + } + } GetQuantizationParamsFromMinMax(model->flags, minmax, quantization_params); *quantized_data_type = ArrayDataType::kUint8; @@ -390,30 +443,52 @@ bool Quantize::Run(Model* model, std::size_t op_index) { if (ChooseQuantizationForOperatorInput(this, model, op, input_index, &quantized_data_type, &quantization_params)) { - changed = true; const auto& input = op.inputs[input_index]; if (IsConstantParameterArray(*model, input)) { QuantizeArray(this, model, input, quantized_data_type, quantization_params); + changed = true; } else { auto dequantize_it = FindOpWithOutput(*model, input); - CHECK(dequantize_it != model->operators.end()); - auto* dequantize_op = dequantize_it->get(); - CHECK(dequantize_op->type == OperatorType::kDequantize); - op.inputs[input_index] = dequantize_op->inputs[0]; - // Check if the output of that Dequantize op was not used by any - // other operator. We will then erase that Dequantize op. - if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { - // If any of the model's output_arrays was pointing to the - // Dequantize op's output, let it point to the Dequantize op's - // input instead. - for (int i = 0; i < model->flags.output_arrays_size(); i++) { - if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { - model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + if (dequantize_it != model->operators.end()) { + auto* dequantize_op = dequantize_it->get(); + CHECK(dequantize_op->type == OperatorType::kDequantize); + op.inputs[input_index] = dequantize_op->inputs[0]; + // Check if the output of that Dequantize op was not used by any + // other operator. We will then erase that Dequantize op. + if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { + // If any of the model's output_arrays was pointing to the + // Dequantize op's output, let it point to the Dequantize op's + // input instead. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + } + } + model->EraseArray(dequantize_op->outputs[0]); + model->operators.erase(dequantize_it); + } + changed = true; + } else { + // This input array is not produced by a Dequantize op. + // We have encountered this situation in RNN graphs, whose cyclic + // nature defeats the basic assumption underlying the quantization + // algorithm implemented here. For now, when we have seen this + // happening, the array in question was a RNN state array itself, + // so let us just implement this case here, and guard that assumption + // with a CHECK. A more general fix would involve revisiting the + // design of this whole Quantization transformation. + bool is_rnn_state_array = false; + for (const auto& rnn_state : model->flags.rnn_states()) { + if (rnn_state.state_array() == input) { + is_rnn_state_array = true; + break; } } - model->EraseArray(dequantize_op->outputs[0]); - model->operators.erase(dequantize_it); + CHECK(is_rnn_state_array); + QuantizeArray(this, model, input, quantized_data_type, + quantization_params); + changed = true; } } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..0cbbcd7c814d38e32ee55e9d9271adf532d20924 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool IsSliceTrivial(const Model& model, const Operator& op, + RemoveTrivialSlice* transformation) { + CHECK(op.type == OperatorType::kSlice); + + // Slices are trivial if they are slicing the entire input contents. + const auto& input_array = model.GetArray(op.inputs[0]); + const auto& output_array = model.GetArray(op.outputs[0]); + if (input_array.has_shape() && output_array.has_shape()) { + if (input_array.shape() == output_array.shape()) { + transformation->AddMessageF( + "%s is trivial because its input and output shapes are equal", + LogName(op)); + return true; + } + } + + return false; +} + +} // namespace + +bool RemoveTrivialSlice::Run(Model* model, std::size_t op_index) { + const auto reshape_it = model->operators.begin() + op_index; + auto* slice_op = reshape_it->get(); + if (slice_op->type != OperatorType::kSlice) { + return false; + } + + if (!IsSliceTrivial(*model, *slice_op, this)) { + return false; + } + + AddMessageF("Removing trivial %s", LogName(*slice_op)); + + CHECK_EQ(slice_op->inputs.size(), 3); + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..30a005c789bb12e880e8e4534088d99ebacba84a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ReorderActivationFunctions::Run(Model* model, std::size_t op_index) { + const auto ac_it = model->operators.begin() + op_index; + std::unique_ptr& ac_op = *ac_it; + DCHECK(ac_op); + + if (ac_op->type != OperatorType::kRelu6 && + ac_op->type != OperatorType::kRelu1 && + ac_op->type != OperatorType::kRelu) { + return false; + } + + auto exchange_it = FindOpWithOutput(*model, ac_op->inputs[0]); + if (exchange_it == model->operators.end()) return false; + // Find the op producing the array passed to this activation function + std::unique_ptr& exchange_op = *exchange_it; + DCHECK(exchange_op); + + if (exchange_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + + DCHECK_EQ(exchange_op->outputs[0], ac_op->inputs[0]); + const auto& exchange_op_input = exchange_op->inputs[0]; + const auto& intermediate_array = exchange_op->outputs[0]; + const auto& ac_op_output = ac_op->outputs[0]; + + int count_ops_consuming_output = + CountOpsWithInput(*model, intermediate_array); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not exchanging activation function with %s because it is consumed by " + "more than 1 other operator", + LogName(*exchange_op)); + return false; + } + + // If the ac_op was originally producing an output_array we can't reorder as + // otherwise the output array would change. It'd be nice to still be able to + // reorder but if code is relying on the fetch names instead of array indices + // this won't work. + for (int i = 0; i < model->flags.output_arrays_size(); ++i) { + if (model->flags.output_arrays(i) == ac_op->outputs[0]) { + AddMessageF( + "Not exchanging activation function with %s to preserve output array " + "name %s", + LogName(*exchange_op), ac_op->outputs[0]); + return false; + } + } + + // Rewire by changing inputs, including all consumers. + Operator* consumer = GetFirstOpWithInput(*model, ac_op_output); + while (consumer) { + for (int i = 0; i < consumer->inputs.size(); ++i) { + if (consumer->inputs[i] == ac_op_output) { + consumer->inputs[i] = intermediate_array; + } + } + consumer = GetFirstOpWithInput(*model, ac_op_output); + } + ac_op->inputs[0] = exchange_op_input; + exchange_op->inputs[0] = ac_op_output; + + // Clear shapes; this will allow shape propagation to fix the sizes for us. + model->GetOrCreateArray(ac_op->outputs[0]).clear_shape(); + model->GetOrCreateArray(exchange_op->outputs[0]).clear_shape(); + + // Finally, reorder operators. Note that this only works when there are no + // other direct descendents of the exchange_op. + ac_op.swap(exchange_op); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc index 5ac449749adbc9b5422f996eeccb72575dca8722..064810b53e7c3bee4601204c9dbd976c374a6a60 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -73,7 +73,7 @@ void CopyTensorSegments(const std::vector& input_arrays, // Receives a series of input arrays of type Array and an integer showing the // axis on which those arrays will be concatenated. It returns the concatenated -// arrray. +// array. template void ConcatenateTensorBuffers(const std::vector& input_arrays, int concatenation_axis, @@ -190,7 +190,7 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { // Remove all the resolved arrays. for (const string& input_name : concat_op->inputs) { // Check to prevent removal of shared tensors - if(CountOpsWithInput(*model, input_name) == 1) { + if (CountOpsWithInput(*model, input_name) == 1) { model->EraseArray(input_name); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index 81fe37d7e017c6e2440de34cc2daedf7fb2a422e..944901ece77430708013ea4ca340a30511ba0174 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -50,6 +50,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { output_array.data_type = ArrayDataType::kFloat; CHECK(!output_array.buffer); const auto& input_buffer = input_array.GetBuffer(); + output_array.GetOrCreateMinMax() = *fakequant_op->minmax; auto& output_buffer = output_array.GetMutableBuffer(); const int size = input_buffer.data.size(); output_buffer.data.resize(size); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f984bfde55b3457694bb411bbfdf30723c7066e --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc @@ -0,0 +1,180 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// Transposes an array up to rank 4. +// This is ShuffleArrayTemplate with non-enum permutation. +template +void Transpose(Model* model, const Array& input_array, + const std::vector& perm, Array* output_array) { + const Shape& input_shape = input_array.shape(); + const std::vector>& input_data = + input_array.GetBuffer().data; + + const Shape& output_shape = output_array->shape(); + std::vector>& output_data = + output_array->GetMutableBuffer().data; + output_data.resize(RequiredBufferSizeForShape(output_shape)); + + CHECK(input_shape.dimensions_count() == output_shape.dimensions_count()); + const int dim = input_shape.dimensions_count(); + CHECK_LE(dim, 4); + CHECK(perm.size() >= dim); + for (int i = 0; i < dim; i++) { + CHECK(perm[i] >= 0 && perm[i] < dim); + CHECK(input_shape.dims(perm[i]) == output_shape.dims(i)); + } + Shape extended_input_shape = input_shape; + ExtendShape(&extended_input_shape, 4); + Shape extended_output_shape = output_shape; + ExtendShape(&extended_output_shape, 4); + std::vector extended_perm; + ExtendShuffle(perm, 4, &extended_perm); + + const std::vector& extended_input_dims = extended_input_shape.dims(); + const std::vector& extended_output_dims = extended_output_shape.dims(); + + // TODO(starka): Rework to handle different numbers of dimensions. + int input_strides[4]; + input_strides[3] = 1; + input_strides[2] = extended_input_dims[3]; + input_strides[1] = input_strides[2] * extended_input_dims[2]; + input_strides[0] = input_strides[1] * extended_input_dims[1]; + const int input_stride_0 = input_strides[extended_perm[3]]; + const int input_stride_1 = input_strides[extended_perm[2]]; + const int input_stride_2 = input_strides[extended_perm[1]]; + const int input_stride_3 = input_strides[extended_perm[0]]; + + const int output_size_0 = extended_output_dims[3]; + const int output_size_1 = extended_output_dims[2]; + const int output_size_2 = extended_output_dims[1]; + const int output_size_3 = extended_output_dims[0]; + const int output_stride_0 = 1; + const int output_stride_1 = output_size_0; + const int output_stride_2 = output_stride_1 * output_size_1; + const int output_stride_3 = output_stride_2 * output_size_2; + + for (int i3 = 0; i3 < output_size_3; i3++) { + const DataType* const input_ptr_3 = + input_data.data() + i3 * input_stride_3; + DataType* const output_ptr_3 = + output_data.data() + i3 * output_stride_3; + for (int i2 = 0; i2 < output_size_2; i2++) { + const DataType* const input_ptr_2 = + input_ptr_3 + i2 * input_stride_2; + DataType* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2; + for (int i1 = 0; i1 < output_size_1; i1++) { + const DataType* input_ptr = input_ptr_2 + i1 * input_stride_1; + DataType* output_ptr = output_ptr_2 + i1 * output_stride_1; + DataType* const output_ptr_end = + output_ptr + output_size_0 * output_stride_0; + while (output_ptr != output_ptr_end) { + *output_ptr = *input_ptr; + input_ptr += input_stride_0; + output_ptr += output_stride_0; + } + } + } + } +} + +} // namespace + +bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kTranspose) { + return false; + } + const auto* op = static_cast(base_op); + + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes. + return false; + } + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes. + return false; + } + + // We require constant inputs. + if (!IsConstantParameterArray(*model, op->inputs[0]) || + !IsConstantParameterArray(*model, op->inputs[1])) { + return false; + } + const Array& input_array = model->GetArray(op->inputs[0]); + + if (input_array.minmax) { + output_array.GetOrCreateMinMax() = input_array.GetMinMax(); + } + + if (op->perm.empty()) { + // Yield until perm has been populated by ResolveTransposeAttributes. + return false; + } + + // We currently only support 1-4 dimensions. + CHECK_LE(op->perm.size(), 4); + + CHECK(!output_array.buffer); + switch (output_array.data_type) { + case ArrayDataType::kFloat: + Transpose(model, input_array, op->perm, + &output_array); + break; + case ArrayDataType::kUint8: + Transpose(model, input_array, op->perm, + &output_array); + break; + case ArrayDataType::kInt32: + Transpose(model, input_array, op->perm, + &output_array); + break; + case ArrayDataType::kInt64: + Transpose(model, input_array, op->perm, + &output_array); + break; + default: + LOG(FATAL) << "Unsupported data type given to Transpose op with output \"" + << op->outputs[0] << "\""; + break; + } + + // Erase input arrays if no longer used. + for (const auto& input : op->inputs) { + if (IsDiscardableArray(*model, input) && + CountOpsWithInput(*model, input) == 1) { + model->EraseArray(input); + } + } + + // Erase the operator. + model->operators.erase(it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index 1cd2aff28c68eaba4e9b18d8e2c2803834328696..f227554bc505efe6a758fdd9894fee43f2500641 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -139,14 +139,13 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { output_buffer_size * sizeof(output_float_data[0])); } else if (unary_op->type == OperatorType::kTensorFlowSum) { // At the moment only full reduction across all dimensions is supported. - for (int i = 0; i < output_dims_count; i++) { - CHECK_EQ(output_shape.dims(i), 1); - } float sum = 0.f; for (int i = 0; i < input_buffer_size; i++) { sum += (*input_float_data)[i]; } - output_float_data[0] = sum; + for (int i = 0; i < output_buffer_size; ++i) { + output_float_data[i] = sum; + } } else if (unary_op->type == OperatorType::kTensorFlowMin) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc new file mode 100644 index 0000000000000000000000000000000000000000..37beb41dfc5904fc6ace79ebea2420d2ab92fbfb --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc @@ -0,0 +1,152 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +template +bool AreAllBufferElementsZero(const std::vector& buffer_data) { + for (auto x : buffer_data) { + if (x != 0) { + return false; + } + } + return true; +} + +template +void FillArrayWithZeros(Array* array) { + CHECK(array->data_type == Type); + std::vector>& data = array->GetMutableBuffer().data; + data.resize(RequiredBufferSizeForShape(array->shape())); + for (size_t i = 0; i < data.size(); i++) { + data[i] = 0; + } +} + +} // namespace + +// Removes a multiplication by array of constant zeros by making the output +// array an array of constant zeros and removing the input arrays if they are no +// longer needed. +bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { + const auto mul_it = model->operators.begin() + op_index; + auto* mul_op = mul_it->get(); + if (mul_op->type != OperatorType::kMul) { + return false; + } + const auto& output_array_name = mul_op->outputs[0]; + auto& output_array = model->GetArray(output_array_name); + + // Yield if the output shape is not known yet. + if (!output_array.has_shape()) { + return false; + } + + // This transformation only handles the case where one operand is all 0's and + // the other is non-constant. Other cases are handled by constant propagation + // or the trivial binary removal pass. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, mul_op->inputs[0]), + IsConstantParameterArray(*model, mul_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can resolve here. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants propagation, not + // for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + const auto& constant_input_array = + model->GetArray(mul_op->inputs[index_of_constant_input]); + + CHECK(constant_input_array.data_type == output_array.data_type); + switch (output_array.data_type) { + case ArrayDataType::kFloat: { + const auto& constant_input_data = + constant_input_array.GetBuffer().data; + if (!AreAllBufferElementsZero>( + constant_input_data)) { + return false; + } + FillArrayWithZeros(&output_array); + } break; + case ArrayDataType::kUint8: { + const auto& constant_input_data = + constant_input_array.GetBuffer().data; + if (!AreAllBufferElementsZero>( + constant_input_data)) { + return false; + } + FillArrayWithZeros(&output_array); + } break; + case ArrayDataType::kInt32: { + const auto& constant_input_data = + constant_input_array.GetBuffer().data; + if (!AreAllBufferElementsZero>( + constant_input_data)) { + return false; + } + FillArrayWithZeros(&output_array); + } break; + case ArrayDataType::kInt64: { + const auto& constant_input_data = + constant_input_array.GetBuffer().data; + if (!AreAllBufferElementsZero>( + constant_input_data)) { + return false; + } + FillArrayWithZeros(&output_array); + } break; + default: + AddMessageF( + "Cannot resolve multiply by 0 because of unsupported data type\n"); + return false; + } + + // Erase input arrays to the multiply if no longer used + if (IsDiscardableArray(*model, mul_op->inputs[0]) && + CountOpsWithInput(*model, mul_op->inputs[0]) == 1) { + model->EraseArray(mul_op->inputs[0]); + } + if (IsDiscardableArray(*model, mul_op->inputs[1]) && + CountOpsWithInput(*model, mul_op->inputs[1]) == 1) { + model->EraseArray(mul_op->inputs[1]); + } + + // Erase the multiply operator. + model->operators.erase(mul_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc index 5c68f87f6ccd912a94213c95a59a78076b0e768b..bc70db0bd8c26319fa140616de96452260a01058 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -60,16 +60,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { const auto& output_array_name = reorder_op->outputs[0]; auto& input_array = model->GetArray(input_array_name); auto& output_array = model->GetArray(output_array_name); - string constant_input_array_name = input_array_name; if (!input_array.buffer) { - const auto* op_producing_input = GetOpWithOutput(*model, input_array_name); - if (op_producing_input && - op_producing_input->type == OperatorType::kFakeQuant) { - constant_input_array_name = op_producing_input->inputs[0]; - } - } - auto& constant_input_array = model->GetArray(constant_input_array_name); - if (!constant_input_array.buffer) { return false; } // Yield until output dims have been resolved. @@ -77,14 +68,14 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { return false; } // Reorder the input array dims and buffer data - if (constant_input_array.buffer->type == ArrayDataType::kFloat) { - ReorderAxes( - reorder_op->input_axes_order, reorder_op->output_axes_order, - &constant_input_array, &output_array); - } else if (constant_input_array.buffer->type == ArrayDataType::kInt32) { - ReorderAxes( - reorder_op->input_axes_order, reorder_op->output_axes_order, - &constant_input_array, &output_array); + if (input_array.buffer->type == ArrayDataType::kFloat) { + ReorderAxes(reorder_op->input_axes_order, + reorder_op->output_axes_order, + &input_array, &output_array); + } else if (input_array.buffer->type == ArrayDataType::kInt32) { + ReorderAxes(reorder_op->input_axes_order, + reorder_op->output_axes_order, + &input_array, &output_array); } else { LOG(FATAL) << "Cannot ReorderAxes unless input buffer is float or uint8."; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index ad1e56888e53133c5a84cc0e3d5e76b7ef3b29b4..f38203c80fcb7ab8bc1639129fd98e4e342e5cb7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -29,7 +29,36 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) { return false; } - const auto* matmul_op = matmul_it->get(); + const auto* matmul_op = + static_cast(matmul_it->get()); + + // Reorder the axes on the second input. TensorFlow uses row-major ordering + // on both inputs, however this is inefficient for the FullyConnected + // operator. We'll transpose the second input to be in column-major order now + // and let constant propagation optimize things (if possible). + auto* transpose_op = new TransposeOperator; + transpose_op->inputs = { + matmul_op->inputs[1], + CreateInt32Array( + model, + AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose/perm"), + {1, 0})}; + transpose_op->outputs = { + AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")}; + model->GetOrCreateArray(transpose_op->outputs[0]); + model->operators.emplace(matmul_it, transpose_op); + + // Refresh iterator. + matmul_it = model->operators.begin(); + for (; matmul_it != model->operators.end(); ++matmul_it) { + if (matmul_it->get() == matmul_op) { + break; + } + } + DCHECK_EQ(matmul_it->get(), matmul_op); + + string input_lhs = matmul_op->inputs[0]; + string input_rhs = transpose_op->outputs[0]; // Find the op producing the array passed to this MatMul auto previous_op_it = model->operators.begin(); @@ -47,22 +76,26 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { } Operator* previous_op = (found) ? previous_op_it->get() : nullptr; - // construct the new FullyConnectedOperator + // Construct the new FullyConnectedOperator. auto* fc_op = new FullyConnectedOperator; fc_op->outputs = matmul_op->outputs; - // insert the newly constructed FullyConnectedOperator - auto fc_it = model->operators.emplace(matmul_it, fc_op); + // Insert the newly constructed FullyConnectedOperator. + model->operators.emplace(matmul_it, fc_op) + 1; - // refresh invalidated iterator - matmul_it = fc_it + 1; + // Refresh iterator. + matmul_it = model->operators.begin(); + for (; matmul_it != model->operators.end(); ++matmul_it) { + if (matmul_it->get() == matmul_op) { + break; + } + } DCHECK_EQ(matmul_it->get(), matmul_op); // The way that TensorFlow encodes FullyConnected ops is as a pair // (Reshape, MatMul), so we want to remove the Reshape op and rewrite the - // MatMul - // op as a FullyConnected. However, TensorFlow skips the Reshape ops if the - // input doesn't need reshaping, so we can't just match (Reshape, MatMul) + // MatMul op as a FullyConnected. However, TensorFlow skips the Reshape ops if + // the input doesn't need reshaping, so we can't just match (Reshape, MatMul) // pairs. if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) { AddMessageF("Combining %s and %s into %s", LogName(*previous_op), @@ -72,7 +105,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { model->EraseArray(previous_op_output); } CHECK_EQ(previous_op->inputs.size(), 2); - fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]}; + input_lhs = previous_op->inputs[0]; // Only remove Reshape node if no other node uses its output. if (CountOpsWithInput(*model, previous_op_output) == 1) { const auto& previous_op_shape = previous_op->inputs[1]; @@ -95,9 +128,10 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { } else { AddMessageF("Replacing %s by a FullyConnected operator", LogName(*matmul_op)); - fc_op->inputs = {matmul_op->inputs[0], matmul_op->inputs[1]}; } + fc_op->inputs = {input_lhs, input_rhs}; + // erase the MatMul operator model->operators.erase(matmul_it); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD index 893149878293c9ef2740effe331d3b6c51b49983..2f94f9cd8a9ab24809fb3d137b5d05ab12f43003 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -18,6 +18,17 @@ tf_cc_test( ], ) +tf_cc_test( + name = "lstm_utils_test", + srcs = ["lstm_utils_test.cc"], + deps = [ + "//tensorflow/contrib/lite/toco:graph_transformations", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_googletest//:gtest_main", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6aae0775d3445daf7d990bcce09d335c5f686601 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc @@ -0,0 +1,442 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +std::vector> ArrayFloatNear( + const std::vector& values, float max_abs_error = 1e-5) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(testing::FloatNear(v, max_abs_error)); + } + return matchers; +} +} // namespace + +class CopyArrayDataTest : public ::testing::Test { + public: + CopyArrayDataTest() {} + + void PrepareBuffers(Model* model, std::initializer_list src_data, + int src_dim_1, int src_dim_2, + std::initializer_list dst_data, int dst_dim_1, + int dst_dim_2) { + string src_array = "src_array"; + src_buffer_ = CreateFloatArrayBuffer( + model, &src_array, + src_dim_2 == 1 ? Shape({src_dim_1}) : Shape({src_dim_1, src_dim_2})); + PopulateBuffer(src_buffer_, src_data); + string dst_array = "dst_array"; + dst_buffer_ = CreateFloatArrayBuffer( + model, &dst_array, + dst_dim_2 == 1 ? Shape({dst_dim_1}) : Shape({dst_dim_1, dst_dim_2})); + PopulateBuffer(dst_buffer_, dst_data); + } + + Buffer* GetSrcBuffer() { return src_buffer_; } + Buffer* GetDstBuffer() { return dst_buffer_; } + + void PopulateBuffer(Buffer* buffer, + const std::vector& init_data) { + for (int i = 0; i < init_data.size(); i++) { + buffer->data[i] = init_data[i]; + } + } + void UpdateBuffer(Buffer* buffer, + std::initializer_list data) { + buffer->data.resize(data.size()); + PopulateBuffer(buffer, data); + } + + private: + Buffer* src_buffer_; + Buffer* dst_buffer_; +}; + +// Copy from 1 big 2D array to 8 smaller ones. +TEST_F(CopyArrayDataTest, CopyFromBigArrayToSmallerArrayes2D) { + // Init src_buffer, dst_buffer. + Model model; + std::initializer_list large_tf_weight_data = { + -0.320407, -0.108683, 0.406358, -0.410811, -0.285786, -0.15769, + -0.194201, 0.170866, 0.084135, 0.201878, 0.21519, -0.284458, + 0.495906, -0.073818, 0.045578, 0.149816, -0.447073, -0.453578, + 0.116766, 0.21808, 0.047326, -0.001985, 0.402193, 0.315517, + 0.38258, 0.43599, 0.11986, 0.465195, 0.33548, -0.118789, + -0.414159, 0.049269, 0.156108, 0.093459, -0.129103, -0.086274, + 0.186188, -0.324923, 0.4117, -0.344439, 0.240465, -0.343331, + -0.463082, -0.231706, -0.487465, -0.186592, -0.020756, -0.239007, + 0.364817, 0.459106, -0.171447, -0.006542, 0.204032, -0.375317, + -0.041911, 0.051664, 0.320483, 0.155899, 0.156555, -0.249823, + -0.353107, 0.031563, -0.340771, -0.052532, 0.134631, -0.257957, + -0.50141, 0.486939, -0.43853, 0.268426, -0.08754, -0.109447, + -0.502462, -0.028055, -0.121838, -0.046016, 0.105309, -0.070774, + 0.495683, -0.475088, 0.048654, -0.38582, 0.411018, -0.315606, + 0.349628, 0.21698, 0.258989, -0.097902, 0.331218, 0.034602, + 0.418069, -0.089025, -0.417513, 0.07609, 0.393821, 0.404733, + -0.055418, -0.43903, -0.447049, 0.013125, 0.278503, 0.459869, + 0.143755, -0.177335, -0.162247, -0.432371, 0.153714, -0.047403, + -0.446775, -0.418363, 0.019743, 0.042025}; + std::initializer_list tflite_lstm_input_weight = {0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0}; + PrepareBuffers(&model, large_tf_weight_data, /*src_dim_1=*/16, + /*src_dim_2=*/7, tflite_lstm_input_weight, + /*dst_dim_1=*/4, /*dst_dim_2=*/3); + + // Copy src starts at (0,0), size (4,3). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + std::vector expected = {-0.320407, -0.108683, 0.406358, 0.170866, + 0.084135, 0.201878, 0.045578, 0.149816, + -0.447073, -0.001985, 0.402193, 0.315517}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (4,0), size (4,3). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/4, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + expected = {0.33548, -0.118789, -0.414159, -0.086274, 0.186188, -0.324923, + -0.463082, -0.231706, -0.487465, 0.459106, -0.171447, -0.006542}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (8,0), size (4,3). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/8, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + expected = {0.320483, 0.155899, 0.156555, -0.052532, 0.134631, -0.257957, + -0.08754, -0.109447, -0.502462, -0.070774, 0.495683, -0.475088}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (12,0), size (4,3). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/12, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + expected = {0.349628, 0.21698, 0.258989, -0.089025, -0.417513, 0.07609, + -0.447049, 0.013125, 0.278503, -0.432371, 0.153714, -0.047403}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // New dst_buffer with size 16. + std::initializer_list tflite_lstm_recurrent_weight = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + PrepareBuffers(&model, large_tf_weight_data, /*src_dim_1=*/16, + /*src_dim_2=*/7, tflite_lstm_recurrent_weight, + /*dst_dim_1=*/4, /*dst_dim_2=*/4); + + // Copy src starts at (0,3), size (4,4). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/0, + /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + expected = {-0.410811, -0.285786, -0.15769, -0.194201, 0.21519, -0.284458, + 0.495906, -0.073818, -0.453578, 0.116766, 0.21808, 0.047326, + 0.38258, 0.43599, 0.11986, 0.465195}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (4,3), size (4,4). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/4, + /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + expected = {0.049269, 0.156108, 0.093459, -0.129103, 0.4117, -0.344439, + 0.240465, -0.343331, -0.186592, -0.020756, -0.239007, 0.364817, + 0.204032, -0.375317, -0.041911, 0.051664}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (8,3), size (4,4). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/8, + /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + expected = {-0.249823, -0.353107, 0.031563, -0.340771, -0.50141, 0.486939, + -0.43853, 0.268426, -0.028055, -0.121838, -0.046016, 0.105309, + 0.048654, -0.38582, 0.411018, -0.315606}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (12,3), size (4,4). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/12, + /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + expected = {-0.097902, 0.331218, 0.034602, 0.418069, 0.393821, 0.404733, + -0.055418, -0.43903, 0.459869, 0.143755, -0.177335, -0.162247, + -0.446775, -0.418363, 0.019743, 0.042025}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); +} + +// Copy from 1 big 1D array to 4 small ones. +TEST_F(CopyArrayDataTest, CopyFromBigArrayToSmallerArrayes1D) { + // Init src_buffer, dst_buffer. + Model model; + std::initializer_list large_tf_bias_data = { + 0.980304, 0.419808, 0.080278, 0.728548, 0.581674, 0.672433, + 0.434190, 0.844357, 0.229587, 0.785629, 0.022065, 0.753082, + 0.422080, 0.539481, 0.878386, 0.168965}; + std::initializer_list tflite_lstm_i_bias = {0, 0, 0, 0}; + PrepareBuffers(&model, large_tf_bias_data, /*src_dim_1=*/16, + /*src_dim_2=*/1, tflite_lstm_i_bias, + /*dst_dim_1=*/4, /*dst_dim_2=*/1); + + // Copy starts at (0,), size (4,). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + std::vector expected = {0.980304, 0.419808, 0.080278, 0.728548}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy starts at (4,), size (4,). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/4, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + expected = {0.581674, 0.672433, 0.434190, 0.844357}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy starts at (8,), size (4,). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/8, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + expected = {0.229587, 0.785629, 0.022065, 0.753082}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy starts at (12,), size (4,). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/12, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + expected = {0.422080, 0.539481, 0.878386, 0.168965}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); +} + +// Copy from 8 small 2D arrayes to 1 big one. +TEST_F(CopyArrayDataTest, CopyFromSmallArrayesToBigArray2D) { + // Init src_buffer, dst_buffer. + Model model; + std::initializer_list large_tf_weights_data = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + // Copy dst starts (0, 0), size (4, 3). + std::initializer_list tflite_lstm_i2i_weight = { + -0.320407, -0.108683, 0.406358, 0.170866, 0.084135, 0.201878, + 0.045578, 0.149816, -0.447073, -0.001985, 0.402193, 0.315517}; + PrepareBuffers(&model, tflite_lstm_i2i_weight, /*src_dim_1=*/4, + /*src_dim_2=*/3, large_tf_weights_data, + /*dst_dim_1=*/16, /*dst_dim_2=*/7); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/3, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + + // Copy dst starts (4, 0), size (4, 3). + std::initializer_list tflite_lstm_i2c_weight = { + 0.33548, -0.118789, -0.414159, -0.086274, 0.186188, -0.324923, + -0.463082, -0.231706, -0.487465, 0.459106, -0.171447, -0.006542}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2c_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/3, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/4, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + + // Copy dst starts (8, 0), size (4, 3). + std::initializer_list tflite_lstm_i2f_weight = { + 0.320483, 0.155899, 0.156555, -0.052532, 0.134631, -0.257957, + -0.08754, -0.109447, -0.502462, -0.070774, 0.495683, -0.475088}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2f_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/3, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/8, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + + // Copy dst starts (12, 0), size (4, 3). + std::initializer_list tflite_lstm_i2o_weight = { + 0.349628, 0.21698, 0.258989, -0.089025, -0.417513, 0.07609, + -0.447049, 0.013125, 0.278503, -0.432371, 0.153714, -0.047403}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2o_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/3, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/12, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + + // Copy dst starts (0, 3), size (4, 4). + std::initializer_list tflite_lstm_i2r_weight = { + -0.410811, -0.285786, -0.15769, -0.194201, 0.21519, -0.284458, + 0.495906, -0.073818, -0.453578, 0.116766, 0.21808, 0.047326, + 0.38258, 0.43599, 0.11986, 0.465195}; + UpdateBuffer(GetSrcBuffer(), tflite_lstm_i2r_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/4, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/3, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + + // Copy dst starts (4, 3), size (4, 4). + std::initializer_list tflite_lstm_c2r_weight = { + 0.049269, 0.156108, 0.093459, -0.129103, 0.4117, -0.344439, + 0.240465, -0.343331, -0.186592, -0.020756, -0.239007, 0.364817, + 0.204032, -0.375317, -0.041911, 0.051664}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_c2r_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/4, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/4, /*dst_start_idx2=*/3, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + + // Copy dst starts (8, 3), size (4, 4). + std::initializer_list tflite_lstm_f2r_weight = { + -0.249823, -0.353107, 0.031563, -0.340771, -0.50141, 0.486939, + -0.43853, 0.268426, -0.028055, -0.121838, -0.046016, 0.105309, + 0.048654, -0.38582, 0.411018, -0.315606}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_f2r_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/4, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/8, /*dst_start_idx2=*/3, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + + // Copy dst starts (12, 3), size (4, 4). + std::initializer_list tflite_lstm_o2r_weight = { + -0.097902, 0.331218, 0.034602, 0.418069, 0.393821, 0.404733, + -0.055418, -0.43903, 0.459869, 0.143755, -0.177335, -0.162247, + -0.446775, -0.418363, 0.019743, 0.042025}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_o2r_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/4, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/12, /*dst_start_idx2=*/3, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + + std::vector expected = { + -0.320407, -0.108683, 0.406358, -0.410811, -0.285786, -0.15769, + -0.194201, 0.170866, 0.084135, 0.201878, 0.21519, -0.284458, + 0.495906, -0.073818, 0.045578, 0.149816, -0.447073, -0.453578, + 0.116766, 0.21808, 0.047326, -0.001985, 0.402193, 0.315517, + 0.38258, 0.43599, 0.11986, 0.465195, 0.33548, -0.118789, + -0.414159, 0.049269, 0.156108, 0.093459, -0.129103, -0.086274, + 0.186188, -0.324923, 0.4117, -0.344439, 0.240465, -0.343331, + -0.463082, -0.231706, -0.487465, -0.186592, -0.020756, -0.239007, + 0.364817, 0.459106, -0.171447, -0.006542, 0.204032, -0.375317, + -0.041911, 0.051664, 0.320483, 0.155899, 0.156555, -0.249823, + -0.353107, 0.031563, -0.340771, -0.052532, 0.134631, -0.257957, + -0.50141, 0.486939, -0.43853, 0.268426, -0.08754, -0.109447, + -0.502462, -0.028055, -0.121838, -0.046016, 0.105309, -0.070774, + 0.495683, -0.475088, 0.048654, -0.38582, 0.411018, -0.315606, + 0.349628, 0.21698, 0.258989, -0.097902, 0.331218, 0.034602, + 0.418069, -0.089025, -0.417513, 0.07609, 0.393821, 0.404733, + -0.055418, -0.43903, -0.447049, 0.013125, 0.278503, 0.459869, + 0.143755, -0.177335, -0.162247, -0.432371, 0.153714, -0.047403, + -0.446775, -0.418363, 0.019743, 0.042025}; + + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); +} + +// Copy from 4 small 1D arrayes to 1 big one. +TEST_F(CopyArrayDataTest, CopyFromSmallArrayesToBigArray1D) { + // Init src_buffer, dst_buffer. + Model model; + std::initializer_list large_tf_bias_data = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + + std::initializer_list tflite_lstm_i_bias = {0.980304, 0.419808, + 0.080278, 0.728548}; + + PrepareBuffers(&model, tflite_lstm_i_bias, /*src_dim_1=*/4, + /*src_dim_2=*/1, large_tf_bias_data, + /*dst_dim_1=*/16, /*dst_dim_2=*/1); + + // Copy starts at (0,), size (4,). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + + // Copy starts at (4,), size (4,). + std::initializer_list tflite_lstm_cell_bias = {0.581674, 0.672433, + 0.434190, 0.844357}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_cell_bias); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/4, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + + // Copy starts at (8,0), size (4,). + std::initializer_list tflite_lstm_forget_bias = {0.229587, 0.785629, + 0.022065, 0.753082}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_forget_bias); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/8, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + + // Copy starts at (12,), size (4,). + std::initializer_list tflite_lstm_output_bias = {0.422080, 0.539481, + 0.878386, 0.168965}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_output_bias); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/12, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + + std::vector expected = {0.980304, 0.419808, 0.080278, 0.728548, + 0.581674, 0.672433, 0.434190, 0.844357, + 0.229587, 0.785629, 0.022065, 0.753082, + 0.422080, 0.539481, 0.878386, 0.168965}; + + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc new file mode 100644 index 0000000000000000000000000000000000000000..da81ea2ff3b4ab0bee0550874a9c4ea1044a3579 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -0,0 +1,172 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// Unrolls a BatchMatMul on the batch dimension. +// We need to slice each batch out of the inputs, matmul them individually, then +// stack them all back together at the end. +// +// This transform effectively looks like: +// result_slices = [] +// for bat in B: +// slice_a = tf.reshape(tf.slice(a, [bat, 0, 0], [1, M, N]), [M, N]) +// slice_b = tf.reshape(tf.slice(b, [bat, 0, 0], [1, M, N]), [M, N]) +// slice_c = tf.matmul(slice_a, slice_b) +// result_slices[bat] = slice_c +// result = tf.stack(result_slices) +bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { + auto batch_op_it = model->operators.begin() + op_index; + if (batch_op_it->get()->type != OperatorType::kBatchMatMul) { + return false; + } + const auto* batch_op = + static_cast(batch_op_it->get()); + + // We must have the shape of at least one input to know our batch size. + const auto& input_array_a = model->GetArray(batch_op->inputs[0]); + const auto& input_array_b = model->GetArray(batch_op->inputs[1]); + if (!input_array_a.has_shape() || !input_array_b.has_shape()) return false; + + // We only support the rank 3 case. If you are batching on rank > 3 you'll + // have to figure that out. + CHECK_EQ(input_array_a.shape().dimensions_count(), + input_array_b.shape().dimensions_count()) + << "Input dimensions must have the same rank"; + if (input_array_a.shape().dimensions_count() == 2) { + // This is really just a MatMul. This likely means that someone hand-crafted + // a graphdef with a BatchMatMul when they really wanted a MatMul. + AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator", + LogName(*batch_op)); + auto* matmul_op = new TensorFlowMatMulOperator; + matmul_op->inputs = batch_op->inputs; + matmul_op->outputs = batch_op->outputs; + const auto matmul_op_it = model->operators.emplace(batch_op_it, matmul_op); + batch_op_it = matmul_op_it + 1; + CHECK_EQ(batch_op_it->get(), batch_op); + model->operators.erase(batch_op_it); + return true; + } + CHECK_EQ(input_array_a.shape().dimensions_count(), 3) + << "Input arrays must have rank 3"; + + // Perform the matmul for each slice of the batch. + int batch_count = input_array_a.shape().dims(0); + AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op), + batch_count); + auto tail_it = batch_op_it; + std::vector stack_inputs; + for (int batch = 0; batch < batch_count; ++batch) { + std::string batch_name = + std::string(batch_op->outputs[0]) + "_b" + std::to_string(batch); + + // tf.slice(a, ...). + auto* slice_a_op = new SliceOperator; + slice_a_op->inputs = { + batch_op->inputs[0], + CreateInt32Array(model, batch_name + "/slice_a/slice/begin", + {batch, 0, 0}), + CreateInt32Array( + model, batch_name + "/slice_a/slice/size", + {1, input_array_a.shape().dims(1), input_array_a.shape().dims(2)}), + }; + slice_a_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_a")}; + auto& slice_a_op_output = model->GetOrCreateArray(slice_a_op->outputs[0]); + slice_a_op_output.data_type = input_array_a.data_type; + tail_it = model->operators.emplace(tail_it, slice_a_op) + 1; + + // Reshape to remove the first dimension ([1,M,N] -> [M,N]). + auto* slice_a_reshape_op = new TensorFlowReshapeOperator; + slice_a_reshape_op->inputs = { + slice_a_op->outputs[0], + CreateInt32Array(model, batch_name + "/slice_a/reshape/shape", + {-1, input_array_a.shape().dims(2)})}; + slice_a_reshape_op->outputs = { + AvailableArrayName(*model, batch_name + "/slice_a/reshape")}; + auto& slice_a_reshape_op_output = + model->GetOrCreateArray(slice_a_reshape_op->outputs[0]); + slice_a_reshape_op_output.data_type = input_array_a.data_type; + tail_it = model->operators.emplace(tail_it, slice_a_reshape_op) + 1; + + // tf.slice(b, ...). + auto* slice_b_op = new SliceOperator; + slice_b_op->inputs = { + batch_op->inputs[1], + CreateInt32Array(model, batch_name + "/slice_b/slice/begin", {0, 0, 0}), + CreateInt32Array( + model, batch_name + "/slice_b/slice/size", + {1, input_array_b.shape().dims(1), input_array_b.shape().dims(2)}), + }; + slice_b_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_b")}; + auto& slice_b_op_output = model->GetOrCreateArray(slice_b_op->outputs[0]); + slice_b_op_output.data_type = input_array_b.data_type; + tail_it = model->operators.emplace(tail_it, slice_b_op) + 1; + + // Reshape to remove the first dimension ([1,M,N] -> [M,N]). + auto* slice_b_reshape_op = new TensorFlowReshapeOperator; + slice_b_reshape_op->inputs = { + slice_b_op->outputs[0], + CreateInt32Array(model, batch_name + "/slice_b/reshape/shape", + {-1, input_array_b.shape().dims(2)})}; + slice_b_reshape_op->outputs = { + AvailableArrayName(*model, batch_name + "/slice_b/reshape")}; + auto& slice_b_reshape_op_output = + model->GetOrCreateArray(slice_b_reshape_op->outputs[0]); + slice_b_reshape_op_output.data_type = input_array_b.data_type; + tail_it = model->operators.emplace(tail_it, slice_b_reshape_op) + 1; + + // tf.matmul(slice_a, slice_b). + auto* matmul_op = new TensorFlowMatMulOperator; + matmul_op->inputs = {slice_a_reshape_op->outputs[0], + slice_b_reshape_op->outputs[0]}; + matmul_op->outputs = {AvailableArrayName(*model, batch_name)}; + auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]); + matmul_op_output.data_type = input_array_a.data_type; + tail_it = model->operators.emplace(tail_it, matmul_op) + 1; + + // Add to stack. + stack_inputs.push_back(matmul_op->outputs[0]); + } + + // The stack that will join all the individual matmul results together. + auto* stack_op = new StackOperator; + stack_op->inputs = stack_inputs; + stack_op->outputs = {batch_op->outputs[0]}; + stack_op->axis = 0; + model->operators.emplace(tail_it, stack_op); + + // Remove the old batch matmul now that we've unrolled. + batch_op_it = model->operators.begin(); + for (; batch_op_it != model->operators.end(); ++batch_op_it) { + if (batch_op_it->get() == batch_op) { + break; + } + } + CHECK(batch_op_it != model->operators.end()); + CHECK(batch_op_it->get() == batch_op); + model->operators.erase(batch_op_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index ca378af4c5c1e1b8cf42a10d3820db3feeb49a05..9c01b67420603a0d3c0e095dafe6a3359f2514b5 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -21,6 +21,7 @@ limitations under the License. #include "google/protobuf/map.h" #include "google/protobuf/text_format.h" +#include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -173,7 +174,8 @@ void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { } auto& output_float_data = output_array->GetMutableBuffer().data; - output_float_data.resize(input_flat_size); + output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()), + 0.f); if (input_tensor.float_val_size() == 1) { for (int i = 0; i < input_flat_size; i++) { output_float_data[i] = input_tensor.float_val(0); @@ -203,7 +205,7 @@ void ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { } auto& output_int_data = output_array->GetMutableBuffer().data; - output_int_data.resize(input_flat_size); + output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); if (input_tensor.int_val_size()) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); @@ -229,7 +231,7 @@ void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { } auto& output_int_data = output_array->GetMutableBuffer().data; - output_int_data.resize(input_flat_size); + output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); if (input_tensor.int_val_size()) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); @@ -255,7 +257,7 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { } auto& output_int_data = output_array->GetMutableBuffer().data; - output_int_data.resize(input_flat_size); + output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); if (input_tensor.int64_val_size()) { for (int i = 0; i < input_tensor.int64_val_size(); i++) { output_int_data[i] = input_tensor.int64_val(i); @@ -281,7 +283,7 @@ void ImportStringArray(const TensorProto& input_tensor, Array* output_array) { } auto& output_string_data = output_array->GetMutableBuffer().data; - output_string_data.resize(input_flat_size); + output_string_data.resize(RequiredBufferSizeForShape(output_array->shape())); if (input_flat_size != input_tensor.string_val_size()) { LOG(FATAL) << "Input_content string_val doesn't have the right " "dimensions for this string tensor."; @@ -838,6 +840,7 @@ void ConvertSwitchOperator(const NodeDef& node, op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op); } + void ConvertSoftmaxOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -853,6 +856,18 @@ void ConvertSoftmaxOperator(const NodeDef& node, model->operators.emplace_back(softmax); } +void ConvertLogSoftmaxOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "LogSoftmax"); + CheckInputsCount(node, tf_import_flags, 1); + const auto& input_name = node.input(0); + auto* log_softmax = new LogSoftmaxOperator; + log_softmax->inputs.push_back(input_name); + log_softmax->outputs.push_back(node.name()); + model->operators.emplace_back(log_softmax); +} + void ConvertLRNOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -961,49 +976,37 @@ void ConvertReshapeOperator(const NodeDef& node, model->operators.emplace_back(op); } +void ConvertBatchMatMulOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CheckInputsCount(node, tf_import_flags, 2); + + // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions + CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false)); + CHECK(!HasAttr(node, "adj_b") || (GetBoolAttr(node, "adj_b") == false)); + + auto* batch_matmul = new BatchMatMulOperator; + batch_matmul->inputs = {node.input(0), node.input(1)}; + batch_matmul->outputs = {node.name()}; + model->operators.emplace_back(batch_matmul); +} + void ConvertMatMulOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CheckInputsCount(node, tf_import_flags, 2); - if (node.op() == "MatMul") { - // Transpose flags should be easy to support, but we don't have a - // GraphDef with them to test on at the moment. - CHECK_EQ(GetBoolAttr(node, "transpose_a"), false); - CHECK_EQ(GetBoolAttr(node, "transpose_b"), false); - CHECK(!HasAttr(node, "adjoint_a") || - (GetBoolAttr(node, "adjoint_a") == false)); - CHECK(!HasAttr(node, "adjoint_b") || - (GetBoolAttr(node, "adjoint_b") == false)); - } else if (node.op() == "BatchMatMul") { - // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions - CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false)); - CHECK(!HasAttr(node, "adj_b") || (GetBoolAttr(node, "adj_b") == false)); - } else { - LOG(FATAL) << "op must be 'MatMul' or 'BatchMatMul'"; - } - const auto& input_name = node.input(0); - const auto& weights_name = node.input(1); - const auto& reordered_weights_name = weights_name + "_reordered"; - // Check if a ReorderAxesOperator was already created for these weights - // (that happens when multiple layers share the same weights). - const Operator* existing_reorder = - GetOpWithOutput(*model, reordered_weights_name); - if (existing_reorder) { - // Check that it is safe to rely on the _reordered naming of the output - // array! - CHECK(existing_reorder->type == OperatorType::kReorderAxes); - } else { - // Create a new ReorderAxesOperator - auto* reorder = new ReorderAxesOperator; - reorder->inputs = {weights_name}; - reorder->outputs = {reordered_weights_name}; - reorder->input_axes_order = AxesOrder::kRC; - reorder->output_axes_order = AxesOrder::kCR; - model->operators.emplace_back(reorder); - } + // Transpose flags should be easy to support, but we don't have a + // GraphDef with them to test on at the moment. + CHECK_EQ(GetBoolAttr(node, "transpose_a"), false); + CHECK_EQ(GetBoolAttr(node, "transpose_b"), false); + CHECK(!HasAttr(node, "adjoint_a") || + (GetBoolAttr(node, "adjoint_a") == false)); + CHECK(!HasAttr(node, "adjoint_b") || + (GetBoolAttr(node, "adjoint_b") == false)); + auto* matmul = new TensorFlowMatMulOperator; - matmul->inputs = {input_name, reordered_weights_name}; + matmul->inputs = {node.input(0), node.input(1)}; matmul->outputs = {node.name()}; model->operators.emplace_back(matmul); } @@ -1311,6 +1314,12 @@ void ConvertResizeBilinearOperator(const NodeDef& node, CHECK_EQ(node.op(), "ResizeBilinear"); CheckInputsCount(node, tf_import_flags, 2); auto* op = new ResizeBilinearOperator; + + op->align_corners = false; + if (HasAttr(node, "align_corners")) { + op->align_corners = GetBoolAttr(node, "align_corners"); + } + op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); @@ -1452,6 +1461,17 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node, model->operators.emplace_back(op); } +void ConvertExpOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Exp"); + CheckInputsCount(node, tf_import_flags, 1); + auto* op = new ExpOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + void ConvertMeanOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1562,7 +1582,7 @@ void ConvertFloorDivOperator(const NodeDef& node, void ConvertFloorModOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK(node.op() == "FloorMod"); + CHECK_EQ(node.op(), "FloorMod"); CheckInputsCount(node, tf_import_flags, 2); auto* op = new FloorModOperator; op->inputs.push_back(node.input(0)); @@ -1797,6 +1817,37 @@ bool InlineAllFunctions(GraphDef* graphdef) { } return graph_modified; } + +void ConvertTopKV2Operator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK((node.op() == "TopK") || (node.op() == "TopKV2")); + auto op = absl::make_unique(); + op->inputs.push_back(node.input(0)); + // K can be encoded as attr (TopK) convert it to a const. + if (HasAttr(node, "k")) { + // Convert attribute into const tensor. + const string array_name = node.name() + "k"; + auto& array = model->GetOrCreateArray(array_name); + array.data_type = ArrayDataType::kInt32; + // Size of array is always 1. + array.mutable_shape()->mutable_dims()->emplace_back(1); + + auto& output_int_data = + array.GetMutableBuffer().data; + output_int_data.resize(1); + output_int_data[0] = GetIntAttr(node, "k"); + op->inputs.push_back(array_name); + + } else { + CheckInputsCount(node, tf_import_flags, 2); + op->inputs.push_back(node.input(1)); + } + // The op has two outputs. + op->outputs.push_back(node.name() + ":0"); + op->outputs.push_back(node.name() + ":1"); + model->operators.emplace_back(op.release()); +} } // namespace std::unique_ptr ImportTensorFlowGraphDef( @@ -1852,7 +1903,9 @@ std::unique_ptr ImportTensorFlowGraphDef( ConvertAvgPoolOperator(node, tf_import_flags, model); } else if (node.op() == "Reshape") { ConvertReshapeOperator(node, tf_import_flags, model); - } else if (node.op() == "MatMul" || node.op() == "BatchMatMul") { + } else if (node.op() == "BatchMatMul") { + ConvertBatchMatMulOperator(node, tf_import_flags, model); + } else if (node.op() == "MatMul") { ConvertMatMulOperator(node, tf_import_flags, model); } else if (node.op() == "Div" || node.op() == "RealDiv") { ConvertDivOperator(node, tf_import_flags, model); @@ -1891,6 +1944,8 @@ std::unique_ptr ImportTensorFlowGraphDef( ConvertLRNOperator(node, tf_import_flags, model); } else if (node.op() == "Softmax") { ConvertSoftmaxOperator(node, tf_import_flags, model); + } else if (node.op() == "LogSoftmax") { + ConvertLogSoftmaxOperator(node, tf_import_flags, model); } else if (node.op() == "All") { ConvertAllOperator(node, tf_import_flags, model); } else if (node.op() == "Assert") { @@ -1974,6 +2029,10 @@ std::unique_ptr ImportTensorFlowGraphDef( ConvertTransposeOperator(node, tf_import_flags, model); } else if (node.op() == "ArgMax") { ConvertArgMaxOperator(node, tf_import_flags, model); + } else if (node.op() == "Exp") { + ConvertExpOperator(node, tf_import_flags, model); + } else if (node.op() == "TopK" || node.op() == "TopKV2") { + ConvertTopKV2Operator(node, tf_import_flags, model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index d1af371fd4c43d7059bfd70597ea765c9c2e51fd..c55bf664f8b65c6eb53ff9ae926bed11adc7b183 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ #define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#include #include #include #include @@ -34,6 +35,7 @@ enum class OperatorType { kAdd, kAddN, kAveragePool, + kBatchMatMul, kBatchNormalization, kConv, kConcatenation, @@ -42,6 +44,7 @@ enum class OperatorType { kSpaceToDepth, kDequantize, kDiv, + kExp, kExpandDims, kFill, kFloorDiv, @@ -61,6 +64,7 @@ enum class OperatorType { kRelu1, kRelu6, kSoftmax, + kLogSoftmax, kSub, kTanh, kTransposeConv, @@ -110,6 +114,7 @@ enum class OperatorType { kTensorFlowSwitch, kTensorFlowTile, kTranspose, + kTopK_V2, // An unsupported TF operation. It's only needed to be able to represent TF // graph internally and is expected to be dropped by graph transformations. kTensorFlowUnsupported, @@ -155,12 +160,17 @@ enum class AxesOrder { // may be involved only in debug-only subgraphs that we may not be interested // in actually supporting). enum class ArrayDataType { - kNone, + kNone, // 0 kBool, kFloat, + kInt8, kUint8, + kInt16, // 5 + kUint16, kInt32, + kUint32, kInt64, + kUint64, // 10 kString }; @@ -180,18 +190,38 @@ struct DataTypeImpl { typedef float Type; }; template <> +struct DataTypeImpl { + typedef int8 Type; +}; +template <> struct DataTypeImpl { typedef uint8 Type; }; template <> +struct DataTypeImpl { + typedef int16 Type; +}; +template <> +struct DataTypeImpl { + typedef uint16 Type; +}; +template <> struct DataTypeImpl { typedef int32 Type; }; template <> +struct DataTypeImpl { + typedef uint32 Type; +}; +template <> struct DataTypeImpl { typedef int64 Type; }; template <> +struct DataTypeImpl { + typedef uint64 Type; +}; +template <> struct DataTypeImpl { typedef string Type; }; @@ -711,6 +741,19 @@ struct TensorFlowIdentityOperator : Operator { TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {} }; +// Batch matrix multiplication operator. This comes from the (deprecated) +// tf.batch_matmul or a tf.matmul that has rank 3. dims(0) is the batch count +// and it can be trivially unrolled into a series of matmuls on each element. +// +// Inputs: +// inputs[0]: required: the left-hand side matrix +// inputs[1]: required: the right-hand side matrix +// +// TensorFlow equivalent: MatMul +struct BatchMatMulOperator : Operator { + BatchMatMulOperator() : Operator(OperatorType::kBatchMatMul) {} +}; + // General matrix multiplication operator. We don't want to support general // matrix multiplication at inference time, so we resolve it during tooling // to more specific operator types, namely, FullyConnected. @@ -811,6 +854,17 @@ struct TransposeConvOperator : Operator { int stride_height = 0; }; +// Given a tensor input, this operation calculates element-wise exponential +// (y = e^x). +// +// Inputs: +// inputs[0]: required: input tensor +// +// TensorFlow equivalent: Exp +struct ExpOperator : Operator { + ExpOperator() : Operator(OperatorType::kExp) {} +}; + // Given a tensor input, this operation inserts a dimension of 1 at the // dimension index axis of input's shape. The dimension index axis starts at // zero; if you specify a negative number for axis it is counted backward from @@ -1215,6 +1269,16 @@ struct SoftmaxOperator : Operator { float beta = 0.f; }; +// LogSoftmax activation function. +// +// Inputs: +// inputs[0]: required: the logits input array +// +// TensorFlow equivalent: LogSoftmax +struct LogSoftmaxOperator : Operator { + LogSoftmaxOperator() : Operator(OperatorType::kLogSoftmax) {} +}; + // Cast operator. // // Inputs: @@ -1272,6 +1336,8 @@ struct ArgMaxOperator : Operator { // TensorFlow equivalent: ResizeBilinear struct ResizeBilinearOperator : Operator { ResizeBilinearOperator() : Operator(OperatorType::kResizeBilinear) {} + + bool align_corners = false; }; // SpaceToBatchND operator. It divides spatial dimensions into a grid of @@ -1335,6 +1401,14 @@ struct SvdfOperator : Operator { int rank; }; +// TopKV2 operator. +// +// Inputs: +// input tensor and top_k scalar. +struct TopKV2Operator : Operator { + TopKV2Operator() : Operator(OperatorType::kTopK_V2) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are @@ -1541,7 +1615,7 @@ class Model { bool HasArray(const string& name) const { return arrays.count(name) > 0; } Array& GetArray(const string& name) const { - DCHECK(HasArray(name)); + DCHECK(HasArray(name)) << "Array not found: " << name; return *arrays.at(name); } Array& GetOrCreateArray(const string& name) { diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 790b3443cef1c577e19bafc5e087ca42e6fce60a..4e2dec15a534607ef9207149a2e6061069eabcb1 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -148,6 +148,12 @@ bool ParseModelFlagsFromCommandLineFlags( "ranging from 32 to 127. This is disallowed by default so as to " "catch common copy-and-paste issues where invisible unicode " "characters are unwittingly added to these strings."), + Flag( + "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(), + parsed_flags.arrays_extra_info_file.default_value(), + "Path to an optional file containing a serialized ArraysExtraInfo " + "proto allowing to pass extra information about arrays not specified " + "in the input model file, such as extra MinMax information."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); @@ -327,9 +333,6 @@ void ReadModelFlagsFromCommandLineFlags( CHECK(absl::SimpleAtoi(value, &size)); CHECK_GT(size, 0); rnn_state_proto->set_size(size); - } else if (key == "manually_create") { - CHECK_EQ(absl::AsciiStrToLower(value), "true"); - rnn_state_proto->set_manually_create(true); } else { LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states"; } @@ -368,6 +371,15 @@ void ReadModelFlagsFromCommandLineFlags( parsed_model_flags.allow_nonascii_arrays.value()); model_flags->set_allow_nonexistent_arrays( parsed_model_flags.allow_nonexistent_arrays.value()); + + if (parsed_model_flags.arrays_extra_info_file.specified()) { + string arrays_extra_info_file_contents; + port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(), + &arrays_extra_info_file_contents, + port::file::Defaults()); + ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents, + model_flags->mutable_arrays_extra_info()); + } } ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) { diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto index 13fea29a07ed9ea75ebe1b9b046f2a68d814c649..e4b39b34e85e4d703c1b41cb68f8139abd1f6279 100644 --- a/tensorflow/contrib/lite/toco/model_flags.proto +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -81,19 +81,26 @@ message RnnState { optional string state_array = 1; optional string back_edge_source_array = 2; optional bool discardable = 5; - // TODO(benoitjacob): drop the 'size' field. Should be redundant with - // --input_shapes and shapes propagation. + // size allows to specify a 1-D shape for the RNN state array. + // Will be expanded with 1's to fit the model. + // TODO(benoitjacob): should allow a generic, explicit shape. optional int32 size = 3; - // TODO(benoitjacob): manually_create is a temporary hack: - // due to discrepancies between the current toco dims tracking and - // TensorFlow shapes, for some models we need to manually create RNN state - // arrays with a specified shape. - // Maybe we should actually implement back-edges as operators of their own, - // which would remove the need for much special-casing, including here, - // we could probably consistently let PropagateFixedSizes handle state - // arrays. - // TODO(benoitjacob): should really drop manually_create now. - optional bool manually_create = 4; +} + +// An ArraysExtraInfo message stores a collection of additional Information +// about arrays in a model, complementing the information in the model itself. +// It is intentionally a separate message so that it may be serialized and +// passed separately from the model. See --arrays_extra_info_file. +// +// A typical use case is to manually specify MinMax for specific arrays in a +// model that does not already contain such MinMax information. +message ArraysExtraInfo { + message Entry { + optional string name = 1; + optional float min = 2; + optional float max = 3; + } + repeated Entry entries = 1; } // ModelFlags encodes properties of a model that, depending on the file @@ -117,7 +124,7 @@ message RnnState { // optional int32 input_dims = 11 [ default = 4]; // repeated int32 input_shape = 13; // -// Next ID to USE: 18. +// Next ID to USE: 19. message ModelFlags { // Information about the input arrays, i.e. the arrays from which input // activations will be read. @@ -160,4 +167,8 @@ message ModelFlags { // catch common copy-and-paste issues where invisible unicode // characters are unwittingly added to these strings. optional bool allow_nonascii_arrays = 17; + + // If set, this ArraysExtraInfo allows to pass extra information about arrays + // not specified in the input model file, such as extra MinMax information. + optional ArraysExtraInfo arrays_extra_info = 18; } diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc index 664e828c19dca1117b81113f723416541f48d621..646d048496c27955aa641fd01a35d8acfbd8dd90 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc @@ -103,11 +103,11 @@ class ResolveSvdfTest : public ::testing::Test { // Add the float vector as an attribute to the node. (*node->mutable_attr())["dtype"].set_type(tensorflow::DT_FLOAT); tensorflow::TensorProto* allocated_tensor = new tensorflow::TensorProto; - tensorflow::TensorShapeProto* allocated_tesnor_shape = + tensorflow::TensorShapeProto* allocated_tensor_shape = new tensorflow::TensorShapeProto; - auto tensor_shape_dim0 = allocated_tesnor_shape->add_dim(); + auto tensor_shape_dim0 = allocated_tensor_shape->add_dim(); tensor_shape_dim0->set_size(values.size()); - allocated_tensor->set_allocated_tensor_shape(allocated_tesnor_shape); + allocated_tensor->set_allocated_tensor_shape(allocated_tensor_shape); allocated_tensor->set_tensor_content( string(reinterpret_cast(values.data()), values.size() * sizeof(float))); @@ -122,11 +122,11 @@ class ResolveSvdfTest : public ::testing::Test { // Add the float vector as an attribute to the node. (*node->mutable_attr())["dtype"].set_type(tensorflow::DT_INT32); tensorflow::TensorProto* allocated_tensor = new tensorflow::TensorProto; - tensorflow::TensorShapeProto* allocated_tesnor_shape = + tensorflow::TensorShapeProto* allocated_tensor_shape = new tensorflow::TensorShapeProto; - auto tensor_shape_dim0 = allocated_tesnor_shape->add_dim(); + auto tensor_shape_dim0 = allocated_tensor_shape->add_dim(); tensor_shape_dim0->set_size(values.size()); - allocated_tensor->set_allocated_tensor_shape(allocated_tesnor_shape); + allocated_tensor->set_allocated_tensor_shape(allocated_tensor_shape); allocated_tensor->set_tensor_content( string(reinterpret_cast(values.data()), values.size() * sizeof(int))); diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.cc b/tensorflow/contrib/lite/toco/tensorflow_util.cc index 82e2800ca2f5bb017f91b5bf43d8d3cd05e97b83..0e7e9c41a066581b14fe1b78f83d8d57b916be6c 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_util.cc +++ b/tensorflow/contrib/lite/toco/tensorflow_util.cc @@ -51,7 +51,8 @@ void LogDumpGraphDef(int log_level, const string& message, BEGIN DUMP OF TENSORFLOW GRAPHDEF (%s) There are %d nodes. There are %zu different op types: -)MSG", message, tf_graph.node_size(), ops.size()); +)MSG", + message, tf_graph.node_size(), ops.size()); for (const auto& op : ops) { toco::port::AppendF(&dump, " %s\n", op); } @@ -63,7 +64,8 @@ PROTO DUMP BEGIN NODE: name = %s op = %s inputs = [ -)MSG", node.name(), node.op()); +)MSG", + node.name(), node.op()); for (const auto& input : node.input()) { toco::port::AppendF(&dump, " %s\n", input); } diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index 72c926656449da981abf6c11c03cd7c00a634ce7..a2b8145a67278c3ac0065f9551da6ffd1de60772 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -117,6 +117,7 @@ cc_library( ":types", "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", "@flatbuffers", ], ) diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 391ef87029d019ab52af2716f72883f5f82f94d9..27719599708a7eb14f72a82f8e5d76b3b8af9dc4 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -26,6 +26,9 @@ namespace toco { namespace tflite { +using flatbuffers::FlatBufferBuilder; +using flatbuffers::Offset; +using flatbuffers::Vector; using ::tflite::Buffer; using ::tflite::BuiltinOperator; using ::tflite::BuiltinOperator_CUSTOM; @@ -39,9 +42,6 @@ using ::tflite::Operator; using ::tflite::OperatorCode; using ::tflite::SubGraph; using ::tflite::Tensor; -using flatbuffers::FlatBufferBuilder; -using flatbuffers::Offset; -using flatbuffers::Vector; namespace { diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index bbf201fd288140d990b8f739adcd9244e1196072..5b1ab514b23248cd98e66847185d0e8b9fe2d6aa 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/toco/tflite/operator.h" #include "tensorflow/contrib/lite/toco/tflite/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" namespace toco { @@ -119,8 +120,16 @@ void ImportOperators( auto inputs = input_op->inputs(); for (int i = 0; i < inputs->Length(); i++) { auto input_index = inputs->Get(i); - const string& input_name = tensors_table.at(input_index); - op->inputs.push_back(input_name); + // input_index == -1 indicates optional tensor. + if (input_index != -1) { + const string& input_name = tensors_table.at(input_index); + op->inputs.push_back(input_name); + } else { + const string& tensor_name = + toco::AvailableArrayName(*model, "OptionalTensor"); + model->CreateOptionalArray(tensor_name); + op->inputs.push_back(tensor_name); + } } auto outputs = input_op->outputs(); for (int i = 0; i < outputs->Length(); i++) { diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 298f49025f9dc8b636dc76a04b8e2e5f11d27db7..aabc7c5109ddc205a7862c3ee2253390dae25095 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -140,25 +140,11 @@ class SpaceToBatchND flatbuffers::Offset WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - auto block_shape = builder->CreateVector(op.block_shape); - auto before_paddings = builder->CreateVector(op.before_paddings); - auto after_paddings = builder->CreateVector(op.after_paddings); - return ::tflite::CreateSpaceToBatchNDOptions( - *builder, block_shape, before_paddings, after_paddings); + return ::tflite::CreateSpaceToBatchNDOptions(*builder); } void ReadOptions(const TfLiteOptions& options, - TocoOperator* op) const override { - op->block_shape.insert(op->block_shape.end(), - options.block_shape()->begin(), - options.block_shape()->end()); - op->before_paddings.insert(op->before_paddings.end(), - options.before_paddings()->begin(), - options.before_paddings()->end()); - op->after_paddings.insert(op->after_paddings.end(), - options.after_paddings()->begin(), - options.after_paddings()->end()); - } + TocoOperator* op) const override {} }; class Sub : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - auto block_shape = builder->CreateVector(op.block_shape); - auto before_crops = builder->CreateVector(op.before_crops); - auto after_crops = builder->CreateVector(op.after_crops); - return ::tflite::CreateBatchToSpaceNDOptions(*builder, block_shape, - before_crops, after_crops); + return ::tflite::CreateBatchToSpaceNDOptions(*builder); } void ReadOptions(const TfLiteOptions& options, - TocoOperator* op) const override { - op->block_shape.insert(op->block_shape.end(), - options.block_shape()->begin(), - options.block_shape()->end()); - op->before_crops.insert(op->before_crops.end(), - options.before_crops()->begin(), - options.before_crops()->end()); - op->after_crops.insert(op->after_crops.end(), - options.after_crops()->begin(), - options.after_crops()->end()); - } + TocoOperator* op) const override {} }; class Cast : public CustomOperator { @@ -478,8 +450,7 @@ class Pad : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateTransposeOptions(*builder, - builder->CreateVector(op.perm)); + return ::tflite::CreateTransposeOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} +}; + +class Lstm : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + // Current toco converter only supports tanh, no clip. + return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/ + ::tflite::ActivationFunctionType_TANH, + /*cell_clip=*/0.0, + /*proj_clip=*/0.0); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - op->perm.insert(op->perm.end(), options.perm()->begin(), - options.perm()->end()); + // Only support tanh activation, so check that tflite type is tanh. + CHECK(options.fused_activation_function() == + ::tflite::ActivationFunctionType_TANH); } }; @@ -564,18 +553,33 @@ class Mean : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - auto axis = builder->CreateVector(op.axis); - return ::tflite::CreateMeanOptions(*builder, axis, op.keep_dims); + return ::tflite::CreateMeanOptions(*builder, op.keep_dims); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - op->axis.insert(op->axis.end(), options.axis()->begin(), - options.axis()->end()); op->keep_dims = options.keep_dims(); } }; +class ResizeBilinear + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->align_corners = options.align_corners(); + } +}; + class Squeeze : public BuiltinOperator { @@ -597,15 +601,21 @@ class Squeeze } }; -class Split : public CustomOperator { +class Split + : public BuiltinOperator { public: - using CustomOperator::CustomOperator; - void WriteOptions(const TocoOperator& op, - flexbuffers::Builder* fbb) const override { - fbb->Int("num_split", op.num_split); + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSplitOptions(*builder, op.num_split); } - void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { - op->num_split = m["num_split"].AsInt64(); + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->num_split = options.num_splits(); } }; @@ -633,6 +643,20 @@ class StridedSlice } }; +class TopK_V2 : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateTopKV2Options(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -791,17 +815,24 @@ std::vector> BuildOperatorList() { OperatorType::kTranspose)); ops.emplace_back( new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); + ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, + OperatorType::kResizeBilinear)); ops.emplace_back( new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); + ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT, + OperatorType::kTensorFlowSplit)); ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice)); + ops.emplace_back( + new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2)); + ops.emplace_back( + new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell)); // Custom Operators. ops.emplace_back(new Cast("CAST", OperatorType::kCast)); ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); - ops.emplace_back(new Split("SPLIT", OperatorType::kTensorFlowSplit)); ops.emplace_back(new TensorFlowUnsupported( "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported)); @@ -823,12 +854,11 @@ std::vector> BuildOperatorList() { new SimpleOperator("RELU_N1_TO_1", OperatorType::kRelu1)); ops.emplace_back( new SimpleOperator("RELU6", OperatorType::kRelu6)); - ops.emplace_back(new SimpleOperator( - "RESIZE_BILINEAR", OperatorType::kResizeBilinear)); ops.emplace_back(new SimpleOperator( "LOGISTIC", OperatorType::kLogistic)); ops.emplace_back( new SimpleOperator("TANH", OperatorType::kTanh)); + ops.emplace_back(new SimpleOperator("EXP", OperatorType::kExp)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 9036a16d1c928702a71ccbe3fdad826fb037fcaf..5c486f72ade9ec5f366f075fcc39274bb7b12679 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -104,10 +104,9 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("RELU", OperatorType::kRelu); CheckSimpleOperator("RELU_N1_TO_1", OperatorType::kRelu1); CheckSimpleOperator("RELU6", OperatorType::kRelu6); - CheckSimpleOperator("RESIZE_BILINEAR", - OperatorType::kResizeBilinear); CheckSimpleOperator("LOGISTIC", OperatorType::kLogistic); CheckSimpleOperator("TANH", OperatorType::kTanh); + CheckSimpleOperator("EXP", OperatorType::kExp); } TEST_F(OperatorTest, BuiltinAdd) { @@ -119,40 +118,12 @@ TEST_F(OperatorTest, BuiltinAdd) { output_toco_op->fused_activation_function); } -TEST_F(OperatorTest, BuiltinSpaceToBatchND) { - SpaceToBatchNDOperator op; - op.block_shape = {2, 2}; - op.before_paddings = {1, 2}; - op.after_paddings = {3, 4}; - - auto output_toco_op = SerializeAndDeserialize( - GetOperator("SPACE_TO_BATCH_ND", OperatorType::kSpaceToBatchND), op); - EXPECT_EQ(op.block_shape, output_toco_op->block_shape); - EXPECT_EQ(op.before_paddings, output_toco_op->before_paddings); - EXPECT_EQ(op.after_paddings, output_toco_op->after_paddings); -} - -TEST_F(OperatorTest, BuiltinBatchToSpaceND) { - BatchToSpaceNDOperator op; - op.block_shape = {2, 2}; - op.before_crops = {1, 2}; - op.after_crops = {3, 4}; - - auto output_toco_op = SerializeAndDeserialize( - GetOperator("BATCH_TO_SPACE_ND", OperatorType::kBatchToSpaceND), op); - EXPECT_EQ(op.block_shape, output_toco_op->block_shape); - EXPECT_EQ(op.before_crops, output_toco_op->before_crops); - EXPECT_EQ(op.after_crops, output_toco_op->after_crops); -} - TEST_F(OperatorTest, BuiltinMean) { MeanOperator op; - op.axis = {1, 2}; op.keep_dims = false; auto output_toco_op = SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op); - EXPECT_EQ(op.axis, output_toco_op->axis); EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims); } @@ -359,6 +330,14 @@ TEST_F(OperatorTest, BuiltinMul) { output_toco_op->fused_activation_function); } +TEST_F(OperatorTest, ResizeBilinear) { + ResizeBilinearOperator op; + op.align_corners = true; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op); + EXPECT_EQ(op.align_corners, output_toco_op->align_corners); +} + TEST_F(OperatorTest, Svdf) { SvdfOperator op; op.fused_activation_function = FusedActivationFunctionType::kRelu; @@ -370,15 +349,6 @@ TEST_F(OperatorTest, Svdf) { EXPECT_EQ(op.rank, output_toco_op->rank); } -TEST_F(OperatorTest, Transpose) { - TransposeOperator op; - op.perm = {0, 1, 2, 3}; - - auto output_toco_op = SerializeAndDeserialize( - GetOperator("TRANSPOSE", OperatorType::kTranspose), op); - EXPECT_EQ(op.perm, output_toco_op->perm); -} - TEST_F(OperatorTest, Squeeze) { SqueezeOperator op; op.squeeze_dims = {-2, -3, 4, 1, 4}; @@ -410,6 +380,13 @@ TEST_F(OperatorTest, StridedSlice) { EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask); } +TEST_F(OperatorTest, BuiltinTopKV2) { + TopKV2Operator op; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TOPK_V2", OperatorType::kTopK_V2), op); + ASSERT_NE(nullptr, output_toco_op.get()); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index f8281f3a5725283d472e5e1a36e4d904b4dc1c49..c5a62fdb620ee7d6b7195f6e8e2bc3cb208feb10 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -44,9 +44,11 @@ bool ParseTocoFlagsFromCommandLineFlags( "For Protobuf formats, the binary format will be used."), Flag("input_format", parsed_flags.input_format.bind(), parsed_flags.input_format.default_value(), - "Input file format. One of: tensorflow_graphdef, "), + "Input file format. One of: TENSORFLOW_GRAPHDEF, TFLITE."), Flag("output_format", parsed_flags.output_format.bind(), - parsed_flags.output_format.default_value(), "Output file format."), + parsed_flags.output_format.default_value(), + "Output file format. " + "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."), Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), parsed_flags.default_ranges_min.default_value(), "If defined, will be used as the default value for the min bound " @@ -58,11 +60,13 @@ bool ParseTocoFlagsFromCommandLineFlags( Flag("inference_type", parsed_flags.inference_type.bind(), parsed_flags.inference_type.default_value(), "Target data type of arrays in the output file (for input_arrays, " - "this may be overridden by inference_input_type)."), + "this may be overridden by inference_input_type). " + "One of FLOAT, QUANTIZED_UINT8."), Flag("inference_input_type", parsed_flags.inference_input_type.bind(), parsed_flags.inference_input_type.default_value(), - "Target data type of input arrays. If not specified, inference_type " - "is used."), + "Target data type of input arrays. " + "If not specified, inference_type is used. " + "One of FLOAT, QUANTIZED_UINT8."), Flag("input_type", parsed_flags.input_type.bind(), parsed_flags.input_type.default_value(), "Deprecated ambiguous flag that set both --input_data_types and " @@ -76,35 +80,31 @@ bool ParseTocoFlagsFromCommandLineFlags( Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(), parsed_flags.drop_fake_quant.default_value(), - "Ignore and discard FakeQuant nodes. For instance, that can be used " - "to " + "Ignore and discard FakeQuant nodes. For instance, to " "generate plain float code without fake-quantization from a " - "quantized " - "graph."), + "quantized graph."), Flag( "reorder_across_fake_quant", parsed_flags.reorder_across_fake_quant.bind(), parsed_flags.reorder_across_fake_quant.default_value(), "Normally, FakeQuant nodes must be strict boundaries for graph " "transformations, in order to ensure that quantized inference has " - "the " - "exact same arithmetic behavior as quantized training --- which is " - "the " - "whole point of quantized training and of FakeQuant nodes in the " - "first " - "place. However, that entails subtle requirements on where exactly " + "the exact same arithmetic behavior as quantized training --- which " + "is the whole point of quantized training and of FakeQuant nodes in " + "the first place. " + "However, that entails subtle requirements on where exactly " "FakeQuant nodes must be placed in the graph. Some quantized graphs " "have FakeQuant nodes at unexpected locations, that prevent graph " "transformations that are necessary in order to generate inference " "code for these graphs. Such graphs should be fixed, but as a " "temporary work-around, setting this reorder_across_fake_quant flag " - "allows toco to perform necessary graph transformaitons on them, " + "allows TOCO to perform necessary graph transformaitons on them, " "at the cost of no longer faithfully matching inference and training " "arithmetic."), Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(), parsed_flags.allow_custom_ops.default_value(), - "If true, allow TOCO to create TF Lite Custom operators for all the" - "unsupported Tensorflow ops."), + "If true, allow TOCO to create TF Lite Custom operators for all the " + "unsupported TensorFlow ops."), Flag( "drop_control_dependency", parsed_flags.drop_control_dependency.bind(), diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h index 0572848cb5a998457cd669a2b0bce5fe8a0e15a2..4be3b5a0bf00ed204a1218545d9e66f7685a50d7 100644 --- a/tensorflow/contrib/lite/toco/toco_port.h +++ b/tensorflow/contrib/lite/toco/toco_port.h @@ -19,6 +19,7 @@ limitations under the License. // can build and use on google internal environments and on OSX. #include +#include "google/protobuf/text_format.h" #include "tensorflow/contrib/lite/toco/format_port.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/platform.h" @@ -75,6 +76,26 @@ void CopyToBuffer(const ::Cord& src, char* dest); #endif // PLATFORM_GOOGLE void CopyToBuffer(const string& src, char* dest); } // namespace port + +inline bool ParseFromStringOverload(const std::string& in, + TFLITE_PROTO_NS::Message* proto) { + return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto); +} + +template +bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents, + Proto* proto) { + if (proto->ParseFromString(input_file_contents)) { + return true; + } + + if (ParseFromStringOverload(input_file_contents, proto)) { + return true; + } + + return false; +} + } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 720c33777d707994c6e1003bb1210eadd96bc8a8..1b836fbc151db2141ad64d5370f15a43246fdd8b 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -53,33 +53,40 @@ void MakeGeneralGraphTransformationsSet( CHECK(transformations->empty()); transformations->Add(new ConvertExpandDimsToReshape); transformations->Add(new ConvertTrivialAddNToAdd); + transformations->Add(new ConvertTrivialStackToReshape); transformations->Add(new ConvertTrivialTransposeToReshape); transformations->Add(new ConvertReorderAxes); transformations->Add(new ResolveReshapeAttributes); + transformations->Add(new ResolveTransposeAttributes); transformations->Add(new PropagateArrayDataTypes); transformations->Add(new PropagateFixedSizes); transformations->Add(new RemoveTensorFlowAssert); transformations->Add(new RemoveTensorFlowIdentity); transformations->Add(new RemoveTrivialConcatenation); transformations->Add(new RemoveTrivialConcatenationInput); + transformations->Add(new RemoveTrivialSlice); transformations->Add(new RemoveUnusedOp); transformations->Add(new EnsureBiasVectors); transformations->Add(new ResolveReorderAxes); + transformations->Add(new UnrollBatchMatMul); transformations->Add(new ResolveTensorFlowMatMul); transformations->Add(new FuseBinaryIntoPrecedingAffine); transformations->Add(new FuseBinaryIntoFollowingAffine); + transformations->Add(new ReorderActivationFunctions); transformations->Add(new ResolveBatchNormalization); transformations->Add(new ResolveConstantBinaryOperator); transformations->Add(new ResolveConstantFill); transformations->Add(new ResolveConstantRange); transformations->Add(new ResolveConstantStack); transformations->Add(new ResolveConstantStridedSlice); + transformations->Add(new ResolveConstantTranspose); transformations->Add(new ResolveConstantUnaryOperator); transformations->Add(new ResolveTensorFlowMerge); transformations->Add(new ResolveSqueezeAttributes); transformations->Add(new ResolveTensorFlowSwitch); transformations->Add(new ResolveTensorFlowTile); transformations->Add(new ResolveTensorFlowConcat); + transformations->Add(new ResolveMultiplyByZero); transformations->Add(new IdentifyL2Normalization); transformations->Add(new IdentifyL2Pool); transformations->Add(new IdentifyRelu1); @@ -91,9 +98,9 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveStridedSliceAttributes); transformations->Add(new ResolveSliceAttributes); transformations->Add(new ResolveMeanAttributes); - transformations->Add(new ResolveTransposeAttributes); transformations->Add(new ResolveConstantShapeOrRank); transformations->Add(new MakeInitialDequantizeOperator); + transformations->Add(new ResolveConstantFakeQuant); } bool SupportsQuantization(FileFormat format) { @@ -105,7 +112,8 @@ bool SupportsFusedActivationFunction(FileFormat format) { } bool SupportsLstmCell(FileFormat format) { - return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT); + return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT || + format == TFLITE); } bool SupportsPreallocatedWorkspace(FileFormat format) { @@ -181,6 +189,11 @@ std::unique_ptr Import(const TocoFlags& toco_flags, } void Transform(const TocoFlags& toco_flags, Model* model) { + // Clean up after import. + SetFinalDataTypeOnInputs(toco_flags, model); + UseArraysExtraInfo(model); + FinishBuildingRNNStates(model); + const FileFormat output_format = toco_flags.output_format(); const IODataType inference_type = toco_flags.inference_type(); @@ -192,8 +205,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) { << "Quantized inference is not allowed with float inputs."; } - SetFinalDataTypeOnInputs(toco_flags, model); - // Remove unused ops before performing any other optimizations. This is to // stop optimizations from crossing the input/output boundaries. For example // this will stop BatchNorm fusing if the output node is in between a conv @@ -210,9 +221,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) { } else { transformations.Add(new UnfuseActivationFunctions); } - if (output_format != TENSORFLOW_GRAPHDEF) { - transformations.Add(new ResolveConstantFakeQuant); - } if (toco_flags.drop_fake_quant()) { transformations.Add(new DropFakeQuant); } else { @@ -225,13 +233,18 @@ void Transform(const TocoFlags& toco_flags, Model* model) { } } transformations.Add(new ConvertPureConvToDepthwise); - // TFLite export does not yet support fused LSTM cell. if (SupportsLstmCell(output_format)) { transformations.Add(new IdentifyLstmCell); + if (output_format == TFLITE) { + transformations.Add(new toco::SplitLstmCellInputs); + } else { + transformations.Add(new toco::MergeLstmCellInputs); + } } transformations.Add(new ResolveConstantConcatenation); RunGraphTransformations(model, "general graph transformations", transformations); + if (quantize_output) { RunGraphTransformations(model, "pre-quantization graph transformations", {new HardcodeMinMax, new DropFakeQuant}); @@ -264,6 +277,10 @@ void Transform(const TocoFlags& toco_flags, Model* model) { dequantization_transformations); } + if (output_format == TENSORFLOW_GRAPHDEF) { + EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model); + } + LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model); if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 6577bb778184ef774c4102aa0a22153a428d5c61..dcb409c84d8d80790f3c5e41a6eb7bce1b1efd2e 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" @@ -33,6 +34,24 @@ limitations under the License. namespace toco { +// Find the longest common prefix of two strings. +absl::string_view FindLongestCommonPrefix(absl::string_view a, + absl::string_view b) { + if (a.empty() || b.empty()) return absl::string_view(); + + const char* pa = a.data(); + const char* pb = b.data(); + size_t count = 0; + const size_t limit = std::min(a.size(), b.size()); + while (count < limit && *pa == *pb) { + ++pa; + ++pb; + ++count; + } + + return absl::string_view(a.data(), count); +} + string LogName(const Operator& op) { const string& opname = HelpfulOperatorTypeName(op); if (op.outputs.empty()) { @@ -92,7 +111,17 @@ int CountOpsWithInput(const Model& model, const string& array_name) { } bool DeleteArrayIfUnused(const string& array_name, Model* model) { - if (CountOpsWithInput(*model, array_name) == 0) { + if (IsDiscardableArray(*model, array_name) && + CountOpsWithInput(*model, array_name) == 0) { + model->EraseArray(array_name); + return true; + } + return false; +} + +bool DeleteArrayIfUsedOnce(const string& array_name, Model* model) { + if (IsDiscardableArray(*model, array_name) && + CountOpsWithInput(*model, array_name) == 1) { model->EraseArray(array_name); return true; } @@ -141,6 +170,18 @@ std::vector>::const_iterator FindOpWithInput( return model.operators.end(); } +std::vector>::iterator FindOpWithInput( + Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& input : it->get()->inputs) { + if (input == array_name) { + return it; + } + } + } + return model.operators.end(); +} + std::vector>::const_iterator FindOp( const Model& model, const Operator* op) { for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { @@ -199,6 +240,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Add) HANDLE_OPERATORTYPENAME_CASE(AddN) HANDLE_OPERATORTYPENAME_CASE(AveragePool) + HANDLE_OPERATORTYPENAME_CASE(BatchMatMul) HANDLE_OPERATORTYPENAME_CASE(BatchNormalization) HANDLE_OPERATORTYPENAME_CASE(Conv) HANDLE_OPERATORTYPENAME_CASE(Concatenation) @@ -220,6 +262,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Relu6) HANDLE_OPERATORTYPENAME_CASE(ReorderAxes) HANDLE_OPERATORTYPENAME_CASE(Softmax) + HANDLE_OPERATORTYPENAME_CASE(LogSoftmax) HANDLE_OPERATORTYPENAME_CASE(Div) HANDLE_OPERATORTYPENAME_CASE(Tanh) HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll) @@ -270,7 +313,9 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Mean) HANDLE_OPERATORTYPENAME_CASE(Svdf) HANDLE_OPERATORTYPENAME_CASE(ArgMax) + HANDLE_OPERATORTYPENAME_CASE(TopK_V2) HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported) + HANDLE_OPERATORTYPENAME_CASE(Exp) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -286,6 +331,20 @@ string HelpfulOperatorTypeName(const Operator& op) { return OperatorTypeName(op.type); } +bool OperatorSupportsFusedActivation(OperatorType type) { + switch (type) { + case OperatorType::kConcatenation: + case OperatorType::kGather: + case OperatorType::kSlice: + case OperatorType::kSqueeze: + case OperatorType::kTensorFlowReshape: + case OperatorType::kTensorFlowSplit: + return false; + default: + return true; + } +} + void LogSummary(int log_level, const Model& model) { VLOG(log_level) << "Operators summary (" << model.operators.size() << " operators):"; @@ -375,7 +434,7 @@ void LogArray(int log_level, const Model& model, const string& name) { } if (array.quantization_params) { VLOG(log_level) << " QuantizationParams: zero_point=" - << array.quantization_params->zero_point + << static_cast(array.quantization_params->zero_point) << ", scale=" << array.quantization_params->scale; } } @@ -575,6 +634,14 @@ bool IsConstantParameterArray(const Model& model, const string& name) { } namespace { +// Take an array name, which may be something like "name:3_5" and make it +// acceptable as a TF node name, say "name_3_5"; +string SanitizeNameForTFNode(const string& array_name) { + auto node_name = array_name; + std::replace(node_name.begin(), node_name.end(), ':', '_'); + return node_name; +} + void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) { for (const auto& input_array : model_flags.input_arrays()) { for (const string& output_array : model_flags.output_arrays()) { @@ -654,12 +721,10 @@ void CheckNoMissingArray(const Model& model) { for (const auto& op : model.operators) { for (const auto& input : op->inputs) { CHECK(model.HasArray(input) || model.optional_arrays.count(input)) - << "Input: " << input << " missing for op: " - << op->outputs[0] << "."; + << "Input: " << input << " missing for op: " << op->outputs[0] << "."; } for (const auto& output : op->outputs) { - CHECK(model.HasArray(output)) << "Output: " << output - << " missing."; + CHECK(model.HasArray(output)) << "Output: " << output << " missing."; } } CheckNonExistentIOArrays(model); @@ -740,7 +805,10 @@ void FixNoOrphanedArray(Model* model) { } } -void CheckArrayFieldsConsistent(const Model& model) { +// Apply checks to arrays individually (for-each fashion). +// +// Check consistency of array fields, check name. +void CheckEachArray(const Model& model) { for (const auto& array_entry : model.GetArrayMap()) { const auto& array = array_entry.second; if (array->has_shape()) { @@ -755,6 +823,18 @@ void CheckArrayFieldsConsistent(const Model& model) { if (array->buffer) { CHECK(array->buffer->type == array->data_type); } + + // Check name. Either "name_with_suffix_8", "name_with_port:3", but not + // "name_with_both:3_8". + const string& name = array_entry.first; + auto colon_pos = name.find_first_of(":"); + if (colon_pos != string::npos) { + CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"), + string::npos) + << "Array name must only have digits after colon"; + } + CHECK_GT(colon_pos, 0) + << "First character of array name must not be a colon."; } } @@ -903,7 +983,7 @@ void CheckInvariants(const Model& model) { CheckNonAsciiIOArrays(model.flags); CheckNoMissingArray(model); CheckNoOrphanedArray(model); - CheckArrayFieldsConsistent(model); + CheckEachArray(model); CheckOperatorOrdering(model); } @@ -961,7 +1041,9 @@ void CheckModelCounts(const Model& model) { void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, std::vector* out_dims) { CHECK(out_dims->empty()); - if (num_dims == 1) { + if (num_dims == 0) { + return; + } else if (num_dims == 1) { CHECK_EQ(batch, 1); *out_dims = {depth}; } else if (num_dims == 2) { @@ -993,13 +1075,10 @@ void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) { if (array.has_shape()) { num_dims = array.shape().dimensions_count(); } - std::vector dims; - MakeArrayDims(num_dims, batch, 1, 1, size, &dims); - CHECK(array.data_type == ArrayDataType::kFloat || - array.data_type == ArrayDataType::kNone); - array.data_type = ArrayDataType::kFloat; - if (!array.has_shape()) { + if (!array.has_shape() && num_dims >= 0) { Shape* shape = array.mutable_shape(); + std::vector dims; + MakeArrayDims(num_dims, batch, 1, 1, size, &dims); *shape->mutable_dims() = dims; } } @@ -1019,7 +1098,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { } } if (!dst_input_array) { - // specified_input_array from model_flags is not found in model->flags. + // Specified_input_array from model_flags is not found in model->flags. // Match a name-less specified input array when there can be no ambiguity // as there is only 1 input array. if (model->flags.input_arrays_size() == 1 && @@ -1188,9 +1267,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { } // Creation of the RNN state arrays for (const auto& rnn_state : model->flags.rnn_states()) { - if (!rnn_state.manually_create()) { - continue; - } CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(), model); } @@ -1204,6 +1280,9 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays()); model->flags.set_allow_nonexistent_arrays( model_flags.allow_nonexistent_arrays()); + + CHECK(!model->flags.has_arrays_extra_info()); + *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info(); } void CheckIsReadyForQuantization(const Model& model) { @@ -1263,12 +1342,23 @@ int ElementSize(ArrayDataType data_type) { switch (data_type) { case ArrayDataType::kFloat: return 4; - case ArrayDataType::kInt32: - return 4; + case ArrayDataType::kInt8: + return 1; case ArrayDataType::kUint8: return 1; + case ArrayDataType::kInt16: + return 2; + case ArrayDataType::kUint16: + return 2; + case ArrayDataType::kInt32: + return 4; + case ArrayDataType::kUint32: + return 4; case ArrayDataType::kInt64: return 8; + case ArrayDataType::kUint64: + return 8; + // Usually not critical limitation because strings are only input and/or // output. case ArrayDataType::kString: @@ -1315,18 +1405,23 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) { } string AvailableArrayName(const Model& model, const string& name) { - if (!model.HasArray(name) && !model.optional_arrays.count(name)) { - return name; + string sanitized_name = SanitizeNameForTFNode(name); + if (!model.HasArray(sanitized_name) && + !model.IsOptionalArray(sanitized_name)) { + return sanitized_name; } const int kNumSuffixesToTry = 1000; for (int i = 0; i < kNumSuffixesToTry; i++) { - const string& name_with_suffix = toco::port::StringF("%s_%d", name, i); - if (!model.HasArray(name_with_suffix)) { + const string& name_with_suffix = + toco::port::StringF("%s_%d", sanitized_name, i); + if (!model.HasArray(name_with_suffix) && + !model.IsOptionalArray(name_with_suffix)) { return name_with_suffix; } } - LOG(FATAL) << "Could not find an available array name starting with " << name - << ". Tried " << kNumSuffixesToTry << " suffixes, all were taken!"; + LOG(FATAL) << "Could not find an available array name starting with " + << sanitized_name << ". Tried " << kNumSuffixesToTry + << " suffixes, all were taken!"; return ""; } @@ -1365,6 +1460,21 @@ bool IsArrayFullyConnectedWeights(const Model& model, const string& name) { return is_fc_weights; } +string CreateInt32Array(Model* model, const string& param_name, + const std::vector& value) { + auto param_array_name = AvailableArrayName(*model, param_name); + auto& param_array = model->GetOrCreateArray(param_array_name); + param_array.mutable_shape()->ReplaceDims({static_cast(value.size())}); + param_array.data_type = ArrayDataType::kInt32; + auto& param_array_data = + param_array.GetMutableBuffer().data; + param_array_data.resize(RequiredBufferSizeForShape(param_array.shape())); + for (int i = 0; i < value.size(); ++i) { + param_array_data[i] = value[i]; + } + return param_array_name; +} + bool EstimateArithmeticOpsCount(const Model& model, int64* result) { int64 total = 0; for (const auto& op : model.operators) { @@ -1412,6 +1522,7 @@ bool EstimateArithmeticOpsCount(const Model& model, int64* result) { } case OperatorType::kLogistic: case OperatorType::kSoftmax: + case OperatorType::kLogSoftmax: case OperatorType::kTanh: { const auto& output_array = model.GetArray(op->outputs[0]); if (!output_array.has_shape()) { @@ -1511,10 +1622,6 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, } } -namespace { - -// Extend shuffle is designed to match ExtendShape, which pads the shape with -// unit dimensions at the beginning. void ExtendShuffle(const std::vector& input_shuffle, int newdim, std::vector* extended_shuffle) { *extended_shuffle = input_shuffle; @@ -1529,8 +1636,6 @@ void ExtendShuffle(const std::vector& input_shuffle, int newdim, } } -} // end anonymous namespace - void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, AxesOrder output_axes_order, Shape* output_shape) { if (input_axes_order == AxesOrder::kHWIM && @@ -1715,4 +1820,32 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) { } } +void FinishBuildingRNNStates(Model* model) { + for (const auto& rnn_state : model->flags.rnn_states()) { + if (!model->HasArray(rnn_state.back_edge_source_array()) || + !model->HasArray(rnn_state.state_array())) { + CHECK(model->HasArray(rnn_state.back_edge_source_array())); + CHECK(model->HasArray(rnn_state.state_array())); + continue; + } + const auto& src_array = model->GetArray(rnn_state.back_edge_source_array()); + auto& dst_array = model->GetArray(rnn_state.state_array()); + if (src_array.data_type == ArrayDataType::kNone && + dst_array.data_type == ArrayDataType::kNone) { + dst_array.data_type = ArrayDataType::kFloat; + } + } +} + +void UseArraysExtraInfo(Model* model) { + for (const auto& entry : model->flags.arrays_extra_info().entries()) { + QCHECK(model->HasArray(entry.name())) + << "ArraysExtraInfo refers to non-existent array name: " + << entry.name(); + auto& minmax = model->GetArray(entry.name()).GetOrCreateMinMax(); + minmax.min = entry.min(); + minmax.max = entry.max(); + } +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 5986d6364939e0f01b057ce3fb653b19fe8040cd..0aaa0f6a215288430dfdde7d5042012730c3be4c 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -23,7 +23,7 @@ limitations under the License. #include #include -#include "google/protobuf/text_format.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/platform/logging.h" #if TOCO_SUPPORT_PORTABLE_PROTOS #include "third_party/protobuf/src/google/protobuf/text_format.h" @@ -50,6 +50,8 @@ namespace toco { constexpr int kLogLevelModelChanged = 1; constexpr int kLogLevelModelUnchanged = 2; +absl::string_view FindLongestCommonPrefix(absl::string_view a, + absl::string_view b); string LogName(const Operator& op); bool IsInputArray(const Model& model, const string& name); @@ -58,6 +60,7 @@ int CountTrueOutputs(const Model& model, const Operator& op); int CountOpsWithInput(const Model& model, const string& array_name); bool DeleteArrayIfUnused(const string& array_name, Model* model); +bool DeleteArrayIfUsedOnce(const string& array_name, Model* model); std::vector>::const_iterator FindOpWithOutput( const Model& model, const string& array_name); @@ -65,10 +68,15 @@ Operator* GetOpWithOutput(const Model& model, const string& array_name); std::vector>::iterator FindOpWithOutput( Model& model, const string& array_name); + Operator* GetOpWithOutput(const Model& model, const string& array_name); std::vector>::const_iterator FindOpWithInput( const Model& model, const string& array_name); + +std::vector>::iterator FindOpWithInput( + Model& model, const string& array_name); + Operator* GetOpWithInput(const Model& model, const string& array_name); Operator* GetFirstOpWithInput(const Model& model, const string& array_name); @@ -80,29 +88,12 @@ std::vector>::iterator FindOp(Model& model, const char* OperatorTypeName(OperatorType type); string HelpfulOperatorTypeName(const Operator& op); +bool OperatorSupportsFusedActivation(OperatorType type); + void DumpGraphvizVideoFrame(const Model& model); void LogDump(int log_level, const string& message, const Model& model); void LogSummary(int log_level, const string& message, const Model& model); -inline bool ParseFromStringOverload(const std::string& in, - TFLITE_PROTO_NS::Message* proto) { - return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto); -} - -template -bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents, - Proto* proto) { - if (proto->ParseFromString(input_file_contents)) { - return true; - } - - if (ParseFromStringOverload(input_file_contents, proto)) { - return true; - } - - return false; -} - // TODO(b/36075966): Clean up when dims superseded by array shape. void ExtendShape(Shape* shape, int new_shape_size); @@ -270,6 +261,11 @@ void PrintArrayShape(Model* model, const string& name); void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, std::vector* out_dims); +// Defines a constant int32 array with the provided values formatted for use +// as op parameters. +string CreateInt32Array(Model* model, const string& param_name, + const std::vector& value); + bool EstimateArithmeticOpsCount(const Model& model, int64* result); int AxesCount(AxesOrder axes_order); @@ -279,6 +275,11 @@ int AxesCount(AxesOrder axes_order); void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, std::vector* shuffle); +// Extend shuffle is designed to match ExtendShape, which pads the shape with +// unit dimensions at the beginning. +void ExtendShuffle(const std::vector& input_shuffle, int newdim, + std::vector* extended_shuffle); + void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, AxesOrder output_axes_order, Shape* output_shape); void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, @@ -298,6 +299,25 @@ void CheckFinalDataTypesSatisfied(const Model& model); ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type); +// The process of building models varies according to the import format. +// +// (a) In some cases, such as model-proto format, the model should be fully +// specified. In these cases, no extra action should be taken by this function. +// (b) In other cases, such as TF graphdef format, the desired types of RNN +// arrays are not specified directly in the model, neither can they be inferred. +// However, we can set the types of RNN destination arrays to float. This breaks +// any cycles such as when resolution of the type of an RNN source array depends +// on the type of its destination array. +// +// This function is applied after the main import, after resolution of flags and +// after application of ArraysExtraInfo. It only defaults destination RNN arrays +// to float. If the model is subsequently quantized, it is assumed that the +// model contains sufficient information for that to be completed. If it is +// already quantized, then case (a) should hold. +void FinishBuildingRNNStates(Model* model); + +void UseArraysExtraInfo(Model* model); + } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 20df905270b0692e2bc9b78fc020447108282d01..999ccf2ebc009b6b7c50a9a2d1667d69a3f690e7 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -93,3 +93,34 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +cc_library( + name = "verifier", + srcs = ["verifier.cc"], + hdrs = ["verifier.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/schema:schema_fbs", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_test( + name = "verifier_test", + size = "small", + srcs = ["verifier_test.cc"], + deps = [ + ":mutable_op_resolver", + ":verifier", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/testing:util", + "//tensorflow/core:framework_lite", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/contrib/lite/tools/verifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..59c74205f0a311ec12ff87f46622041605fb493b --- /dev/null +++ b/tensorflow/contrib/lite/tools/verifier.cc @@ -0,0 +1,234 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/verifier.h" +#include +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { + +namespace { + +// Reports error message when the reporter is set. +void ReportError(ErrorReporter* error_reporter, const char* format, ...) { + if (error_reporter) { + va_list args; + va_start(args, format); + error_reporter->Report(format, args); + va_end(args); + } +} + +// Returns the int32_t value pointed by ptr. +const uint32_t* GetIntPtr(const char* ptr) { + return reinterpret_cast(ptr); +} + +// Verifies flatbuffer format of the model contents and returns the in-memory +// model. +const Model* VerifyFlatbufferAndGetModel(const void* buf, size_t len) { + ::flatbuffers::Verifier verifier(static_cast(buf), len); + if (VerifyModelBuffer(verifier)) { + return ::tflite::GetModel(buf); + } else { + return nullptr; + } +} + +const uint32_t kMaxNumString = UINT_MAX / sizeof(int32_t) - 2; + +// Verifies string tensor has legit buffer contents that follow the schema +// defined in lite/string_util.h +bool VerifyStringTensorBuffer(const Buffer& buffer, + ErrorReporter* error_reporter) { + uint32_t buffer_size = buffer.data()->size(); + const char* buffer_ptr = reinterpret_cast(buffer.data()->data()); + + uint32_t num_strings = *GetIntPtr(buffer_ptr); + if (num_strings > kMaxNumString) { + ReportError(error_reporter, + "String tensor has invalid num of string set: %d", num_strings); + return false; + } + uint32_t header_offsets = + static_cast(num_strings + 2) * sizeof(int32_t); + + if (buffer_size < header_offsets) { + ReportError(error_reporter, + "String tensor buffer requires at least %d bytes, but is " + "allocated with %d bytes", + header_offsets, buffer_size); + return false; + } + + uint32_t prev_ptr = header_offsets; + uint32_t offset = sizeof(int32_t); + + if (*GetIntPtr(buffer_ptr + offset) != header_offsets) { + ReportError(error_reporter, + "String tensor buffer initial offset must be: %d", + header_offsets); + return false; + } + offset += sizeof(int32_t); + for (int i = 1; i <= num_strings; i++, offset += sizeof(int32_t)) { + int string_offset = *GetIntPtr(buffer_ptr + offset); + if (string_offset < prev_ptr || string_offset > buffer_size) { + ReportError(error_reporter, "String tensor buffer is invalid: index %d", + i); + return false; + } + } + if (*GetIntPtr(buffer_ptr + offset - sizeof(int32_t)) != buffer_size) { + ReportError(error_reporter, "String tensor buffer last offset must be %d", + buffer_size); + return false; + } + return true; +} + +// Verifies numeric tensor has legit buffer. +bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer, + ErrorReporter* error_reporter) { + uint64_t bytes_required = 1; + for (int dim : *tensor.shape()) { + bytes_required *= dim; + if (bytes_required > UINT_MAX) { + ReportError(error_reporter, "Tensor dimension overflow"); + return false; + } + } + switch (tensor.type()) { + case TensorType_FLOAT32: + bytes_required *= sizeof(float); + break; + case TensorType_INT32: + bytes_required *= sizeof(int32_t); + break; + case TensorType_UINT8: + bytes_required *= sizeof(uint8_t); + break; + case TensorType_INT64: + bytes_required *= sizeof(int64_t); + break; + case TensorType_FLOAT16: + // FALLTHROUGH_INTENDED; + default: + ReportError(error_reporter, "Invalid tensor type: %d", tensor.type()); + return false; + } + if (bytes_required > UINT_MAX) { + ReportError(error_reporter, "Tensor dimension overflow"); + return false; + } + + if (bytes_required != buffer.data()->size()) { + ReportError( + error_reporter, + "Tensor requires %d bytes, but is allocated with %d bytes buffer", + bytes_required, buffer.data()->size()); + return false; + } + return true; + + // TODO(yichengfan): verify quantized tensors. +} + +// Verifies tensors have valid properties and legit buffer if set. +bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) { + if (!model.subgraphs()) { + return true; + } + for (const auto& subgraph : *model.subgraphs()) { + if (!subgraph->tensors()) { + continue; + } + for (const auto& tensor : *subgraph->tensors()) { + if (!tensor->buffer()) { + continue; + } + if (tensor->buffer() >= model.buffers()->size()) { + ReportError(error_reporter, "Invalid tensor buffer index: %d", + tensor->buffer()); + return false; + } + auto* buffer = model.buffers()->Get(tensor->buffer()); + if (!buffer || !buffer->data()) { + ReportError(error_reporter, "Tensor buffer %d not set", + tensor->buffer()); + return false; + } + + if (tensor->type() == TensorType_STRING) { + if (!VerifyStringTensorBuffer(*buffer, error_reporter)) { + return false; + } + } else { + if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) { + return false; + } + } + } + } + return true; +} + +bool VerifyOps(const Model& model, const OpResolver& resolver, + ErrorReporter* error_reporter) { + if (!model.operator_codes()) { + return true; + } + for (const auto& opcode : *model.operator_codes()) { + if (opcode->builtin_code() == BuiltinOperator_CUSTOM) { + if (!resolver.FindOp(opcode->custom_code()->c_str())) { + ReportError(error_reporter, "Unsupported custom op: %s", + opcode->custom_code()->c_str()); + return false; + } + } else { + if (!resolver.FindOp(opcode->builtin_code())) { + ReportError(error_reporter, "Unsupported builtin op: %s", + EnumNameBuiltinOperator(opcode->builtin_code())); + return false; + } + } + } + return true; +} + +} // namespace + +bool Verify(const void* buf, size_t len, const OpResolver& resolver, + ErrorReporter* error_reporter) { + const Model* model = VerifyFlatbufferAndGetModel(buf, len); + if (model == nullptr) { + ReportError(error_reporter, "Invalid flatbuffer format"); + return false; + } + if (model->version() != TFLITE_SCHEMA_VERSION) { + ReportError(error_reporter, "Invalid model version %d", model->version()); + return false; + } + if (!VerifyTensors(*model, error_reporter)) { + return false; + } + if (!VerifyOps(*model, resolver, error_reporter)) { + return false; + } + return true; +} +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/contrib/lite/tools/verifier.h new file mode 100644 index 0000000000000000000000000000000000000000..c2ee11215c861ed7b27696a8d786bb6e2a48e930 --- /dev/null +++ b/tensorflow/contrib/lite/tools/verifier.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_ + +#include + +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { + +// Verifies the integrity of a Tensorflow Lite flatbuffer model file. +// Currently, it verifies: +// * The file is following a legit flatbuffer schema. +// * The model is in supported version. +// * All ops used in the model are supported by OpResolver. +bool Verify(const void* buf, size_t len, const OpResolver& resolver, + ErrorReporter* error_reporter); + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_ diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3e611f999b2837efbf8876bd989db44c408b8c7 --- /dev/null +++ b/tensorflow/contrib/lite/tools/verifier_test.cc @@ -0,0 +1,288 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "flatbuffers/flatbuffers.h" +#include "flatbuffers/util.h" +#include +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" +#include "tensorflow/contrib/lite/tools/verifier.h" +#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/core/framework/numeric_types.h" + +namespace tflite { + +using flatbuffers::FlatBufferBuilder; +using flatbuffers::Offset; +using flatbuffers::Vector; + +// Build single subgraph model. +class TfLiteFlatbufferModelBuilder { + public: + TfLiteFlatbufferModelBuilder() { + buffers_.push_back( + CreateBuffer(builder_, builder_.CreateVector(std::vector{}))); + } + + TfLiteFlatbufferModelBuilder(const std::vector& builtin_ops, + const std::vector& custom_ops) { + buffers_.push_back( + CreateBuffer(builder_, builder_.CreateVector(std::vector{}))); + + for (const auto& iter : builtin_ops) { + resolver_.AddBuiltin(iter, &fake_op_); + } + for (const auto& iter : custom_ops) { + resolver_.AddCustom(iter.data(), &fake_op_); + } + } + + void AddTensor(const std::vector& shape, tflite::TensorType type, + const std::vector& buffer, const char* name) { + int buffer_index = 0; + if (!buffer.empty()) { + buffer_index = buffers_.size(); + buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector(buffer))); + } + tensors_.push_back(CreateTensorDirect(builder_, &shape, type, buffer_index, + name, /*quantization=*/0)); + } + + void AddOperator(const std::vector& inputs, + const std::vector& outputs, + tflite::BuiltinOperator builtin_op, const char* custom_op) { + operator_codes_.push_back( + CreateOperatorCodeDirect(builder_, builtin_op, custom_op)); + operators_.push_back(CreateOperator( + builder_, operator_codes_.size() - 1, builder_.CreateVector(inputs), + builder_.CreateVector(outputs), BuiltinOptions_NONE, + /*builtin_options=*/0, + /*custom_options=*/0, tflite::CustomOptionsFormat_FLEXBUFFERS)); + } + + void FinishModel(const std::vector& inputs, + const std::vector& outputs) { + auto subgraph = std::vector>({CreateSubGraph( + builder_, builder_.CreateVector(tensors_), + builder_.CreateVector(inputs), builder_.CreateVector(outputs), + builder_.CreateVector(operators_), + builder_.CreateString("test_subgraph"))}); + auto result = CreateModel( + builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(operator_codes_), + builder_.CreateVector(subgraph), builder_.CreateString("test_model"), + builder_.CreateVector(buffers_)); + tflite::FinishModelBuffer(builder_, result); + } + + bool Verify() { + return tflite::Verify(builder_.GetBufferPointer(), builder_.GetSize(), + resolver_, DefaultErrorReporter()); + } + + private: + FlatBufferBuilder builder_; + MutableOpResolver resolver_; + TfLiteRegistration fake_op_; + std::vector> operators_; + std::vector> operator_codes_; + std::vector> tensors_; + std::vector> buffers_; +}; + +TEST(VerifyModel, TestEmptyModel) { + FlatBufferBuilder builder; + auto model = CreateModel(builder, /*version=*/TFLITE_SCHEMA_VERSION, + /*operator_codes=*/0, /*subgraphs=*/0, + /*description=*/0, /*buffers=*/0); + ::tflite::FinishModelBuffer(builder, model); + + ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(), + MutableOpResolver{}, DefaultErrorReporter())); +} + +TEST(VerifyModel, TestSimpleModel) { + TfLiteFlatbufferModelBuilder builder({}, {"test"}); + builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "test"); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4, 5, 6}, "input"); + builder.AddTensor( + {2}, TensorType_STRING, + {2, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 19, 0, 0, 0, 'A', 'B', 'C'}, + "data"); + builder.AddTensor({2, 3}, TensorType_INT32, {}, "output"); + builder.FinishModel({0, 1}, {2}); + ASSERT_TRUE(builder.Verify()); +} + +TEST(VerifyModel, TestCorruptedData) { + std::string model = "123"; + ASSERT_FALSE(Verify(model.data(), model.size(), MutableOpResolver{}, + /*error_reporter=*/nullptr)); +} + +TEST(VerifyModel, TestUnsupportedVersion) { + FlatBufferBuilder builder; + auto model = CreateModel(builder, /*version=*/1, /*operator_codes=*/0, + /*subgraphs=*/0, /*description=*/0, /*buffers=*/0); + ::tflite::FinishModelBuffer(builder, model); + ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), + MutableOpResolver{}, DefaultErrorReporter())); +} + +TEST(VerifyModel, TestRandomModificationIsNotAllowed) { + FlatBufferBuilder builder; + auto model = CreateModel(builder, /*version=*/TFLITE_SCHEMA_VERSION, + /*operator_codes=*/0, + /*subgraphs=*/0, /*description=*/0, /*buffers=*/0); + ::tflite::FinishModelBuffer(builder, model); + + std::string model_content(reinterpret_cast(builder.GetBufferPointer()), + builder.GetSize()); + for (int i = 0; i < model_content.size(); i++) { + model_content[i] = (model_content[i] + 137) % 255; + EXPECT_FALSE(Verify(model_content.data(), model_content.size(), + MutableOpResolver{}, DefaultErrorReporter())) + << "Fail at position: " << i; + } +} + +TEST(VerifyModel, TestIntTensorShapeIsGreaterThanBuffer) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, TestIntTensorShapeIsSmallerThanBuffer) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor({2, 1}, TensorType_UINT8, {1, 2, 3, 4}, "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, TestIntTensorShapeOverflow) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor({1024, 2048, 4096}, TensorType_UINT8, {1, 2, 3, 4}, + "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, TensorBufferIsNotValid) { + FlatBufferBuilder builder; + std::vector shape = {2, 3}; + auto tensors = builder.CreateVector(std::vector>{ + CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/2, + "input", /*quantization=*/0)}); + auto subgraph = std::vector>( + {CreateSubGraph(builder, tensors, /*inputs=*/0, /*outputs=*/0, + /*operators=*/0, builder.CreateString("Main"))}); + + auto buffers = builder.CreateVector(std::vector>{ + CreateBuffer(builder, + builder.CreateVector(std::vector{1, 2, 3, 4, 5, 6})), + }); + + auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, /*operator_codes=*/0, + builder.CreateVector(subgraph), + builder.CreateString("SmartReply"), buffers); + + ::tflite::FinishModelBuffer(builder, model); + ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), + MutableOpResolver{}, DefaultErrorReporter())); +} + +TEST(VerifyModel, StringTensorHasInvalidNumString) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor( + {2}, TensorType_STRING, + {0x00, 0x00, 0x00, 0x20, 16, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 'A', 'B'}, + "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, StringTensorOffsetTooSmall) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor( + {2}, TensorType_STRING, + {2, 0, 0, 0, 12, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 'A', 'B'}, "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, StringTensorOffsetOutOfRange) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor( + {2}, TensorType_STRING, + {2, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 22, 0, 0, 0, 'A', 'B'}, "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, StringTensorIsLargerThanRequired) { + TfLiteFlatbufferModelBuilder builder; + builder.AddTensor( + {2}, TensorType_STRING, + {2, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 'A', 'B', 'C'}, + "input"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, AllOpsAreSupported) { + TfLiteFlatbufferModelBuilder builder({BuiltinOperator_ADD}, {"CustomOp"}); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1"); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2"); + builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output"); + builder.AddOperator({0, 1}, {2}, BuiltinOperator_ADD, nullptr); + builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "CustomOp"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, UseUnsupportedBuiltinOps) { + TfLiteFlatbufferModelBuilder builder({BuiltinOperator_SUB}, {"CustomOp"}); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1"); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2"); + builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output"); + builder.AddOperator({0, 1}, {2}, BuiltinOperator_ADD, nullptr); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, UseUnsupportedCustomOps) { + TfLiteFlatbufferModelBuilder builder({BuiltinOperator_ADD}, {"NewOp"}); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1"); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2"); + builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output"); + builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "Not supported"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +// TODO(yichengfan): make up malicious files to test with. + +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py index d0d78e3afab7d89f216bb8ceb42e4429ca4f1759..f571dd59da0a3f4aff264b48fba3e41f75b50404 100644 --- a/tensorflow/contrib/lite/tools/visualize.py +++ b/tensorflow/contrib/lite/tools/visualize.py @@ -198,10 +198,13 @@ class TensorMapper(object): def GenerateGraph(subgraph_idx, g, opcode_mapper): """Produces the HTML required to have a d3 visualization of the dag.""" + def TensorName(idx): - return "t%d"%idx + return "t%d" % idx + def OpName(idx): - return "o%d"%idx + return "o%d" % idx + edges = [] nodes = [] first = {} @@ -210,27 +213,35 @@ def GenerateGraph(subgraph_idx, g, opcode_mapper): for tensor_input_position, tensor_index in enumerate(op["inputs"]): if tensor_index not in first: first[tensor_index] = ( - op_index*pixel_mult, - tensor_input_position*pixel_mult - pixel_mult/2) - edges.append( - {"source": TensorName(tensor_index), "target": OpName(op_index)}) + op_index * pixel_mult, + tensor_input_position * pixel_mult - pixel_mult / 2) + edges.append({ + "source": TensorName(tensor_index), + "target": OpName(op_index) + }) for tensor_index in op["outputs"]: - edges.append( - {"target": TensorName(tensor_index), "source": OpName(op_index)}) - nodes.append({"id": OpName(op_index), - "name": opcode_mapper(op["opcode_index"]), - "group": 2, - "x": pixel_mult, - "y": op_index * pixel_mult}) + edges.append({ + "target": TensorName(tensor_index), + "source": OpName(op_index) + }) + nodes.append({ + "id": OpName(op_index), + "name": opcode_mapper(op["opcode_index"]), + "group": 2, + "x": pixel_mult, + "y": op_index * pixel_mult + }) for tensor_index, tensor in enumerate(g["tensors"]): - initial_y = (first[tensor_index] if tensor_index in first - else len(g["operators"])) - - nodes.append({"id": TensorName(tensor_index), - "name": "%s (%d)" % (tensor["name"], tensor_index), - "group": 1, - "x": 2, - "y": initial_y}) + initial_y = ( + first[tensor_index] if tensor_index in first else len(g["operators"])) + + nodes.append({ + "id": TensorName(tensor_index), + "name": "%s (%d)" % (tensor["name"], tensor_index), + "group": 1, + "x": 2, + "y": initial_y + }) graph_str = json.dumps({"nodes": nodes, "edges": edges}) html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx) @@ -267,7 +278,7 @@ def GenerateTableHtml(items, keys_to_print, display_index=True): for h, mapper in keys_to_print: val = tensor[h] if h in tensor else None val = val if mapper is None else mapper(val) - html += "%s\n"%val + html += "%s\n" % val html += "\n" html += "\n" @@ -279,18 +290,19 @@ def CreateHtmlFile(tflite_input, html_output): # Convert the model into a JSON flatbuffer using flatc (build if doesn't # exist. - if not os.path.exists(tflite_input): + if not os.path.exists(tflite_input): raise RuntimeError("Invalid filename %r" % tflite_input) if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"): # Run convert - cmd = (_BINARY + " -t " - "--strict-json --defaults-json -o /tmp {schema} -- {input}".format( - input=tflite_input, schema=_SCHEMA)) + cmd = ( + _BINARY + " -t " + "--strict-json --defaults-json -o /tmp {schema} -- {input}".format( + input=tflite_input, schema=_SCHEMA)) print(cmd) os.system(cmd) - real_output = ("/tmp/"+ os.path.splitext(os.path.split(tflite_input)[-1])[0] - + ".json") + real_output = ("/tmp/" + os.path.splitext( + os.path.split(tflite_input)[-1])[0] + ".json") data = json.load(open(real_output)) elif tflite_input.endswith(".json"): @@ -302,12 +314,13 @@ def CreateHtmlFile(tflite_input, html_output): html += "

TensorFlow Lite Model

" data["filename"] = tflite_input # Avoid special case - toplevel_stuff = [("filename", None), ("version", None), - ("description", None)] + toplevel_stuff = [("filename", None), ("version", None), ("description", + None)] html += "\n" for key, mapping in toplevel_stuff: - if not mapping: mapping = lambda x: x + if not mapping: + mapping = lambda x: x html += "\n" % (key, mapping(data[key])) html += "
%s%s
\n" @@ -320,22 +333,22 @@ def CreateHtmlFile(tflite_input, html_output): html += "
" tensor_mapper = TensorMapper(g) opcode_mapper = OpCodeMapper(data) - op_keys_to_display = [ - ("inputs", tensor_mapper), ("outputs", tensor_mapper), - ("builtin_options", None), ("opcode_index", opcode_mapper)] - tensor_keys_to_display = [ - ("name", None), ("type", None), ("shape", None), ("buffer", None), - ("quantization", None)] + op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper), + ("builtin_options", None), ("opcode_index", + opcode_mapper)] + tensor_keys_to_display = [("name", None), ("type", None), ("shape", None), + ("buffer", None), ("quantization", None)] html += "

Subgraph %d

\n" % subgraph_idx # Inputs and outputs. html += "

Inputs/Outputs

\n" - html += GenerateTableHtml([{"inputs": g["inputs"], - "outputs": g["outputs"]}], - [("inputs", tensor_mapper), - ("outputs", tensor_mapper)], - display_index=False) + html += GenerateTableHtml( + [{ + "inputs": g["inputs"], + "outputs": g["outputs"] + }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)], + display_index=False) # Print the tensors. html += "

Tensors

\n" @@ -357,8 +370,7 @@ def CreateHtmlFile(tflite_input, html_output): # Operator codes html += "

Operator Codes

\n" - html += GenerateTableHtml(data["operator_codes"], - operator_keys_to_display) + html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display) html += "\n" @@ -370,10 +382,10 @@ def main(argv): tflite_input = argv[1] html_output = argv[2] except IndexError: - print ("Usage: %s " % (argv[0])) + print("Usage: %s " % (argv[0])) else: CreateHtmlFile(tflite_input, html_output) + if __name__ == "__main__": main(sys.argv) - diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 7c523ad49265aaf32c8d5a8ae04d3e93262a1b55..8c3a8afe7a0f6f5ad9ceae566288ba60be73d339 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -30,20 +30,13 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.deprecation import deprecated_args -__all__ = ["absolute_difference", - "add_loss", - "cosine_distance", - "compute_weighted_loss", - "get_losses", - "get_regularization_losses", - "get_total_loss", - "hinge_loss", - "log_loss", - "mean_pairwise_squared_error", - "mean_squared_error", - "sigmoid_cross_entropy", - "softmax_cross_entropy", - "sparse_softmax_cross_entropy"] +__all__ = [ + "absolute_difference", "add_loss", "cosine_distance", + "compute_weighted_loss", "get_losses", "get_regularization_losses", + "get_total_loss", "hinge_loss", "log_loss", "mean_pairwise_squared_error", + "mean_squared_error", "sigmoid_cross_entropy", "softmax_cross_entropy", + "sparse_softmax_cross_entropy" +] def _scale_losses(losses, weights): @@ -66,8 +59,8 @@ def _scale_losses(losses, weights): # First, compute the sum of the losses over all elements: start_index = max(0, weights.get_shape().ndims) reduction_indices = list(range(start_index, losses.get_shape().ndims)) - reduced_losses = math_ops.reduce_sum(losses, - reduction_indices=reduction_indices) + reduced_losses = math_ops.reduce_sum( + losses, reduction_indices=reduction_indices) reduced_losses = math_ops.multiply(reduced_losses, weights) return math_ops.reduce_sum(reduced_losses) @@ -90,9 +83,10 @@ def _safe_div(numerator, denominator, name="value"): """ return array_ops.where( math_ops.greater(denominator, 0), - math_ops.div(numerator, array_ops.where( - math_ops.equal(denominator, 0), - array_ops.ones_like(denominator), denominator)), + math_ops.div(numerator, + array_ops.where( + math_ops.equal(denominator, 0), + array_ops.ones_like(denominator), denominator)), array_ops.zeros_like(numerator), name=name) @@ -176,14 +170,15 @@ def _num_present(losses, weights, per_batch=False): """ # If weights is a scalar, its easy to compute: if weights.get_shape().ndims == 0: - batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses), - [0], [1]), []) - num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)), - math_ops.to_float(batch_size)) - num_per_batch = array_ops.where(math_ops.equal(weights, 0), - 0.0, num_per_batch) - num_per_batch = math_ops.multiply(array_ops.ones( - array_ops.reshape(batch_size, [1])), num_per_batch) + batch_size = array_ops.reshape( + array_ops.slice(array_ops.shape(losses), [0], [1]), []) + num_per_batch = math_ops.div( + math_ops.to_float(array_ops.size(losses)), + math_ops.to_float(batch_size)) + num_per_batch = array_ops.where( + math_ops.equal(weights, 0), 0.0, num_per_batch) + num_per_batch = math_ops.multiply( + array_ops.ones(array_ops.reshape(batch_size, [1])), num_per_batch) return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) # First, count the number of nonzero weights: @@ -194,8 +189,8 @@ def _num_present(losses, weights, per_batch=False): reduction_indices=reduction_indices) # Next, determine the number of elements that weights would broadcast to: - broadcast_dims = array_ops.slice(array_ops.shape(losses), - [weights.get_shape().ndims], [-1]) + broadcast_dims = array_ops.slice( + array_ops.shape(losses), [weights.get_shape().ndims], [-1]) num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims)) num_per_batch = math_ops.multiply(num_nonzero_per_batch, num_to_broadcast) @@ -303,8 +298,11 @@ def absolute_difference(predictions, labels=None, weights=1.0, scope=None): @deprecated("2016-12-30", "Use tf.losses.sigmoid_cross_entropy instead. Note that the order " "of the predictions and labels arguments has been changed.") -def sigmoid_cross_entropy( - logits, multi_class_labels, weights=1.0, label_smoothing=0, scope=None): +def sigmoid_cross_entropy(logits, + multi_class_labels, + weights=1.0, + label_smoothing=0, + scope=None): """Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits. `weights` acts as a coefficient for the loss. If a scalar is provided, @@ -340,20 +338,22 @@ def sigmoid_cross_entropy( multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype) if label_smoothing > 0: - multi_class_labels = (multi_class_labels * (1 - label_smoothing) + - 0.5 * label_smoothing) + multi_class_labels = ( + multi_class_labels * (1 - label_smoothing) + 0.5 * label_smoothing) - losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels, - logits=logits, - name="xentropy") + losses = nn.sigmoid_cross_entropy_with_logits( + labels=multi_class_labels, logits=logits, name="xentropy") return compute_weighted_loss(losses, weights, scope=scope) @deprecated("2016-12-30", "Use tf.losses.softmax_cross_entropy instead. Note that the order " "of the logits and labels arguments has been changed.") -def softmax_cross_entropy( - logits, onehot_labels, weights=1.0, label_smoothing=0, scope=None): +def softmax_cross_entropy(logits, + onehot_labels, + weights=1.0, + label_smoothing=0, + scope=None): """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits. `weights` acts as a coefficient for the loss. If a scalar is provided, @@ -393,9 +393,8 @@ def softmax_cross_entropy( smooth_negatives = label_smoothing / num_classes onehot_labels = onehot_labels * smooth_positives + smooth_negatives - losses = nn.softmax_cross_entropy_with_logits(labels=onehot_labels, - logits=logits, - name="xentropy") + losses = nn.softmax_cross_entropy_with_logits( + labels=onehot_labels, logits=logits, name="xentropy") return compute_weighted_loss(losses, weights, scope=scope) @@ -429,9 +428,8 @@ def sparse_softmax_cross_entropy(logits, labels, weights=1.0, scope=None): [logits, labels, weights]) as scope: labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]]) - losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels, - logits=logits, - name="xentropy") + losses = nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits, name="xentropy") return compute_weighted_loss(losses, weights, scope=scope) @@ -470,8 +468,7 @@ def log_loss(predictions, labels=None, weights=1.0, epsilon=1e-7, scope=None): predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) losses = -math_ops.multiply( - labels, - math_ops.log(predictions + epsilon)) - math_ops.multiply( + labels, math_ops.log(predictions + epsilon)) - math_ops.multiply( (1 - labels), math_ops.log(1 - predictions + epsilon)) return compute_weighted_loss(losses, weights, scope=scope) @@ -490,7 +487,8 @@ def hinge_loss(logits, labels=None, scope=None): scope: The scope for the operations performed in computing the loss. Returns: - An unweighted `Tensor` of same shape as `logits` and `labels` representing the + An unweighted `Tensor` of same shape as `logits` and `labels` representing + the loss values across the batch. Raises: @@ -544,8 +542,10 @@ def mean_squared_error(predictions, labels=None, weights=1.0, scope=None): @deprecated("2016-12-30", "Use tf.losses.mean_pairwise_squared_error instead. Note that the " "order of the predictions and labels arguments has been changed.") -def mean_pairwise_squared_error( - predictions, labels=None, weights=1.0, scope=None): +def mean_pairwise_squared_error(predictions, + labels=None, + weights=1.0, + scope=None): """Adds a pairwise-errors-squared loss to the training procedure. Unlike `mean_squared_error`, which is a measure of the differences between @@ -602,31 +602,34 @@ def mean_pairwise_squared_error( reduction_indices = list(range(1, diffs.get_shape().ndims)) sum_squares_diff_per_batch = math_ops.reduce_sum( - math_ops.square(diffs), - reduction_indices=reduction_indices) + math_ops.square(diffs), reduction_indices=reduction_indices) num_present_per_batch = _num_present(diffs, weights, per_batch=True) - term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, - num_present_per_batch) + term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch) sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices) - term2 = 2.0 * _safe_div(math_ops.square(sum_diff), - math_ops.square(num_present_per_batch)) + term2 = 2.0 * _safe_div( + math_ops.square(sum_diff), math_ops.square(num_present_per_batch)) loss = _scale_losses(term1 - term2, weights) - mean_loss = array_ops.where(math_ops.reduce_sum(num_present_per_batch) > 0, - loss, - array_ops.zeros_like(loss), - name="value") + mean_loss = array_ops.where( + math_ops.reduce_sum(num_present_per_batch) > 0, + loss, + array_ops.zeros_like(loss), + name="value") add_loss(mean_loss) return mean_loss @deprecated("2016-12-30", "Use tf.losses.cosine_distance instead.") @deprecated_args(None, "dim is deprecated, use axis instead", "dim") -def cosine_distance( - predictions, labels=None, axis=None, weights=1.0, scope=None, dim=None): +def cosine_distance(predictions, + labels=None, + axis=None, + weights=1.0, + scope=None, + dim=None): """Adds a cosine-distance loss to the training procedure. Note that the function assumes that `predictions` and `labels` are already @@ -662,5 +665,8 @@ def cosine_distance( labels = math_ops.to_float(labels) radial_diffs = math_ops.multiply(predictions, labels) - losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[axis,]) + losses = 1 - math_ops.reduce_sum( + radial_diffs, reduction_indices=[ + axis, + ]) return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py index 9d0f95e6f3e7fa9666a99e31578b38d52e0b6b4a..1417772e0496cb571488e5b30bd4f3fb1b591730 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -274,6 +275,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase): self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3) +@test_util.with_c_api class SparseSoftmaxCrossEntropyLossTest(test.TestCase): def testNoneWeightRaisesValueError(self): @@ -471,7 +473,11 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): labels = constant_op.constant([[0, 1], [2, 3]]) weights = constant_op.constant([1.2, 3.4, 5.6, 7.8]) - with self.assertRaises(errors_impl.InvalidArgumentError): + if ops._USE_C_API: + error_type = ValueError + else: + error_type = errors_impl.InvalidArgumentError + with self.assertRaises(error_type): loss_ops.sparse_softmax_cross_entropy( logits, labels, weights=weights).eval() diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py index c3a57ba51bcf0a292490dfaa9e556f6e5811ed66..6842bc38eb108b46cc3eff715c9cbc74f991308b 100644 --- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py @@ -53,12 +53,12 @@ def pairwise_distance(feature, squared=False): math_ops.reduce_sum( math_ops.square(feature), axis=[1], - keep_dims=True), + keepdims=True), math_ops.reduce_sum( math_ops.square( array_ops.transpose(feature)), axis=[0], - keep_dims=True)) - 2.0 * math_ops.matmul( + keepdims=True)) - 2.0 * math_ops.matmul( feature, array_ops.transpose(feature)) # Deal with numerical inaccuracies. Set small negatives to zero. @@ -132,10 +132,10 @@ def masked_maximum(data, mask, dim=1): masked_maximums: N-D `Tensor`. The maximized dimension is of size 1 after the operation. """ - axis_minimums = math_ops.reduce_min(data, dim, keep_dims=True) + axis_minimums = math_ops.reduce_min(data, dim, keepdims=True) masked_maximums = math_ops.reduce_max( math_ops.multiply( - data - axis_minimums, mask), dim, keep_dims=True) + axis_minimums + data - axis_minimums, mask), dim, keepdims=True) + axis_minimums return masked_maximums @@ -151,10 +151,10 @@ def masked_minimum(data, mask, dim=1): masked_minimums: N-D `Tensor`. The minimized dimension is of size 1 after the operation. """ - axis_maximums = math_ops.reduce_max(data, dim, keep_dims=True) + axis_maximums = math_ops.reduce_max(data, dim, keepdims=True) masked_minimums = math_ops.reduce_min( math_ops.multiply( - data - axis_maximums, mask), dim, keep_dims=True) + axis_maximums + data - axis_maximums, mask), dim, keepdims=True) + axis_maximums return masked_minimums @@ -203,7 +203,7 @@ def triplet_semihard_loss(labels, embeddings, margin=1.0): math_ops.greater( math_ops.reduce_sum( math_ops.cast( - mask, dtype=dtypes.float32), 1, keep_dims=True), + mask, dtype=dtypes.float32), 1, keepdims=True), 0.0), [batch_size, batch_size]) mask_final = array_ops.transpose(mask_final) @@ -290,7 +290,7 @@ def npairs_loss(labels, embeddings_anchor, embeddings_positive, labels_remapped = math_ops.to_float( math_ops.equal(labels, array_ops.transpose(labels))) - labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True) + labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True) # Add the softmax loss. xent_loss = nn.softmax_cross_entropy_with_logits( @@ -395,7 +395,7 @@ def npairs_loss_multilabel(sparse_labels, embeddings_anchor, multilabel_adjacency_matrix = _build_multilabel_adjacency(sparse_labels) labels_remapped = math_ops.to_float(multilabel_adjacency_matrix) - labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True) + labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True) # Add the softmax loss. xent_loss = nn.softmax_cross_entropy_with_logits( @@ -448,10 +448,10 @@ def lifted_struct_loss(labels, embeddings, margin=1.0): # Safe maximum: Temporarily shift negative distances # above zero before taking max. # this is to take the max only among negatives. - row_minimums = math_ops.reduce_min(diff, 1, keep_dims=True) + row_minimums = math_ops.reduce_min(diff, 1, keepdims=True) row_negative_maximums = math_ops.reduce_max( math_ops.multiply( - diff - row_minimums, mask), 1, keep_dims=True) + row_minimums + diff - row_minimums, mask), 1, keepdims=True) + row_minimums # Compute the loss. # Keep track of matrix of maximums where M_ij = max(m_i, m_j) @@ -470,7 +470,7 @@ def lifted_struct_loss(labels, embeddings, margin=1.0): math_ops.reduce_sum(math_ops.multiply( math_ops.exp( diff_tiled - max_elements_vect), - mask_tiled), 1, keep_dims=True), [batch_size, batch_size]) + mask_tiled), 1, keepdims=True), [batch_size, batch_size]) loss_mat = max_elements + math_ops.log( loss_exp_left + array_ops.transpose(loss_exp_left)) @@ -686,7 +686,7 @@ def _find_loss_augmented_facility_idx(pairwise_distances, labels, chosen_ids, array_ops.reshape(pairwise_distances_candidate, [1, -1]) ], 0), axis=0, - keep_dims=True), [num_candidates, -1]), + keepdims=True), [num_candidates, -1]), axis=1) nmi_scores = array_ops.zeros([num_candidates]) diff --git a/tensorflow/contrib/makefile/BUILD b/tensorflow/contrib/makefile/BUILD index a8dd59f32a7f3b27993a7ee48ee7cc07ada59a4c..701eeb44fe3f814cb3fb1cedd8618753946cc3e5 100644 --- a/tensorflow/contrib/makefile/BUILD +++ b/tensorflow/contrib/makefile/BUILD @@ -12,20 +12,3 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) - -sh_test( - name = "build_all_linux", - size = "enormous", - srcs = ["build_all_linux.sh"], - data = [ - "//tensorflow:all_opensource_files", - "//third_party/eigen3:all_files", - "//third_party/fft2d:all_files", - ], - tags = [ - "manual", - "no_gpu", - "no_oss", - "notap", - ], -) diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index c50f8ceec0a634010a7f04dbb47f267be1d7074b..81327407d44b4317b7aecb964a689a35aa35c163 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -407,7 +407,7 @@ $(MARCH_OPTION) \ -I$(JETPACK)/cuda/extras/CUPTI/include - LIBS += \ + CUDA_LIBS := \ -ltfcuda \ -lcudart_static \ -lcudnn \ @@ -420,10 +420,10 @@ $(MARCH_OPTION) \ -lculibos \ -lcurand_static - OBJDIR := $(OBJDIR)Tegra/ - LIBDIR := $(LIBDIR)Tegra/ - BINDIR := $(BINDIR)Tegra/ - DEPDIR := $(DEPDIR)Tegra/ + OBJDIR := $(OBJDIR)android_arm64-v8a/ + LIBDIR := $(LIBDIR)android_arm64-v8a/ + BINDIR := $(BINDIR)android_arm64-v8a/ + DEPDIR := $(DEPDIR)android_arm64-v8a/ TEGRA_LIBS := \ -L$(JETPACK)/cuda/targets/aarch64-linux-androideabi/lib \ @@ -606,7 +606,8 @@ $(wildcard tensorflow/core/util/*/*.cc) \ tensorflow/core/util/version_info.cc # Remove duplicates (for version_info.cc) CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) -CORE_CC_EXCLUDE_SRCS := \ + +CORE_CC_EXCLUDE_SRCS_NON_GPU := \ $(wildcard tensorflow/core/*/*test.cc) \ $(wildcard tensorflow/core/*/*testutil*) \ $(wildcard tensorflow/core/*/*testlib*) \ @@ -626,49 +627,31 @@ $(wildcard tensorflow/core/lib/jpeg/*) \ $(wildcard tensorflow/core/lib/png/*) \ $(wildcard tensorflow/core/util/events_writer.*) \ $(wildcard tensorflow/core/util/reporter.*) \ -$(wildcard tensorflow/core/platform/default/cuda_libdevice_path.*) \ -$(wildcard tensorflow/core/platform/default/stream_executor.*) \ $(wildcard tensorflow/core/platform/default/test_benchmark.*) \ -$(wildcard tensorflow/core/platform/cuda.h) \ -$(wildcard tensorflow/core/platform/cuda_libdevice_path.*) \ $(wildcard tensorflow/core/platform/cloud/*) \ $(wildcard tensorflow/core/platform/google/*) \ $(wildcard tensorflow/core/platform/google/*/*) \ $(wildcard tensorflow/core/platform/jpeg.*) \ $(wildcard tensorflow/core/platform/png.*) \ $(wildcard tensorflow/core/platform/s3/*) \ -$(wildcard tensorflow/core/platform/stream_executor.*) \ $(wildcard tensorflow/core/platform/windows/*) \ -$(wildcard tensorflow/core/user_ops/*.cu.cc) \ -$(wildcard tensorflow/core/common_runtime/gpu/*) \ -$(wildcard tensorflow/core/common_runtime/gpu_device_factory.*) \ $(wildcard tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.*) \ $(wildcard tensorflow/core/grappler/inputs/file_input_yielder.*) \ -$(wildcard tensorflow/core/grappler/clusters/single_machine.*) +$(wildcard tensorflow/core/grappler/clusters/single_machine.*) \ +tensorflow/core/util/cuda_kernel_helper_test.cu.cc + +CORE_CC_EXCLUDE_SRCS := \ +$(CORE_CC_EXCLUDE_SRCS_NON_GPU) \ +$(wildcard tensorflow/core/platform/stream_executor.*) \ +$(wildcard tensorflow/core/platform/default/cuda_libdevice_path.*) \ +$(wildcard tensorflow/core/platform/cuda.h) \ +$(wildcard tensorflow/core/platform/cuda_libdevice_path.*) \ +$(wildcard tensorflow/core/user_ops/*.cu.cc) \ +$(wildcard tensorflow/core/common_runtime/gpu/*) \ +$(wildcard tensorflow/core/common_runtime/gpu_device_factory.*) ifeq ($(BUILD_FOR_TEGRA),1) -CORE_CC_ALL_SRCS := \ -$(wildcard tensorflow/core/*.cc) \ -$(wildcard tensorflow/core/common_runtime/*.cc) \ -$(wildcard tensorflow/core/common_runtime/gpu/*.cc) \ -$(wildcard tensorflow/core/framework/*.cc) \ -$(wildcard tensorflow/core/graph/*.cc) \ -$(wildcard tensorflow/core/platform/*.cc) \ -$(wildcard tensorflow/core/platform/*/*.cc) \ -$(wildcard tensorflow/core/platform/*/*/*.cc) \ -$(wildcard tensorflow/core/util/*.cc) \ -$(wildcard tensorflow/core/util/*/*.cc) \ -$(wildcard tensorflow/cc/training/*.cc) \ -$(wildcard tensorflow/stream_executor/*.cc) \ -$(wildcard tensorflow/stream_executor/*/*.cc) \ -$(wildcard tensorflow/core/grappler/optimizers/*.cc) \ -$(wildcard tensorflow/core/grappler/*.cc) \ -$(wildcard tensorflow/core/grappler/costs/*.cc) \ -$(wildcard tensorflow/core/grappler/clusters/*.cc) \ -$(wildcard tensorflow/core/grappler/utils/*.cc) \ -$(wildcard tensorflow/core/lib/core/*.cc) \ -$(wildcard tensorflow/core/lib/*/*.cc) \ -tensorflow/core/grappler/inputs/utils.cc \ +CORE_CC_ALL_SRCS := $(CORE_CC_ALL_SRCS) \ tensorflow/core/kernels/concat_lib_gpu.cc \ tensorflow/core/kernels/cuda_solvers.cc \ tensorflow/core/kernels/cudnn_pooling_gpu.cc \ @@ -677,28 +660,14 @@ tensorflow/core/kernels/fractional_avg_pool_op.cc \ tensorflow/core/kernels/fractional_max_pool_op.cc \ tensorflow/core/kernels/fractional_pool_common.cc \ tensorflow/core/kernels/pooling_ops_3d.cc \ -tensorflow/core/kernels/sparse_fill_empty_rows_op.cc +tensorflow/core/kernels/sparse_fill_empty_rows_op.cc \ +tensorflow/core/kernels/list_kernels.cc \ +$(wildcard tensorflow/core/common_runtime/gpu/*.cc) \ +$(wildcard tensorflow/stream_executor/*.cc) \ +$(wildcard tensorflow/stream_executor/*/*.cc) CORE_CC_EXCLUDE_SRCS := \ -$(wildcard tensorflow/core/*/*test.cc) \ -$(wildcard tensorflow/core/*/*testutil*) \ -$(wildcard tensorflow/core/*/*testlib*) \ -$(wildcard tensorflow/core/*/*/*test.cc) \ -$(wildcard tensorflow/core/*/*/*testutil*) \ -$(wildcard tensorflow/core/framework/op_gen_lib.cc) \ -$(wildcard tensorflow/core/lib/gif/*) \ -$(wildcard tensorflow/core/lib/jpeg/*) \ -$(wildcard tensorflow/core/lib/png/*) \ -$(wildcard tensorflow/core/lib/db/*) \ -$(wildcard tensorflow/core/platform/jpeg.*) \ -$(wildcard tensorflow/core/platform/png.*) \ -$(wildcard tensorflow/core/platform/cloud/*) \ -$(wildcard tensorflow/core/platform/s3/*) \ -$(wildcard tensorflow/core/platform/windows/*) \ -$(wildcard tensorflow/core/*/*/*testlib*) \ -$(wildcard tensorflow/cc/training/*test.cc) \ -tensorflow/core/lib/io/record_reader.cc \ -tensorflow/core/util/cuda_kernel_helper_test.cu.cc +$(CORE_CC_EXCLUDE_SRCS_NON_GPU) CUDA_CC_SRCS := $(wildcard tensorflow/core/kernels/*.cu.cc) CUDA_CC_OBJS := $(addprefix $(OBJDIR), $(CUDA_CC_SRCS:.cc=.o)) @@ -760,7 +729,7 @@ $(BENCHMARK_NAME): $(BENCHMARK_OBJS) $(LIB_PATH) $(CUDA_LIB_DEPS) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ -o $(BENCHMARK_NAME) $(BENCHMARK_OBJS) \ - $(LIBFLAGS) $(TEGRA_LIBS) $(LIB_PATH) $(LDFLAGS) $(LIBS) + $(LIBFLAGS) $(TEGRA_LIBS) $(LIB_PATH) $(LDFLAGS) $(LIBS) $(CUDA_LIBS) # NVCC compilation rules for Tegra ifeq ($(BUILD_FOR_TEGRA),1) diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md index 0613de2cabe2065f1e4a816f2295d41b69159c10..995230dfa848532dc2a50b85f58d19ba264f293e 100644 --- a/tensorflow/contrib/makefile/README.md +++ b/tensorflow/contrib/makefile/README.md @@ -130,6 +130,105 @@ adb shell '/data/local/tmp/benchmark \ For more details, see the [benchmark documentation](../../tools/benchmark). +## CUDA support for Tegra devices running Android (Nvidia Shield TV, etc) + +With the release of TF 1.6 and JetPack for Android 3.2 (currently pending), you can now build a version of TensorFlow for compatible devices according to the following instructions which will receive the full benefits of GPU acceleration. + +#### Environment setup: + +First, download and install JetPack for Android version 3.2 or greater from [Nvidia](https://developers.nvidia.com). Note that as of the TF 1.6 release the JetPack for Android 3.2 release is still pending, and regular JetPack for L4T will not work. + +```bash +git clone https://github.com/tensorflow/tensorflow.git +cd tensorflow +JETPACK=$HOME/JetPack_Android_3.2 +TEGRA_LIBS="$JETPACK/cuDNN/aarch64/cuda/lib64/libcudnn.so $JETPACK/cuda-9.0/extras/CUPTI/lib64/libcupti.so $JETPACK/cuda/targets/aarch64-linux-androideabi/lib64/libcufft.so" +``` + +#### Building all CUDA-enabled native binaries: +This will build CUDA-enabled versions of libtensorflow_inference.so and the benchmark binary. (libtensorflow_demo.so will also be built incidentally, but it does not support CUDA) + +```bash +NDK_ROOT=$JETPACK/android-ndk-r13b +CC_PREFIX=ccache tensorflow/contrib/makefile/build_all_android.sh -s tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in -t "libtensorflow_inference.so libtensorflow_demo.so all" -a tegra +``` +(add -T on subsequent builds to skip protobuf downloading/building) + + +#### Testing the CUDA-enabled benchmark via adb: +Build binaries first as above, then run: + +```bash +adb shell mkdir -p /data/local/tmp/lib64 +adb push $TEGRA_LIBS /data/local/tmp/lib64 +adb push tensorflow/contrib/makefile/gen/bin/android_arm64-v8a/benchmark /data/local/tmp +wget https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk +unzip tensorflow_demo.apk -d /tmp/tensorflow_demo +adb push /tmp/tensorflow_demo/assets/*.pb /data/local/tmp +adb shell "LD_LIBRARY_PATH=/data/local/tmp/lib64 /data/local/tmp/benchmark --graph=/data/local/tmp/tensorflow_inception_graph.pb" +``` + +#### Building the CUDA-enabled TensorFlow AAR with Bazel: +Build the native binaries first as above. Then, build the aar and package the native libs by executing the following: +```bash +mkdir -p /tmp/tf/jni/arm64-v8a +cp tensorflow/contrib/makefile/gen/lib/android_tegra/libtensorflow_*.so /tmp/tf/jni/arm64-v8a/ +cp $TEGRA_LIBS /tmp/tf/jni/arm64-v8a +bazel build //tensorflow/contrib/android:android_tensorflow_inference_java.aar +cp bazel-bin/tensorflow/contrib/android/android_tensorflow_inference_java.aar /tmp/tf/tensorflow.aar +cd /tmp/tf +chmod +w tensorflow.aar +zip -ur tensorflow.aar $(find jni -name *.so) +``` + +#### Building the CUDA-enabled TensorFlow Android demo with Bazel: +Build binaries first as above, then edit tensorflow/examples/android/BUILD and replace: +``` + srcs = [ + ":libtensorflow_demo.so", + "//tensorflow/contrib/android:libtensorflow_inference.so", + ], +``` +with: +``` +srcs = glob(["libs/arm64-v8a/*.so"]), +``` + +Then run: +```bash +# Create dir for native libs +mkdir -p tensorflow/examples/android/libs/arm64-v8a + +# Copy JetPack libs +cp $TEGRA_LIBS tensorflow/examples/android/libs/arm64-v8a + +# Copy native TensorFlow libraries +cp tensorflow/contrib/makefile/gen/lib/android_arm64-v8a/libtensorflow_*.so tensorflow/examples/android/libs/arm64-v8a/ + +# Build APK +bazel build -c opt --fat_apk_cpu=arm64-v8a tensorflow/android:tensorflow_demo + +# Install +adb install -r -f bazel-bin/tensorflow/examples/android/tensorflow_demo.apk +``` + +#### Building the CUDA-enabled Android demo with gradle/Android Studio: + +Add tensorflow/examples/android as an Android project in Android Studio as normal. + +Edit build.gradle and: +* set nativeBuildSystem = 'makefile' +* set cpuType = 'arm64-v8a' +* in "buildNativeMake", replace cpuType with 'tegra' (optional speedups like -T and ccache also work) +* set the environment "NDK_ROOT" var to $JETPACK/android-ndk-r13b + +Click "build apk" to build. + +Install: +```bash +adb install -r -f tensorflow/examples/android/gradleBuild/outputs/apk/debug/android-debug.apk +``` + ## iOS _Note: To use this library in an iOS application, see related instructions in @@ -268,7 +367,7 @@ selectively register only for the operators used in your graph. ```bash tensorflow/contrib/makefile/build_all_ios.sh -a arm64 -g $HOME/graphs/inception/tensorflow_inception_graph.pb ``` -Please note this is an aggresive optimization of the operators and the resulting library may not work with other graphs but will reduce the size of the final library. +Please note this is an aggressive optimization of the operators and the resulting library may not work with other graphs but will reduce the size of the final library. The `compile_ios_tensorflow.sh` script can take optional command-line arguments. The first argument will be passed as a C++ optimization flag and defaults to diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh index 980a44a5952a098da8a00e666d37a6d1642f4095..fc88f59e0948e1d3ed7cce9b809bf30ba280af12 100755 --- a/tensorflow/contrib/makefile/build_all_android.sh +++ b/tensorflow/contrib/makefile/build_all_android.sh @@ -18,7 +18,7 @@ set -e usage() { - echo "Usage: NDK_ROOT= $(basename "$0") [-Es:t:Tx:a:X]" + echo "Usage: NDK_ROOT= $(basename "$0") [-Es:t:Tx:a]" echo "-E enable experimental hexnn ops" echo "-s [sub_makefiles] sub makefiles separated by white space" echo "-t [build_target] build target for Android makefile [default=all]" @@ -52,7 +52,7 @@ shift $((OPTIND - 1)) if [ "$ARCH" == "tegra" ]; then if [[ -z "${JETPACK}" ]]; then - export JETPACK="$HOME/JetPack_Android_3.0" + export JETPACK="$HOME/JetPack_Android_3.2" fi if [ ! -d ${JETPACK} ]; then echo "Can't find Jetpack at ${JETPACK}" diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh index a18df256f976c3c0ac4cefe1c884d951e63ef823..2d9979183975e6a17527b40ef5ee1795ced44a7b 100755 --- a/tensorflow/contrib/makefile/build_all_ios.sh +++ b/tensorflow/contrib/makefile/build_all_ios.sh @@ -96,7 +96,7 @@ if [[ "${ONLY_MAKE_TENSORFLOW}" != "true" ]]; then if [[ -z "${BUILD_ARCH}" ]]; then # Compile protobuf for the target iOS device architectures. - tensorflow/contrib/makefile/compile_ios_protobuf.sh -a ${DEFAULT_ARCH} + tensorflow/contrib/makefile/compile_ios_protobuf.sh else # Compile protobuf for the target iOS device architectures. tensorflow/contrib/makefile/compile_ios_protobuf.sh -a ${BUILD_ARCH} diff --git a/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh b/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh index 861bb885c7031b996b48dbc50887cfce55c638f3..421ddd210fd5b1ac6487918d5797eab5953316df 100755 --- a/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh +++ b/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh @@ -36,7 +36,7 @@ while getopts "bc:Eps" opt_name; do b) BUILD_ONLY="true";; c) TEST_COUNT="${OPTARG}";; E) ENABLE_EXPERIMENTAL_HEXNN_OPS="true";; - p) USE_PREBUILT_HEXAOGON_BINARIES="true";; + p) USE_PREBUILT_HEXAGON_BINARIES="true";; s) SKIP_DOWNLOAD_IF_EXIST="true";; *) usage;; esac @@ -49,7 +49,7 @@ if [[ -z "${NDK_ROOT}" ]]; then exit 1 fi -if [[ "${USE_PREBUILT_HEXAOGON_BINARIES}" != "true" && +if [[ "${USE_PREBUILT_HEXAGON_BINARIES}" != "true" && -z "${QUALCOMM_SDK}" ]]; then echo "QUALCOMM_SDK is empty" 1>&2 usage @@ -76,13 +76,15 @@ GEN_LIBS_DIR="${GEN_DIR}/libs" GEN_DOWNLOAD_DIR="${GEN_DIR}/downloads" URL_BASE="https://storage.googleapis.com/download.tensorflow.org" +ARCH="armeabi-v7a" + source "${SCRIPT_DIR}/../build_helper.subr" rm -rf "${GEN_DIR}" mkdir -p "${GEN_LIBS_DIR}" mkdir -p "${GEN_DOWNLOAD_DIR}" -if [[ "${USE_PREBUILT_HEXAOGON_BINARIES}" == "true" ]]; then +if [[ "${USE_PREBUILT_HEXAGON_BINARIES}" == "true" ]]; then echo "Download prebuilt hexagon binaries" if [[ "${BUILD_ONLY}" != "true" ]]; then CONTROLLER_PUSH_DEST="/data/local/tmp" @@ -219,7 +221,7 @@ if [[ "${BUILD_ONLY}" != "true" ]]; then adb push "${GEN_LIBS_DIR}/libhexagon_nn_skel.so" "/vendor/lib/rfsa/adsp" adb push -p \ - "${TF_ROOT_DIR}/tensorflow/contrib/makefile/gen/bin/hexagon_graph_execution" \ + "${TF_ROOT_DIR}/tensorflow/contrib/makefile/gen/bin/android_${ARCH}/hexagon_graph_execution" \ "/data/local/tmp/" adb wait-for-device adb shell chmod "${ANDROID_EXEC_FILE_MODE}" \ diff --git a/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in b/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in index d9277ed60cb456208572ca1ad8df530648faef82..3081084ee76e41de801f49a67c1fec07f4ff03b9 100644 --- a/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in +++ b/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in @@ -54,7 +54,7 @@ $(INFERENCE_SO_PATH): $(LIB_OBJS) $(INFERENCE_OBJS) $(CUDA_LIB_DEPS) -o $@ $(INFERENCE_OBJS) $(LIB_OBJS) $(TEGRA_LIBS) \ $(LIBFLAGS) $(LDFLAGS) \ -shared -Wl,-soname,$(INFERENCE_SO_NAME) \ - $(LIBS) + $(LIBS) $(CUDA_LIBS) $(INFERENCE_SO_NAME): $(INFERENCE_SO_PATH) diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 5f275663986f9d480659880ab601eeb5c41037be..5a812af4e95fe7a05b9c2634b0cc1d860fb7f619 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -91,6 +91,7 @@ tensorflow/core/kernels/reduction_ops_max.cc tensorflow/core/kernels/reduction_ops_common.cc tensorflow/core/kernels/reduction_ops_any.cc tensorflow/core/kernels/reduction_ops_all.cc +tensorflow/core/kernels/roll_op.cc tensorflow/core/kernels/queue_ops.cc tensorflow/core/kernels/queue_base.cc tensorflow/core/kernels/pooling_ops_common.cc @@ -270,6 +271,7 @@ tensorflow/core/ops/parsing_ops.cc tensorflow/core/ops/no_op.cc tensorflow/core/ops/nn_ops.cc tensorflow/core/ops/nn_grad.cc +tensorflow/core/ops/manip_ops.cc tensorflow/core/ops/math_ops.cc tensorflow/core/ops/math_grad.cc tensorflow/core/ops/logging_ops.cc @@ -291,3 +293,4 @@ tensorflow/core/kernels/batchtospace_op.cc tensorflow/core/kernels/warn_about_ints.cc tensorflow/core/kernels/segment_reduction_ops.cc tensorflow/core/kernels/batch_util.cc +tensorflow/core/ops/audio_ops.cc diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc index 39c0d5af45b4a81fa4dde0b5deac14a3af372cbb..974fb537499c5ea4591a0a128f53d2dea67b9e57 100644 --- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc +++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc @@ -80,9 +80,9 @@ REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_GPU).HostMemory("out"), BytesLimitOp); #ifdef TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("BytesLimit").Device(DEVICE_SYCL).HostMemory("out"), - BytesLimitOp); -#endif // TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("BytesLimit").Device(DEVICE_SYCL).HostMemory("out"), BytesLimitOp); +#endif // TENSORFLOW_USE_SYCL // Op that measures the peak memory in bytes. class MaxBytesInUseOp : public MemoryStatsOp { @@ -107,6 +107,6 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER( Name("MaxBytesInUse").Device(DEVICE_SYCL).HostMemory("out"), MaxBytesInUseOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py index 2932ae1c8df32cd936cff932b061571c513fda79..ff88b4fa841673fc52b9f6fdc5ca43d30c44bbfd 100644 --- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py +++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py @@ -171,7 +171,14 @@ def _clean_save_and_restore(graph_def, op, removed_op_names): shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes) op.attr['dtypes'].list.type[:] = dtypes + if not name_op.attr['_output_shapes'].list.shape: + name_op.attr['_output_shapes'].list.shape.add() + name_op.attr['_output_shapes'].list.shape[0].dim.add() name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names) + + if not shape_op.attr['_output_shapes'].list.shape: + shape_op.attr['_output_shapes'].list.shape.add() + shape_op.attr['_output_shapes'].list.shape[0].dim.add() shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes) diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 9de664c822bf7a9abf7b8082f444c61dfa45f499..e90c525113348532a3ebdadde7e712bf2d98cee9 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -43,6 +43,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:weights_broadcast_ops", + "//tensorflow/python/ops/distributions", ], ) diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index d3dce46bfb6e9c77cc7ae107b323a9bc7074c47e..de02dc8f457364450929776035829d86035d706b 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -16,6 +16,7 @@ See the @{$python/contrib.metrics} guide. +@@auc_with_confidence_intervals @@streaming_accuracy @@streaming_mean @@streaming_recall @@ -83,6 +84,7 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics +from tensorflow.contrib.metrics.python.ops.metric_ops import auc_with_confidence_intervals from tensorflow.contrib.metrics.python.ops.metric_ops import cohen_kappa from tensorflow.contrib.metrics.python.ops.metric_ops import count from tensorflow.contrib.metrics.python.ops.metric_ops import precision_recall_at_equal_thresholds diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index c3de1c4c62f04c7ef3d85f36662805c0c0ec4b4c..31e274c5fd7c670458b1b40a4f58c668a23776c7 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.ops.distributions.normal import Normal from tensorflow.python.util.deprecation import deprecated # Epsilon constant used to represent extremely small quantity. @@ -339,9 +340,9 @@ def streaming_mean_tensor(values, name=name) -@deprecated( - None, 'Please switch to tf.metrics.accuracy. Note that the order of the ' - 'labels and predictions arguments has been switched.') +@deprecated(None, + 'Please switch to tf.metrics.accuracy. Note that the order of the ' + 'labels and predictions arguments has been switched.') def streaming_accuracy(predictions, labels, weights=None, @@ -739,7 +740,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, else: for include in includes: if include not in all_includes: - raise ValueError('Invaild key: %s.' % include) + raise ValueError('Invalid key: %s.' % include) predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) @@ -936,8 +937,9 @@ def streaming_curve_points(labels=None, if curve != 'ROC' and curve != 'PR': raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) kepsilon = _EPSILON # to account for floating point imprecisions - thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) - for i in range(num_thresholds - 2)] + thresholds = [ + (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) + ] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] values, update_ops = _streaming_confusion_matrix_at_thresholds( @@ -973,9 +975,8 @@ def streaming_curve_points(labels=None, return points, update_op -@deprecated( - None, 'Please switch to tf.metrics.auc. Note that the order of the ' - 'labels and predictions arguments has been switched.') +@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of the ' + 'labels and predictions arguments has been switched.') def streaming_auc(predictions, labels, weights=None, @@ -1105,8 +1106,7 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): # For conformance, set precision to 1 when the number of positive # classifications is 0. y_axis_values = array_ops.where( - math_ops.greater(splits, 0), - math_ops.truediv(true_positives, splits), + math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits), array_ops.ones_like(true_positives, dtype=dtypes.float64)) # Calculate trapezoid areas. @@ -1119,9 +1119,8 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): # exception seems excessive) so we return 0, otherwise we finish computing. return control_flow_ops.cond( math_ops.logical_or( - math_ops.equal(total_positive, 0), - math_ops.equal(total_positive, size) - ), + math_ops.equal(total_positive, 0), math_ops.equal( + total_positive, size)), true_fn=lambda: array_ops.constant(0, dtypes.float64), false_fn=continue_computing_dynamic_auc) @@ -1185,10 +1184,10 @@ def streaming_dynamic_auc(labels, array_ops.ones_like(labels, dtypes.int64), message='labels must be 0 or 1, at least one is >1') ]): - preds_accum, update_preds = streaming_concat(predictions, - name='concat_preds') - labels_accum, update_labels = streaming_concat(labels, - name='concat_labels') + preds_accum, update_preds = streaming_concat( + predictions, name='concat_preds') + labels_accum, update_labels = streaming_concat( + labels, name='concat_labels') update_op = control_flow_ops.group(update_labels, update_preds) auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve) if updates_collections: @@ -1198,6 +1197,295 @@ def streaming_dynamic_auc(labels, return auc, update_op +def _compute_placement_auc(labels, predictions, weights, alpha, + logit_transformation, is_valid): + """Computes the AUC and asymptotic normally distributed confidence interval. + + The calculations are achieved using the fact that AUC = P(Y_1>Y_0) and the + concept of placement values for each labeled group, as presented by Delong and + Delong (1988). The actual algorithm used is a more computationally efficient + approach presented by Sun and Xu (2014). This could be slow for large batches, + but has the advantage of not having its results degrade depending on the + distribution of predictions. + + Args: + labels: A `Tensor` of ground truth labels with the same shape as + `predictions` with values of 0 or 1 and type `int64`. + predictions: A 1-D `Tensor` of predictions whose values are `float64`. + weights: `Tensor` whose rank is either 0, or the same rank as `labels`. + alpha: Confidence interval level desired. + logit_transformation: A boolean value indicating whether the estimate should + be logit transformed prior to calculating the confidence interval. Doing + so enforces the restriction that the AUC should never be outside the + interval [0,1]. + is_valid: A bool tensor describing whether the input is valid. + + Returns: + A 1-D `Tensor` containing the area-under-curve, lower, and upper confidence + interval values. + """ + # Disable the invalid-name checker so that we can capitalize the name. + # pylint: disable=invalid-name + AucData = collections_lib.namedtuple('AucData', ['auc', 'lower', 'upper']) + # pylint: enable=invalid-name + + # If all the labels are the same or if number of observations are too few, + # AUC isn't well-defined + size = array_ops.size(predictions, out_type=dtypes.int32) + + # Count the total number of positive and negative labels in the input. + total_0 = math_ops.reduce_sum( + math_ops.cast(1 - labels, weights.dtype) * weights) + total_1 = math_ops.reduce_sum( + math_ops.cast(labels, weights.dtype) * weights) + + # Sort the predictions ascending, as well as + # (i) the corresponding labels and + # (ii) the corresponding weights. + ordered_predictions, indices = nn.top_k(predictions, k=size, sorted=True) + ordered_predictions = array_ops.reverse( + ordered_predictions, axis=array_ops.zeros(1, dtypes.int32)) + indices = array_ops.reverse(indices, axis=array_ops.zeros(1, dtypes.int32)) + ordered_labels = array_ops.gather(labels, indices) + ordered_weights = array_ops.gather(weights, indices) + + # We now compute values required for computing placement values. + + # We generate a list of indices (segmented_indices) of increasing order. An + # index is assigned for each unique prediction float value. Prediction + # values that are the same share the same index. + _, segmented_indices = array_ops.unique(ordered_predictions) + + # We create 2 tensors of weights. weights_for_true is non-zero for true + # labels. weights_for_false is non-zero for false labels. + float_labels_for_true = math_ops.cast(ordered_labels, dtypes.float32) + float_labels_for_false = 1.0 - float_labels_for_true + weights_for_true = ordered_weights * float_labels_for_true + weights_for_false = ordered_weights * float_labels_for_false + + # For each set of weights with the same segmented indices, we add up the + # weight values. Note that for each label, we deliberately rely on weights + # for the opposite label. + weight_totals_for_true = math_ops.segment_sum(weights_for_false, + segmented_indices) + weight_totals_for_false = math_ops.segment_sum(weights_for_true, + segmented_indices) + + # These cumulative sums of weights importantly exclude the current weight + # sums. + cum_weight_totals_for_true = math_ops.cumsum(weight_totals_for_true, + exclusive=True) + cum_weight_totals_for_false = math_ops.cumsum(weight_totals_for_false, + exclusive=True) + + # Compute placement values using the formula. Values with the same segmented + # indices and labels share the same placement values. + placements_for_true = ( + (cum_weight_totals_for_true + weight_totals_for_true / 2.0) / + (math_ops.reduce_sum(weight_totals_for_true) + _EPSILON)) + placements_for_false = ( + (cum_weight_totals_for_false + weight_totals_for_false / 2.0) / + (math_ops.reduce_sum(weight_totals_for_false) + _EPSILON)) + + # We expand the tensors of placement values (for each label) so that their + # shapes match that of predictions. + placements_for_true = array_ops.gather(placements_for_true, segmented_indices) + placements_for_false = array_ops.gather(placements_for_false, + segmented_indices) + + # Select placement values based on the label for each index. + placement_values = ( + placements_for_true * float_labels_for_true + + placements_for_false * float_labels_for_false) + + # Split placement values by labeled groups. + placement_values_0 = placement_values * math_ops.cast( + 1 - ordered_labels, weights.dtype) + weights_0 = ordered_weights * math_ops.cast( + 1 - ordered_labels, weights.dtype) + placement_values_1 = placement_values * math_ops.cast( + ordered_labels, weights.dtype) + weights_1 = ordered_weights * math_ops.cast( + ordered_labels, weights.dtype) + + # Calculate AUC using placement values + auc_0 = (math_ops.reduce_sum(weights_0 * (1. - placement_values_0)) / + (total_0 + _EPSILON)) + auc_1 = (math_ops.reduce_sum(weights_1 * (placement_values_1)) / + (total_1 + _EPSILON)) + auc = array_ops.where(math_ops.less(total_0, total_1), auc_1, auc_0) + + # Calculate variance and standard error using the placement values. + var_0 = ( + math_ops.reduce_sum( + weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) / + (total_0 - 1. + _EPSILON)) + var_1 = ( + math_ops.reduce_sum( + weights_1 * math_ops.square(placement_values_1 - auc_1)) / + (total_1 - 1. + _EPSILON)) + auc_std_err = math_ops.sqrt( + (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON))) + + # Calculate asymptotic normal confidence intervals + std_norm_dist = Normal(loc=0., scale=1.) + z_value = std_norm_dist.quantile((1.0 - alpha) / 2.0) + if logit_transformation: + estimate = math_ops.log(auc / (1. - auc + _EPSILON)) + std_err = auc_std_err / (auc * (1. - auc + _EPSILON)) + transformed_auc_lower = estimate + (z_value * std_err) + transformed_auc_upper = estimate - (z_value * std_err) + def inverse_logit_transformation(x): + exp_negative = math_ops.exp(math_ops.negative(x)) + return 1. / (1. + exp_negative + _EPSILON) + + auc_lower = inverse_logit_transformation(transformed_auc_lower) + auc_upper = inverse_logit_transformation(transformed_auc_upper) + else: + estimate = auc + std_err = auc_std_err + auc_lower = estimate + (z_value * std_err) + auc_upper = estimate - (z_value * std_err) + + ## If estimate is 1 or 0, no variance is present so CI = 1 + ## n.b. This can be misleading, since number obs can just be too low. + lower = array_ops.where( + math_ops.logical_or( + math_ops.equal(auc, array_ops.ones_like(auc)), + math_ops.equal(auc, array_ops.zeros_like(auc))), + auc, auc_lower) + upper = array_ops.where( + math_ops.logical_or( + math_ops.equal(auc, array_ops.ones_like(auc)), + math_ops.equal(auc, array_ops.zeros_like(auc))), + auc, auc_upper) + + # If all the labels are the same, AUC isn't well-defined (but raising an + # exception seems excessive) so we return 0, otherwise we finish computing. + trivial_value = array_ops.constant(0.0) + + return AucData(*control_flow_ops.cond( + is_valid, lambda: [auc, lower, upper], lambda: [trivial_value]*3)) + + +def auc_with_confidence_intervals(labels, + predictions, + weights=None, + alpha=0.95, + logit_transformation=True, + metrics_collections=(), + updates_collections=(), + name=None): + """Computes the AUC and asymptotic normally distributed confidence interval. + + USAGE NOTE: this approach requires storing all of the predictions and labels + for a single evaluation in memory, so it may not be usable when the evaluation + batch size and/or the number of evaluation steps is very large. + + Computes the area under the ROC curve and its confidence interval using + placement values. This has the advantage of being resilient to the + distribution of predictions by aggregating across batches, accumulating labels + and predictions and performing the final calculation using all of the + concatenated values. + + Args: + labels: A `Tensor` of ground truth labels with the same shape as `labels` + and with values of 0 or 1 whose values are castable to `int64`. + predictions: A `Tensor` of predictions whose values are castable to + `float64`. Will be flattened into a 1-D `Tensor`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`. + alpha: Confidence interval level desired. + logit_transformation: A boolean value indicating whether the estimate should + be logit transformed prior to calculating the confidence interval. Doing + so enforces the restriction that the AUC should never be outside the + interval [0,1]. + metrics_collections: An optional iterable of collections that `auc` should + be added to. + updates_collections: An optional iterable of collections that `update_op` + should be added to. + name: An optional name for the variable_scope that contains the metric + variables. + + Returns: + auc: A 1-D `Tensor` containing the current area-under-curve, lower, and + upper confidence interval values. + update_op: An operation that concatenates the input labels and predictions + to the accumulated values. + + Raises: + ValueError: If `labels`, `predictions`, and `weights` have mismatched shapes + or if `alpha` isn't in the range (0,1). + """ + if not (alpha > 0 and alpha < 1): + raise ValueError('alpha must be between 0 and 1; currently %.02f' % alpha) + + if weights is None: + weights = array_ops.ones_like(predictions) + + with variable_scope.variable_scope( + name, + default_name='auc_with_confidence_intervals', + values=[labels, predictions, weights]): + + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions, + labels=labels, + weights=weights) + + total_weight = math_ops.reduce_sum(weights) + + weights = array_ops.reshape(weights, [-1]) + predictions = array_ops.reshape( + math_ops.cast(predictions, dtypes.float64), [-1]) + labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1]) + + with ops.control_dependencies([ + check_ops.assert_greater_equal( + labels, + array_ops.zeros_like(labels, dtypes.int64), + message='labels must be 0 or 1, at least one is <0'), + check_ops.assert_less_equal( + labels, + array_ops.ones_like(labels, dtypes.int64), + message='labels must be 0 or 1, at least one is >1'), + ]): + preds_accum, update_preds = streaming_concat( + predictions, name='concat_preds') + labels_accum, update_labels = streaming_concat(labels, + name='concat_labels') + weights_accum, update_weights = streaming_concat( + weights, name='concat_weights') + update_op_for_valid_case = control_flow_ops.group( + update_labels, update_preds, update_weights) + + # Only perform updates if this case is valid. + all_labels_positive_or_0 = math_ops.logical_and( + math_ops.equal(math_ops.reduce_min(labels), 0), + math_ops.equal(math_ops.reduce_max(labels), 1)) + sums_of_weights_at_least_1 = math_ops.greater_equal(total_weight, 1.0) + is_valid = math_ops.logical_and(all_labels_positive_or_0, + sums_of_weights_at_least_1) + + update_op = control_flow_ops.cond( + sums_of_weights_at_least_1, + lambda: update_op_for_valid_case, control_flow_ops.no_op) + + auc = _compute_placement_auc( + labels_accum, + preds_accum, + weights_accum, + alpha=alpha, + logit_transformation=logit_transformation, + is_valid=is_valid) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + if metrics_collections: + ops.add_to_collections(metrics_collections, auc) + return auc, update_op + + def precision_recall_at_equal_thresholds(labels, predictions, weights=None, @@ -1228,7 +1516,7 @@ def precision_recall_at_equal_thresholds(labels, predictions: A floating point `Tensor` of arbitrary shape and whose values are in the range `[0, 1]`. weights: Optional; If provided, a `Tensor` that has the same dtype as, - and broadcastable to, `predictions`. This tensor is multplied by counts. + and broadcastable to, `predictions`. This tensor is multiplied by counts. num_thresholds: Optional; Number of thresholds, evenly distributed in `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins is 1 less than `num_thresholds`. Using an even `num_thresholds` value @@ -1571,9 +1859,9 @@ def streaming_precision_at_thresholds(predictions, name=name) -@deprecated( - None, 'Please switch to tf.metrics.recall_at_thresholds. Note that the ' - 'order of the labels and predictions arguments has been switched.') +@deprecated(None, + 'Please switch to tf.metrics.recall_at_thresholds. Note that the ' + 'order of the labels and predictions arguments has been switched.') def streaming_recall_at_thresholds(predictions, labels, thresholds, @@ -3299,8 +3587,13 @@ def count(values, return count_, update_op -def cohen_kappa(labels, predictions_idx, num_classes, weights=None, - metrics_collections=None, updates_collections=None, name=None): +def cohen_kappa(labels, + predictions_idx, + num_classes, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Calculates Cohen's kappa. [Cohen's kappa](https://en.wikipedia.org/wiki/Cohen's_kappa) is a statistic @@ -3367,14 +3660,15 @@ def cohen_kappa(labels, predictions_idx, num_classes, weights=None, labels = array_ops.squeeze(labels, axis=[-1]) predictions_idx, labels, weights = ( metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access - predictions=predictions_idx, labels=labels, weights=weights)) + predictions=predictions_idx, + labels=labels, + weights=weights)) predictions_idx.get_shape().assert_is_compatible_with(labels.get_shape()) - stat_dtype = (dtypes.int64 - if weights is None or weights.dtype.is_integer - else dtypes.float32) - po = metrics_impl.metric_variable( - (num_classes,), stat_dtype, name='po') + stat_dtype = ( + dtypes.int64 + if weights is None or weights.dtype.is_integer else dtypes.float32) + po = metrics_impl.metric_variable((num_classes,), stat_dtype, name='po') pe_row = metrics_impl.metric_variable( (num_classes,), stat_dtype, name='pe_row') pe_col = metrics_impl.metric_variable( @@ -3382,9 +3676,12 @@ def cohen_kappa(labels, predictions_idx, num_classes, weights=None, # Table of the counts of agreement: counts_in_table = confusion_matrix.confusion_matrix( - labels, predictions_idx, - num_classes=num_classes, weights=weights, - dtype=stat_dtype, name="counts_in_table") + labels, + predictions_idx, + num_classes=num_classes, + weights=weights, + dtype=stat_dtype, + name='counts_in_table') po_t = array_ops.diag_part(counts_in_table) pe_row_t = math_ops.reduce_sum(counts_in_table, axis=0) @@ -3404,12 +3701,14 @@ def cohen_kappa(labels, predictions_idx, num_classes, weights=None, math_ops.to_double(total)) # kappa = (po - pe) / (N - pe) k = metrics_impl._safe_scalar_div( # pylint: disable=protected-access - po_sum - pe_sum, total - pe_sum, name=name) + po_sum - pe_sum, + total - pe_sum, + name=name) return k kappa = _calculate_k(po, pe_row, pe_col, name='value') - update_op = _calculate_k(update_po, update_pe_row, update_pe_col, - name='update_op') + update_op = _calculate_k( + update_po, update_pe_row, update_pe_col, name='update_op') if metrics_collections: ops.add_to_collections(metrics_collections, kappa) @@ -3421,6 +3720,7 @@ def cohen_kappa(labels, predictions_idx, num_classes, weights=None, __all__ = [ + 'auc_with_confidence_intervals', 'aggregate_metric_map', 'aggregate_metrics', 'cohen_kappa', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 89aa29f711e3b0114a5d776b258f77214cb349bc..b387f26c0195432fb972dac450d2919bdaa702a1 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -46,8 +46,7 @@ def _enqueue_vector(sess, queue, values, shape=None): shape = (1, len(values)) dtype = queue.dtypes[0] sess.run( - queue.enqueue(constant_op.constant( - values, dtype=dtype, shape=shape))) + queue.enqueue(constant_op.constant(values, dtype=dtype, shape=shape))) def _binary_2d_label_to_sparse_value(labels): @@ -79,8 +78,8 @@ def _binary_2d_label_to_sparse_value(labels): batch += 1 shape = [len(labels), len(labels[0])] return sparse_tensor.SparseTensorValue( - np.array(indices, np.int64), - np.array(values, np.int64), np.array(shape, np.int64)) + np.array(indices, np.int64), np.array(values, np.int64), + np.array(shape, np.int64)) def _binary_2d_label_to_sparse(labels): @@ -125,8 +124,8 @@ def _binary_3d_label_to_sparse_value(labels): assert label == 0 shape = [len(labels), len(labels[0]), len(labels[0][0])] return sparse_tensor.SparseTensorValue( - np.array(indices, np.int64), - np.array(values, np.int64), np.array(shape, np.int64)) + np.array(indices, np.int64), np.array(values, np.int64), + np.array(shape, np.int64)) def _binary_3d_label_to_sparse(labels): @@ -669,20 +668,18 @@ class StreamingTruePositivesTest(test.TestCase): for expand_predictions in [True, False]: for expand_labels in [True, False]: for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) + predictions = math_ops.cast( + constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))), + dtype=dtype) if expand_predictions: predictions = array_ops.expand_dims(predictions, 2) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) + labels = math_ops.cast( + constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))), + dtype=dtype) if expand_labels: labels = array_ops.expand_dims(labels, 2) - tp, tp_update_op = metrics.streaming_true_positives(predictions, - labels) + tp, tp_update_op = metrics.streaming_true_positives( + predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -692,14 +689,12 @@ class StreamingTruePositivesTest(test.TestCase): def testWeighted(self): for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) + predictions = math_ops.cast( + constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))), + dtype=dtype) + labels = math_ops.cast( + constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))), + dtype=dtype) tp, tp_update_op = metrics.streaming_true_positives( predictions, labels, weights=37.0) @@ -717,28 +712,25 @@ class StreamingFalseNegativesTest(test.TestCase): ops.reset_default_graph() def testVars(self): - metrics.streaming_false_negatives((0, 1, 0), - (0, 1, 1)) + metrics.streaming_false_negatives((0, 1, 0), (0, 1, 1)) _assert_metric_variables(self, ('false_negatives/count:0',)) def testUnweighted(self): for expand_predictions in [True, False]: for expand_labels in [True, False]: for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) + predictions = math_ops.cast( + constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))), + dtype=dtype) if expand_predictions: predictions = array_ops.expand_dims(predictions, 2) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) + labels = math_ops.cast( + constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))), + dtype=dtype) if expand_labels: labels = array_ops.expand_dims(labels, 2) - fn, fn_update_op = metrics.streaming_false_negatives(predictions, - labels) + fn, fn_update_op = metrics.streaming_false_negatives( + predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -748,14 +740,12 @@ class StreamingFalseNegativesTest(test.TestCase): def testWeighted(self): for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) + predictions = math_ops.cast( + constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))), + dtype=dtype) + labels = math_ops.cast( + constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))), + dtype=dtype) fn, fn_update_op = metrics.streaming_false_negatives( predictions, labels, weights=((3.0,), (5.0,), (7.0,))) @@ -773,28 +763,25 @@ class StreamingFalsePositivesTest(test.TestCase): ops.reset_default_graph() def testVars(self): - metrics.streaming_false_positives((0, 1, 0), - (0, 1, 1)) + metrics.streaming_false_positives((0, 1, 0), (0, 1, 1)) _assert_metric_variables(self, ('false_positives/count:0',)) def testUnweighted(self): for expand_predictions in [True, False]: for expand_labels in [True, False]: for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) + predictions = math_ops.cast( + constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))), + dtype=dtype) if expand_predictions: predictions = array_ops.expand_dims(predictions, 2) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) + labels = math_ops.cast( + constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))), + dtype=dtype) if expand_labels: labels = array_ops.expand_dims(labels, 2) - fp, fp_update_op = metrics.streaming_false_positives(predictions, - labels) + fp, fp_update_op = metrics.streaming_false_positives( + predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -804,20 +791,17 @@ class StreamingFalsePositivesTest(test.TestCase): def testWeighted(self): for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) + predictions = math_ops.cast( + constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))), + dtype=dtype) + labels = math_ops.cast( + constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))), + dtype=dtype) fp, fp_update_op = metrics.streaming_false_positives( predictions, labels, - weights=((1.0, 2.0, 3.0, 5.0), - (7.0, 11.0, 13.0, 17.0), - (19.0, 23.0, 29.0, 31.0))) + weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0, + 29.0, 31.0))) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -833,28 +817,25 @@ class StreamingTrueNegativesTest(test.TestCase): ops.reset_default_graph() def testVars(self): - metrics.streaming_true_negatives((0, 1, 0), - (0, 1, 1)) + metrics.streaming_true_negatives((0, 1, 0), (0, 1, 1)) _assert_metric_variables(self, ('true_negatives/count:0',)) def testUnweighted(self): for expand_predictions in [True, False]: for expand_labels in [True, False]: for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) + predictions = math_ops.cast( + constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))), + dtype=dtype) if expand_predictions: predictions = array_ops.expand_dims(predictions, 2) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) + labels = math_ops.cast( + constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))), + dtype=dtype) if expand_labels: labels = array_ops.expand_dims(labels, 2) - tn, tn_update_op = metrics.streaming_true_negatives(predictions, - labels) + tn, tn_update_op = metrics.streaming_true_negatives( + predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -864,14 +845,12 @@ class StreamingTrueNegativesTest(test.TestCase): def testWeighted(self): for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): - predictions = math_ops.cast(constant_op.constant( - ((1, 0, 1, 0), - (0, 1, 1, 1), - (0, 0, 0, 0))), dtype=dtype) - labels = math_ops.cast(constant_op.constant( - ((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))), dtype=dtype) + predictions = math_ops.cast( + constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))), + dtype=dtype) + labels = math_ops.cast( + constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))), + dtype=dtype) tn, tn_update_op = metrics.streaming_true_negatives( predictions, labels, weights=((0.0, 2.0, 3.0, 5.0),)) @@ -894,12 +873,9 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase): _assert_metric_variables(self, ('true_positives:0',)) def testUnweighted(self): - predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), - (0.2, 0.9, 0.7, 0.6), - (0.1, 0.2, 0.4, 0.3))) - labels = constant_op.constant(((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))) + predictions = constant_op.constant( + ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3))) + labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))) tp, tp_update_op = metrics.streaming_true_positives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) @@ -910,12 +886,9 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase): self.assertAllEqual((3, 1, 0), tp.eval()) def testWeighted(self): - predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), - (0.2, 0.9, 0.7, 0.6), - (0.1, 0.2, 0.4, 0.3))) - labels = constant_op.constant(((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))) + predictions = constant_op.constant( + ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3))) + labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))) tp, tp_update_op = metrics.streaming_true_positives_at_thresholds( predictions, labels, weights=37.0, thresholds=(0.15, 0.5, 0.85)) @@ -937,16 +910,14 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase): (0.0, 1.0, 0.0), (0, 1, 1), thresholds=( 0.15, 0.5, - 0.85,)) + 0.85, + )) _assert_metric_variables(self, ('false_negatives:0',)) def testUnweighted(self): - predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), - (0.2, 0.9, 0.7, 0.6), - (0.1, 0.2, 0.4, 0.3))) - labels = constant_op.constant(((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))) + predictions = constant_op.constant( + ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3))) + labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))) fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) @@ -957,12 +928,9 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase): self.assertAllEqual((0, 2, 3), fn.eval()) def testWeighted(self): - predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), - (0.2, 0.9, 0.7, 0.6), - (0.1, 0.2, 0.4, 0.3))) - labels = constant_op.constant(((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))) + predictions = constant_op.constant( + ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3))) + labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))) fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds( predictions, labels, @@ -988,12 +956,9 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase): _assert_metric_variables(self, ('false_positives:0',)) def testUnweighted(self): - predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), - (0.2, 0.9, 0.7, 0.6), - (0.1, 0.2, 0.4, 0.3))) - labels = constant_op.constant(((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))) + predictions = constant_op.constant( + ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3))) + labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))) fp, fp_update_op = metrics.streaming_false_positives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) @@ -1004,18 +969,14 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase): self.assertAllEqual((7, 4, 2), fp.eval()) def testWeighted(self): - predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), - (0.2, 0.9, 0.7, 0.6), - (0.1, 0.2, 0.4, 0.3))) - labels = constant_op.constant(((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))) + predictions = constant_op.constant( + ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3))) + labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))) fp, fp_update_op = metrics.streaming_false_positives_at_thresholds( predictions, labels, - weights=((1.0, 2.0, 3.0, 5.0), - (7.0, 11.0, 13.0, 17.0), - (19.0, 23.0, 29.0, 31.0)), + weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0, + 29.0, 31.0)), thresholds=(0.15, 0.5, 0.85)) with self.test_session() as sess: @@ -1037,12 +998,9 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase): _assert_metric_variables(self, ('true_negatives:0',)) def testUnweighted(self): - predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), - (0.2, 0.9, 0.7, 0.6), - (0.1, 0.2, 0.4, 0.3))) - labels = constant_op.constant(((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))) + predictions = constant_op.constant( + ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3))) + labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))) tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds( predictions, labels, thresholds=(0.15, 0.5, 0.85)) @@ -1053,12 +1011,9 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase): self.assertAllEqual((2, 5, 7), tn.eval()) def testWeighted(self): - predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), - (0.2, 0.9, 0.7, 0.6), - (0.1, 0.2, 0.4, 0.3))) - labels = constant_op.constant(((0, 1, 1, 0), - (1, 0, 0, 0), - (0, 0, 0, 0))) + predictions = constant_op.constant( + ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3))) + labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))) tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds( predictions, labels, @@ -1393,8 +1348,7 @@ class StreamingFPRTest(test.TestCase): (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) - fpr, update_op = metrics.streaming_false_positive_rate( - predictions, labels) + fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -1413,8 +1367,7 @@ class StreamingFPRTest(test.TestCase): predictions = constant_op.constant(np_inputs) labels = constant_op.constant(np_inputs) - fpr, update_op = metrics.streaming_false_positive_rate( - predictions, labels) + fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -1424,8 +1377,7 @@ class StreamingFPRTest(test.TestCase): def testSomeCorrect(self): predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4)) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) - fpr, update_op = metrics.streaming_false_positive_rate( - predictions, labels) + fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -1467,8 +1419,7 @@ class StreamingFPRTest(test.TestCase): predictions = constant_op.constant(np_inputs) labels = constant_op.constant(1 - np_inputs) - fpr, update_op = metrics.streaming_false_positive_rate( - predictions, labels) + fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -1478,8 +1429,7 @@ class StreamingFPRTest(test.TestCase): def testZeroFalsePositivesAndTrueNegativesGivesZeroFPR(self): predictions = array_ops.ones((1, 4)) labels = array_ops.ones((1, 4)) - fpr, update_op = metrics.streaming_false_positive_rate( - predictions, labels) + fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -1521,8 +1471,7 @@ class StreamingFNRTest(test.TestCase): (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) - fnr, update_op = metrics.streaming_false_negative_rate( - predictions, labels) + fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -1541,8 +1490,7 @@ class StreamingFNRTest(test.TestCase): predictions = constant_op.constant(np_inputs) labels = constant_op.constant(np_inputs) - fnr, update_op = metrics.streaming_false_negative_rate( - predictions, labels) + fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -1552,8 +1500,7 @@ class StreamingFNRTest(test.TestCase): def testSomeCorrect(self): predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4)) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) - fnr, update_op = metrics.streaming_false_negative_rate( - predictions, labels) + fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -1595,8 +1542,7 @@ class StreamingFNRTest(test.TestCase): predictions = constant_op.constant(np_inputs) labels = constant_op.constant(1 - np_inputs) - fnr, update_op = metrics.streaming_false_negative_rate( - predictions, labels) + fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -1606,8 +1552,7 @@ class StreamingFNRTest(test.TestCase): def testZeroFalseNegativesAndTruePositivesGivesZeroFNR(self): predictions = array_ops.zeros((1, 4)) labels = array_ops.zeros((1, 4)) - fnr, update_op = metrics.streaming_false_negative_rate( - predictions, labels) + fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -1857,9 +1802,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.54166603, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.54166603, auc.eval(), delta=1e-3) def testAnotherAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1871,9 +1816,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3) def testThirdAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1885,9 +1830,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3) def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -1920,9 +1865,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(1, sess.run(update_op), 6) + self.assertAlmostEqual(0.49999976, sess.run(update_op), 6) - self.assertAlmostEqual(1, auc.eval(), 6) + self.assertAlmostEqual(0.49999976, auc.eval(), 6) def testWithMultipleUpdates(self): num_samples = 1000 @@ -1944,16 +1889,17 @@ class StreamingAUCTest(test.TestCase): enqueue_ops[i].append(x_queue.enqueue(x_batches[i, :])) return x_queue.dequeue() - for weights in (None, np.ones(num_samples), np.random.exponential( - scale=1.0, size=num_samples)): + for weights in (None, np.ones(num_samples), + np.random.exponential(scale=1.0, size=num_samples)): expected_auc = _np_auc(predictions, labels, weights) with self.test_session() as sess: enqueue_ops = [[] for i in range(num_batches)] tf_predictions = _enqueue_as_batches(predictions, enqueue_ops) tf_labels = _enqueue_as_batches(labels, enqueue_ops) - tf_weights = (_enqueue_as_batches(weights, enqueue_ops) if - weights is not None else None) + tf_weights = ( + _enqueue_as_batches(weights, enqueue_ops) + if weights is not None else None) for i in range(num_batches): sess.run(enqueue_ops[i]) @@ -1985,17 +1931,18 @@ class StreamingDynamicAUCTest(test.TestCase): def testUnknownCurve(self): with self.assertRaisesRegexp( ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'): - metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)), - predictions=array_ops.ones((10, 1)), - curve='TEST_CURVE') + metrics.streaming_dynamic_auc( + labels=array_ops.ones((10, 1)), + predictions=array_ops.ones((10, 1)), + curve='TEST_CURVE') def testVars(self): metrics.streaming_dynamic_auc( labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1))) - _assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0', - 'dynamic_auc/concat_labels/size:0', - 'dynamic_auc/concat_preds/array:0', - 'dynamic_auc/concat_preds/size:0']) + _assert_metric_variables(self, [ + 'dynamic_auc/concat_labels/array:0', 'dynamic_auc/concat_labels/size:0', + 'dynamic_auc/concat_preds/array:0', 'dynamic_auc/concat_preds/size:0' + ]) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -2049,8 +1996,8 @@ class StreamingDynamicAUCTest(test.TestCase): def testNonZeroOnePredictions(self): with self.test_session() as sess: - predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5], - dtype=dtypes_lib.float32) + predictions = constant_op.constant( + [2.5, -2.5, 2.5, -2.5], dtype=dtypes_lib.float32) labels = constant_op.constant([1, 0, 1, 0]) auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) sess.run(variables.local_variables_initializer()) @@ -2122,9 +2069,10 @@ class StreamingDynamicAUCTest(test.TestCase): num_batches = 100 labels = np.array([]) predictions = np.array([]) - tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32), - collections=[ops.GraphKeys.LOCAL_VARIABLES], - dtype=dtypes_lib.int32) + tf_labels = variables.Variable( + array_ops.ones(batch_size, dtypes_lib.int32), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.int32) tf_predictions = variables.Variable( array_ops.ones(batch_size), collections=[ops.GraphKeys.LOCAL_VARIABLES], @@ -2180,7 +2128,7 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5) -class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): +class AucWithConfidenceIntervalsTest(test.TestCase): def setUp(self): np.random.seed(1) @@ -2192,7 +2140,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): Args: expected_dict: A dictionary with keys that are the names of properties of PrecisionRecallData and whose values are lists of floats. - gotten_result: A PrecisionRecallData object. + gotten_result: A AucWithConfidenceIntervalData object. """ gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()} self.assertItemsEqual( @@ -2201,6 +2149,204 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): for key, expected_values in expected_dict.items(): self.assertAllClose(expected_values, gotten_dict[key]) + def _testCase(self, predictions, labels, expected_result, weights=None): + """Performs a test given a certain scenario of labels, predictions, weights. + + Args: + predictions: The predictions tensor. Of type float32. + labels: The labels tensor. Of type bool. + expected_result: The expected result (dict) that maps to tensors. + weights: Optional weights tensor. + """ + with self.test_session() as sess: + predictions_tensor = constant_op.constant( + predictions, dtype=dtypes_lib.float32) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.int64) + weights_tensor = None + if weights: + weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32) + gotten_result, update_op = ( + metric_ops.auc_with_confidence_intervals( + labels=labels_tensor, + predictions=predictions_tensor, + weights=weights_tensor)) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self._testResultsEqual(expected_result, gotten_result) + + def testAucAllCorrect(self): + self._testCase( + predictions=[0., 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0], + labels=[0, 0, 1, 0, 0, 1, 0, 1, 1, 0], + expected_result={ + 'auc': 0.66666667, + 'lower': 0.27826795, + 'upper': 0.91208512, + }) + + def testAucUnorderedInput(self): + self._testCase( + predictions=[1.0, 0.6, 0., 0.3, 0.4, 0.2, 0.5, 0.3, 0.6, 0.8], + labels=[0, 1, 0, 1, 0, 0, 1, 0, 0, 1], + expected_result={ + 'auc': 0.66666667, + 'lower': 0.27826795, + 'upper': 0.91208512, + }) + + def testAucWithWeights(self): + self._testCase( + predictions=[0., 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0], + labels=[0, 0, 1, 0, 0, 1, 0, 1, 1, 0], + weights=[0.5, 0.6, 1.2, 1.5, 2.0, 2.0, 1.5, 1.2, 0.6, 0.5], + expected_result={ + 'auc': 0.65151515, + 'lower': 0.28918604, + 'upper': 0.89573906, + }) + + def testAucEqualOne(self): + self._testCase( + predictions=[0, 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0], + labels=[0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + expected_result={ + 'auc': 1.0, + 'lower': 1.0, + 'upper': 1.0, + }) + + def testAucEqualZero(self): + self._testCase( + predictions=[0, 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0], + labels=[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + expected_result={ + 'auc': 0.0, + 'lower': 0.0, + 'upper': 0.0, + }) + + def testNonZeroOnePredictions(self): + self._testCase( + predictions=[2.5, -2.5, .5, -.5, 1], + labels=[1, 0, 1, 0, 0], + expected_result={ + 'auc': 0.83333333, + 'lower': 0.15229267, + 'upper': 0.99286517, + }) + + def testAllLabelsOnes(self): + self._testCase( + predictions=[1., 1., 1., 1., 1.], + labels=[1, 1, 1, 1, 1], + expected_result={ + 'auc': 0., + 'lower': 0., + 'upper': 0., + }) + + def testAllLabelsZeros(self): + self._testCase( + predictions=[0., 0., 0., 0., 0.], + labels=[0, 0, 0, 0, 0], + expected_result={ + 'auc': 0., + 'lower': 0., + 'upper': 0., + }) + + def testWeightSumLessThanOneAll(self): + self._testCase( + predictions=[1., 1., 0., 1., 0., 0.], + labels=[1, 1, 1, 0, 0, 0], + weights=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + expected_result={ + 'auc': 0., + 'lower': 0., + 'upper': 0., + }) + + def testWithMultipleUpdates(self): + batch_size = 50 + num_batches = 100 + labels = np.array([]) + predictions = np.array([]) + tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.int32) + tf_predictions = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + auc, update_op = metrics.auc_with_confidence_intervals(tf_labels, + tf_predictions) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_batches): + new_labels = np.random.randint(0, 2, size=batch_size) + noise = np.random.normal(0.0, scale=0.2, size=batch_size) + new_predictions = 0.4 + 0.2 * new_labels + noise + labels = np.concatenate([labels, new_labels]) + predictions = np.concatenate([predictions, new_predictions]) + sess.run(tf_labels.assign(new_labels)) + sess.run(tf_predictions.assign(new_predictions)) + sess.run(update_op) + expected_auc = _np_auc(predictions, labels) + self.assertAllClose(expected_auc, auc.auc.eval()) + + def testExceptionOnFloatLabels(self): + with self.test_session() as sess: + predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) + labels = constant_op.constant([0.7, 0, 1, 0, 1]) + _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) + sess.run(variables.local_variables_initializer()) + self.assertRaises(TypeError, sess.run(update_op)) + + def testExceptionOnGreaterThanOneLabel(self): + with self.test_session() as sess: + predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) + labels = constant_op.constant([2, 1, 0, 1, 0]) + _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) + sess.run(variables.local_variables_initializer()) + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + '.*labels must be 0 or 1, at least one is >1.*'): + sess.run(update_op) + + def testExceptionOnNegativeLabel(self): + with self.test_session() as sess: + predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) + labels = constant_op.constant([1, 0, -1, 1, 0]) + _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) + sess.run(variables.local_variables_initializer()) + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + '.*labels must be 0 or 1, at least one is <0.*'): + sess.run(update_op) + + +class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def _testResultsEqual(self, expected_dict, gotten_result): + """Tests that 2 results (dicts) represent the same data. + + Args: + expected_dict: A dictionary with keys that are the names of properties + of PrecisionRecallData and whose values are lists of floats. + gotten_result: A PrecisionRecallData object. + """ + gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()} + self.assertItemsEqual(list(expected_dict.keys()), list(gotten_dict.keys())) + + for key, expected_values in expected_dict.items(): + self.assertAllClose(expected_values, gotten_dict[key]) + def _testCase(self, predictions, labels, expected_result, weights=None): """Performs a test given a certain scenario of labels, predictions, weights. @@ -2261,60 +2407,65 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): sess.run(update_op) # Then verify idempotency. - initial_result = {k: value.eval().tolist() for k, value in - result._asdict().items()} + initial_result = { + k: value.eval().tolist() + for k, value in result._asdict().items() + } for _ in range(3): self._testResultsEqual(initial_result, result) def testAllTruePositives(self): - self._testCase([[1]], [[True]], { - 'tp': [1, 1, 1], - 'fp': [0, 0, 0], - 'tn': [0, 0, 0], - 'fn': [0, 0, 0], - 'precision': [1.0, 1.0, 1.0], - 'recall': [1.0, 1.0, 1.0], - 'thresholds': [0.0, 0.5, 1.0], - }) + self._testCase( + [[1]], [[True]], { + 'tp': [1, 1, 1], + 'fp': [0, 0, 0], + 'tn': [0, 0, 0], + 'fn': [0, 0, 0], + 'precision': [1.0, 1.0, 1.0], + 'recall': [1.0, 1.0, 1.0], + 'thresholds': [0.0, 0.5, 1.0], + }) def testAllTrueNegatives(self): - self._testCase([[0]], [[False]], { - 'tp': [0, 0, 0], - 'fp': [1, 0, 0], - 'tn': [0, 1, 1], - 'fn': [0, 0, 0], - 'precision': [0.0, 0.0, 0.0], - 'recall': [0.0, 0.0, 0.0], - 'thresholds': [0.0, 0.5, 1.0], - }) + self._testCase( + [[0]], [[False]], { + 'tp': [0, 0, 0], + 'fp': [1, 0, 0], + 'tn': [0, 1, 1], + 'fn': [0, 0, 0], + 'precision': [0.0, 0.0, 0.0], + 'recall': [0.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) def testAllFalsePositives(self): - self._testCase([[1]], [[False]], { - 'tp': [0, 0, 0], - 'fp': [1, 1, 1], - 'tn': [0, 0, 0], - 'fn': [0, 0, 0], - 'precision': [0.0, 0.0, 0.0], - 'recall': [0.0, 0.0, 0.0], - 'thresholds': [0.0, 0.5, 1.0], - }) + self._testCase( + [[1]], [[False]], { + 'tp': [0, 0, 0], + 'fp': [1, 1, 1], + 'tn': [0, 0, 0], + 'fn': [0, 0, 0], + 'precision': [0.0, 0.0, 0.0], + 'recall': [0.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) def testAllFalseNegatives(self): - self._testCase([[0]], [[True]], { - 'tp': [1, 0, 0], - 'fp': [0, 0, 0], - 'tn': [0, 0, 0], - 'fn': [0, 1, 1], - 'precision': [1.0, 0.0, 0.0], - 'recall': [1.0, 0.0, 0.0], - 'thresholds': [0.0, 0.5, 1.0], - }) + self._testCase( + [[0]], [[True]], { + 'tp': [1, 0, 0], + 'fp': [0, 0, 0], + 'tn': [0, 0, 0], + 'fn': [0, 1, 1], + 'precision': [1.0, 0.0, 0.0], + 'recall': [1.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) def testManyValues(self): self._testCase( [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], - [[True, False, False, True, True, True]], - { + [[True, False, False, True, True, True]], { 'tp': [4, 3, 0], 'fp': [2, 0, 0], 'tn': [0, 2, 2], @@ -2327,8 +2478,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): def testManyValuesWithWeights(self): self._testCase( [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], - [[True, False, False, True, True, True]], - { + [[True, False, False, True, True, True]], { 'tp': [1.5, 1.5, 0.0], 'fp': [2.5, 0.0, 0.0], 'tn': [0.0, 2.5, 2.5], @@ -2644,11 +2794,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): labels = random_ops.random_uniform( (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -2672,11 +2821,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2690,11 +2838,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2709,11 +2856,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2779,11 +2925,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) thresholds = [-1.0, 2.0] # lower/higher than any values - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds) prec_low = prec[0] prec_high = prec[1] @@ -2803,11 +2948,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2872,12 +3016,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): tf_predictions = predictions_queue.dequeue() tf_labels = labels_queue.dequeue() - prec, prec_op = metrics.streaming_precision_at_thresholds(tf_predictions, - tf_labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(tf_predictions, - tf_labels, - thresholds) + prec, prec_op = metrics.streaming_precision_at_thresholds( + tf_predictions, tf_labels, thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds( + tf_predictions, tf_labels, thresholds) sess.run(variables.local_variables_initializer()) for _ in range(int(num_samples / batch_size)): @@ -2921,8 +3063,7 @@ class StreamingFPRThresholdsTest(test.TestCase): labels=array_ops.ones((10, 1)), thresholds=[0, 0.5, 1.0], updates_collections=[my_collection_name]) - self.assertListEqual( - ops.get_collection(my_collection_name), [update_op]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) def testValueTensorIsIdempotent(self): predictions = random_ops.random_uniform( @@ -3271,8 +3412,7 @@ class StreamingFNRThresholdsTest(test.TestCase): labels=array_ops.ones((10, 1)), thresholds=[0, 0.5, 1.0], updates_collections=[my_collection_name]) - self.assertListEqual( - ops.get_collection(my_collection_name), [update_op]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) def testValueTensorIsIdempotent(self): predictions = random_ops.random_uniform( @@ -3492,8 +3632,7 @@ class StreamingRecallAtKTest(test.TestCase): def testVars(self): metrics.streaming_recall_at_k( predictions=array_ops.ones((self._batch_size, self._num_classes)), - labels=array_ops.ones( - (self._batch_size,), dtype=dtypes_lib.int32), + labels=array_ops.ones((self._batch_size,), dtype=dtypes_lib.int32), k=1) _assert_metric_variables(self, ('recall_at_1/count:0', 'recall_at_1/total:0')) @@ -3502,8 +3641,7 @@ class StreamingRecallAtKTest(test.TestCase): my_collection_name = '__metrics__' mean, _ = metrics.streaming_recall_at_k( predictions=array_ops.ones((self._batch_size, self._num_classes)), - labels=array_ops.ones( - (self._batch_size,), dtype=dtypes_lib.int32), + labels=array_ops.ones((self._batch_size,), dtype=dtypes_lib.int32), k=1, metrics_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [mean]) @@ -3512,8 +3650,7 @@ class StreamingRecallAtKTest(test.TestCase): my_collection_name = '__updates__' _, update_op = metrics.streaming_recall_at_k( predictions=array_ops.ones((self._batch_size, self._num_classes)), - labels=array_ops.ones( - (self._batch_size,), dtype=dtypes_lib.int32), + labels=array_ops.ones((self._batch_size,), dtype=dtypes_lib.int32), k=1, updates_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) @@ -3715,9 +3852,17 @@ class StreamingSparsePrecisionTest(test.TestCase): # top_k_predictions has rank < 2. top_k_predictions = [9, 4, 6, 2, 0] sp_labels = sparse_tensor.SparseTensorValue( - indices=np.array([[0,], [1,], [2,]], np.int64), + indices=np.array([[ + 0, + ], [ + 1, + ], [ + 2, + ]], np.int64), values=np.array([2, 7, 8], np.int64), - dense_shape=np.array([10,], np.int64)) + dense_shape=np.array([ + 10, + ], np.int64)) with self.assertRaises(ValueError): precision, _ = metrics.streaming_sparse_precision_at_top_k( @@ -3774,8 +3919,9 @@ class StreamingSparsePrecisionTest(test.TestCase): # average of the 2 examples. labels = np.array([labels_ex1, labels_ex2], dtype=np.int64) predictions = (predictions_ex1, predictions_ex2) - streaming_precision = [(ex1 + ex2) / 2 - for ex1, ex2 in zip(precision_ex1, precision_ex2)] + streaming_precision = [ + (ex1 + ex2) / 2 for ex1, ex2 in zip(precision_ex1, precision_ex2) + ] streaming_average_precision = [ (ex1 + ex2) / 2 for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2) @@ -3835,29 +3981,29 @@ class StreamingSparsePrecisionTest(test.TestCase): (predictions_top_k_ex1[:k],), labels, expected=avg_precision_ex1[i]) def test_average_precision_at_top_k_static_shape_check(self): - predictions_top_k = array_ops.placeholder(shape=(2, None), - dtype=dtypes_lib.int64) + predictions_top_k = array_ops.placeholder( + shape=(2, None), dtype=dtypes_lib.int64) labels = np.array(((1,), (2,)), dtype=np.int64) # Fails due to non-static predictions_idx shape. with self.assertRaises(ValueError): - metric_ops.streaming_sparse_average_precision_at_top_k(predictions_top_k, - labels) + metric_ops.streaming_sparse_average_precision_at_top_k( + predictions_top_k, labels) predictions_top_k = (2, 1) # Fails since rank of predictions_idx is less than one. with self.assertRaises(ValueError): - metric_ops.streaming_sparse_average_precision_at_top_k(predictions_top_k, - labels) + metric_ops.streaming_sparse_average_precision_at_top_k( + predictions_top_k, labels) predictions_top_k = ((2,), (1,)) # Valid static shape. - metric_ops.streaming_sparse_average_precision_at_top_k(predictions_top_k, - labels) + metric_ops.streaming_sparse_average_precision_at_top_k( + predictions_top_k, labels) def test_one_label_at_k1_nan(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] top_k_predictions = [[3], [3]] - sparse_labels = _binary_2d_label_to_sparse_value( - [[0, 0, 0, 1], [0, 0, 1, 0]]) + sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], + [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) for labels in (sparse_labels, dense_labels): @@ -3871,8 +4017,8 @@ class StreamingSparsePrecisionTest(test.TestCase): def test_one_label_at_k1(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] top_k_predictions = [[3], [3]] - sparse_labels = _binary_2d_label_to_sparse_value( - [[0, 0, 0, 1], [0, 0, 1, 0]]) + sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], + [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) for labels in (sparse_labels, dense_labels): @@ -3971,8 +4117,8 @@ class StreamingSparsePrecisionTest(test.TestCase): [5, 7, 2, 9, 6], ] sp_labels = sparse_tensor.SparseTensorValue( - indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], - [1, 3]], + indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], [1, + 3]], # values -1 and 10 are outside the [0, n_classes) range and are ignored. values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64), dense_shape=[2, 4]) @@ -4324,8 +4470,8 @@ class StreamingSparseRecallTest(test.TestCase): def test_one_label_at_k1_nan(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] top_k_predictions = [[3], [3]] - sparse_labels = _binary_2d_label_to_sparse_value( - [[0, 0, 0, 1], [0, 0, 1, 0]]) + sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], + [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) # Classes 0,1 have 0 labels, 0 predictions, classes -1 and 4 are out of @@ -4340,8 +4486,8 @@ class StreamingSparseRecallTest(test.TestCase): def test_one_label_at_k1_no_predictions(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] top_k_predictions = [[3], [3]] - sparse_labels = _binary_2d_label_to_sparse_value( - [[0, 0, 0, 1], [0, 0, 1, 0]]) + sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], + [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) for labels in (sparse_labels, dense_labels): @@ -4354,8 +4500,8 @@ class StreamingSparseRecallTest(test.TestCase): def test_one_label_at_k1(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] top_k_predictions = [[3], [3]] - sparse_labels = _binary_2d_label_to_sparse_value( - [[0, 0, 0, 1], [0, 0, 1, 0]]) + sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], + [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) for labels in (sparse_labels, dense_labels): @@ -4374,8 +4520,8 @@ class StreamingSparseRecallTest(test.TestCase): def test_one_label_at_k1_weighted(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] top_k_predictions = [[3], [3]] - sparse_labels = _binary_2d_label_to_sparse_value( - [[0, 0, 0, 1], [0, 0, 1, 0]]) + sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], + [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) for labels in (sparse_labels, dense_labels): @@ -4647,8 +4793,8 @@ class StreamingSparseRecallTest(test.TestCase): [5, 7, 2, 9, 6], ] sp_labels = sparse_tensor.SparseTensorValue( - indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], - [1, 3]], + indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], [1, + 3]], # values -1 and 10 are outside the [0, n_classes) range. values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64), dense_shape=[2, 4]) @@ -4661,10 +4807,7 @@ class StreamingSparseRecallTest(test.TestCase): expected=2.0 / 2, class_id=2) self._test_sparse_recall_at_top_k( - sp_labels, - top_k_predictions, - expected=2.0 / 2, - class_id=2) + sp_labels, top_k_predictions, expected=2.0 / 2, class_id=2) # Class 5: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -4674,10 +4817,7 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=5) self._test_sparse_recall_at_top_k( - sp_labels, - top_k_predictions, - expected=1.0 / 1, - class_id=5) + sp_labels, top_k_predictions, expected=1.0 / 1, class_id=5) # Class 7: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -4687,10 +4827,7 @@ class StreamingSparseRecallTest(test.TestCase): expected=0.0 / 1, class_id=7) self._test_sparse_recall_at_top_k( - sp_labels, - top_k_predictions, - expected=0.0 / 1, - class_id=7) + sp_labels, top_k_predictions, expected=0.0 / 1, class_id=7) # All classes: 8 labels, 3 correct. self._test_streaming_sparse_recall_at_k( @@ -4740,10 +4877,8 @@ class StreamingSparseRecallTest(test.TestCase): [9, 4, 6, 2, 0], ]] sparse_labels = _binary_3d_label_to_sparse_value( - [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], - [[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]]) + [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], + [[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]]) dense_labels = np.array( [[[2, 7, 8], [1, 2, 5]], [ [1, 2, 5], @@ -4771,10 +4906,8 @@ class StreamingSparseRecallTest(test.TestCase): [9, 4, 6, 2, 0], ]] labels = _binary_3d_label_to_sparse_value( - [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], - [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) + [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], + [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) # Class 2: 4 labels, all correct. self._test_streaming_sparse_recall_at_k( @@ -4813,10 +4946,8 @@ class StreamingSparseRecallTest(test.TestCase): [9, 4, 6, 2, 0], ]] labels = _binary_3d_label_to_sparse_value( - [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], - [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) + [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], + [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) for class_id in xrange(10): self._test_streaming_sparse_recall_at_k( @@ -4867,10 +4998,8 @@ class StreamingSparseRecallTest(test.TestCase): [9, 4, 6, 2, 0], ]] labels = _binary_3d_label_to_sparse_value( - [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], - [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) + [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], + [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) # Class 2: 2 labels, both correct. self._test_streaming_sparse_recall_at_k( @@ -4963,10 +5092,8 @@ class StreamingSparseRecallTest(test.TestCase): weights=[[0, 1], [0, 1]]) def test_sparse_tensor_value(self): - predictions = [[0.1, 0.3, 0.2, 0.4], - [0.1, 0.2, 0.3, 0.4]] - labels = [[0, 0, 1, 0], - [0, 0, 0, 1]] + predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + labels = [[0, 0, 1, 0], [0, 0, 0, 1]] expected_recall = 0.5 with self.test_session(): _, recall = metrics.streaming_sparse_recall_at_k( @@ -5009,8 +5136,8 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase): def testValueTensorIsIdempotent(self): predictions = random_ops.random_normal((10, 3), seed=1) labels = random_ops.random_normal((10, 3), seed=2) - error, update_op = metrics.streaming_mean_absolute_error(predictions, - labels) + error, update_op = metrics.streaming_mean_absolute_error( + predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -5031,8 +5158,8 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase): [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) - error, update_op = metrics.streaming_mean_absolute_error(predictions, - labels, weights) + error, update_op = metrics.streaming_mean_absolute_error( + predictions, labels, weights) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -5075,8 +5202,8 @@ class StreamingMeanRelativeErrorTest(test.TestCase): predictions = random_ops.random_normal((10, 3), seed=1) labels = random_ops.random_normal((10, 3), seed=2) normalizer = random_ops.random_normal((10, 3), seed=3) - error, update_op = metrics.streaming_mean_relative_error(predictions, - labels, normalizer) + error, update_op = metrics.streaming_mean_relative_error( + predictions, labels, normalizer) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -5200,8 +5327,8 @@ class StreamingMeanSquaredErrorTest(test.TestCase): [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) - error, update_op = metrics.streaming_mean_squared_error(predictions, labels, - weights) + error, update_op = metrics.streaming_mean_squared_error( + predictions, labels, weights) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -5224,8 +5351,8 @@ class StreamingMeanSquaredErrorTest(test.TestCase): _enqueue_vector(sess, labels_queue, [2, 4, 6]) labels = labels_queue.dequeue() - error, update_op = metrics.streaming_mean_squared_error(predictions, - labels) + error, update_op = metrics.streaming_mean_squared_error( + predictions, labels) sess.run(variables.local_variables_initializer()) sess.run(update_op) @@ -5292,10 +5419,10 @@ class StreamingMeanSquaredErrorTest(test.TestCase): _enqueue_vector(sess, labels_queue, [2, 4, 6]) labels = labels_queue.dequeue() - mae, ma_update_op = metrics.streaming_mean_absolute_error(predictions, - labels) - mse, ms_update_op = metrics.streaming_mean_squared_error(predictions, - labels) + mae, ma_update_op = metrics.streaming_mean_absolute_error( + predictions, labels) + mse, ms_update_op = metrics.streaming_mean_squared_error( + predictions, labels) sess.run(variables.local_variables_initializer()) sess.run([ma_update_op, ms_update_op]) @@ -5336,8 +5463,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): def testValueTensorIsIdempotent(self): predictions = random_ops.random_normal((10, 3), seed=1) labels = random_ops.random_normal((10, 3), seed=2) - error, update_op = metrics.streaming_root_mean_squared_error(predictions, - labels) + error, update_op = metrics.streaming_root_mean_squared_error( + predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -5357,8 +5484,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): 0.0, shape=(1, 3), dtype=dtypes_lib.float32) labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32) - rmse, update_op = metrics.streaming_root_mean_squared_error(predictions, - labels) + rmse, update_op = metrics.streaming_root_mean_squared_error( + predictions, labels) sess.run(variables.local_variables_initializer()) self.assertEqual(0, sess.run(update_op)) @@ -5372,8 +5499,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): labels = constant_op.constant( [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32) - rmse, update_op = metrics.streaming_root_mean_squared_error(predictions, - labels) + rmse, update_op = metrics.streaming_root_mean_squared_error( + predictions, labels) sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(math.sqrt(6), update_op.eval(), 5) @@ -5387,9 +5514,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) - rmse, update_op = metrics.streaming_root_mean_squared_error(predictions, - labels, - weights) + rmse, update_op = metrics.streaming_root_mean_squared_error( + predictions, labels, weights) sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(math.sqrt(13), sess.run(update_op)) @@ -5404,8 +5530,8 @@ class StreamingCovarianceTest(test.TestCase): def testVars(self): metrics.streaming_covariance( - predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones( - [10, 10]), + predictions=math_ops.to_float(math_ops.range(10)) + + array_ops.ones([10, 10]), labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10])) _assert_metric_variables(self, ( 'covariance/comoment:0', @@ -5417,8 +5543,8 @@ class StreamingCovarianceTest(test.TestCase): def testMetricsCollection(self): my_collection_name = '__metrics__' cov, _ = metrics.streaming_covariance( - predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones( - [10, 10]), + predictions=math_ops.to_float(math_ops.range(10)) + + array_ops.ones([10, 10]), labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), metrics_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [cov]) @@ -5426,8 +5552,8 @@ class StreamingCovarianceTest(test.TestCase): def testUpdatesCollection(self): my_collection_name = '__updates__' _, update_op = metrics.streaming_covariance( - predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones( - [10, 10]), + predictions=math_ops.to_float(math_ops.range(10)) + + array_ops.ones([10, 10]), labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), updates_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) @@ -5487,9 +5613,8 @@ class StreamingCovarianceTest(test.TestCase): cov, update_op = metrics.streaming_covariance( predictions, labels, weights=weights) - expected_cov = np.cov([2, 4, 6, 8], - [1, 3, 2, 7], - fweights=[0, 1, 3, 1])[0, 1] + expected_cov = np.cov( + [2, 4, 6, 8], [1, 3, 2, 7], fweights=[0, 1, 3, 1])[0, 1] sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(expected_cov, sess.run(update_op)) self.assertAlmostEqual(expected_cov, cov.eval()) @@ -5514,17 +5639,18 @@ class StreamingCovarianceTest(test.TestCase): predictions_t: predictions[stride * i:stride * (i + 1)], labels_t: labels[stride * i:stride * (i + 1)] } - self.assertEqual(np.isnan(prev_expected_cov), - np.isnan(sess.run(cov, feed_dict=feed_dict))) + self.assertEqual( + np.isnan(prev_expected_cov), + np.isnan(sess.run(cov, feed_dict=feed_dict))) if not np.isnan(prev_expected_cov): - self.assertAlmostEqual( - prev_expected_cov, sess.run(cov, feed_dict=feed_dict), 5) + self.assertAlmostEqual(prev_expected_cov, + sess.run(cov, feed_dict=feed_dict), 5) expected_cov = np.cov(predictions[:stride * (i + 1)], labels[:stride * (i + 1)])[0, 1] - self.assertAlmostEqual( - expected_cov, sess.run(update_op, feed_dict=feed_dict), 5) - self.assertAlmostEqual( - expected_cov, sess.run(cov, feed_dict=feed_dict), 5) + self.assertAlmostEqual(expected_cov, + sess.run(update_op, feed_dict=feed_dict), 5) + self.assertAlmostEqual(expected_cov, sess.run(cov, feed_dict=feed_dict), + 5) prev_expected_cov = expected_cov def testMultiUpdateWithErrorAndWeights(self): @@ -5552,18 +5678,20 @@ class StreamingCovarianceTest(test.TestCase): labels_t: labels[stride * i:stride * (i + 1)], weights_t: weights[stride * i:stride * (i + 1)] } - self.assertEqual(np.isnan(prev_expected_cov), - np.isnan(sess.run(cov, feed_dict=feed_dict))) + self.assertEqual( + np.isnan(prev_expected_cov), + np.isnan(sess.run(cov, feed_dict=feed_dict))) if not np.isnan(prev_expected_cov): - self.assertAlmostEqual( - prev_expected_cov, sess.run(cov, feed_dict=feed_dict), 5) - expected_cov = np.cov(predictions[:stride * (i + 1)], - labels[:stride * (i + 1)], - fweights=weights[:stride * (i + 1)])[0, 1] - self.assertAlmostEqual( - expected_cov, sess.run(update_op, feed_dict=feed_dict), 5) - self.assertAlmostEqual( - expected_cov, sess.run(cov, feed_dict=feed_dict), 5) + self.assertAlmostEqual(prev_expected_cov, + sess.run(cov, feed_dict=feed_dict), 5) + expected_cov = np.cov( + predictions[:stride * (i + 1)], + labels[:stride * (i + 1)], + fweights=weights[:stride * (i + 1)])[0, 1] + self.assertAlmostEqual(expected_cov, + sess.run(update_op, feed_dict=feed_dict), 5) + self.assertAlmostEqual(expected_cov, sess.run(cov, feed_dict=feed_dict), + 5) prev_expected_cov = expected_cov @@ -5574,8 +5702,8 @@ class StreamingPearsonRTest(test.TestCase): def testVars(self): metrics.streaming_pearson_correlation( - predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones( - [10, 10]), + predictions=math_ops.to_float(math_ops.range(10)) + + array_ops.ones([10, 10]), labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10])) _assert_metric_variables(self, ( 'pearson_r/covariance/comoment:0', @@ -5595,8 +5723,8 @@ class StreamingPearsonRTest(test.TestCase): def testMetricsCollection(self): my_collection_name = '__metrics__' pearson_r, _ = metrics.streaming_pearson_correlation( - predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones( - [10, 10]), + predictions=math_ops.to_float(math_ops.range(10)) + + array_ops.ones([10, 10]), labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), metrics_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [pearson_r]) @@ -5604,8 +5732,8 @@ class StreamingPearsonRTest(test.TestCase): def testUpdatesCollection(self): my_collection_name = '__updates__' _, update_op = metrics.streaming_pearson_correlation( - predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones( - [10, 10]), + predictions=math_ops.to_float(math_ops.range(10)) + + array_ops.ones([10, 10]), labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), updates_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) @@ -5613,8 +5741,8 @@ class StreamingPearsonRTest(test.TestCase): def testValueTensorIsIdempotent(self): labels = random_ops.random_normal((10, 3), seed=2) predictions = labels * 0.5 + random_ops.random_normal((10, 3), seed=1) * 0.5 - pearson_r, update_op = metrics.streaming_pearson_correlation(predictions, - labels) + pearson_r, update_op = metrics.streaming_pearson_correlation( + predictions, labels) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) @@ -5633,8 +5761,8 @@ class StreamingPearsonRTest(test.TestCase): predictions = math_ops.to_float(math_ops.range(10)) labels = math_ops.to_float(math_ops.range(10)) - pearson_r, update_op = metrics.streaming_pearson_correlation(predictions, - labels) + pearson_r, update_op = metrics.streaming_pearson_correlation( + predictions, labels) expected_r = np.corrcoef(np.arange(10), np.arange(10))[0, 1] sess.run(variables.local_variables_initializer()) @@ -5648,8 +5776,8 @@ class StreamingPearsonRTest(test.TestCase): labels = constant_op.constant( [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32) - pearson_r, update_op = metrics.streaming_pearson_correlation(predictions, - labels) + pearson_r, update_op = metrics.streaming_pearson_correlation( + predictions, labels) expected_r = np.corrcoef([2, 4, 6], [1, 3, 2])[0, 1] sess.run(variables.local_variables_initializer()) @@ -5698,17 +5826,18 @@ class StreamingPearsonRTest(test.TestCase): predictions_t: predictions[stride * i:stride * (i + 1)], labels_t: labels[stride * i:stride * (i + 1)] } - self.assertEqual(np.isnan(prev_expected_r), - np.isnan(sess.run(pearson_r, feed_dict=feed_dict))) + self.assertEqual( + np.isnan(prev_expected_r), + np.isnan(sess.run(pearson_r, feed_dict=feed_dict))) if not np.isnan(prev_expected_r): - self.assertAlmostEqual( - prev_expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5) + self.assertAlmostEqual(prev_expected_r, + sess.run(pearson_r, feed_dict=feed_dict), 5) expected_r = np.corrcoef(predictions[:stride * (i + 1)], labels[:stride * (i + 1)])[0, 1] - self.assertAlmostEqual( - expected_r, sess.run(update_op, feed_dict=feed_dict), 5) - self.assertAlmostEqual( - expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5) + self.assertAlmostEqual(expected_r, + sess.run(update_op, feed_dict=feed_dict), 5) + self.assertAlmostEqual(expected_r, + sess.run(pearson_r, feed_dict=feed_dict), 5) prev_expected_r = expected_r def testMultiUpdateWithErrorAndWeights(self): @@ -5736,19 +5865,21 @@ class StreamingPearsonRTest(test.TestCase): labels_t: labels[stride * i:stride * (i + 1)], weights_t: weights[stride * i:stride * (i + 1)] } - self.assertEqual(np.isnan(prev_expected_r), - np.isnan(sess.run(pearson_r, feed_dict=feed_dict))) + self.assertEqual( + np.isnan(prev_expected_r), + np.isnan(sess.run(pearson_r, feed_dict=feed_dict))) if not np.isnan(prev_expected_r): - self.assertAlmostEqual( - prev_expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5) - cmat = np.cov(predictions[:stride * (i + 1)], - labels[:stride * (i + 1)], - fweights=weights[:stride * (i + 1)]) + self.assertAlmostEqual(prev_expected_r, + sess.run(pearson_r, feed_dict=feed_dict), 5) + cmat = np.cov( + predictions[:stride * (i + 1)], + labels[:stride * (i + 1)], + fweights=weights[:stride * (i + 1)]) expected_r = cmat[0, 1] / np.sqrt(cmat[0, 0] * cmat[1, 1]) - self.assertAlmostEqual( - expected_r, sess.run(update_op, feed_dict=feed_dict), 5) - self.assertAlmostEqual( - expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5) + self.assertAlmostEqual(expected_r, + sess.run(update_op, feed_dict=feed_dict), 5) + self.assertAlmostEqual(expected_r, + sess.run(pearson_r, feed_dict=feed_dict), 5) prev_expected_r = expected_r def testMultiUpdateWithErrorAndSingletonBatches(self): @@ -5758,7 +5889,7 @@ class StreamingPearsonRTest(test.TestCase): predictions = np.random.randn(n) labels = 0.5 * predictions + np.random.randn(n) stride = 10 - weights = (np.arange(n).reshape(n//stride, stride) % stride == 0) + weights = (np.arange(n).reshape(n // stride, stride) % stride == 0) for row in weights: np.random.shuffle(row) # Now, weights is one-hot by row - one item per batch has non-zero weight. @@ -5778,19 +5909,20 @@ class StreamingPearsonRTest(test.TestCase): labels_t: labels[stride * i:stride * (i + 1)], weights_t: weights[stride * i:stride * (i + 1)] } - cmat = np.cov(predictions[:stride * (i + 1)], - labels[:stride * (i + 1)], - fweights=weights[:stride * (i + 1)]) + cmat = np.cov( + predictions[:stride * (i + 1)], + labels[:stride * (i + 1)], + fweights=weights[:stride * (i + 1)]) expected_r = cmat[0, 1] / np.sqrt(cmat[0, 0] * cmat[1, 1]) actual_r = sess.run(update_op, feed_dict=feed_dict) self.assertEqual(np.isnan(expected_r), np.isnan(actual_r)) - self.assertEqual(np.isnan(expected_r), - np.isnan(sess.run(pearson_r, feed_dict=feed_dict))) + self.assertEqual( + np.isnan(expected_r), + np.isnan(sess.run(pearson_r, feed_dict=feed_dict))) if not np.isnan(expected_r): - self.assertAlmostEqual( - expected_r, actual_r, 5) - self.assertAlmostEqual( - expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5) + self.assertAlmostEqual(expected_r, actual_r, 5) + self.assertAlmostEqual(expected_r, + sess.run(pearson_r, feed_dict=feed_dict), 5) class StreamingMeanCosineDistanceTest(test.TestCase): @@ -6191,20 +6323,14 @@ class StreamingMeanIOUTest(test.TestCase): self.assertAlmostEqual(desired_output, miou.eval()) def testUpdateOpEvalIsAccumulatedConfusionMatrix(self): - predictions = array_ops.concat( - [ - constant_op.constant( - 0, shape=[5]), constant_op.constant( - 1, shape=[5]) - ], - 0) - labels = array_ops.concat( - [ - constant_op.constant( - 0, shape=[3]), constant_op.constant( - 1, shape=[7]) - ], - 0) + predictions = array_ops.concat([ + constant_op.constant(0, shape=[5]), + constant_op.constant(1, shape=[5]) + ], 0) + labels = array_ops.concat([ + constant_op.constant(0, shape=[3]), + constant_op.constant(1, shape=[7]) + ], 0) num_classes = 2 with self.test_session() as sess: miou, update_op = metrics.streaming_mean_iou(predictions, labels, @@ -6238,29 +6364,20 @@ class StreamingMeanIOUTest(test.TestCase): self.assertEqual(0., miou.eval()) def testResultsWithSomeMissing(self): - predictions = array_ops.concat( - [ - constant_op.constant( - 0, shape=[5]), constant_op.constant( - 1, shape=[5]) - ], - 0) - labels = array_ops.concat( - [ - constant_op.constant( - 0, shape=[3]), constant_op.constant( - 1, shape=[7]) - ], - 0) + predictions = array_ops.concat([ + constant_op.constant(0, shape=[5]), + constant_op.constant(1, shape=[5]) + ], 0) + labels = array_ops.concat([ + constant_op.constant(0, shape=[3]), + constant_op.constant(1, shape=[7]) + ], 0) num_classes = 2 - weights = array_ops.concat( - [ - constant_op.constant( - 0, shape=[1]), constant_op.constant( - 1, shape=[8]), constant_op.constant( - 0, shape=[1]) - ], - 0) + weights = array_ops.concat([ + constant_op.constant(0, shape=[1]), + constant_op.constant(1, shape=[8]), + constant_op.constant(0, shape=[1]) + ], 0) with self.test_session() as sess: miou, update_op = metrics.streaming_mean_iou( predictions, labels, num_classes, weights=weights) @@ -6270,56 +6387,45 @@ class StreamingMeanIOUTest(test.TestCase): self.assertAlmostEqual(desired_miou, miou.eval()) def testMissingClassInLabels(self): - labels = constant_op.constant([ - [[0, 0, 1, 1, 0, 0], - [1, 0, 0, 0, 0, 1]], - [[1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0]]]) - predictions = constant_op.constant([ - [[0, 0, 2, 1, 1, 0], - [0, 1, 2, 2, 0, 1]], - [[0, 0, 2, 1, 1, 1], - [1, 1, 2, 0, 0, 0]]]) + labels = constant_op.constant([[[0, 0, 1, 1, 0, 0], [1, 0, 0, 0, 0, 1]], + [[1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0]]]) + predictions = constant_op.constant( + [[[0, 0, 2, 1, 1, 0], [0, 1, 2, 2, 0, 1]], [[0, 0, 2, 1, 1, 1], + [1, 1, 2, 0, 0, 0]]]) num_classes = 3 with self.test_session() as sess: - miou, update_op = metrics.streaming_mean_iou( - predictions, labels, num_classes) + miou, update_op = metrics.streaming_mean_iou(predictions, labels, + num_classes) sess.run(variables.local_variables_initializer()) self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval()) - self.assertAlmostEqual( - 1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)), - miou.eval()) + self.assertAlmostEqual(1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / + (0 + 5 + 0)), miou.eval()) def testMissingClassOverallSmall(self): labels = constant_op.constant([0]) predictions = constant_op.constant([0]) num_classes = 2 with self.test_session() as sess: - miou, update_op = metrics.streaming_mean_iou( - predictions, labels, num_classes) + miou, update_op = metrics.streaming_mean_iou(predictions, labels, + num_classes) sess.run(variables.local_variables_initializer()) self.assertAllEqual([[1, 0], [0, 0]], update_op.eval()) self.assertAlmostEqual(1, miou.eval()) def testMissingClassOverallLarge(self): - labels = constant_op.constant([ - [[0, 0, 1, 1, 0, 0], - [1, 0, 0, 0, 0, 1]], - [[1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0]]]) - predictions = constant_op.constant([ - [[0, 0, 1, 1, 0, 0], - [1, 1, 0, 0, 1, 1]], - [[0, 0, 0, 1, 1, 1], - [1, 1, 1, 0, 0, 0]]]) + labels = constant_op.constant([[[0, 0, 1, 1, 0, 0], [1, 0, 0, 0, 0, 1]], + [[1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0]]]) + predictions = constant_op.constant( + [[[0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1]], [[0, 0, 0, 1, 1, 1], + [1, 1, 1, 0, 0, 0]]]) num_classes = 3 with self.test_session() as sess: - miou, update_op = metrics.streaming_mean_iou( - predictions, labels, num_classes) + miou, update_op = metrics.streaming_mean_iou(predictions, labels, + num_classes) sess.run(variables.local_variables_initializer()) self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval()) - self.assertAlmostEqual( - 1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), miou.eval()) + self.assertAlmostEqual(1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), + miou.eval()) class StreamingConcatTest(test.TestCase): @@ -6683,7 +6789,8 @@ class CohenKappaTest(test.TestCase): _assert_metric_variables(self, ( 'cohen_kappa/po:0', 'cohen_kappa/pe_row:0', - 'cohen_kappa/pe_col:0',)) + 'cohen_kappa/pe_col:0', + )) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -6705,9 +6812,9 @@ class CohenKappaTest(test.TestCase): def testValueTensorIsIdempotent(self): predictions = random_ops.random_uniform( - (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=1) + (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2) + (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2) kappa, update_op = metrics.cohen_kappa(labels, predictions, 3) with self.test_session() as sess: @@ -6723,10 +6830,7 @@ class CohenKappaTest(test.TestCase): self.assertAlmostEqual(initial_kappa, kappa.eval(), 5) def testBasic(self): - confusion_matrix = np.array([ - [9, 3, 1], - [4, 8, 2], - [2, 1, 6]]) + confusion_matrix = np.array([[9, 3, 1], [4, 8, 2], [2, 1, 6]]) # overall total = 36 # po = [9, 8, 6], sum(po) = 23 # pe_row = [15, 12, 9], pe_col = [13, 14, 9], so pe = [5.42, 4.67, 2.25] @@ -6738,8 +6842,10 @@ class CohenKappaTest(test.TestCase): labels, predictions = self._confusion_matrix_to_samples(confusion_matrix) dtypes = [dtypes_lib.int16, dtypes_lib.int32, dtypes_lib.int64] - shapes = [(len(labels,)), # 1-dim - (len(labels), 1)] # 2-dim + shapes = [ + (len(labels,)), # 1-dim + (len(labels), 1) + ] # 2-dim weights = [None, np.ones_like(labels)] for dtype in dtypes: @@ -6782,7 +6888,8 @@ class CohenKappaTest(test.TestCase): # [[0, 25, 0], # [0, 0, 25], # [25, 0, 0]] - # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions) + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score( + # labels, predictions) expect = -0.333333333333 with self.test_session() as sess: @@ -6795,10 +6902,7 @@ class CohenKappaTest(test.TestCase): self.assertAlmostEqual(expect, kappa.eval(), 5) def testWeighted(self): - confusion_matrix = np.array([ - [9, 3, 1], - [4, 8, 2], - [2, 1, 6]]) + confusion_matrix = np.array([[9, 3, 1], [4, 8, 2], [2, 1, 6]]) labels, predictions = self._confusion_matrix_to_samples(confusion_matrix) num_samples = np.sum(confusion_matrix, dtype=np.int32) weights = (np.arange(0, num_samples) % 5) / 5.0 @@ -6809,31 +6913,26 @@ class CohenKappaTest(test.TestCase): with self.test_session() as sess: predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32) labels = constant_op.constant(labels) - kappa, update_op = metrics.cohen_kappa(labels, predictions, 4, - weights=weights) + kappa, update_op = metrics.cohen_kappa( + labels, predictions, 4, weights=weights) sess.run(variables.local_variables_initializer()) self.assertAlmostEqual(expect, sess.run(update_op), 5) self.assertAlmostEqual(expect, kappa.eval(), 5) def testWithMultipleUpdates(self): - confusion_matrix = np.array([ - [90, 30, 10, 20], - [40, 80, 20, 30], - [20, 10, 60, 35], - [15, 25, 30, 25]]) + confusion_matrix = np.array([[90, 30, 10, 20], [40, 80, 20, 30], + [20, 10, 60, 35], [15, 25, 30, 25]]) labels, predictions = self._confusion_matrix_to_samples(confusion_matrix) num_samples = np.sum(confusion_matrix, dtype=np.int32) weights = (np.arange(0, num_samples) % 5) / 5.0 num_classes = confusion_matrix.shape[0] batch_size = num_samples // 10 - predictions_t = array_ops.placeholder(dtypes_lib.float32, - shape=(batch_size,)) - labels_t = array_ops.placeholder(dtypes_lib.int32, - shape=(batch_size,)) - weights_t = array_ops.placeholder(dtypes_lib.float32, - shape=(batch_size,)) + predictions_t = array_ops.placeholder( + dtypes_lib.float32, shape=(batch_size,)) + labels_t = array_ops.placeholder(dtypes_lib.int32, shape=(batch_size,)) + weights_t = array_ops.placeholder(dtypes_lib.float32, shape=(batch_size,)) kappa, update_op = metrics.cohen_kappa( labels_t, predictions_t, num_classes, weights=weights_t) with self.test_session() as sess: @@ -6841,12 +6940,16 @@ class CohenKappaTest(test.TestCase): for idx in range(0, num_samples, batch_size): batch_start, batch_end = idx, idx + batch_size - sess.run(update_op, - feed_dict={labels_t: labels[batch_start:batch_end], - predictions_t: predictions[batch_start:batch_end], - weights_t: weights[batch_start:batch_end]}) + sess.run( + update_op, + feed_dict={ + labels_t: labels[batch_start:batch_end], + predictions_t: predictions[batch_start:batch_end], + weights_t: weights[batch_start:batch_end] + }) # Calculated by v0.19: sklearn.metrics.cohen_kappa_score( - # labels_np, predictions_np, sample_weight=weights_np) + # labels_np, predictions_np, + # sample_weight=weights_np) expect = 0.289965397924 self.assertAlmostEqual(expect, kappa.eval(), 5) @@ -6862,7 +6965,8 @@ class CohenKappaTest(test.TestCase): with self.assertRaises(ValueError): metrics.cohen_kappa(invalid_labels, predictions, 3) - invalid_predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 2)) + invalid_predictions = array_ops.placeholder( + dtypes_lib.float32, shape=(4, 2)) labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1)) with self.assertRaises(ValueError): metrics.cohen_kappa(labels, invalid_predictions, 3) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py index d07fece4bc668612d517e8dcaab1a35451a0238e..6a3b535eb447dd80f8e39d1d005f8f1d4f503549 100644 --- a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py @@ -58,6 +58,7 @@ def read_cifar10(filename_queue): class CIFAR10Record(object): pass + result = CIFAR10Record() # Dimensions of the images in the CIFAR-10 dataset. @@ -147,8 +148,9 @@ def distorted_inputs(data_dir, batch_size): images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. """ - filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) - for i in xrange(1, 6)] + filenames = [ + os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6) + ] for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) @@ -174,10 +176,9 @@ def distorted_inputs(data_dir, batch_size): # Because these operations are not commutative, consider randomizing # the order their operation. - distorted_image = tf.image.random_brightness(distorted_image, - max_delta=63) - distorted_image = tf.image.random_contrast(distorted_image, - lower=0.2, upper=1.8) + distorted_image = tf.image.random_brightness(distorted_image, max_delta=63) + distorted_image = tf.image.random_contrast( + distorted_image, lower=0.2, upper=1.8) # Subtract off the mean and divide by the variance of the pixels. float_image = tf.image.per_image_standardization(distorted_image) @@ -188,15 +189,18 @@ def distorted_inputs(data_dir, batch_size): # Ensure that the random shuffling has good mixing properties. min_fraction_of_examples_in_queue = 0.4 - min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * - min_fraction_of_examples_in_queue) - print ('Filling queue with %d CIFAR images before starting to train. ' - 'This will take a few minutes.' % min_queue_examples) + min_queue_examples = int( + NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue) + print('Filling queue with %d CIFAR images before starting to train. ' + 'This will take a few minutes.' % min_queue_examples) # Generate a batch of images and labels by building up a queue of examples. - return _generate_image_and_label_batch(float_image, read_input.label, - min_queue_examples, batch_size, - shuffle=True) + return _generate_image_and_label_batch( + float_image, + read_input.label, + min_queue_examples, + batch_size, + shuffle=True) def inputs(eval_data, data_dir, batch_size): @@ -212,8 +216,9 @@ def inputs(eval_data, data_dir, batch_size): labels: Labels. 1D tensor of [batch_size] size. """ if not eval_data: - filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) - for i in xrange(1, 6)] + filenames = [ + os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6) + ] num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN else: filenames = [os.path.join(data_dir, 'test_batch.bin')] @@ -235,8 +240,8 @@ def inputs(eval_data, data_dir, batch_size): # Image processing for evaluation. # Crop the central [height, width] of the image. - resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, - width, height) + resized_image = tf.image.resize_image_with_crop_or_pad( + reshaped_image, width, height) # Subtract off the mean and divide by the variance of the pixels. float_image = tf.image.per_image_standardization(resized_image) @@ -247,10 +252,13 @@ def inputs(eval_data, data_dir, batch_size): # Ensure that the random shuffling has good mixing properties. min_fraction_of_examples_in_queue = 0.4 - min_queue_examples = int(num_examples_per_epoch * - min_fraction_of_examples_in_queue) + min_queue_examples = int( + num_examples_per_epoch * min_fraction_of_examples_in_queue) # Generate a batch of images and labels by building up a queue of examples. - return _generate_image_and_label_batch(float_image, read_input.label, - min_queue_examples, batch_size, - shuffle=False) + return _generate_image_and_label_batch( + float_image, + read_input.label, + min_queue_examples, + batch_size, + shuffle=False) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py index 0d1de869f6ef91791a235cfe545b3b3a9b734e72..660f0168b10aa1e5b320cb476b051918804d2bde 100644 --- a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py @@ -48,16 +48,16 @@ from tensorflow.contrib.model_pruning.python import pruning # Global constants describing the CIFAR-10 data set. IMAGE_SIZE = cifar10_input.IMAGE_SIZE NUM_CLASSES = cifar10_input.NUM_CLASSES -NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN # pylint: disable=line-too-long NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL BATCH_SIZE = 128 DATA_DIR = '/tmp/cifar10_data' # Constants describing the training process. -MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. -NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays. +MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. +NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays. LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor. -INITIAL_LEARNING_RATE = 0.1 # Initial learning rate. +INITIAL_LEARNING_RATE = 0.1 # Initial learning rate. # If a model is trained with multiple GPUs, prefix all Op names with tower_name # to differentiate the operations. Note that this prefix is removed from the @@ -82,8 +82,7 @@ def _activation_summary(x): # session. This helps the clarity of presentation on tensorboard. tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name) tf.summary.histogram(tensor_name + '/activations', x) - tf.summary.scalar(tensor_name + '/sparsity', - tf.nn.zero_fraction(x)) + tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) def _variable_on_cpu(name, shape, initializer): @@ -120,10 +119,9 @@ def _variable_with_weight_decay(name, shape, stddev, wd): Variable Tensor """ dtype = tf.float32 - var = _variable_on_cpu( - name, - shape, - tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) + var = _variable_on_cpu(name, shape, + tf.truncated_normal_initializer( + stddev=stddev, dtype=dtype)) if wd is not None: weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') tf.add_to_collection('losses', weight_decay) @@ -188,10 +186,8 @@ def inference(images): # Note that the masks are applied only to the weight tensors # conv1 with tf.variable_scope('conv1') as scope: - kernel = _variable_with_weight_decay('weights', - shape=[5, 5, 3, 64], - stddev=5e-2, - wd=0.0) + kernel = _variable_with_weight_decay( + 'weights', shape=[5, 5, 3, 64], stddev=5e-2, wd=0.0) conv = tf.nn.conv2d( images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME') @@ -201,18 +197,20 @@ def inference(images): _activation_summary(conv1) # pool1 - pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], - padding='SAME', name='pool1') + pool1 = tf.nn.max_pool( + conv1, + ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], + padding='SAME', + name='pool1') # norm1 - norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, - name='norm1') + norm1 = tf.nn.lrn( + pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1') # conv2 with tf.variable_scope('conv2') as scope: - kernel = _variable_with_weight_decay('weights', - shape=[5, 5, 64, 64], - stddev=5e-2, - wd=0.0) + kernel = _variable_with_weight_decay( + 'weights', shape=[5, 5, 64, 64], stddev=5e-2, wd=0.0) conv = tf.nn.conv2d( norm1, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME') biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1)) @@ -221,19 +219,23 @@ def inference(images): _activation_summary(conv2) # norm2 - norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, - name='norm2') + norm2 = tf.nn.lrn( + conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2') # pool2 - pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], - strides=[1, 2, 2, 1], padding='SAME', name='pool2') + pool2 = tf.nn.max_pool( + norm2, + ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], + padding='SAME', + name='pool2') # local3 with tf.variable_scope('local3') as scope: # Move everything into depth so we can perform a single matrix multiply. reshape = tf.reshape(pool2, [BATCH_SIZE, -1]) dim = reshape.get_shape()[1].value - weights = _variable_with_weight_decay('weights', shape=[dim, 384], - stddev=0.04, wd=0.004) + weights = _variable_with_weight_decay( + 'weights', shape=[dim, 384], stddev=0.04, wd=0.004) biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1)) local3 = tf.nn.relu( tf.matmul(reshape, pruning.apply_mask(weights, scope)) + biases, @@ -242,8 +244,8 @@ def inference(images): # local4 with tf.variable_scope('local4') as scope: - weights = _variable_with_weight_decay('weights', shape=[384, 192], - stddev=0.04, wd=0.004) + weights = _variable_with_weight_decay( + 'weights', shape=[384, 192], stddev=0.04, wd=0.004) biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) local4 = tf.nn.relu( tf.matmul(local3, pruning.apply_mask(weights, scope)) + biases, @@ -255,8 +257,8 @@ def inference(images): # tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits # and performs the softmax internally for efficiency. with tf.variable_scope('softmax_linear') as scope: - weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES], - stddev=1/192.0, wd=0.0) + weights = _variable_with_weight_decay( + 'weights', [192, NUM_CLASSES], stddev=1 / 192.0, wd=0.0) biases = _variable_on_cpu('biases', [NUM_CLASSES], tf.constant_initializer(0.0)) softmax_linear = tf.add( @@ -337,11 +339,12 @@ def train(total_loss, global_step): decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) # Decay the learning rate exponentially based on the number of steps. - lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, - global_step, - decay_steps, - LEARNING_RATE_DECAY_FACTOR, - staircase=True) + lr = tf.train.exponential_decay( + INITIAL_LEARNING_RATE, + global_step, + decay_steps, + LEARNING_RATE_DECAY_FACTOR, + staircase=True) tf.summary.scalar('learning_rate', lr) # Generate moving averages of all losses and associated summaries. @@ -365,8 +368,8 @@ def train(total_loss, global_step): tf.summary.histogram(var.op.name + '/gradients', grad) # Track the moving averages of all trainable variables. - variable_averages = tf.train.ExponentialMovingAverage( - MOVING_AVERAGE_DECAY, global_step) + variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, + global_step) variables_averages_op = variable_averages.apply(tf.trainable_variables()) with tf.control_dependencies([apply_gradient_op, variables_averages_op]): @@ -383,10 +386,13 @@ def maybe_download_and_extract(): filename = DATA_URL.split('/')[-1] filepath = os.path.join(dest_directory, filename) if not os.path.exists(filepath): + def _progress(count, block_size, total_size): - sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, - float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.write('\r>> Downloading %s %.1f%%' % + (filename, + float(count * block_size) / float(total_size) * 100.0)) sys.stdout.flush() + filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) print() statinfo = os.stat(filepath) diff --git a/tensorflow/contrib/model_pruning/python/layers/layers.py b/tensorflow/contrib/model_pruning/python/layers/layers.py index dfebb9a6794056dd43b0699ccbcc5797f2f172f7..988748ad75bdf72f1da3f4e1c6e85aabb04a5954 100644 --- a/tensorflow/contrib/model_pruning/python/layers/layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/layers.py @@ -21,7 +21,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np import six from tensorflow.contrib.framework.python.ops import add_arg_scope diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc index 8d14a3ef0404e727c47ad2ab39a69838fe1588aa..6a7f5efecdb4062874a09df227d139ad20d59f3f 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc @@ -24,11 +24,11 @@ limitations under the License. #include #include -#include "tensorflow/core/distributed_runtime/tensor_coding.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" +#include "tensorflow/core/distributed_runtime/tensor_coding.h" namespace tensorflow { @@ -62,7 +62,6 @@ BaseRemoteRendezvous* MPIRendezvousMgr::Create(int64 step_id, void MPIRemoteRendezvous::RecvFromRemoteAsync( const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { - Status s = Status::OK(); MPIRequestTensorCall* rendezvous_call = new MPIRequestTensorCall(); @@ -103,37 +102,37 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync( // Create the function which is called when the Tensor is send by remote const int64 temp1 = step_id_; rendezvous_call->recv_call_ = - [this, parsed, recv_args, done, dst, temp1, rendezvous_call]( - MPIRecvTensorResponse mpi_response) { - Status s; - Device* dst_device; - if (s.ok()) { - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); - CHECK(s.ok()) << "Device lookup failed"; - } - - VLOG(3) << "MPI Received tensor " << parsed.FullKey() - << " @ step: " << temp1 - << " single-send: " << mpi_response.singlesend(); - - Tensor val; - if (mpi_response.singlesend()) { - dst_device->MakeTensorFromProto(mpi_response.response().tensor(), - recv_args.alloc_attrs, &val); - } else { - TensorResponse tr; - tr.InitAlloc(dst_device, recv_args.alloc_attrs); - tr.InitPartial(mpi_response.response()); - const size_t nBytes = tr.tensor().TotalBytes(); - void* data = const_cast(DMAHelper::base(&tr.tensor())); - MPI_Status status; - MPI_CHECK(MPI_Recv(data, static_cast(nBytes), MPI_BYTE, dst, - TAG_SENDTENSOR2, MPI_COMM_WORLD, &status)); - val = std::move(tr.tensor()); - } - - done(s, Args(), recv_args, val, mpi_response.response().is_dead()); - }; + [this, parsed, recv_args, done, dst, temp1, + rendezvous_call](MPIRecvTensorResponse mpi_response) { + Status s; + Device* dst_device; + if (s.ok()) { + s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + CHECK(s.ok()) << "Device lookup failed"; + } + + VLOG(3) << "MPI Received tensor " << parsed.FullKey() + << " @ step: " << temp1 + << " single-send: " << mpi_response.singlesend(); + + Tensor val; + if (mpi_response.singlesend()) { + dst_device->MakeTensorFromProto(mpi_response.response().tensor(), + recv_args.alloc_attrs, &val); + } else { + TensorResponse tr; + tr.InitAlloc(dst_device, recv_args.alloc_attrs); + tr.InitPartial(mpi_response.response()); + const size_t nBytes = tr.tensor().TotalBytes(); + void* data = const_cast(DMAHelper::base(&tr.tensor())); + MPI_Status status; + MPI_CHECK(MPI_Recv(data, static_cast(nBytes), MPI_BYTE, dst, + TAG_SENDTENSOR2, MPI_COMM_WORLD, &status)); + val = std::move(tr.tensor()); + } + + done(s, Args(), recv_args, val, mpi_response.response().is_dead()); + }; MPIRendezvousMgr* mgr = reinterpret_cast(this->rendezvous_mgr_); @@ -152,16 +151,18 @@ MPIRemoteRendezvous::~MPIRemoteRendezvous() {} void MPIRendezvousMgr::AddRequest(RecvTensorRequest request, const int mpi_dst) { TF_CHECK_OK(recv_tensor_recent_request_ids_.TrackUnique( - req.request_id(), "RecvTensor (MPIRendezvousMgr)", req)); + request.request_id(), "RecvTensor (MPIRendezvousMgr)", request)); const int64 step_id = request.step_id(); const std::string& key = request.rendezvous_key(); Rendezvous::ParsedKey parsed; TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed)); MPIRecvTensorCallBack send_cb = [this, mpi_dst, parsed]( - const Status& status, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead, - MPISendTensorCall* mpi_send_call) { + const Status& status, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& val, bool is_dead, + MPISendTensorCall* mpi_send_call) { // TODO(jbedorf) this should be a loop over max size CHECK(mpi_send_call->mRes_.ByteSize() < INT_MAX) << "Buffer too large for single transfer"; @@ -194,74 +195,78 @@ void MPIRendezvousMgr::AddRequest(RecvTensorRequest request, }; // Wrapper around the read callback to place the callback on our queue - Rendezvous::DoneCallback done_cb = [this, parsed, step_id, send_cb]( - const Status& status, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) { - if (!status.ok()) { - CHECK(status.ok()) << "RecvLocalAsync was not ok, key: " - << parsed.FullKey() << " step: " << step_id - << " error message: " << status.error_message(); - return; - } - - VLOG(3) << "MPI Sending tensor " << parsed.FullKey() - << " @ step: " << step_id << std::endl; - - auto mpi_send_call = new MPISendTensorCall(); - mpi_send_call->Init(parsed, step_id, is_dead); - - Device* src_dev = nullptr; - Status s = this->worker_env_2->device_mgr->LookupDevice(parsed.src_device, - &src_dev); - CHECK(s.ok()) << "src device not found"; - - // Control if shape and data should be send together or if we can optimize - // it in two different transfers, thereby reducing memory copies - bool doOptimalTransfer = true; - if (!DataTypeCanUseMemcpy(val.dtype())) doOptimalTransfer = false; - if (val.TotalBytes() < 1024) doOptimalTransfer = false; - - doOptimalTransfer = doOptimalTransfer && use_optimal_transfer_; - - if (doOptimalTransfer) { - // First send the Tensor description and in a follow up transfer the data - mpi_send_call->mRes_.mutable_response()->mutable_tensor()->set_dtype( - val.dtype()); - val.shape().AsProto(mpi_send_call->mRes_.mutable_response() - ->mutable_tensor() - ->mutable_tensor_shape()); - mpi_send_call->mRes_.set_singlesend(false); - } else { - // Send the Tensor description and data in a single transfer - if (src_dev->tensorflow_gpu_device_info() && - (!send_args.alloc_attrs.on_host())) { - Notification n; - GPUUtil::SetProtoFromGPU( - val, src_dev, send_args.device_context, - mpi_send_call->mRes_.mutable_response()->mutable_tensor(), is_dead, - [&n, &s](const Status& s_) { - s = s_; - n.Notify(); - }); - n.WaitForNotification(); - } else { - val.AsProtoTensorContent( - mpi_send_call->mRes_.mutable_response()->mutable_tensor()); - } - } - - std::function res = std::bind( - send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call); - - SendQueueEntry req(parsed.FullKey().ToString().c_str(), std::move(res)); - - this->QueueSendRequest(req); - - // Wait for the notification that indicates the tensor has been - // successfully transmitted to the remote process. Only needed if we - // have not parsed the tensor to proto - if (doOptimalTransfer) mpi_send_call->n_.WaitForNotification(); - }; // done_cb + Rendezvous::DoneCallback done_cb = + [this, parsed, step_id, send_cb]( + const Status& status, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) { + if (!status.ok()) { + CHECK(status.ok()) + << "RecvLocalAsync was not ok, key: " << parsed.FullKey() + << " step: " << step_id + << " error message: " << status.error_message(); + return; + } + + VLOG(3) << "MPI Sending tensor " << parsed.FullKey() + << " @ step: " << step_id << std::endl; + + auto mpi_send_call = new MPISendTensorCall(); + mpi_send_call->Init(parsed, step_id, is_dead); + + Device* src_dev = nullptr; + Status s = this->worker_env_2->device_mgr->LookupDevice( + parsed.src_device, &src_dev); + CHECK(s.ok()) << "src device not found"; + + // Control if shape and data should be send together or if we can + // optimize it in two different transfers, thereby reducing memory + // copies + bool doOptimalTransfer = true; + if (!DataTypeCanUseMemcpy(val.dtype())) doOptimalTransfer = false; + if (val.TotalBytes() < 1024) doOptimalTransfer = false; + + doOptimalTransfer = doOptimalTransfer && use_optimal_transfer_; + + if (doOptimalTransfer) { + // First send the Tensor description and in a follow up transfer the + // data + mpi_send_call->mRes_.mutable_response()->mutable_tensor()->set_dtype( + val.dtype()); + val.shape().AsProto(mpi_send_call->mRes_.mutable_response() + ->mutable_tensor() + ->mutable_tensor_shape()); + mpi_send_call->mRes_.set_singlesend(false); + } else { + // Send the Tensor description and data in a single transfer + if (src_dev->tensorflow_gpu_device_info() && + (!send_args.alloc_attrs.on_host())) { + Notification n; + GPUUtil::SetProtoFromGPU( + val, src_dev, send_args.device_context, + mpi_send_call->mRes_.mutable_response()->mutable_tensor(), + is_dead, [&n, &s](const Status& s_) { + s = s_; + n.Notify(); + }); + n.WaitForNotification(); + } else { + val.AsProtoTensorContent( + mpi_send_call->mRes_.mutable_response()->mutable_tensor()); + } + } + + std::function res = std::bind( + send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call); + + SendQueueEntry req(parsed.FullKey().ToString().c_str(), std::move(res)); + + this->QueueSendRequest(req); + + // Wait for the notification that indicates the tensor has been + // successfully transmitted to the remote process. Only needed if we + // have not parsed the tensor to proto + if (doOptimalTransfer) mpi_send_call->n_.WaitForNotification(); + }; // done_cb worker_env_2->compute_pool->Schedule([this, step_id, parsed, done_cb]() { this->RecvLocalAsync(step_id, parsed, done_cb); @@ -293,9 +298,8 @@ void MPIRendezvousMgr::MPIBackgroundThread() { } // Remove sends that have been completed - active_sends.remove_if([](std::unique_ptr& i) { - return i->IsFinished(); - }); + active_sends.remove_if( + [](std::unique_ptr& i) { return i->IsFinished(); }); // send a Tensor request RequestQueueEntry req; diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h index ca42ee2f6d246f67f5c4c668fe27b16722bc6130..5596601ddb9846c0e4f5be4bf33114fc19c0a59d 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h @@ -18,12 +18,12 @@ limitations under the License. #ifdef TENSORFLOW_USE_MPI -#include -#include #include -#include -#include #include +#include +#include +#include +#include #include #include #include @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/contrib/mpi/mpi_msg.pb.h" #include "tensorflow/contrib/mpi/mpi_utils.h" #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/recent_request_ids.h" #include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/protobuf/worker.pb.h" @@ -160,7 +161,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr { private: typedef std::function MPIRecvTensorCallBack; + const Tensor&, const bool, MPISendTensorCall*)> + MPIRecvTensorCallBack; typedef std::pair> RequestQueueEntry; typedef std::pair> diff --git a/tensorflow/contrib/mpi/mpi_server_lib.cc b/tensorflow/contrib/mpi/mpi_server_lib.cc index d585c0565eb234655e7a1bbc92df5741e18c8f33..a31fa9ce0b3110d875689d74a41ca9f9cc85f532 100644 --- a/tensorflow/contrib/mpi/mpi_server_lib.cc +++ b/tensorflow/contrib/mpi/mpi_server_lib.cc @@ -22,8 +22,8 @@ limitations under the License. #include "grpc/support/alloc.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/contrib/mpi/mpi_utils.h b/tensorflow/contrib/mpi/mpi_utils.h index 45e21f2b25ab4897641ffec776eb1b3c32ab9a2e..fa297c28cb47d43ba927ab941854bd472d90b465 100644 --- a/tensorflow/contrib/mpi/mpi_utils.h +++ b/tensorflow/contrib/mpi/mpi_utils.h @@ -18,8 +18,8 @@ limitations under the License. #ifdef TENSORFLOW_USE_MPI -#include #include +#include #include #include "tensorflow/core/lib/strings/str_util.h" diff --git a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc index 2d5b98022c3aafb627e986a2764ee60184014945..8dca90a1e34d6a234c2b1479ca5594e88afcc194 100644 --- a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc @@ -35,8 +35,8 @@ limitations under the License. #define OMPI_SKIP_MPICXX #include "third_party/mpi/mpi.h" -#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h" #include "tensorflow/contrib/mpi_collectives/kernels/ring.h" +#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h" /* * MPI Allreduce and Allgather Ops for TensorFlow. diff --git a/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py b/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py index f0a116239d6f4f7271c2a8f68806ff1ccaae80ae..2fbefef0d36f6a1507827427ebbafe5e81e35ea3 100644 --- a/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py +++ b/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py @@ -26,7 +26,8 @@ from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader _mpi_ops_so = loader.load_op_library( - resource_loader.get_path_to_datafile("_mpi_ops.so")) + resource_loader.get_path_to_datafile('_mpi_ops.so')) + def size(name=None): """An op which returns the number of MPI processes. @@ -120,15 +121,14 @@ def allgather(tensor, name=None): """ # Specify that first allgather is to collect the tensor gather sizes, # indicated by passing in a scalar (0-D tensor) of value 0 - sizes_flag = tf.constant(0, dtype=tf.int64, name="size_flag_const") - my_size = tf.slice(tf.shape(tensor, out_type=tf.int64), [0], [1], name="size_slice") + sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const') + my_size = tf.slice( + tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice') if name is None: - name = "allgather" - sizing_name = "{}_sizing".format(name) + name = 'allgather' + sizing_name = '{}_sizing'.format(name) sizes = gen_mpi_ops.mpi_allgather(my_size, sizes_flag, name=sizing_name) return gen_mpi_ops.mpi_allgather(tensor, sizes, name=name) ops.NotDifferentiable('MPIAllgather') - - diff --git a/tensorflow/contrib/ndlstm/BUILD b/tensorflow/contrib/ndlstm/BUILD deleted file mode 100644 index 8403f841884d4640ce8156ff4db46868dbe1788c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/ndlstm/BUILD +++ /dev/null @@ -1,92 +0,0 @@ -# Description: -# Contains classes implementing 1D and 2D LSTMs for image and signal -# processing problems. - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -package(default_visibility = ["//tensorflow:__subpackages__"]) - -load("//tensorflow:tensorflow.bzl", "tf_py_test") - -py_library( - name = "ndlstm", - srcs = [ - "__init__.py", - "python/__init__.py", - "python/lstm1d.py", - "python/lstm2d.py", - "python/misc.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/rnn:rnn_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - "//tensorflow/python:ops", - "//tensorflow/python:platform", - "//tensorflow/python:random_ops", - "//tensorflow/python:rnn", - "//tensorflow/python:rnn_cell", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - ], -) - -tf_py_test( - name = "lstm1d_test", - srcs = ["python/lstm1d_test.py"], - additional_deps = [ - ":ndlstm", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:variables", - ], -) - -tf_py_test( - name = "lstm2d_test", - srcs = ["python/lstm2d_test.py"], - additional_deps = [ - ":ndlstm", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:variables", - ], -) - -tf_py_test( - name = "misc_test", - srcs = ["python/misc_test.py"], - additional_deps = [ - ":ndlstm", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:variables", - ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/ndlstm/README.md b/tensorflow/contrib/ndlstm/README.md deleted file mode 100644 index 7ccb57f1b34a24af7d776f7dbb12a2a00bb5ca30..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/ndlstm/README.md +++ /dev/null @@ -1,31 +0,0 @@ -Library of multidimensional LSTM models and related code. - -# 2D LSTM code - -The 2D LSTM layers take tensors of the form (batch_size, height, width, -depth), compatible with convolutional layers, as inputs. The library -transposes and reshapes these tensors in a way that allows batches of -images to be processed by LSTMs. - -The library currently provides: - - - a separable 2D LSTM layer - - a simple 2D convolutional layer that can be swapped out against 2D LSTM - - layers to reduce images to sequences and images to final state vectors - - layers for sequence classification, pixel-wise classification - -# Other Dimensions - -There is 1D LSTM code in `lstm1d.py`. This code implements 1D LSTM versions -suitable as a basis for higher dimensional LSTMs. It is intended for constant -batch size and uses a different layout. Although the code is perfectly fine for -1D use, you may find other 1D LSTM implementations to be more convenient if you -are interested in sequence problems. - -# Upcoming Changes - - - PyramidLSTM - - support for 3D and 4D - - optional use of native fused LSTM op - - easy-to-use command line drivers and examples - - operators for patch-wise processing diff --git a/tensorflow/contrib/ndlstm/python/lstm1d.py b/tensorflow/contrib/ndlstm/python/lstm1d.py deleted file mode 100644 index b24e332e4aea7f0ef981909558dcd6d730ca08a7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/ndlstm/python/lstm1d.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""LSTM layers for sequences.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.contrib.framework.python.ops import variables -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import rnn -from tensorflow.python.ops import rnn_cell -from tensorflow.python.ops import variable_scope - - -def _shape(tensor): - return tensor.get_shape().as_list() - - -def ndlstm_base_unrolled(inputs, noutput, scope=None, reverse=False): - """Run an LSTM, either forward or backward. - - This is a 1D LSTM implementation using unrolling and the TensorFlow - LSTM op. - - Args: - inputs: input sequence (length, batch_size, ninput) - noutput: depth of output - scope: optional scope name - reverse: run LSTM in reverse - - Returns: - Output sequence (length, batch_size, noutput) - - """ - with variable_scope.variable_scope(scope, "SeqLstmUnrolled", [inputs]): - length, batch_size, _ = _shape(inputs) - lstm_cell = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False) - state = array_ops.zeros([batch_size, lstm_cell.state_size]) - output_u = [] - inputs_u = array_ops.unstack(inputs) - if reverse: - inputs_u = list(reversed(inputs_u)) - for i in xrange(length): - if i > 0: - variable_scope.get_variable_scope().reuse_variables() - output, state = lstm_cell(inputs_u[i], state) - output_u += [output] - if reverse: - output_u = list(reversed(output_u)) - outputs = array_ops.stack(output_u) - return outputs - - -def ndlstm_base_dynamic(inputs, noutput, scope=None, reverse=False): - """Run an LSTM, either forward or backward. - - This is a 1D LSTM implementation using dynamic_rnn and - the TensorFlow LSTM op. - - Args: - inputs: input sequence (length, batch_size, ninput) - noutput: depth of output - scope: optional scope name - reverse: run LSTM in reverse - - Returns: - Output sequence (length, batch_size, noutput) - """ - with variable_scope.variable_scope(scope, "SeqLstm", [inputs]): - lstm_cell = rnn_cell.BasicLSTMCell(noutput) - if reverse: - inputs = array_ops.reverse_v2(inputs, [0]) - outputs, _ = rnn.dynamic_rnn( - lstm_cell, inputs, time_major=True, dtype=inputs.dtype) - if reverse: - outputs = array_ops.reverse_v2(outputs, [0]) - return outputs - - -def ndlstm_base(inputs, noutput, scope=None, reverse=False, dynamic=True): - """Implements a 1D LSTM, either forward or backward. - - This is a base case for multidimensional LSTM implementations, which - tend to be used differently from sequence-to-sequence - implementations. For general 1D sequence to sequence - transformations, you may want to consider another implementation - from TF slim. - - Args: - inputs: input sequence (length, batch_size, ninput) - noutput: depth of output - scope: optional scope name - reverse: run LSTM in reverse - dynamic: use dynamic_rnn - - Returns: - Output sequence (length, batch_size, noutput) - - """ - # TODO(tmb) maybe add option for other LSTM implementations, like - # slim.rnn.basic_lstm_cell - if dynamic: - return ndlstm_base_dynamic(inputs, noutput, scope=scope, reverse=reverse) - else: - return ndlstm_base_unrolled(inputs, noutput, scope=scope, reverse=reverse) - - -def sequence_to_final(inputs, noutput, scope=None, name=None, reverse=False): - """Run an LSTM across all steps and returns only the final state. - - Args: - inputs: (length, batch_size, depth) tensor - noutput: size of output vector - scope: optional scope name - name: optional name for output tensor - reverse: run in reverse - - Returns: - Batch of size (batch_size, noutput). - """ - with variable_scope.variable_scope(scope, "SequenceToFinal", [inputs]): - length, batch_size, _ = _shape(inputs) - lstm = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False) - state = array_ops.zeros([batch_size, lstm.state_size]) - inputs_u = array_ops.unstack(inputs) - if reverse: - inputs_u = list(reversed(inputs_u)) - for i in xrange(length): - if i > 0: - variable_scope.get_variable_scope().reuse_variables() - output, state = lstm(inputs_u[i], state) - outputs = array_ops.reshape(output, [batch_size, noutput], name=name) - return outputs - - -def sequence_softmax(inputs, noutput, scope=None, name=None, linear_name=None): - """Run a softmax layer over all the time steps of an input sequence. - - Args: - inputs: (length, batch_size, depth) tensor - noutput: output depth - scope: optional scope name - name: optional name for output tensor - linear_name: name for linear (pre-softmax) output - - Returns: - A tensor of size (length, batch_size, noutput). - - """ - length, _, ninputs = _shape(inputs) - inputs_u = array_ops.unstack(inputs) - output_u = [] - with variable_scope.variable_scope(scope, "SequenceSoftmax", [inputs]): - initial_w = random_ops.truncated_normal([0 + ninputs, noutput], stddev=0.1) - initial_b = constant_op.constant(0.1, shape=[noutput]) - w = variables.model_variable("weights", initializer=initial_w) - b = variables.model_variable("biases", initializer=initial_b) - for i in xrange(length): - with variable_scope.variable_scope(scope, "SequenceSoftmaxStep", - [inputs_u[i]]): - # TODO(tmb) consider using slim.fully_connected(..., - # activation_fn=tf.nn.softmax) - linear = nn_ops.xw_plus_b(inputs_u[i], w, b, name=linear_name) - output = nn_ops.softmax(linear) - output_u += [output] - outputs = array_ops.stack(output_u, name=name) - return outputs diff --git a/tensorflow/contrib/ndlstm/python/lstm1d_test.py b/tensorflow/contrib/ndlstm/python/lstm1d_test.py deleted file mode 100644 index 49b15cc814cc54aaea7c67c4e509e5aa144e063e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/ndlstm/python/lstm1d_test.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for 1D LSTM.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.ndlstm.python import lstm1d as lstm1d_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import gradient_checker -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -lstm1d = lstm1d_lib - - -def _rand(*size): - return np.random.uniform(size=size).astype("f") - - -class Lstm1DTest(test.TestCase): - - def testSequenceToSequenceDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(17, 1, 5)) - outputs = lstm1d.ndlstm_base(inputs, 8) - variables.global_variables_initializer().run() - names = [v.name for v in variables.trainable_variables()] - self.assertEqual(len(names), 2) - result = outputs.eval() - self.assertEqual(tuple(result.shape), (17, 1, 8)) - - def testSequenceToSequenceGradient(self): - with self.test_session(): - size = (17, 1, 15) - output_size = (17, 1, 8) - inputs = constant_op.constant(_rand(*size)) - outputs = lstm1d.ndlstm_base(inputs, 8, dynamic=False) - variables.global_variables_initializer().run() - gradients = gradients_impl.gradients(outputs, inputs) - if 1: # pylint: disable=using-constant-test - gradients = gradients_impl.gradients(outputs, inputs)[0].eval() - self.assertEqual(gradients.shape, size) - else: - # TODO(tmb) tf.test.compute_gradient error is currently broken - # with dynamic_rnn. Enable this test case eventually. - err = gradient_checker.compute_gradient_error( - inputs, size, outputs, output_size, delta=1e-4) - self.assert_(not np.isnan(err)) - self.assert_(err < 0.1) - - def testSequenceToSequenceGradientReverse(self): - with self.test_session(): - size = (17, 1, 15) - output_size = (17, 1, 8) - inputs = constant_op.constant(_rand(*size)) - outputs = lstm1d.ndlstm_base(inputs, 8, reverse=1, dynamic=False) - variables.global_variables_initializer().run() - if 1: # pylint: disable=using-constant-test - gradients = gradients_impl.gradients(outputs, inputs)[0].eval() - self.assertEqual(gradients.shape, size) - else: - # TODO(tmb) tf.test.compute_gradient error is currently broken - # with dynamic_rnn. Enable this test case eventually. - err = gradient_checker.compute_gradient_error( - inputs, size, outputs, output_size, delta=1e-4) - self.assert_(not np.isnan(err)) - self.assert_(err < 0.1) - - def testSequenceToFinalDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(17, 6, 5)) - outputs = lstm1d.sequence_to_final(inputs, 8) - variables.global_variables_initializer().run() - names = [v.name for v in variables.trainable_variables()] - self.assertEqual(len(names), 2) - result = outputs.eval() - self.assertEqual(tuple(result.shape), (6, 8)) - - def testSequenceSoftmaxDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(17, 1, 5)) - outputs = lstm1d.sequence_softmax(inputs, 8) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (17, 1, 8)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/ndlstm/python/lstm2d.py b/tensorflow/contrib/ndlstm/python/lstm2d.py deleted file mode 100644 index ebbb4ccf11b219e86578d05e99a7a02ebe08271e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/ndlstm/python/lstm2d.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A small library of functions dealing with LSTMs applied to images. - -Tensors in this library generally have the shape (num_images, height, width, -depth). -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.ndlstm.python import lstm1d -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variable_scope - - -def _shape(tensor): - """Get the shape of a tensor as an int list.""" - return tensor.get_shape().as_list() - - -def images_to_sequence(tensor): - """Convert a batch of images into a batch of sequences. - - Args: - tensor: a (num_images, height, width, depth) tensor - - Returns: - (width, num_images*height, depth) sequence tensor - """ - - num_image_batches, height, width, depth = _shape(tensor) - transposed = array_ops.transpose(tensor, [2, 0, 1, 3]) - return array_ops.reshape(transposed, - [width, num_image_batches * height, depth]) - - -def sequence_to_images(tensor, num_image_batches): - """Convert a batch of sequences into a batch of images. - - Args: - tensor: (num_steps, num_batches, depth) sequence tensor - num_image_batches: the number of image batches - - Returns: - (num_images, height, width, depth) tensor - """ - - width, num_batches, depth = _shape(tensor) - height = num_batches // num_image_batches - reshaped = array_ops.reshape(tensor, - [width, num_image_batches, height, depth]) - return array_ops.transpose(reshaped, [1, 2, 0, 3]) - - -def horizontal_lstm(images, num_filters_out, scope=None): - """Run an LSTM bidirectionally over all the rows of each image. - - Args: - images: (num_images, height, width, depth) tensor - num_filters_out: output depth - scope: optional scope name - - Returns: - (num_images, height, width, num_filters_out) tensor, where - num_steps is width and new num_batches is num_image_batches * height - """ - with variable_scope.variable_scope(scope, "HorizontalLstm", [images]): - batch_size, _, _, _ = _shape(images) - sequence = images_to_sequence(images) - with variable_scope.variable_scope("lr"): - hidden_sequence_lr = lstm1d.ndlstm_base(sequence, num_filters_out // 2) - with variable_scope.variable_scope("rl"): - hidden_sequence_rl = (lstm1d.ndlstm_base( - sequence, num_filters_out - num_filters_out // 2, reverse=1)) - output_sequence = array_ops.concat([hidden_sequence_lr, hidden_sequence_rl], - 2) - output = sequence_to_images(output_sequence, batch_size) - return output - - -def get_blocks(images, kernel_size): - """Split images in blocks - - Args: - images: (num_images, height, width, depth) tensor - kernel_size: A list of length 2 holding the [kernel_height, kernel_width] of - of the pooling. Can be an int if both values are the same. - - Returns: - (num_images, height/kernel_height, width/kernel_width, - depth*kernel_height*kernel_width) tensor - """ - with variable_scope.variable_scope("image_blocks"): - batch_size, height, width, chanels = _shape(images) - - if height % kernel_size[0] != 0: - offset = array_ops.zeros([batch_size, - kernel_size[0] - (height % kernel_size[0]), - width, - chanels]) - images = array_ops.concat([images, offset], 1) - batch_size, height, width, chanels = _shape(images) - if width % kernel_size[1] != 0: - offset = array_ops.zeros([batch_size, - height, - kernel_size[1] - (width % kernel_size[1]), - chanels]) - images = array_ops.concat([images, offset], 2) - batch_size, height, width, chanels = _shape(images) - - h, w = int(height / kernel_size[0]), int(width / kernel_size[1]) - features = kernel_size[1] * kernel_size[0] * chanels - - lines = array_ops.split(images, h, axis=1) - line_blocks = [] - for line in lines: - line = array_ops.transpose(line, [0, 2, 3, 1]) - line = array_ops.reshape(line, [batch_size, w, features]) - line_blocks.append(line) - - return array_ops.stack(line_blocks, axis=1) - - -def separable_lstm(images, num_filters_out, - kernel_size=None, nhidden=None, scope=None): - """Run bidirectional LSTMs first horizontally then vertically. - - Args: - images: (num_images, height, width, depth) tensor - num_filters_out: output layer depth - kernel_size: A list of length 2 holding the [kernel_height, kernel_width] of - of the pooling. Can be an int if both values are the same. Set to None for - not using blocks - nhidden: hidden layer depth - scope: optional scope name - - Returns: - (num_images, height/kernel_height, width/kernel_width, - num_filters_out) tensor - """ - with variable_scope.variable_scope(scope, "SeparableLstm", [images]): - if nhidden is None: - nhidden = num_filters_out - if kernel_size is not None: - images = get_blocks(images, kernel_size) - hidden = horizontal_lstm(images, nhidden) - with variable_scope.variable_scope("vertical"): - transposed = array_ops.transpose(hidden, [0, 2, 1, 3]) - output_transposed = horizontal_lstm(transposed, num_filters_out) - output = array_ops.transpose(output_transposed, [0, 2, 1, 3]) - return output - - -def reduce_to_sequence(images, num_filters_out, scope=None): - """Reduce an image to a sequence by scanning an LSTM vertically. - - Args: - images: (num_images, height, width, depth) tensor - num_filters_out: output layer depth - scope: optional scope name - - Returns: - A (width, num_images, num_filters_out) sequence. - """ - with variable_scope.variable_scope(scope, "ReduceToSequence", [images]): - batch_size, height, width, depth = _shape(images) - transposed = array_ops.transpose(images, [1, 0, 2, 3]) - reshaped = array_ops.reshape(transposed, - [height, batch_size * width, depth]) - reduced = lstm1d.sequence_to_final(reshaped, num_filters_out) - output = array_ops.reshape(reduced, [batch_size, width, num_filters_out]) - return output - - -def reduce_to_final(images, num_filters_out, nhidden=None, scope=None): - """Reduce an image to a final state by running two LSTMs. - - Args: - images: (num_images, height, width, depth) tensor - num_filters_out: output layer depth - nhidden: hidden layer depth (defaults to num_filters_out) - scope: optional scope name - - Returns: - A (num_images, num_filters_out) batch. - """ - with variable_scope.variable_scope(scope, "ReduceToFinal", [images]): - nhidden = nhidden or num_filters_out - batch_size, height, width, depth = _shape(images) - transposed = array_ops.transpose(images, [1, 0, 2, 3]) - reshaped = array_ops.reshape(transposed, - [height, batch_size * width, depth]) - with variable_scope.variable_scope("reduce1"): - reduced = lstm1d.sequence_to_final(reshaped, nhidden) - transposed_hidden = array_ops.reshape(reduced, - [batch_size, width, nhidden]) - hidden = array_ops.transpose(transposed_hidden, [1, 0, 2]) - with variable_scope.variable_scope("reduce2"): - output = lstm1d.sequence_to_final(hidden, num_filters_out) - return output diff --git a/tensorflow/contrib/ndlstm/python/lstm2d_test.py b/tensorflow/contrib/ndlstm/python/lstm2d_test.py deleted file mode 100644 index f1b37d701b868438dcbac4e713ccc2136dacd983..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/ndlstm/python/lstm2d_test.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for 2D LSTMs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.ndlstm.python import lstm2d as lstm2d_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -lstm2d = lstm2d_lib - - -def _rand(*size): - return np.random.uniform(size=size).astype("f") - - -class Lstm2DTest(test_util.TensorFlowTestCase): - - def testImagesToSequenceDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = lstm2d.images_to_sequence(inputs) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (11, 14, 5)) - - def testSequenceToImagesDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(11, 14, 5)) - outputs = lstm2d.sequence_to_images(inputs, 2) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 7, 11, 5)) - - def testImagesAndSequenceDims(self): - with self.test_session(): - size = (2, 7, 11, 5) - inputs = constant_op.constant(_rand(*size)) - sequence = lstm2d.images_to_sequence(inputs) - outputs = lstm2d.sequence_to_images(sequence, size[0]) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), size) - - def testSeparableLstmDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = lstm2d.separable_lstm(inputs, 8) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 7, 11, 8)) - - def testSeparableLstmDimsBlocks(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = lstm2d.separable_lstm(inputs, 8, kernel_size=[2, 2]) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 4, 6, 8)) - - def testReduceToSequenceDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = lstm2d.reduce_to_sequence(inputs, 8) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 11, 8)) - - def testReduceToFinalDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = lstm2d.reduce_to_final(inputs, 8, 12) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 8)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/ndlstm/python/misc.py b/tensorflow/contrib/ndlstm/python/misc.py deleted file mode 100644 index 38eeff84ca4e5afbe45d6c9e0c52af9ae86de24f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/ndlstm/python/misc.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Miscellaneous functions useful for nD-LSTM models. - -Some of these functions duplicate functionality in tfslim with -slightly different interfaces. - -Tensors in this library generally have the shape (num_images, height, width, -depth). -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.layers.python.layers import layers -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import sparse_ops - - -def _shape(tensor): - """Get the shape of a tensor as an int list.""" - return tensor.get_shape().as_list() - - -def pixels_as_vector(images, scope=None): - """Reduce images to vectors by combining all pixels.""" - with ops.name_scope(scope, "PixelsAsVector", [images]): - batch_size, height, width, depth = _shape(images) - return array_ops.reshape(images, [batch_size, height * width * depth]) - - -def pool_as_vector(images, scope=None): - """Reduce images to vectors by averaging all pixels.""" - with ops.name_scope(scope, "PoolAsVector", [images]): - return math_ops.reduce_mean(images, [1, 2]) - - -def one_hot_planes(labels, num_classes, scope=None): - """Compute 1-hot encodings for planes. - - Given a label, this computes a label image that contains - 1 at all pixels in the plane corresponding to the target - class and 0 in all other planes. - - Args: - labels: (batch_size,) tensor - num_classes: number of classes - scope: optional scope name - - Returns: - Tensor of shape (batch_size, 1, 1, num_classes) with a 1-hot encoding. - """ - with ops.name_scope(scope, "OneHotPlanes", [labels]): - batch_size, = _shape(labels) - batched = layers.one_hot_encoding(labels, num_classes) - return array_ops.reshape(batched, [batch_size, 1, 1, num_classes]) - - -def one_hot_mask(labels, num_classes, scope=None): - """Compute 1-hot encodings for masks. - - Given a label image, this computes the one hot encoding at - each pixel. - - Args: - labels: (batch_size, width, height, 1) tensor containing labels. - num_classes: number of classes - scope: optional scope name - - Returns: - Tensor of shape (batch_size, width, height, num_classes) with - a 1-hot encoding. - """ - with ops.name_scope(scope, "OneHotMask", [labels]): - height, width, depth = _shape(labels) - assert depth == 1 - sparse_labels = math_ops.to_int32(array_ops.reshape(labels, [-1, 1])) - sparse_size, _ = _shape(sparse_labels) - indices = array_ops.reshape(math_ops.range(0, sparse_size, 1), [-1, 1]) - concated = array_ops.concat([indices, sparse_labels], 1) - dense_result = sparse_ops.sparse_to_dense(concated, - [sparse_size, num_classes], 1.0, - 0.0) - result = array_ops.reshape(dense_result, [height, width, num_classes]) - return result diff --git a/tensorflow/contrib/ndlstm/python/misc_test.py b/tensorflow/contrib/ndlstm/python/misc_test.py deleted file mode 100644 index fac9023da3b23b89a5494358c6e7ad82c12f9bdf..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/ndlstm/python/misc_test.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Miscellaneous tests.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.ndlstm.python import misc as misc_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -misc = misc_lib - - -def _rand(*size): - return np.random.uniform(size=size).astype("f") - - -class LstmMiscTest(test_util.TensorFlowTestCase): - - def testPixelsAsVectorDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = misc.pixels_as_vector(inputs) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 7 * 11 * 5)) - - def testPoolAsVectorDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = misc.pool_as_vector(inputs) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 5)) - - def testOneHotPlanes(self): - with self.test_session(): - inputs = constant_op.constant([0, 1, 3]) - outputs = misc.one_hot_planes(inputs, 4) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (3, 1, 1, 4)) - target = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) - self.assertAllClose(result.reshape(-1), target.reshape(-1)) - - def testOneHotMask(self): - with self.test_session(): - data = np.array([[0, 1, 2], [2, 0, 1]]).reshape(2, 3, 1) - inputs = constant_op.constant(data) - outputs = misc.one_hot_mask(inputs, 3) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 3, 3)) - target = np.array([[[1, 0, 0], [0, 1, 0]], [[0, 1, 0], [0, 0, 1]], - [[0, 0, 1], [1, 0, 0]]]).transpose(1, 2, 0) - self.assertAllClose(result.reshape(-1), target.reshape(-1)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/nearest_neighbor/kernels/heap.h b/tensorflow/contrib/nearest_neighbor/kernels/heap.h index 32925569a82c43be75a0b6e93d7d781cda3d53f4..a2dbb8052bfa1634d27c8b38a9bb6ca27fae42a2 100644 --- a/tensorflow/contrib/nearest_neighbor/kernels/heap.h +++ b/tensorflow/contrib/nearest_neighbor/kernels/heap.h @@ -56,7 +56,7 @@ class HeapBase { // This method adds an element at the end of the internal array without // "heapifying" the array afterwards. This is useful for setting up a heap - // where a single call to heapify at the end of the inital insertion + // where a single call to heapify at the end of the initial insertion // operations suffices. void InsertUnsorted(const KeyType& key, const DataType& data) { if (v_.size() == static_cast(num_elements_)) { diff --git a/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc b/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc index 2b412fac9a621f01bd21c6b4391da3c462dd78b3..13db6f62f525b6318687e3bf4b6499eee2c61ea8 100644 --- a/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc +++ b/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc @@ -75,7 +75,8 @@ class HyperplaneLSHProbesOp : public OpKernel { num_hyperplanes_per_table, ".")); OP_REQUIRES(context, num_hyperplanes_per_table <= 30, InvalidArgument("Need num_hyperplanes_per_table <= 30, got ", - num_hyperplanes_per_table, ". " + num_hyperplanes_per_table, + ". " "If you need more hyperplanes, change this Op" " to work for larger integer types (int64).")); @@ -88,12 +89,13 @@ class HyperplaneLSHProbesOp : public OpKernel { InvalidArgument("num_probes must be at least 1.")); int expected_num_hyperplanes = num_tables * num_hyperplanes_per_table; - OP_REQUIRES( - context, products_tensor.dim_size(1) == expected_num_hyperplanes, - InvalidArgument("Expected number of hyperplanes is ", - expected_num_hyperplanes, " but received ", - products_tensor.dim_size(1), " inner products per " - "point.")); + OP_REQUIRES(context, + products_tensor.dim_size(1) == expected_num_hyperplanes, + InvalidArgument("Expected number of hyperplanes is ", + expected_num_hyperplanes, " but received ", + products_tensor.dim_size(1), + " inner products per " + "point.")); auto products_eigen_tensor = products_tensor.matrix(); ConstMatrixMap products_matrix(products_eigen_tensor.data(), @@ -116,13 +118,11 @@ class HyperplaneLSHProbesOp : public OpKernel { // lschmidt's workstation. int64 cost_per_unit = 21 * num_hyperplanes_per_table * num_tables; if (num_probes > num_tables) { - cost_per_unit += 110 * num_hyperplanes_per_table - * (num_probes - num_tables); + cost_per_unit += + 110 * num_hyperplanes_per_table * (num_probes - num_tables); } context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( - batch_size, - cost_per_unit, - [&](int64 start, int64 end) { + batch_size, cost_per_unit, [&](int64 start, int64 end) { HyperplaneMultiprobe multiprobe( num_hyperplanes_per_table, num_tables); diff --git a/tensorflow/contrib/nn/python/ops/alpha_dropout.py b/tensorflow/contrib/nn/python/ops/alpha_dropout.py index d7b61a584478f701726248a41c4992382189223d..2f92d05ba81f30a91f68f3c3ec51b6695d3d0371 100644 --- a/tensorflow/contrib/nn/python/ops/alpha_dropout.py +++ b/tensorflow/contrib/nn/python/ops/alpha_dropout.py @@ -18,7 +18,6 @@ from __future__ import print_function import numbers -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -26,7 +25,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_impl def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name diff --git a/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py b/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py index 2ff978ab89727c0ba2a8654013466838732377e4..54a98e6f142b7ba58c9418a8ac88269d38944aab 100644 --- a/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py +++ b/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.contrib.nn.python.ops.alpha_dropout import alpha_dropout from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import nn_impl diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 827279bd476f9666a972f43ad557fde6d0b6c59a..86ceda71b703a021fd3a71371b5d56ab82c42ee2 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -70,6 +70,7 @@ py_test( srcs = ["python/training/moving_average_optimizer_test.py"], srcs_version = "PY2AND3", tags = [ + "no_oss", # b/73507407 "notsan", # b/31055119 ], deps = [ diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 6132cba1f5aecbafd8ca820ecda39355dd768847..5763593b81497f5d6945ff1e5d000042d295c093 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Wrapper optimizer for Elastic Average SGD """ from __future__ import absolute_import from __future__ import division @@ -78,23 +77,24 @@ class ElasticAverageCustomGetter(object): def __call__(self, getter, name, trainable, collections, *args, **kwargs): if trainable: with ops.device(self._worker_device): - local_var = getter(name, trainable=True, - collections=[ops.GraphKeys.LOCAL_VARIABLES], - *args, **kwargs) + local_var = getter( + name, + trainable=True, + collections=[ops.GraphKeys.LOCAL_VARIABLES], + *args, + **kwargs) global_center_variable = variable_scope.variable( - name='%s/%s' % - (GLOBAL_VARIABLE_NAME, - name), - initial_value=local_var.initialized_value(), - trainable=False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + name='%s/%s' % (GLOBAL_VARIABLE_NAME, name), + initial_value=local_var.initialized_value(), + trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) with ops.device(self._worker_device): local_center_variable = variable_scope.variable( - name='%s/%s' % (LOCAL_VARIABLE_NAME, name), - initial_value=local_var.initialized_value(), - trainable=False, - collections=[ops.GraphKeys.LOCAL_VARIABLES]) + name='%s/%s' % (LOCAL_VARIABLE_NAME, name), + initial_value=local_var.initialized_value(), + trainable=False, + collections=[ops.GraphKeys.LOCAL_VARIABLES]) self._local_map[local_var] = local_center_variable self._global_map[local_var] = global_center_variable @@ -117,16 +117,15 @@ class ElasticAverageOptimizer(optimizer.Optimizer): # Default value as paper described BETA = 0.9 - def __init__( - self, - opt, - num_worker, - ea_custom_getter, - communication_period=10, - moving_rate=None, - rho=None, - use_locking=True, - name="ElasticAverageOptimizer"): + def __init__(self, + opt, + num_worker, + ea_custom_getter, + communication_period=10, + moving_rate=None, + rho=None, + use_locking=True, + name='ElasticAverageOptimizer'): """Construct a new gradient descent optimizer. Args: @@ -151,7 +150,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): self._global_map = ea_custom_getter._global_map if moving_rate is None: - self._moving_rate = BETA / communication_period / num_worker + self._moving_rate = self.BETA / communication_period / num_worker else: self._moving_rate = moving_rate if rho is None: @@ -160,13 +159,15 @@ class ElasticAverageOptimizer(optimizer.Optimizer): self._rho = rho self._local_step = variable_scope.get_variable( - initializer=0, - trainable=False, - collections=[ops.GraphKeys.LOCAL_VARIABLES], - name="local_step") + initializer=0, + trainable=False, + collections=[ops.GraphKeys.LOCAL_VARIABLES], + name='local_step') self._opt._prepare() - def compute_gradients(self, loss, var_list=None, + def compute_gradients(self, + loss, + var_list=None, gate_gradients=optimizer.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, @@ -204,16 +205,18 @@ class ElasticAverageOptimizer(optimizer.Optimizer): if not var_list: var_list = variables.trainable_variables() - elastic_difference = [math_ops.subtract(v, lv) for v, lv in zip( - variables.trainable_variables(), - [self._local_map[var] for var in var_list])] + elastic_difference = [ + math_ops.subtract(v, lv) + for v, lv in zip(variables.trainable_variables(), + [self._local_map[var] for var in var_list]) + ] distance_loss = self._rho * math_ops.add_n( - [gen_nn_ops.l2_loss(ed) for ed in elastic_difference]) + [gen_nn_ops.l2_loss(ed) for ed in elastic_difference]) total_loss = loss + distance_loss - return self._opt.compute_gradients(total_loss, var_list, - gate_gradients, aggregation_method, + return self._opt.compute_gradients(total_loss, var_list, gate_gradients, + aggregation_method, colocate_gradients_with_ops, grad_loss) def apply_gradients(self, grads_and_vars, global_step=None, name=None): @@ -241,7 +244,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): apply_updates = self._opt.apply_gradients(grads_and_vars) with ops.control_dependencies([apply_updates]): local_update = state_ops.assign_add( - self._local_step, 1, name='local_step_update').op + self._local_step, 1, name='local_step_update').op # update global variables. def _Update_global_variables(): @@ -259,12 +262,16 @@ class ElasticAverageOptimizer(optimizer.Optimizer): differences.append(math_ops.subtract(v, lv)) for lvar, diff in zip(local_vars, differences): with ops.device(lvar.device): - update_ops.append(state_ops.assign_sub(lvar, math_ops.multiply( - self._moving_rate, diff))) + update_ops.append( + state_ops.assign_sub(lvar, + math_ops.multiply(self._moving_rate, + diff))) for var, diff in zip(global_center_vars, differences): with ops.device(var.device): - update_ops.append(state_ops.assign_add(var, math_ops.multiply( - self._moving_rate, diff))) + update_ops.append( + state_ops.assign_add(var, + math_ops.multiply(self._moving_rate, + diff))) if global_step: with ops.colocate_with(global_step): update_ops.append(state_ops.assign_add(global_step, 1)) @@ -272,10 +279,10 @@ class ElasticAverageOptimizer(optimizer.Optimizer): return variable_update with ops.control_dependencies([local_update]): - condition = math_ops.equal(math_ops.mod( - self._local_step, self._period), 0) + condition = math_ops.equal( + math_ops.mod(self._local_step, self._period), 0) conditional_update = control_flow_ops.cond( - condition, _Update_global_variables, control_flow_ops.no_op) + condition, _Update_global_variables, control_flow_ops.no_op) return conditional_update def get_init_op(self, task_index): @@ -285,10 +292,12 @@ class ElasticAverageOptimizer(optimizer.Optimizer): def _Add_sync_queues_and_barrier(enqueue_after_list): """Adds ops to enqueu on all worker queues""" sync_queues = [ - data_flow_ops.FIFOQueue(self._num_worker, [dtypes.bool], shapes=[[]], - shared_name='%s%s' % ( - 'variable_init_sync_queue', i)) for i in - range(self._num_worker)] + data_flow_ops.FIFOQueue( + self._num_worker, [dtypes.bool], + shapes=[[]], + shared_name='%s%s' % ('variable_init_sync_queue', i)) + for i in range(self._num_worker) + ] queue_ops = [] # For each other worker, add an entry in a queue token = constant_op.constant(False) @@ -299,7 +308,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): else: queue_ops.append(q.enqueue(token)) queue_ops.append( - sync_queues[task_index].dequeue_many(len(sync_queues) - 1)) + sync_queues[task_index].dequeue_many(len(sync_queues) - 1)) return control_flow_ops.group(*queue_ops) init_ops = [] @@ -307,11 +316,10 @@ class ElasticAverageOptimizer(optimizer.Optimizer): global_center_vars = [self._global_map[var] for var in local_vars] local_center_vars = [self._local_map[var] for var in local_vars] if not (local_vars and global_center_vars and local_center_vars): - raise ValueError( - 'The lists of local_variables, global_center_variables, ' - 'local_center_variables should not be empty ') - for lvar, gc_var, lc_var in zip( - local_vars, global_center_vars, local_center_vars): + raise ValueError('The lists of local_variables, global_center_variables, ' + 'local_center_variables should not be empty ') + for lvar, gc_var, lc_var in zip(local_vars, global_center_vars, + local_center_vars): init_ops.append(state_ops.assign(lvar, gc_var)) init_ops.append(state_ops.assign(lc_var, gc_var)) @@ -325,6 +333,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): + def __init__(self, ea_optimizer, is_chief, task_index): """Creates hook to handle ElasticAverageOptimizer initialization ops. diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py index 446e91018d477d75116f6b78a2443ed79ed3b3ef..37539b959959b5cf1f7b2c8e8d2b6b05191565ad 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py @@ -38,20 +38,20 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"): worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] cluster_dict = { - "worker": ["localhost:%s" % port for port in worker_ports], - "ps": ["localhost:%s" % port for port in ps_ports] + "worker": ["localhost:%s" % port for port in worker_ports], + "ps": ["localhost:%s" % port for port in ps_ports] } cs = server_lib.ClusterSpec(cluster_dict) workers = [ - server_lib.Server( - cs, job_name="worker", protocol=protocol, task_index=ix, start=True) - for ix in range(num_workers) + server_lib.Server( + cs, job_name="worker", protocol=protocol, task_index=ix, start=True) + for ix in range(num_workers) ] ps_servers = [ - server_lib.Server( - cs, job_name="ps", protocol=protocol, task_index=ix, start=True) - for ix in range(num_ps) + server_lib.Server( + cs, job_name="ps", protocol=protocol, task_index=ix, start=True) + for ix in range(num_ps) ] return cluster_dict, workers, ps_servers @@ -68,15 +68,14 @@ def _get_workers(num_workers, period, workers, moving_rate): is_chief = (worker_id == 0) with graph.as_default(): worker_device = "/job:worker/task:%d/cpu:0" % (worker_id) - ea_coustom = ElasticAverageCustomGetter( - worker_device=worker_device) - with variable_scope.variable_scope('', - custom_getter=ea_coustom), ops.device( - device_setter.replica_device_setter(worker_device=worker_device, - ps_device="/job:ps/task:0/cpu:0", - ps_tasks=1)): - global_step = variables.Variable(0, name='global_step', - trainable=False) + ea_coustom = ElasticAverageCustomGetter(worker_device=worker_device) + with variable_scope.variable_scope( + "", custom_getter=ea_coustom), ops.device( + device_setter.replica_device_setter( + worker_device=worker_device, + ps_device="/job:ps/task:0/cpu:0", + ps_tasks=1)): + global_step = variables.Variable(0, name="global_step", trainable=False) var_0 = variable_scope.get_variable(initializer=0.0, name="v0") var_1 = variable_scope.get_variable(initializer=1.0, name="v1") @@ -86,21 +85,19 @@ def _get_workers(num_workers, period, workers, moving_rate): sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) opt = ElasticAverageOptimizer( - opt=sgd_opt, - num_worker=num_workers, - moving_rate=moving_rate, - communication_period=period, - ea_custom_getter=ea_coustom - ) + opt=sgd_opt, + num_worker=num_workers, + moving_rate=moving_rate, + communication_period=period, + ea_custom_getter=ea_coustom) train_op = [ - opt.apply_gradients( - ([grads_0, var_0], - [grads_1, var_1]), global_step) + opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]), + global_step) ] easgd_hook = opt.make_session_run_hook(is_chief, worker_id) # Creates MonitoredSession - sess = training.MonitoredTrainingSession(workers[worker_id].target, - hooks=[easgd_hook]) + sess = training.MonitoredTrainingSession( + workers[worker_id].target, hooks=[easgd_hook]) sessions.append(sess) graphs.append(graph) @@ -110,6 +107,7 @@ def _get_workers(num_workers, period, workers, moving_rate): class ElasticAverageOptimizerTest(test.TestCase): + def _run(self, train_op, sess): sess.run(train_op) @@ -117,15 +115,14 @@ class ElasticAverageOptimizerTest(test.TestCase): num_workers = 1 communication_period = 2 num_ps = 1 - cluster, workers, _ = create_local_cluster(num_workers=num_workers, - num_ps=num_ps) + cluster, workers, _ = create_local_cluster( + num_workers=num_workers, num_ps=num_ps) - sessions, graphs, train_ops = _get_workers(num_workers, - communication_period, - workers, 1.0) + sessions, graphs, train_ops = _get_workers( + num_workers, communication_period, workers, 1.0) - var_0 = graphs[0].get_tensor_by_name('v0:0') - var_1 = graphs[0].get_tensor_by_name('v1:0') + var_0 = graphs[0].get_tensor_by_name("v0:0") + var_1 = graphs[0].get_tensor_by_name("v1:0") global_step = training_util.get_global_step(graphs[0]) var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0") var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0") @@ -166,18 +163,17 @@ class ElasticAverageOptimizerTest(test.TestCase): num_workers = 2 communication_period = 1 num_ps = 2 - cluster, workers, _ = create_local_cluster(num_workers=num_workers, - num_ps=num_ps) + cluster, workers, _ = create_local_cluster( + num_workers=num_workers, num_ps=num_ps) - sessions, graphs, train_ops = _get_workers(num_workers, - communication_period, - workers, 0.5) + sessions, graphs, train_ops = _get_workers( + num_workers, communication_period, workers, 0.5) - var_0 = graphs[0].get_tensor_by_name('v0:0') - var_1 = graphs[0].get_tensor_by_name('v1:0') + var_0 = graphs[0].get_tensor_by_name("v0:0") + var_1 = graphs[0].get_tensor_by_name("v1:0") - var_0_1 = graphs[1].get_tensor_by_name('v0:0') - var_1_1 = graphs[1].get_tensor_by_name('v1:0') + var_0_1 = graphs[1].get_tensor_by_name("v0:0") + var_1_1 = graphs[1].get_tensor_by_name("v1:0") var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0") var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0") @@ -201,25 +197,24 @@ class ElasticAverageOptimizerTest(test.TestCase): def testPS2TasksWithClusterSpecClass(self): cluster_spec = server_lib.ClusterSpec({ - "ps": ["ps0:2222", "ps1:2222"], - "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] }) - ea_coustom = ElasticAverageCustomGetter( - worker_device="/job:worker/task:0") + ea_coustom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0") from tensorflow.python.training import device_setter with ops.device( device_setter.replica_device_setter(cluster=cluster_spec, worker_device="/job:worker/task:0", ps_device="/job:ps")), \ - variable_scope.variable_scope('', custom_getter=ea_coustom): + variable_scope.variable_scope("", custom_getter=ea_coustom): v = variable_scope.get_variable(initializer=[1, 2], name="v") - w = variable_scope.get_variable(initializer=[2, 1], name='w') - v_g, w_g = ea_coustom._global_map[v],ea_coustom._global_map[w] + w = variable_scope.get_variable(initializer=[2, 1], name="w") + v_g, w_g = ea_coustom._global_map[v], ea_coustom._global_map[w] self.assertDeviceEqual("/job:worker/task:0", v.device) self.assertDeviceEqual("job:ps/task:0", v_g.device) self.assertDeviceEqual("/job:worker/task:0", w.device) self.assertDeviceEqual("job:ps/task:1", w_g.device) -if __name__ == '__main__': +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/opt/python/training/external_optimizer.py b/tensorflow/contrib/opt/python/training/external_optimizer.py index f243317f1df2ec8d93d44ad534f3fa58527f3217..82ebca7f20306e5658c8321716e39f9c7f8b8970 100644 --- a/tensorflow/contrib/opt/python/training/external_optimizer.py +++ b/tensorflow/contrib/opt/python/training/external_optimizer.py @@ -397,10 +397,6 @@ class ScipyOptimizerInterface(ExternalOptimizerInterface): 'automatically and cannot be injected manually'.format(kwarg)) minimize_kwargs.update(optimizer_kwargs) - if method == 'SLSQP': - # SLSQP doesn't support step callbacks. Obviate associated warning - # message. - del minimize_kwargs['callback'] import scipy.optimize # pylint: disable=g-import-not-at-top result = scipy.optimize.minimize(*minimize_args, **minimize_kwargs) diff --git a/tensorflow/contrib/opt/python/training/external_optimizer_test.py b/tensorflow/contrib/opt/python/training/external_optimizer_test.py index 0f597d0a246a53892d72939edd1499a86c01017d..953586ee70cd4137295dd254bfb2d37cab0bcfe4 100644 --- a/tensorflow/contrib/opt/python/training/external_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/external_optimizer_test.py @@ -299,6 +299,45 @@ class ScipyOptimizerInterfaceTest(TestCase): method = optimizer.optimizer_kwargs.get('method') self.assertEqual('SLSQP', method) + def test_callbacks(self): + vector_val = np.array([7., -2.], dtype=np.float32) + vector = variables.Variable(vector_val, 'vector') + + minimum_location_val = np.arange(2) + minimum_location = constant_op.constant( + minimum_location_val, dtype=dtypes.float32) + + loss = math_ops.reduce_sum(math_ops.square(vector - minimum_location)) / 2. + loss_val_first = ((vector_val - minimum_location_val)**2).sum() / 2. + + optimizer = external_optimizer.ScipyOptimizerInterface(loss, method='SLSQP') + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + + initial_vector_val = sess.run(vector) + + extra_fetches = [loss] + + step_callback = test.mock.Mock() + loss_callback = test.mock.Mock() + + optimizer.minimize( + sess, + fetches=extra_fetches, + loss_callback=loss_callback, + step_callback=step_callback) + + loss_val_last = sess.run(loss) + + call_first = test.mock.call(loss_val_first) + call_last = test.mock.call(loss_val_last) + loss_calls = [call_first, call_last] + loss_callback.assert_has_calls(loss_calls, any_order=True) + + args, _ = step_callback.call_args + self.assertAllClose(minimum_location_val, args[0]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index d68ad23d65500cc2348459cdc53030c2ea08373a..9ce50bfe1054072b315adecb87f1ba729dfe0d83 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -83,7 +83,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): self._optimizer = opt self._ema = moving_averages.ExponentialMovingAverage( average_decay, num_updates=num_updates) - self._variable_map = None + self._swapped_variable_name_map = None self._sequential_update = sequential_update def compute_gradients(self, *args, **kwargs): @@ -93,7 +93,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): train_op = self._optimizer.apply_gradients( grads_and_vars, global_step=global_step, name=name) var_list = [x[1] for x in grads_and_vars if x[0] is not None] - self._variable_map = {} + self._swapped_variable_name_map = {} if self._sequential_update: with ops.control_dependencies([train_op]): ma_op = self._ema.apply(var_list) @@ -102,9 +102,9 @@ class MovingAverageOptimizer(optimizer.Optimizer): for v in var_list: v_avg = self._ema.average(v) - self._variable_map[v.op.name] = v_avg - self._variable_map[v_avg.op.name] = v - return control_flow_ops.group(train_op, ma_op, name="train_with_avg") + self._swapped_variable_name_map[v.op.name] = v_avg.op.name + self._swapped_variable_name_map[v_avg.op.name] = v.op.name + return control_flow_ops.group(train_op, ma_op, name='train_with_avg') def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): """Create a saver swapping moving averages and variables. @@ -129,22 +129,45 @@ class MovingAverageOptimizer(optimizer.Optimizer): Raises: RuntimeError: If apply_gradients or minimize has not been called before. + ValueError: If var_list is provided and contains some variables but not + their moving average counterpart. """ - if self._variable_map is None: + if self._swapped_variable_name_map is None: raise RuntimeError('Must call apply_gradients or minimize before ' 'creating the swapping_saver') if var_list is None: var_list = variables.global_variables() if not isinstance(var_list, dict): var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + + # OpListToDict converts variables to tensors. We make sure we can get + # the unique variable name for normal and resource vaiables. + def get_v_name(tensor): + if tensor.op.type == 'ReadVariableOp': + return tensor.op.inputs[0].op.name + else: + return tensor.op.name + + v_name_to_tensor = {} + for tensor in six.itervalues(var_list): + v_name = get_v_name(tensor) + v_name_to_tensor[v_name] = tensor + # Now swap variables and moving averages swapped_var_list = {} - for k, v in six.iteritems(var_list): - v_swap = self._variable_map.get(v.op.name, None) - if v_swap: - swapped_var_list[k] = v_swap - else: - swapped_var_list[k] = v + for k, tensor in six.iteritems(var_list): + v_name = get_v_name(tensor) + swapped_v_name = self._swapped_variable_name_map.get(v_name, None) + tensor_to_save = tensor + if swapped_v_name is not None: + if swapped_v_name in v_name_to_tensor: + tensor_to_save = v_name_to_tensor[swapped_v_name] + else: + raise ValueError( + ('Variable to swap %s is not part of variables to save. ' + 'This breaks MovingAverageOptimizer.') % swapped_v_name) + swapped_var_list[k] = tensor_to_save + # Build the swapping saver. return saver.Saver(swapped_var_list, name=name, **kwargs) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py index 60929add198f2e69b5acc2eb5516dafc82b1f3ba..85e3e8d3791f2331ed249c0b7f67a3dbde4fca08 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py @@ -24,6 +24,10 @@ import six from tensorflow.contrib.opt.python.training import moving_average_optimizer from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent @@ -33,13 +37,26 @@ from tensorflow.python.training import saver class MovingAverageOptimizerTest(test.TestCase): def testRun(self): + self._helpTestRun(use_resource=False) + + def testRunUseResource(self): + # Test that MovingAverageOptimizer works with resource variables. + self._helpTestRun(use_resource=True) + + def _helpTestRun(self, use_resource=False): for sequential_update in [True, False]: for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session() as sess: + with self.test_session(graph=ops.Graph()) as sess: orig_val0 = [1.0, 2.0] orig_val1 = [3.0, 4.0] - var0 = variables.Variable(orig_val0, name='var0', dtype=dtype) - var1 = variables.Variable(orig_val1, name='var1', dtype=dtype) + var0 = variable_scope.get_variable( + 'var0', + initializer=constant_op.constant(orig_val0, dtype=dtype), + use_resource=use_resource) + var1 = variable_scope.get_variable( + 'var1', + initializer=constant_op.constant(orig_val1, dtype=dtype), + use_resource=use_resource) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) @@ -52,22 +69,63 @@ class MovingAverageOptimizerTest(test.TestCase): save_path = os.path.join(save_dir, 'model') update = opt.apply_gradients( list(six.moves.zip([grads0, grads1], [var0, var1]))) + global_vars = variables.global_variables() + ema_var0 = [ + v for v in global_vars + if v.op.name == 'var0/ExponentialMovingAverage' + ][0] + ema_var1 = [ + v for v in global_vars + if v.op.name == 'var1/ExponentialMovingAverage' + ][0] + perturb = control_flow_ops.group([ + state_ops.assign_add(var0, [1.0, 1.0]), + state_ops.assign_add(var1, [2.0, 2.0]), + state_ops.assign_add(ema_var0, [3.0, 3.0]), + state_ops.assign_add(ema_var1, [4.0, 4.0]) + ]) + + # Test taht saver with missing ema variables will fail. + with self.assertRaisesRegexp(ValueError, r'Variable to swap'): + opt.swapping_saver(var_list=[var0]) + train_saver = opt.swapping_saver() + train_saver_subset = opt.swapping_saver(var_list=[var0, ema_var0]) inference_saver = saver.Saver() variables.global_variables_initializer().run() # Step 1. update.run() - val0 = var0.eval() - val1 = var1.eval() self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) + if sequential_update: + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) # Test that the swapping saver save/restore operation is identity. train_saver.save(sess, save_path) train_saver.restore(sess, save_path) - val0 = var0.eval() - val1 = var1.eval() self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) + if sequential_update: + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) + # Test that the subset saver saves the EMA variable as well. + if sequential_update: + subset_save_path = save_path + '_subset' + train_saver_subset.save(sess, subset_save_path) + perturb.run() + self.assertAllCloseAccordingToType([1.8, 2.8], var0.eval()) + self.assertAllCloseAccordingToType([3.9, 4.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) + self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) + # Restoring should only restore var0 and ema_var0. + train_saver_subset.restore(sess, subset_save_path) + self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) + self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) + # Restore back to previou state. + train_saver.restore(sess, save_path) + # If updates are parallel, this is not always true after the 1st step. if sequential_update: # Test that the normal saver will have the averaged variables. diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py index b0a257d264f83ae0a54cdc0e9265d6e7098b7b56..825c08a09a05894df1656a9bb6981f1862195244 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py @@ -21,12 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.opt.python.training import nadam_optimizer -from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc index 9cee405cef25f54fd064f8002265c42016c4fa50..e18923c8aae74c66ce78f98eb5e615e99463af74 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc @@ -14,13 +14,12 @@ // limitations under the License. // ============================================================================= -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h" +#include "tensorflow/core/framework/register_types.h" namespace tensorflow { -REGISTER_KERNEL_BUILDER(Name("PeriodicResample") - .Device(DEVICE_CPU), +REGISTER_KERNEL_BUILDER(Name("PeriodicResample").Device(DEVICE_CPU), PeriodicResampleOp); } // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h index ba410f025d497178cfc1666ae231e75bad55b05e..3ab588c45881c8f93b4c1bcdf7ccde39086a1ed7 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h @@ -118,9 +118,9 @@ template #include -#include #include #include #include #include #include #include -#include -#include #include +#include +#include +#include #include #include "tensorflow/core/framework/graph.pb.h" @@ -46,10 +46,10 @@ limitations under the License. // These are all common classes it's handy to reference with no namespace. using tensorflow::Flag; -using tensorflow::Tensor; +using tensorflow::int32; using tensorflow::Status; using tensorflow::string; -using tensorflow::int32; +using tensorflow::Tensor; // Used to store the memory-mapped buffers we use for capture. struct CameraBuffer { diff --git a/tensorflow/contrib/pi_examples/label_image/label_image.cc b/tensorflow/contrib/pi_examples/label_image/label_image.cc index 0b18045789f3a87ceb228033407d6b696bdb33f6..c6935a093f728353caeeb79a9ed85c957d87f066 100644 --- a/tensorflow/contrib/pi_examples/label_image/label_image.cc +++ b/tensorflow/contrib/pi_examples/label_image/label_image.cc @@ -23,9 +23,9 @@ limitations under the License. // // Full build instructions are at tensorflow/contrib/pi_examples/README.md. -#include #include #include +#include #include #include @@ -46,10 +46,10 @@ limitations under the License. // These are all common classes it's handy to reference with no namespace. using tensorflow::Flag; -using tensorflow::Tensor; +using tensorflow::int32; using tensorflow::Status; using tensorflow::string; -using tensorflow::int32; +using tensorflow::Tensor; // Takes a file name, and loads a list of labels from it, one per line, and // returns a vector of the strings. It pads with empty strings so the length @@ -77,23 +77,22 @@ Status ReadLabelsFile(string file_name, std::vector* result, // Error handling for JPEG decoding. void CatchError(j_common_ptr cinfo) { (*cinfo->err->output_message)(cinfo); - jmp_buf *jpeg_jmpbuf = reinterpret_cast(cinfo->client_data); + jmp_buf* jpeg_jmpbuf = reinterpret_cast(cinfo->client_data); jpeg_destroy(cinfo); longjmp(*jpeg_jmpbuf, 1); } // Decompresses a JPEG file from disk. Status LoadJpegFile(string file_name, std::vector* data, - int* width, int* height, int* channels) { + int* width, int* height, int* channels) { struct jpeg_decompress_struct cinfo; - FILE * infile; + FILE* infile; JSAMPARRAY buffer; int row_stride; if ((infile = fopen(file_name.c_str(), "rb")) == NULL) { LOG(ERROR) << "Can't open " << file_name; - return tensorflow::errors::NotFound("JPEG file ", file_name, - " not found"); + return tensorflow::errors::NotFound("JPEG file ", file_name, " not found"); } struct jpeg_error_mgr jerr; @@ -116,10 +115,11 @@ Status LoadJpegFile(string file_name, std::vector* data, data->resize((*height) * (*width) * (*channels)); row_stride = cinfo.output_width * cinfo.output_components; - buffer = (*cinfo.mem->alloc_sarray) - ((j_common_ptr) &cinfo, JPOOL_IMAGE, row_stride, 1); + buffer = (*cinfo.mem->alloc_sarray)((j_common_ptr)&cinfo, JPOOL_IMAGE, + row_stride, 1); while (cinfo.output_scanline < cinfo.output_height) { - tensorflow::uint8* row_address = &((*data)[cinfo.output_scanline * row_stride]); + tensorflow::uint8* row_address = + &((*data)[cinfo.output_scanline * row_stride]); jpeg_read_scanlines(&cinfo, buffer, 1); memcpy(row_address, buffer[0], row_stride); } @@ -141,24 +141,25 @@ Status ReadTensorFromImageFile(string file_name, const int wanted_height, int image_height; int image_channels; TF_RETURN_IF_ERROR(LoadJpegFile(file_name, &image_data, &image_width, - &image_height, &image_channels)); - LOG(INFO) << "Loaded JPEG: " << image_width << "x" << image_height - << "x" << image_channels; + &image_height, &image_channels)); + LOG(INFO) << "Loaded JPEG: " << image_width << "x" << image_height << "x" + << image_channels; const int wanted_channels = 3; if (image_channels < wanted_channels) { - return tensorflow::errors::FailedPrecondition("Image needs to have at least ", - wanted_channels, " but only has ", - image_channels); + return tensorflow::errors::FailedPrecondition( + "Image needs to have at least ", wanted_channels, " but only has ", + image_channels); } - // In these loops, we convert the eight-bit data in the image into float, resize - // it using bilinear filtering, and scale it numerically to the float range that - // the model expects (given by input_mean and input_std). + // In these loops, we convert the eight-bit data in the image into float, + // resize it using bilinear filtering, and scale it numerically to the float + // range that the model expects (given by input_mean and input_std). tensorflow::Tensor image_tensor( - tensorflow::DT_FLOAT, tensorflow::TensorShape( - {1, wanted_height, wanted_width, wanted_channels})); + tensorflow::DT_FLOAT, + tensorflow::TensorShape( + {1, wanted_height, wanted_width, wanted_channels})); auto image_tensor_mapped = image_tensor.tensor(); tensorflow::uint8* in = image_data.data(); - float *out = image_tensor_mapped.data(); + float* out = image_tensor_mapped.data(); const size_t image_rowlen = image_width * image_channels; const float width_scale = static_cast(image_width) / wanted_width; const float height_scale = static_cast(image_height) / wanted_height; @@ -166,35 +167,37 @@ Status ReadTensorFromImageFile(string file_name, const int wanted_height, const float in_y = y * height_scale; const int top_y_index = static_cast(floorf(in_y)); const int bottom_y_index = - std::min(static_cast(ceilf(in_y)), (image_height - 1)); + std::min(static_cast(ceilf(in_y)), (image_height - 1)); const float y_lerp = in_y - top_y_index; tensorflow::uint8* in_top_row = in + (top_y_index * image_rowlen); tensorflow::uint8* in_bottom_row = in + (bottom_y_index * image_rowlen); - float *out_row = out + (y * wanted_width * wanted_channels); + float* out_row = out + (y * wanted_width * wanted_channels); for (int x = 0; x < wanted_width; ++x) { const float in_x = x * width_scale; const int left_x_index = static_cast(floorf(in_x)); const int right_x_index = - std::min(static_cast(ceilf(in_x)), (image_width - 1)); + std::min(static_cast(ceilf(in_x)), (image_width - 1)); tensorflow::uint8* in_top_left_pixel = - in_top_row + (left_x_index * wanted_channels); + in_top_row + (left_x_index * wanted_channels); tensorflow::uint8* in_top_right_pixel = - in_top_row + (right_x_index * wanted_channels); + in_top_row + (right_x_index * wanted_channels); tensorflow::uint8* in_bottom_left_pixel = - in_bottom_row + (left_x_index * wanted_channels); + in_bottom_row + (left_x_index * wanted_channels); tensorflow::uint8* in_bottom_right_pixel = - in_bottom_row + (right_x_index * wanted_channels); + in_bottom_row + (right_x_index * wanted_channels); const float x_lerp = in_x - left_x_index; - float *out_pixel = out_row + (x * wanted_channels); + float* out_pixel = out_row + (x * wanted_channels); for (int c = 0; c < wanted_channels; ++c) { - const float top_left((in_top_left_pixel[c] - input_mean) / input_std); - const float top_right((in_top_right_pixel[c] - input_mean) / input_std); - const float bottom_left((in_bottom_left_pixel[c] - input_mean) / input_std); - const float bottom_right((in_bottom_right_pixel[c] - input_mean) / input_std); - const float top = top_left + (top_right - top_left) * x_lerp; - const float bottom = - bottom_left + (bottom_right - bottom_left) * x_lerp; - out_pixel[c] = top + (bottom - top) * y_lerp; + const float top_left((in_top_left_pixel[c] - input_mean) / input_std); + const float top_right((in_top_right_pixel[c] - input_mean) / input_std); + const float bottom_left((in_bottom_left_pixel[c] - input_mean) / + input_std); + const float bottom_right((in_bottom_right_pixel[c] - input_mean) / + input_std); + const float top = top_left + (top_right - top_left) * x_lerp; + const float bottom = + bottom_left + (bottom_right - bottom_left) * x_lerp; + out_pixel[c] = top + (bottom - top) * y_lerp; } } } @@ -233,10 +236,10 @@ Status GetTopLabels(const std::vector& outputs, int how_many_labels, scores.push_back(std::pair({i, unsorted_scores_flat(i)})); } std::sort(scores.begin(), scores.end(), - [](const std::pair &left, - const std::pair &right) { - return left.second > right.second; - }); + [](const std::pair& left, + const std::pair& right) { + return left.second > right.second; + }); scores.resize(how_many_labels); Tensor sorted_indices(tensorflow::DT_INT32, {scores.size()}); Tensor sorted_scores(tensorflow::DT_FLOAT, {scores.size()}); diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py index e8443e718d1e81a88b752eb639dcee9c89aa56dc..578d9424b25dd38f1d77a267d1fdf1ff9ff2da88 100644 --- a/tensorflow/contrib/predictor/predictor_factories_test.py +++ b/tensorflow/contrib/predictor/predictor_factories_test.py @@ -50,8 +50,8 @@ class PredictorFactoriesTest(test.TestCase): def testFromContribEstimator(self): estimator = testing_common.get_arithmetic_estimator(core=False) input_fn = testing_common.get_arithmetic_input_fn(core=False) - predictor_factories.from_contrib_estimator(estimator, input_fn, - output_alternative_key='sum') + predictor_factories.from_contrib_estimator( + estimator, input_fn, output_alternative_key='sum') def testFromContribEstimatorWithCoreEstimatorRaises(self): estimator = testing_common.get_arithmetic_estimator(core=True) diff --git a/tensorflow/contrib/py2tf/BUILD b/tensorflow/contrib/py2tf/BUILD index d395de986d2364f1f6567e1ecbf0a873cbb0aa8c..d91220f6ddb859ff52d4e5853948cb667981009b 100644 --- a/tensorflow/contrib/py2tf/BUILD +++ b/tensorflow/contrib/py2tf/BUILD @@ -18,66 +18,14 @@ py_library( name = "py2tf", srcs = [ "__init__.py", - "api.py", - "config.py", - "conversion.py", - "naming.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/converters", + "//tensorflow/contrib/py2tf/impl", "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", + "//tensorflow/contrib/py2tf/utils", "@gast_archive//:gast", "@six_archive//:six", ], ) - -# Separate target that allows access to internal symbols for testing. -py_library( - name = "py2tf_internal", - srcs = [ - "api.py", - "config.py", - "conversion.py", - "naming.py", - ], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:__subpackages__"], - deps = [ - "//tensorflow/contrib/py2tf/converters", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", - "@gast_archive//:gast", - "@six_archive//:six", - ], -) - -py_test( - name = "api_test", - srcs = ["api_test.py"], - deps = [ - ":py2tf_internal", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "conversion_test", - srcs = ["conversion_test.py"], - deps = [ - ":py2tf_internal", - "//tensorflow/python:client_testlib", - "@gast_archive//:gast", - ], -) - -py_test( - name = "naming_test", - srcs = ["naming_test.py"], - deps = [ - ":py2tf_internal", - "//tensorflow/python:client_testlib", - ], -) diff --git a/tensorflow/contrib/py2tf/__init__.py b/tensorflow/contrib/py2tf/__init__.py index d187da99e065cb2d31ae4e45a9570378f9d1bf27..379fa7fd5c2a22b5b16a21cca8c2ea8afdcaeefa 100644 --- a/tensorflow/contrib/py2tf/__init__.py +++ b/tensorflow/contrib/py2tf/__init__.py @@ -21,11 +21,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.api import to_code -from tensorflow.contrib.py2tf.api import to_graph +from tensorflow.contrib.py2tf import utils +from tensorflow.contrib.py2tf.impl.api import convert +from tensorflow.contrib.py2tf.impl.api import graph_ready +from tensorflow.contrib.py2tf.impl.api import to_code +from tensorflow.contrib.py2tf.impl.api import to_graph +from tensorflow.contrib.py2tf.pyct.transformer import PyFlowParseError from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ['to_graph', 'to_code'] +_allowed_symbols = [ + 'to_graph', 'to_code', 'convert', 'graph_ready', 'utils', 'PyFlowParseError' +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/py2tf/converters/BUILD b/tensorflow/contrib/py2tf/converters/BUILD index 2b0a1234e6934c8a0ee73316a2fb7bfdb991f7e9..e9a96ec8d1dfc01ff6bc3b1fcaaef8e9b71a14a8 100644 --- a/tensorflow/contrib/py2tf/converters/BUILD +++ b/tensorflow/contrib/py2tf/converters/BUILD @@ -17,15 +17,16 @@ filegroup( py_library( name = "converters", srcs = [ - "break_canonicalization.py", + "asserts.py", + "break_statements.py", "builtin_functions.py", "call_trees.py", - "continue_canonicalization.py", + "continue_statements.py", "control_flow.py", "decorators.py", - "for_canonicalization.py", + "for_loops.py", + "list_comprehension.py", "logical_expressions.py", - "print_functions.py", "side_effect_guards.py", ], srcs_version = "PY2AND3", @@ -45,13 +46,38 @@ py_library( deps = [ ":converters", "//tensorflow/contrib/py2tf/pyct/static_analysis", + "//tensorflow/contrib/py2tf/utils", "@gast_archive//:gast", + "@six_archive//:six", ], ) py_test( - name = "break_canonicalization_test", - srcs = ["break_canonicalization_test.py"], + name = "asserts_test", + srcs = ["asserts_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "break_statements_test", + srcs = ["break_statements_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "builtin_functions_test", + srcs = ["builtin_functions_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", @@ -62,6 +88,7 @@ py_test( py_test( name = "call_trees_test", srcs = ["call_trees_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", @@ -70,8 +97,9 @@ py_test( ) py_test( - name = "continue_canonicalization_test", - srcs = ["continue_canonicalization_test.py"], + name = "continue_statements_test", + srcs = ["continue_statements_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", @@ -82,6 +110,7 @@ py_test( py_test( name = "control_flow_test", srcs = ["control_flow_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", @@ -90,8 +119,9 @@ py_test( ) py_test( - name = "builtin_functions_test", - srcs = ["builtin_functions_test.py"], + name = "decorators_test", + srcs = ["decorators_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", @@ -100,8 +130,9 @@ py_test( ) py_test( - name = "for_canonicalization_test", - srcs = ["for_canonicalization_test.py"], + name = "for_loops_test", + srcs = ["for_loops_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", @@ -110,8 +141,9 @@ py_test( ) py_test( - name = "logical_expressions_test", - srcs = ["logical_expressions_test.py"], + name = "list_comprehension_test", + srcs = ["list_comprehension_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", @@ -120,19 +152,20 @@ py_test( ) py_test( - name = "print_functions_test", - srcs = ["print_functions_test.py"], + name = "logical_expressions_test", + srcs = ["logical_expressions_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", - "@gast_archive//:gast", ], ) py_test( name = "side_effect_guards_test", srcs = ["side_effect_guards_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", diff --git a/tensorflow/contrib/py2tf/converters/print_functions.py b/tensorflow/contrib/py2tf/converters/asserts.py similarity index 54% rename from tensorflow/contrib/py2tf/converters/print_functions.py rename to tensorflow/contrib/py2tf/converters/asserts.py index 5da738c4954fb628212562b73641e1fc27032168..5b9b8e772bed82df2429fd6cb94dbf7b565e22b3 100644 --- a/tensorflow/contrib/py2tf/converters/print_functions.py +++ b/tensorflow/contrib/py2tf/converters/asserts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Compatibility support. Converts Print nodes to function calls.""" +"""Converts Assert statements to their corresponding TF calls.""" from __future__ import absolute_import from __future__ import division @@ -20,32 +20,34 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer -class PrintFunctionTransformer(gast.NodeTransformer): +class AssertsTransformer(transformer.Base): """Transforms Print nodes to Call so they can be handled as functions.""" # pylint:disable=invalid-name - def visit_Print(self, node): + def visit_Assert(self, node): self.generic_visit(node) - for n in node.values: - n.ctx = gast.Param() - call_node = gast.Call( - func=gast.Name('print', gast.Load(), None), - args=node.values, - keywords=[]) - anno.setanno(call_node.func, 'live_val', print) - anno.setanno(call_node.func, 'fqn', 'print') - anno.setanno(call_node, 'args_scope', anno.getanno(node, 'args_scope')) - node = gast.Expr(call_node) - return node + + # Note: The lone tf.Assert call will be wrapped with control_dependencies + # by side_effect_guards. + template = """ + tf.Assert(test, [msg]) + """ + + if node.msg is None: + return templates.replace( + template, test=node.test, msg=gast.Str('Assertion error')) + elif isinstance(node.msg, gast.Str): + return templates.replace(template, test=node.test, msg=node.msg) + else: + raise NotImplementedError('Can only convert string messages for now.') # pylint:enable=invalid-name -def transform(node): - transformer = PrintFunctionTransformer() - node = transformer.visit(node) - return node +def transform(node, context): + return AssertsTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/print_functions_test.py b/tensorflow/contrib/py2tf/converters/asserts_test.py similarity index 73% rename from tensorflow/contrib/py2tf/converters/print_functions_test.py rename to tensorflow/contrib/py2tf/converters/asserts_test.py index 475196ce102955b350acf9bf94255997f875f62c..6611f2777a93a7e819c8becfa06a09b27f4e6aaf 100644 --- a/tensorflow/contrib/py2tf/converters/print_functions_test.py +++ b/tensorflow/contrib/py2tf/converters/asserts_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for print_functions module.""" +"""Tests for asserts module.""" from __future__ import absolute_import from __future__ import division @@ -20,24 +20,21 @@ from __future__ import print_function import gast +from tensorflow.contrib.py2tf.converters import asserts from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import print_functions -from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.python.platform import test -class PrintFunctionsTest(converter_test_base.TestCase): +class AssertsTest(converter_test_base.TestCase): def test_transform(self): def test_fn(a): - print(a) + assert a > 0 - node = self.parse_and_analyze(test_fn, {'print': print}) - node = print_functions.transform(node) - result = compiler.ast_to_object(node) + node = self.parse_and_analyze(test_fn, {}) + node = asserts.transform(node, self.ctx) - result.test_fn('a') self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call)) diff --git a/tensorflow/contrib/py2tf/converters/break_canonicalization.py b/tensorflow/contrib/py2tf/converters/break_statements.py similarity index 73% rename from tensorflow/contrib/py2tf/converters/break_canonicalization.py rename to tensorflow/contrib/py2tf/converters/break_statements.py index ef585734454db1aa1ffdb798d93978fb09752f05..bfb709c5e32c6f19dc0fd109df61ece925d701a3 100644 --- a/tensorflow/contrib/py2tf/converters/break_canonicalization.py +++ b/tensorflow/contrib/py2tf/converters/break_statements.py @@ -22,42 +22,38 @@ import gast from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno -class BreakCanonicalizationTransformer(gast.NodeTransformer): +class BreakCanonicalizationTransformer(transformer.Base): """Canonicalizes continue statements into additional conditionals.""" - def __init__(self, namer): - self.namer = namer + def __init__(self, context): + super(BreakCanonicalizationTransformer, self).__init__(context) # This is a stack structure, to correctly process nested loops. self.break_uses = [] def _create_break_check(self): - - def template(var_name): - (not var_name) # pylint:disable=pointless-statement - - expr, = templates.replace( - template, var_name=gast.Name(self.break_uses[-1][1], None, None)) + template = """ + (not var_name) + """ + expr, = templates.replace(template, var_name=self.break_uses[-1][1]) return expr.value def _create_break_trigger(self): - - def template(var_name): # pylint:disable=unused-argument + template = """ var_name = True - - block = templates.replace( - template, var_name=gast.Name(self.break_uses[-1][1], None, None)) + """ + block = templates.replace(template, var_name=self.break_uses[-1][1]) block.append(gast.Continue()) return block def _create_break_init(self): - - def template(var_name): # pylint:disable=unused-argument + template = """ var_name = False - - assign, = templates.replace( - template, var_name=gast.Name(self.break_uses[-1][1], None, None)) + """ + assign, = templates.replace(template, var_name=self.break_uses[-1][1]) return assign # TODO(mdan): Surely the transformer supports this better? @@ -73,9 +69,10 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer): def visit_While(self, node): self.generic_visit(node.test) - scope = anno.getanno(node, 'body_scope') + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.namer.new_symbol('break_requested', scope.referenced) + break_var = self.context.namer.new_symbol('break_requested', + scope.referenced) self.break_uses.append([False, break_var]) node.body = self._manual_visit_list(node.body) if self.break_uses[-1][0]: @@ -95,9 +92,10 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer): def visit_For(self, node): self.generic_visit(node.target) self.generic_visit(node.iter) - scope = anno.getanno(node, 'body_scope') + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.namer.new_symbol('break_requested', scope.referenced) + break_var = self.context.namer.new_symbol('break_requested', + scope.referenced) self.break_uses.append([False, break_var]) node.body = self._manual_visit_list(node.body) if self.break_uses[-1][0]: @@ -118,7 +116,5 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer): return self._create_break_trigger() -def transform(node, namer): - transformer = BreakCanonicalizationTransformer(namer) - node = transformer.visit(node) - return node +def transform(node, context): + return BreakCanonicalizationTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/break_statements_test.py similarity index 52% rename from tensorflow/contrib/py2tf/converters/break_canonicalization_test.py rename to tensorflow/contrib/py2tf/converters/break_statements_test.py index b5ba2ad923dfeb73b38169494f6c7ea16ee815f1..095fcdff07d44ecc6b9bb7f8d3e2c7c43df72a02 100644 --- a/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/break_statements_test.py @@ -12,25 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for break_canonicalization module.""" +"""Tests for break_statements module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import break_canonicalization -from tensorflow.contrib.py2tf.converters import control_flow +from tensorflow.contrib.py2tf.converters import break_statements from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.python.platform import test -class TestNamer(control_flow.SymbolNamer): - - def new_symbol(self, name_root, _): - return name_root - - class BreakCanonicalizationTest(converter_test_base.TestCase): def test_basic_break(self): @@ -44,15 +36,15 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = break_canonicalization.transform(node, TestNamer()) - result = compiler.ast_to_object(node) + node = self.parse_and_analyze(test_fn, {}) + node = break_statements.transform(node, self.ctx) - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + with self.compiled(node) as result: + self.assertEqual(test_fn(0), result.test_fn(0)) + self.assertEqual(test_fn(1), result.test_fn(1)) + self.assertEqual(test_fn(2), result.test_fn(2)) + self.assertEqual(test_fn(3), result.test_fn(3)) + self.assertEqual(test_fn(4), result.test_fn(4)) def test_basic_break_for_loop(self): @@ -76,16 +68,17 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = break_canonicalization.transform(node, TestNamer()) - result = compiler.ast_to_object(node) + node = self.parse_and_analyze(test_fn, {}) + node = break_statements.transform(node, self.ctx) - # The break is incompletely canonicalized. Everything is in place, but - # the loop does not break. - self.assertEqual(test_equiv_fn([]), result.test_fn([])) - self.assertEqual(test_equiv_fn([1]), result.test_fn([1])) - self.assertEqual(test_equiv_fn([2]), result.test_fn([2])) - self.assertEqual(test_equiv_fn([1, 2, 3, 4]), result.test_fn([1, 2, 3, 4])) + with self.compiled(node) as result: + # The break is incompletely canonicalized. Everything is in place, but + # the loop does not break. + self.assertEqual(test_equiv_fn([]), result.test_fn([])) + self.assertEqual(test_equiv_fn([1]), result.test_fn([1])) + self.assertEqual(test_equiv_fn([2]), result.test_fn([2])) + self.assertEqual( + test_equiv_fn([1, 2, 3, 4]), result.test_fn([1, 2, 3, 4])) def test_continue_deeply_nested(self): @@ -104,15 +97,15 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v, u, w - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = break_canonicalization.transform(node, TestNamer()) - result = compiler.ast_to_object(node) + node = self.parse_and_analyze(test_fn, {}) + node = break_statements.transform(node, self.ctx) - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + with self.compiled(node) as result: + self.assertEqual(test_fn(0), result.test_fn(0)) + self.assertEqual(test_fn(1), result.test_fn(1)) + self.assertEqual(test_fn(2), result.test_fn(2)) + self.assertEqual(test_fn(3), result.test_fn(3)) + self.assertEqual(test_fn(4), result.test_fn(4)) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/py2tf/converters/builtin_functions.py index b80c96c97ac0c55f449a83bd43f2b65cdbdba390..2eb00f90575920ac948e799b0e97a9cfccb42fad 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions.py +++ b/tensorflow/contrib/py2tf/converters/builtin_functions.py @@ -21,34 +21,56 @@ from __future__ import print_function import gast from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer -class BuiltinFunctionTransformer(gast.NodeTransformer): - """Transforms Print nodes to Call so they can be handled as functions.""" +class BuiltinFunctionTransformer(transformer.Base): + """Handles builtin functions. - # TODO(mdan): Bring print_functions in here. + This transformer only covers functions that are translated into a + TF equivalent, like `len`. + """ - def _convert_len(self, node): + def __init__(self, context): + super(BuiltinFunctionTransformer, self).__init__(context) - def template(args): - tf.shape(args)[0] # pylint:disable=undefined-variable,expression-not-assigned + # pylint:disable=invalid-name - new_call = templates.replace(template, args=node.args)[0].value - return new_call + def _convert_len(self, node): + template = """ + tf.shape(args)[0] + """ + return templates.replace(template, args=node.args)[0].value - # pylint:disable=invalid-name + def _convert_print(self, node): + template = """ + py2tf_utils.call_print(args) + """ + return templates.replace(template, args=node.args)[0].value def visit_Call(self, node): self.generic_visit(node) # TODO(mdan): This won't work if the function was hidden. if isinstance(node.func, gast.Name) and node.func.id == 'len': return self._convert_len(node) + if isinstance(node.func, gast.Name) and node.func.id == 'print': + return self._convert_print(node) return node + def visit_Print(self, node): + self.generic_visit(node) + args = node.values + # Following is the case when calling print(a, b) + if len(args) == 1 and isinstance(args[0], gast.Tuple): + args = args[0].elts + template = """ + fname(args) + """ + function_call = templates.replace(template, fname='print', args=args)[0] + return self.visit(function_call) + # pylint:enable=invalid-name -def transform(node): - transformer = BuiltinFunctionTransformer() - node = transformer.visit(node) - return node +def transform(node, context): + return BuiltinFunctionTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py b/tensorflow/contrib/py2tf/converters/builtin_functions_test.py index b5358da6bc0be06ec1f59d0ef58d926289b5b78f..b279ff77ef10b96586d3d68585adb0d5424afb90 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py +++ b/tensorflow/contrib/py2tf/converters/builtin_functions_test.py @@ -18,11 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + +import six + from tensorflow.contrib.py2tf.converters import builtin_functions from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -34,14 +39,76 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): return len(a) node = self.parse_and_analyze(test_fn, {'len': len}) - node = builtin_functions.transform(node) - result = compiler.ast_to_object(node) - setattr(result, 'tf', array_ops) - - with self.test_session() as sess: - self.assertEqual(3, - sess.run( - result.test_fn(constant_op.constant([0, 0, 0])))) + node = builtin_functions.transform(node, self.ctx) + + with self.compiled(node, array_ops.shape) as result: + with self.test_session() as sess: + self.assertEqual(3, + sess.run( + result.test_fn(constant_op.constant([0, 0, 0])))) + + def test_print_with_op(self): + + def test_fn(a): + print(a) + + node = self.parse_and_analyze(test_fn, {'print': print}) + node = builtin_functions.transform(node, self.ctx) + + # Note: it's relevant not to include script_ops.py_func here, to verify + # that tf.Print is used. + with self.compiled(node, logging_ops.Print) as result: + with self.test_session() as sess: + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + result.test_fn('a') + sess.run(sess.graph.get_operations()) + self.assertEqual(out_capturer.getvalue(), 'a\n') + finally: + sys.stdout = sys.__stdout__ + + def test_print_with_op_multiple_values(self): + + def test_fn(a, b): + print(a, b) + + node = self.parse_and_analyze(test_fn, {'print': print}) + node = builtin_functions.transform(node, self.ctx) + + # Note: it's relevant not to include script_ops.py_func here, to verify + # that tf.Print is used. + with self.compiled(node, logging_ops.Print) as result: + with self.test_session() as sess: + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + result.test_fn('a', 1) + sess.run(sess.graph.get_operations()) + self.assertEqual(out_capturer.getvalue(), 'a 1\n') + finally: + sys.stdout = sys.__stdout__ + + def test_print_with_py_func(self): + + def test_fn(a, b, c): + print(a, b, c) + + node = self.parse_and_analyze(test_fn, {'print': print}) + node = builtin_functions.transform(node, self.ctx) + + # Note: it's relevant not to include logging_ops.Print here, to verify + # that py_func is used. + with self.compiled(node, script_ops.py_func) as result: + with self.test_session() as sess: + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + result.test_fn('a', 1, [2, 3]) + sess.run(sess.graph.get_operations()) + self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n') + finally: + sys.stdout = sys.__stdout__ if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/py2tf/converters/call_trees.py index df071f596fc31502a98182f27bb66c54f71d2572..1050ba654c63bb52c1c5e71c981a6a0baa3fc987 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees.py +++ b/tensorflow/contrib/py2tf/converters/call_trees.py @@ -29,46 +29,46 @@ import gast from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import parser from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.python.util import tf_inspect class FunctionNamer(object): """Describes the interface for CallTreeTransformer's namer.""" def compiled_function_name(self, - original_name, - live_object=None, + original_fqn, + live_entity=None, owner_type=None): """Generate the name corresponding to the compiled version of a function. Args: - original_name: String - live_object: Callable, the actual target function, if known. + original_fqn: string or tuple(string) + live_entity: Callable, the actual target function, if known. owner_type: Optional object. If present, it indicates that the function is a member of the given type. Returns: - String. + string, bool """ raise NotImplementedError() - def compiled_class_name(self, original_name, live_object=None): + def compiled_class_name(self, original_fqn, live_entity=None): """Generate the name corresponding to the compiled version of a class. Args: - original_name: String - live_object: The actual target class, if known. + original_fqn: string or tuple(string) + live_entity: The actual target class, if known. Returns: - String. + string """ raise NotImplementedError() -class CallTreeTransformer(gast.NodeTransformer): +class CallTreeTransformer(transformer.Base): """Transforms the call tree by renaming transformed symbols.""" - def __init__(self, namer, namespace, uncompiled_modules, - nocompile_decorators): - self.namer = namer - self.namespace = namespace + def __init__(self, context, uncompiled_modules, nocompile_decorators): + super(CallTreeTransformer, self).__init__(context) self.uncompiled_modules = uncompiled_modules self.nocompile_decorators = nocompile_decorators @@ -78,7 +78,7 @@ class CallTreeTransformer(gast.NodeTransformer): if isinstance(node, gast.Call): return self._resolve_name(node.func) if isinstance(node, gast.Name): - return self.namespace.get(node.id) + return self.context.namespace.get(node.id) if isinstance(node, gast.Attribute): parent = self._resolve_name(node.value) if parent is not None: @@ -91,8 +91,12 @@ class CallTreeTransformer(gast.NodeTransformer): if anno.hasanno(node, 'live_val'): return anno.getanno(node, 'live_val') if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'): - member = getattr(anno.getanno(node, 'type'), node.attr) - return member + owner_type = anno.getanno(node, 'type') + if hasattr(owner_type, node.attr): + return getattr(owner_type, node.attr) + else: + raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' % + (owner_type, node.attr)) return None def _should_compile(self, node, fqn): @@ -106,14 +110,14 @@ class CallTreeTransformer(gast.NodeTransformer): # The decorators themselves are not to be converted. # If present, the decorators should appear as static functions. - target_obj = self._try_resolve_target(node.func) - if target_obj is not None: + target_entity = self._try_resolve_target(node.func) + if target_entity is not None: # This attribute is set by the decorator itself. # TODO(mdan): This may not play nicely with other wrapping decorators. - if hasattr(target_obj, '__pyct_is_compile_decorator'): + if hasattr(target_entity, '__pyct_is_compile_decorator'): return False - if target_obj in self.nocompile_decorators: + if target_entity in self.nocompile_decorators: return False # Inspect the target function decorators. If any include a @convert @@ -122,7 +126,8 @@ class CallTreeTransformer(gast.NodeTransformer): # To parse and re-analize each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: - target_node = parser.parse_object(target_obj).body[0] + target_node, _ = parser.parse_entity(target_entity) + target_node = target_node.body[0] except TypeError: # Functions whose source we cannot access are compilable (e.g. wrapped # to py_func). @@ -136,89 +141,76 @@ class CallTreeTransformer(gast.NodeTransformer): return True + def _determine_function_owner(self, m): + # TODO(mdan): The parent type should be known at analysis. Use that instead. + if hasattr(m, 'im_class'): # Python 2 + return m.im_class + if hasattr(m, '__qualname__'): # Python 3 + # Object attributes: should be bound to "self". + if hasattr(m, '__self__'): + return type(m.__self__) + + # Class attributes: should have the owner name in their namespace. + qn = m.__qualname__.split('.') + if len(qn) < 2: + return None + owner_name, func_name = qn[-2:] + if func_name != m.__name__: + raise ValueError('Inconsistent names detected ' + '(__qualname__[1] = "%s", __name__ = "%s") for %s.' % + (func_name, m.__name__, m)) + if owner_name == '': + return None + if owner_name not in self.context.namespace: + raise ValueError( + 'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' % + (owner_name, m, self.context.namespace)) + return self.context.namespace[owner_name] + return None + def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') - target_obj = anno.getanno(node.func, 'live_val') + target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if not self._should_compile(node, target_fqn): return node if anno.hasanno(node, 'is_constructor'): - new_name = self.namer.compiled_class_name( - '__'.join(target_fqn), live_object=target_obj) + new_name = self.context.namer.compiled_class_name( + target_fqn, live_entity=target_entity) + do_rename = True else: - new_name = self.namer.compiled_function_name( - '__'.join(target_fqn), live_object=target_obj) - node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None) - return node - - def _rename_member_function_of_known_type(self, node): - assert isinstance(node.func, gast.Attribute) - - type_fqn = anno.getanno(node.func, 'type_fqn') - assert anno.hasanno(node.func, 'type') - target_type = anno.getanno(node.func, 'type') - - if not self._should_compile(node, type_fqn): - return node - - # TODO(mdan): We should not assume that the namer only needs the - # member function name. - method_name = node.func.attr - method_object = getattr(target_type, method_name) - new_name = self.namer.compiled_function_name( - method_name, live_object=method_object, owner_type=target_type) - if new_name != node.func.attr: - # If a member function call is renamed, then the new function is no - # longer bound to the target object. We then refactor the call from: - # foo.bar(...) - # to: - # renamed_foo(bar, ...) - # TODO(mdan): This risks causing duplication, if target_type is renamed. - node.args = [node.func.value] + node.args - node.func = gast.Name(new_name, gast.Load(), None) + owner_type = self._determine_function_owner(target_entity) + new_name, do_rename = self.context.namer.compiled_function_name( + target_fqn, live_entity=target_entity, owner_type=owner_type) + + if do_rename: + if target_entity is not None: + if tf_inspect.ismethod(target_entity): + # The renaming process will transform it into a regular function. + # TODO(mdan): Is this complete? How does it work with nested members? + node.args = [node.func.value] + node.args + node.func = templates.replace('func_name', func_name=new_name)[0] return node def _wrap_to_py_func_no_return(self, node): - args_scope = anno.getanno(node, 'args_scope') # TODO(mdan): Properly handle varargs, kwargs, etc. - args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used) - - # pylint:disable=undefined-variable,unused-argument,function-redefined - - def template(call, wrapper, args): - - def wrapper(args): - call(args) - return 1 - - tf.py_func(wrapper, [args], [tf.int64]) - - # pylint:enable=undefined-variable,unused-argument,function-redefined - - wrapper_name = self.namer.compiled_function_name(node.func.id) - wrapper_def, call_expr = templates.replace( - template, - call=node.func, - wrapper=gast.Name(wrapper_name, gast.Load(), None), - args=args) - anno.setanno(call_expr.value, 'args_scope', args_scope) - # TODO(mdan): Rename this annotation to 'graph_ready' - anno.setanno(wrapper_def, 'skip_processing', True) - - return (wrapper_def, call_expr) + template = """ + py2tf_utils.wrap_py_func(func, None, (original_args,), True) + """ + return templates.replace(template, func=node.func, original_args=node.args) - def _function_is_compilable(self, target_obj): + def _function_is_compilable(self, target_entity): # TODO(mdan): This is just a placeholder. Implement. - return not isinstance(target_obj, types.BuiltinFunctionType) + return not isinstance(target_entity, types.BuiltinFunctionType) def visit_Expr(self, node): if isinstance(node.value, gast.Call): if anno.hasanno(node.value.func, 'live_val'): - target_obj = anno.getanno(node.value.func, 'live_val') - if not self._function_is_compilable(target_obj): + target_entity = anno.getanno(node.value.func, 'live_val') + if not self._function_is_compilable(target_entity): if anno.hasanno(node.value.func, 'fqn'): target_fqn = anno.getanno(node.value.func, 'fqn') if not self._should_compile(node.value, target_fqn): @@ -236,8 +228,8 @@ class CallTreeTransformer(gast.NodeTransformer): # If the function is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): - target_obj = anno.getanno(node.func, 'live_val') - if target_obj in self.nocompile_decorators: + target_entity = anno.getanno(node.func, 'live_val') + if target_entity in self.nocompile_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' @@ -246,28 +238,28 @@ class CallTreeTransformer(gast.NodeTransformer): self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): - target_obj = anno.getanno(node.func, 'live_val') - if self._function_is_compilable(target_obj): + target_entity = anno.getanno(node.func, 'live_val') + if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) else: raise NotImplementedError('py_func with return values') - elif anno.hasanno(node.func, 'type_fqn'): - node = self._rename_member_function_of_known_type(node) else: - raise NotImplementedError( - 'Member function call (of unknown type): %s.' % node.func.id) + if self.context.recursive: + raise NotImplementedError('Could not resolve target function.') + else: + # TODO(mdan): Double check. Is this reachable code? + pass return node # pylint:enable=invalid-name -def transform(node, namer, namespace, uncompiled_modules, nocompile_decorators): +def transform(node, context, uncompiled_modules, nocompile_decorators): """Transform function call to the compiled counterparts. Args: node: AST to transform. - namer: FunctionNamer-like. - namespace: Dict mapping symbol names to their corresponding live objects. + context: An EntityContext object. uncompiled_modules: set of string tuples, each tuple represents the fully qualified name of a package containing functions that will not be compiled. @@ -278,7 +270,6 @@ def transform(node, namer, namespace, uncompiled_modules, nocompile_decorators): node: The transformed AST new_names: set(string), containing any newly-generated names """ - transformer = CallTreeTransformer(namer, namespace, uncompiled_modules, - nocompile_decorators) - node = transformer.visit(node) + t = CallTreeTransformer(context, uncompiled_modules, nocompile_decorators) + node = t.visit(node) return node diff --git a/tensorflow/contrib/py2tf/converters/call_trees_test.py b/tensorflow/contrib/py2tf/converters/call_trees_test.py index 8cb8d7be0f122ed124b0fda69c745a349543a16d..777648dc0b31863227262fbf931aba680bb4ed98 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees_test.py +++ b/tensorflow/contrib/py2tf/converters/call_trees_test.py @@ -20,18 +20,11 @@ from __future__ import print_function from tensorflow.contrib.py2tf.converters import call_trees from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.python.framework import constant_op from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class TestNamer(call_trees.FunctionNamer): - - def compiled_function_name(self, original_name, live_object=None): - return 'renamed_%s' % original_name - - class CallTreesTest(converter_test_base.TestCase): def test_basic(self): @@ -46,12 +39,55 @@ class CallTreesTest(converter_test_base.TestCase): return test_fn_1(a) + 1 node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1}) - node = call_trees.transform(node, TestNamer(), {}, (), ()) - result = compiler.ast_to_object(node) - # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually. - setattr(result, 'renamed_test_fn_1', renamed_test_fn_1) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node) as result: + # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 + # manually. + result.renamed_test_fn_1 = renamed_test_fn_1 + self.assertEquals(3, result.test_fn_2(1)) + + def test_simple_methods(self): + + class TestClass(object): + + def test_fn_1(self, a): + return a + 1 + + def test_fn_2(self, a): + return self.test_fn_1(a) + 1 + + node = self.parse_and_analyze( + TestClass.test_fn_2, {'TestClass': TestClass}, + arg_types={'self': (TestClass.__name__, TestClass)}) + node = call_trees.transform(node, self.ctx, (), ()) - self.assertEquals(3, result.test_fn_2(1)) + with self.compiled(node) as result: + tc = TestClass() + self.assertEquals(3, result.test_fn_2(tc, 1)) + + def test_py_func_wrap_no_retval(self): + + def test_fn(a): + setattr(a, 'foo', 'bar') + + node = self.parse_and_analyze(test_fn, {'setattr': setattr}) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node) as result: + with self.test_session() as sess: + # The function has no return value, so we do some tricks to grab the + # generated py_func node and ensure its effect only happens at graph + # execution. + + class Dummy(object): + pass + + a = Dummy() + result.test_fn(a) + self.assertFalse(hasattr(a, 'foo')) + sess.run(sess.graph.get_operations()[0]) + self.assertEquals('bar', a.foo) def test_uncompiled_modules(self): @@ -64,20 +100,18 @@ class CallTreesTest(converter_test_base.TestCase): 'math_ops': math_ops, 'constant_op': constant_op }) - node = call_trees.transform(node, TestNamer(), {}, + node = call_trees.transform(node, self.ctx, set(((math_ops.__name__,), (constant_op.__name__,))), ()) - result = compiler.ast_to_object(node) - setattr(result, 'math_ops', math_ops) - setattr(result, 'constant_op', constant_op) - - with self.test_session() as sess: - # Not renamed, because the converter doesn't rename the definition itself. - # (the caller is responsible for that). - result_tensor = result.test_fn(constant_op.constant(1)) - result_val = sess.run(result_tensor) - self.assertEquals(3, result_val) + with self.compiled(node) as result: + result.math_ops = math_ops + result.constant_op = constant_op + with self.test_session() as sess: + # Not renamed, because the converter doesn't rename the definition + # itself (the caller is responsible for that). + result_tensor = result.test_fn(constant_op.constant(1)) + self.assertEquals(3, sess.run(result_tensor)) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/converters/continue_canonicalization.py b/tensorflow/contrib/py2tf/converters/continue_statements.py similarity index 78% rename from tensorflow/contrib/py2tf/converters/continue_canonicalization.py rename to tensorflow/contrib/py2tf/converters/continue_statements.py index 7f8ace77a830ebcc4d49fcf2190e4bac920b1cde..4069a678b118b56b59d2e5491bb80cf52efd8143 100644 --- a/tensorflow/contrib/py2tf/converters/continue_canonicalization.py +++ b/tensorflow/contrib/py2tf/converters/continue_statements.py @@ -18,47 +18,43 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gast - from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno -class ContinueCanonicalizationTransformer(gast.NodeTransformer): +class ContinueCanonicalizationTransformer(transformer.Base): """Canonicalizes continue statements into additional conditionals.""" - def __init__(self, namer): - self.namer = namer + def __init__(self, context): + super(ContinueCanonicalizationTransformer, self).__init__(context) # This is a stack structure, to correctly process nested loops. self.continuation_uses = [] def _create_continuation_check(self): - - def template(var_name): + template = """ if not var_name: pass - - cond, = templates.replace( - template, var_name=gast.Name(self.continuation_uses[-1][1], None, None)) + """ + cond, = templates.replace(template, var_name=self.continuation_uses[-1][1]) cond.body = [] return cond def _create_continuation_trigger(self): - - def template(var_name): # pylint:disable=unused-argument + template = """ var_name = True - + """ assign, = templates.replace( - template, var_name=gast.Name(self.continuation_uses[-1][1], None, None)) + template, var_name=self.continuation_uses[-1][1]) return assign def _create_continuation_init(self): - - def template(var_name): # pylint:disable=unused-argument + template = """ var_name = False - + """ assign, = templates.replace( - template, var_name=gast.Name(self.continuation_uses[-1][1], None, None)) + template, var_name=self.continuation_uses[-1][1]) return assign def _visit_and_reindent_if_necessary(self, nodes): @@ -80,7 +76,7 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer): return reorganized_nodes def _process_loop_block(self, block, scope): - cont_var = self.namer.new_symbol('cont_requested', scope.referenced) + cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced) self.continuation_uses.append([False, cont_var]) block = self._visit_and_reindent_if_necessary(block) if self.continuation_uses[-1][0]: @@ -91,7 +87,8 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer): def visit_While(self, node): self.generic_visit(node.test) node.body = self._process_loop_block(node.body, - anno.getanno(node, 'body_scope')) + anno.getanno(node, + NodeAnno.BODY_SCOPE)) for n in node.orelse: self.generic_visit(n) return node @@ -100,7 +97,8 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer): self.generic_visit(node.target) self.generic_visit(node.iter) node.body = self._process_loop_block(node.body, - anno.getanno(node, 'body_scope')) + anno.getanno(node, + NodeAnno.BODY_SCOPE)) for n in node.orelse: self.generic_visit(n) return node @@ -126,6 +124,4 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer): def transform(node, namer): - transformer = ContinueCanonicalizationTransformer(namer) - node = transformer.visit(node) - return node + return ContinueCanonicalizationTransformer(namer).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/continue_statements_test.py similarity index 50% rename from tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py rename to tensorflow/contrib/py2tf/converters/continue_statements_test.py index c1fe903a2dd332626c8e64826652723c30ac412a..a598dcd1aed29478b7e3fe27e3c1b20010247dd9 100644 --- a/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/continue_statements_test.py @@ -12,25 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for continue_canonicalization module.""" +"""Tests for continue_statements module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import continue_canonicalization -from tensorflow.contrib.py2tf.converters import control_flow +from tensorflow.contrib.py2tf.converters import continue_statements from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.python.platform import test -class TestNamer(control_flow.SymbolNamer): - - def new_symbol(self, name_root, _): - return name_root - - class ContinueCanonicalizationTest(converter_test_base.TestCase): def test_basic_continue(self): @@ -44,15 +36,15 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = continue_canonicalization.transform(node, TestNamer()) - result = compiler.ast_to_object(node) + node = self.parse_and_analyze(test_fn, {}) + node = continue_statements.transform(node, self.ctx) - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + with self.compiled(node) as result: + self.assertEqual(test_fn(0), result.test_fn(0)) + self.assertEqual(test_fn(1), result.test_fn(1)) + self.assertEqual(test_fn(2), result.test_fn(2)) + self.assertEqual(test_fn(3), result.test_fn(3)) + self.assertEqual(test_fn(4), result.test_fn(4)) def test_basic_continue_for_loop(self): @@ -65,14 +57,14 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = continue_canonicalization.transform(node, TestNamer()) - result = compiler.ast_to_object(node) + node = self.parse_and_analyze(test_fn, {}) + node = continue_statements.transform(node, self.ctx) - self.assertEqual(test_fn([]), result.test_fn([])) - self.assertEqual(test_fn([1]), result.test_fn([1])) - self.assertEqual(test_fn([2]), result.test_fn([2])) - self.assertEqual(test_fn([1, 2, 3]), result.test_fn([1, 2, 3])) + with self.compiled(node) as result: + self.assertEqual(test_fn([]), result.test_fn([])) + self.assertEqual(test_fn([1]), result.test_fn([1])) + self.assertEqual(test_fn([2]), result.test_fn([2])) + self.assertEqual(test_fn([1, 2, 3]), result.test_fn([1, 2, 3])) def test_continue_deeply_nested(self): @@ -91,15 +83,15 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v, u, w - node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False) - node = continue_canonicalization.transform(node, TestNamer()) - result = compiler.ast_to_object(node) + node = self.parse_and_analyze(test_fn, {}) + node = continue_statements.transform(node, self.ctx) - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + with self.compiled(node) as result: + self.assertEqual(test_fn(0), result.test_fn(0)) + self.assertEqual(test_fn(1), result.test_fn(1)) + self.assertEqual(test_fn(2), result.test_fn(2)) + self.assertEqual(test_fn(3), result.test_fn(3)) + self.assertEqual(test_fn(4), result.test_fn(4)) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/converters/control_flow.py b/tensorflow/contrib/py2tf/converters/control_flow.py index 8ebd9ad93dbc17814d1d7f53c3eac2e078030141..d53e3e4fd6d87004cbe55bd430346ad263e898ea 100644 --- a/tensorflow/contrib/py2tf/converters/control_flow.py +++ b/tensorflow/contrib/py2tf/converters/control_flow.py @@ -21,7 +21,10 @@ from __future__ import print_function import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import ast_util from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -40,33 +43,65 @@ class SymbolNamer(object): raise NotImplementedError() -class SymbolRenamer(gast.NodeTransformer): - - def __init__(self, name_map): - self.name_map = name_map - - def visit_Name(self, node): - if node.id in self.name_map: - node.id = self.name_map[node.id] - return node - - -class ControlFlowTransformer(gast.NodeTransformer): +class ControlFlowTransformer(transformer.Base): """Transforms control flow structures like loops an conditionals.""" - def __init__(self, namer): - self.namer = namer + def __init__(self, context): + super(ControlFlowTransformer, self).__init__(context) # pylint:disable=invalid-name def visit_For(self, node): assert False, 'for statement should have been canonicalized at this point' + def _create_cond_branch(self, body_name, aliased_orig_names, + aliased_new_names, body, returns): + if aliased_orig_names: + template = """ + def body_name(): + aliased_new_names, = aliased_orig_names, + body + return (returns,) + """ + return templates.replace( + template, + body_name=body_name, + body=body, + aliased_orig_names=aliased_orig_names, + aliased_new_names=aliased_new_names, + returns=returns) + else: + template = """ + def body_name(): + body + return (returns,) + """ + return templates.replace( + template, body_name=body_name, body=body, returns=returns) + + def _create_cond_expr(self, results, test, body_name, orelse_name): + if results is not None: + template = """ + results = py2tf_utils.run_cond(test, body_name, orelse_name) + """ + return templates.replace( + template, + test=test, + results=results, + body_name=body_name, + orelse_name=orelse_name) + else: + template = """ + py2tf_utils.run_cond(test, body_name, orelse_name) + """ + return templates.replace( + template, test=test, body_name=body_name, orelse_name=orelse_name) + def visit_If(self, node): self.generic_visit(node) - body_scope = anno.getanno(node, 'body_scope') - orelse_scope = anno.getanno(node, 'orelse_scope') + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) if body_scope.created - orelse_scope.created: raise ValueError( @@ -75,30 +110,7 @@ class ControlFlowTransformer(gast.NodeTransformer): raise ValueError( 'The else branch creates new symbols that the if branch does not.') - def template( # pylint:disable=missing-docstring - test, - body_name, - body, - orelse_name, - orelse, - aliased, - aliases, # pylint:disable=unused-argument - aliased_results, - results): # pylint:disable=unused-argument - - def body_name(): # pylint:disable=function-redefined - aliases, = aliased, # pylint:disable=unused-variable - body # pylint:disable=pointless-statement - return (aliased_results,) - - def orelse_name(): # pylint:disable=function-redefined - aliases, = aliased, # pylint:disable=unused-variable - orelse # pylint:disable=pointless-statement - return (aliased_results,) - - results = tf.cond(test, body_name, orelse_name) # pylint:disable=undefined-variable - - all_modified = tuple(body_scope.modified | orelse_scope.modified) + modified = tuple(body_scope.modified | orelse_scope.modified) all_referenced = body_scope.referenced | orelse_scope.referenced # Alias the closure variables inside the conditional functions @@ -107,83 +119,103 @@ class ControlFlowTransformer(gast.NodeTransformer): need_alias = ( (body_scope.modified | orelse_scope.modified) - (body_scope.created | orelse_scope.created)) - aliased = tuple(need_alias) - aliases = tuple( - self.namer.new_symbol(s, all_referenced) for s in aliased) - alias_map = dict(zip(aliased, aliases)) - node_body = node.body - node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body] - node_orelse = node.orelse - node_orelse = [SymbolRenamer(alias_map).visit(n) for n in node_orelse] - - if len(all_modified) == 1: - results = gast.Name(all_modified[0], None, None) + aliased_orig_names = tuple(need_alias) + aliased_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), all_referenced) + for s in aliased_orig_names) + alias_map = dict(zip(aliased_orig_names, aliased_new_names)) + node_body = ast_util.rename_symbols(node.body, alias_map) + node_orelse = ast_util.rename_symbols(node.orelse, alias_map) + + if not modified: + # When the cond would return no value, we leave the cond called without + # results. That in turn should trigger the side effect guards. The + # branch functions will return a dummy value that ensures cond + # actually has some return value as well. + results = None + elif len(modified) == 1: + results = modified[0] else: - results = gast.Tuple( - tuple(gast.Name(s, None, None) for s in all_modified), None) + results = gast.Tuple([s.ast() for s in modified], None) - return templates.replace( - template, - test=node.test, - body_name=gast.Name( - self.namer.new_symbol('if_true', all_referenced), None, None), + body_name = self.context.namer.new_symbol('if_true', all_referenced) + orelse_name = self.context.namer.new_symbol('if_false', all_referenced) + if modified: + body_returns = tuple( + alias_map[s] if s in aliased_orig_names else s for s in modified) + else: + body_returns = templates.replace('tf.ones(())')[0].value + + body_def = self._create_cond_branch( + body_name, + aliased_orig_names=tuple(aliased_orig_names), + aliased_new_names=tuple(aliased_new_names), body=node_body, - orelse_name=gast.Name( - self.namer.new_symbol('if_false', all_referenced), None, None), - orelse=node_orelse, - aliased=tuple(gast.Name(s, None, None) for s in aliased), - aliases=tuple(gast.Name(s, None, None) for s in aliases), - aliased_results=tuple( - gast.Name(alias_map[s] if s in aliased else s, None, None) - for s in all_modified), - results=results) + returns=body_returns) + orelse_def = self._create_cond_branch( + orelse_name, + aliased_orig_names=tuple(aliased_orig_names), + aliased_new_names=tuple(aliased_new_names), + body=node_orelse, + returns=body_returns) + cond_expr = self._create_cond_expr(results, node.test, body_name, + orelse_name) + + return body_def + orelse_def + cond_expr def visit_While(self, node): self.generic_visit(node) - body_scope = anno.getanno(node, 'body_scope') - body_closure = tuple(body_scope.modified - body_scope.created) + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + body_closure = body_scope.modified - body_scope.created + all_referenced = body_scope.referenced + + state = list(body_closure) + state_ssf = [ + self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + ] + ssf_map = { + name: ssf + for name, ssf in zip(state, state_ssf) + if str(name) != ssf + } + + if len(state) == 1: + state = state[0] + state_ssf = state_ssf[0] + state_ast_tuple = state + else: + state_ast_tuple = gast.Tuple([n.ast() for n in state], None) - def template( - state, # pylint:disable=unused-argument - state_ast_tuple, # pylint:disable=unused-argument - test_name, - test, # pylint:disable=unused-argument - body_name, - body): + node_body = ast_util.rename_symbols(node.body, ssf_map) + test = ast_util.rename_symbols(node.test, ssf_map) - def test_name(state): # pylint:disable=function-redefined,unused-argument + template = """ + def test_name(state_ssf): return test - - def body_name(state): # pylint:disable=function-redefined,unused-argument - body # pylint:disable=pointless-statement - return state, - - state_ast_tuple = tf.while_loop(test_name, body_name, [state]) # pylint:disable=undefined-variable - - test_name = self.namer.new_symbol('loop_test', body_scope.referenced) - body_name = self.namer.new_symbol('loop_body', body_scope.referenced) - if len(body_closure) == 1: - state = gast.Name(body_closure[0], None, None) - state_ast_tuple = state - else: - state = tuple(gast.Name(n, None, None) for n in body_closure) - state_ast_tuple = gast.Tuple(state, None) + def body_name(state_ssf): + body + return state_ssf, + state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state]) + """ node = templates.replace( template, state=state, + state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, - test_name=gast.Name(test_name, gast.Load(), None), - test=node.test, - body_name=gast.Name(body_name, gast.Load(), None), - body=node.body) + test_name=self.context.namer.new_symbol('loop_test', + body_scope.referenced), + test=test, + body_name=self.context.namer.new_symbol('loop_body', + body_scope.referenced), + body=node_body) return node # pylint:enable=invalid-name -def transform(node, namer): - transformer = ControlFlowTransformer(namer) - node = transformer.visit(node) +def transform(node, context): + t = ControlFlowTransformer(context) + node = t.visit(node) return node diff --git a/tensorflow/contrib/py2tf/converters/control_flow_test.py b/tensorflow/contrib/py2tf/converters/control_flow_test.py index 054e33750dbae86559a9575dfecde64132b9a2cd..b785b284a7fb7a0257551326c88b44a341b295ba 100644 --- a/tensorflow/contrib/py2tf/converters/control_flow_test.py +++ b/tensorflow/contrib/py2tf/converters/control_flow_test.py @@ -20,23 +20,11 @@ from __future__ import print_function from tensorflow.contrib.py2tf.converters import control_flow from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.python.framework import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test -class TestNamer(control_flow.SymbolNamer): - - def new_symbol(self, name_root, used): - i = 0 - while True: - name = '%s%d' % (name_root, i) - if name not in used: - return name - i += 1 - - class ControlFlowTest(converter_test_base.TestCase): def test_simple_while(self): @@ -50,13 +38,12 @@ class ControlFlowTest(converter_test_base.TestCase): return s, i, n node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, TestNamer()) - result = compiler.ast_to_object(node) - setattr(result, 'tf', control_flow_ops) + node = control_flow.transform(node, self.ctx) - with self.test_session() as sess: - self.assertEqual((10, 5, 5), - sess.run(result.test_fn(constant_op.constant(5)))) + with self.compiled(node, control_flow_ops.while_loop) as result: + with self.test_session() as sess: + self.assertEqual((10, 5, 5), + sess.run(result.test_fn(constant_op.constant(5)))) def test_while_single_var(self): @@ -66,12 +53,11 @@ class ControlFlowTest(converter_test_base.TestCase): return n node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, TestNamer()) - result = compiler.ast_to_object(node) - setattr(result, 'tf', control_flow_ops) + node = control_flow.transform(node, self.ctx) - with self.test_session() as sess: - self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) + with self.compiled(node, control_flow_ops.while_loop) as result: + with self.test_session() as sess: + self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) def test_simple_if(self): @@ -85,15 +71,14 @@ class ControlFlowTest(converter_test_base.TestCase): return a, b node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, TestNamer()) - result = compiler.ast_to_object(node) - setattr(result, 'tf', control_flow_ops) + node = control_flow.transform(node, self.ctx) - with self.test_session() as sess: - self.assertEqual((-1, 0), sess.run( - result.test_fn(constant_op.constant(1)))) - self.assertEqual((0, -2), - sess.run(result.test_fn(constant_op.constant(-1)))) + with self.compiled(node, control_flow_ops.cond) as result: + with self.test_session() as sess: + self.assertEqual((-1, 0), + sess.run(result.test_fn(constant_op.constant(1)))) + self.assertEqual((0, -2), + sess.run(result.test_fn(constant_op.constant(-1)))) def test_if_single_var(self): @@ -103,12 +88,11 @@ class ControlFlowTest(converter_test_base.TestCase): return n node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, TestNamer()) - result = compiler.ast_to_object(node) - setattr(result, 'tf', control_flow_ops) + node = control_flow.transform(node, self.ctx) - with self.test_session() as sess: - self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) + with self.compiled(node, control_flow_ops.cond) as result: + with self.test_session() as sess: + self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/converters/converter_test_base.py b/tensorflow/contrib/py2tf/converters/converter_test_base.py index ed006bad6d833b3682f819e87aa8b9c279372e51..67747183dd323a799a04943ce4c7fe8c4093d002 100644 --- a/tensorflow/contrib/py2tf/converters/converter_test_base.py +++ b/tensorflow/contrib/py2tf/converters/converter_test_base.py @@ -18,31 +18,86 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib +import imp + +from tensorflow.contrib.py2tf import utils +from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.py2tf.pyct.static_analysis import activity from tensorflow.contrib.py2tf.pyct.static_analysis import live_values from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.platform import test +class FakeNamer(object): + + def new_symbol(self, name_root, used): + i = 0 + while True: + name = '%s%d' % (name_root, i) + if name not in used: + return name + i += 1 + + def compiled_function_name(self, + original_fqn, + live_entity=None, + owner_type=None): + del live_entity + if owner_type is not None: + return None, False + return ('renamed_%s' % '_'.join(original_fqn)), True + + class TestCase(test.TestCase): + """Base class for unit tests in this module. Contains relevant utilities.""" + + @contextlib.contextmanager + def compiled(self, node, *symbols): + source = '' + try: + result, source = compiler.ast_to_object(node) + result.tf = self.make_fake_tf(*symbols) + result.py2tf_utils = utils + yield result + except Exception: # pylint:disable=broad-except + print('Offending compiled code:\n%s' % source) + raise + + def make_fake_tf(self, *symbols): + fake_tf = imp.new_module('fake_tf') + for s in symbols: + setattr(fake_tf, s.__name__, s) + return fake_tf + + def attach_namespace(self, module, **ns): + for k, v in ns.items(): + setattr(module, k, v) def parse_and_analyze(self, test_fn, namespace, + namer=None, arg_types=None, - include_type_analysis=True): + include_type_analysis=True, + recursive=True): + node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( - namer=None, - source_code=None, + namer=namer or FakeNamer(), + source_code=source, source_file=None, namespace=namespace, arg_values=None, - arg_types=arg_types) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, namespace, {}) + arg_types=arg_types, + recursive=recursive) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + node = live_values.resolve(node, ctx, {}) if include_type_analysis: node = type_info.resolve(node, ctx) + node = live_values.resolve(node, ctx, {}) + self.ctx = ctx return node diff --git a/tensorflow/contrib/py2tf/converters/decorators.py b/tensorflow/contrib/py2tf/converters/decorators.py index a4313bfa510a81463a218cd21b41d9a7f43d1892..3f620c1cd2d9b75f82410754a7e812e13eabe3ae 100644 --- a/tensorflow/contrib/py2tf/converters/decorators.py +++ b/tensorflow/contrib/py2tf/converters/decorators.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Handles decorators.""" +"""Handles decorators. + +Note: this module only deals with functions whose decorators are still recorded +in the AST. This does not always happen. See the unit test for an example. +""" from __future__ import absolute_import from __future__ import division @@ -34,17 +38,19 @@ class DecoratorsTransformer(gast.NodeTransformer): def visit_FunctionDef(self, node): self.generic_visit(node) + kept_decorators = [] for dec in node.decorator_list: if isinstance(dec, gast.Call): - dec = dec.func - if not anno.hasanno(dec, 'live_val'): + dec_func = dec.func + else: + dec_func = dec + if not anno.hasanno(dec_func, 'live_val'): raise ValueError( - 'Could not resolve decorator: %s' % pretty_printer.fmt(dec)) - dec_value = anno.getanno(dec, 'live_val') - if dec_value in self.remove_decorators: - continue - raise ValueError('Dont know how to convert decorators for now.') - node.decorator_list = [] + 'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func)) + dec_value = anno.getanno(dec_func, 'live_val') + if dec_value not in self.remove_decorators: + kept_decorators.append(dec) + node.decorator_list = kept_decorators return node # pylint:enable=invalid-name diff --git a/tensorflow/contrib/py2tf/converters/decorators_test.py b/tensorflow/contrib/py2tf/converters/decorators_test.py new file mode 100644 index 0000000000000000000000000000000000000000..402fa0dda28e696f70d0354ca4abf3a6c83506d9 --- /dev/null +++ b/tensorflow/contrib/py2tf/converters/decorators_test.py @@ -0,0 +1,102 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for decorators module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import textwrap + +from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.py2tf.converters import decorators +from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect + + +class DecoratorsTest(converter_test_base.TestCase): + + def test_function_decorator(self): + + def function_decorator(): + + def decorator(f): + return lambda a: f(a) + 1 + + return decorator + + # The Python parser does capture decorators into the AST. + # However, the interpreter desugars them on load, and refering to the + # decorated function at runtime usually loses any trace of the decorator. + # Below is an example when that doesn't happen. + def static_wrapper(): + + @function_decorator() + def test_fn(a): # pylint:disable=unused-variable + return a + + node = self.parse_and_analyze(static_wrapper, + {'function_decorator': function_decorator}) + node = node.body[0].body[0] + + node = decorators.transform(node, remove_decorators=()) + # Since the decorator is not removed, we need to include its source + # code. We cannot do it after the fact because decorators are executed + # on load. + result, _ = compiler.ast_to_object( + node, + source_prefix=textwrap.dedent(tf_inspect.getsource(function_decorator))) + self.assertEqual(2, result.test_fn(1)) + + node = decorators.transform(node, remove_decorators=(function_decorator,)) + with self.compiled(node) as result: + self.assertEqual(1, result.test_fn(1)) + + def test_simple_decorator(self): + + def simple_decorator(f): + return lambda a: f(a) + 1 + + # The Python parser does capture decorators into the AST. + # However, the interpreter desugars them upon load, and refering to the + # decorated function at runtime usually loses any trace of the decorator. + # Below is an example when that doesn't happen. + def static_wrapper(): + + @simple_decorator + def test_fn(a): # pylint:disable=unused-variable + return a + + node = self.parse_and_analyze(static_wrapper, + {'simple_decorator': simple_decorator}) + node = node.body[0].body[0] + + node = decorators.transform(node, remove_decorators=()) + # Since the decorator is not removed, we need to include its source + # code. We cannot do it after the fact because decorators are executed + # on load. + result, _ = compiler.ast_to_object( + node, + source_prefix=textwrap.dedent(tf_inspect.getsource(simple_decorator))) + self.assertEqual(2, result.test_fn(1)) + + node = decorators.transform(node, remove_decorators=(simple_decorator,)) + with self.compiled(node) as result: + self.assertEqual(1, result.test_fn(1)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/for_canonicalization.py b/tensorflow/contrib/py2tf/converters/for_loops.py similarity index 62% rename from tensorflow/contrib/py2tf/converters/for_canonicalization.py rename to tensorflow/contrib/py2tf/converters/for_loops.py index 52360789cdc25528d925092e3e269c9968f2022f..935dade0ed30975dd29c8ffe5be875993936d241 100644 --- a/tensorflow/contrib/py2tf/converters/for_canonicalization.py +++ b/tensorflow/contrib/py2tf/converters/for_loops.py @@ -22,66 +22,58 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gast - from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno -class ForLoopCanonicalizationTransformer(gast.NodeTransformer): +class ForLoopCanonicalizationTransformer(transformer.Base): """Canonicalizes for loops (e.g. into while loops).""" - def __init__(self, namer): - self.namer = namer + def __init__(self, context): + super(ForLoopCanonicalizationTransformer, self).__init__(context) def visit_For(self, node): self.generic_visit(node) - body_scope = anno.getanno(node, 'body_scope') - - # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)` - # Or maybe we should replace range with tf.range? + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) if anno.hasanno(node, 'extra_cond'): - - def template(loop_iter, target, body, i, n, extra_cond): # pylint:disable=unused-argument + template = """ i = 0 - n = len(loop_iter) # pylint:disable=undefined-variable + n = len(loop_iter) while i < n and extra_cond: # TODO(mdan): Use TensorListFromTensor(loop_iter) here. target = loop_iter[i] - body # pylint:disable=pointless-statement + body i += 1 - + """ return templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, - i=gast.Name( - self.namer.new_symbol('i', body_scope.referenced), None, None), - n=gast.Name( - self.namer.new_symbol('n', body_scope.referenced), None, None), + i=self.context.namer.new_symbol('i', body_scope.referenced), + n=self.context.namer.new_symbol('n', body_scope.referenced), extra_cond=anno.getanno(node, 'extra_cond')) else: - - def template(loop_iter, target, body, i, n): # pylint:disable=unused-argument + template = """ i = 0 - n = len(loop_iter) # pylint:disable=undefined-variable + n = len(loop_iter) while i < n: # TODO(mdan): Use TensorListFromTensor(loop_iter) here. target = loop_iter[i] body # pylint:disable=pointless-statement i += 1 - - return templates.replace( + """ + repl = templates.replace( template, loop_iter=node.iter, target=node.target, body=node.body, - i=gast.Name( - self.namer.new_symbol('i', body_scope.referenced), None, None), - n=gast.Name( - self.namer.new_symbol('n', body_scope.referenced), None, None)) + i=self.context.namer.new_symbol('i', body_scope.referenced), + n=self.context.namer.new_symbol('n', body_scope.referenced)) + return repl def visit_Continue(self, node): assert False, 'continue statement should be desugared at this point' @@ -90,7 +82,5 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer): assert False, 'break statement should be desugared at this point' -def transform(node, namer): - transformer = ForLoopCanonicalizationTransformer(namer) - node = transformer.visit(node) - return node +def transform(node, context): + return ForLoopCanonicalizationTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/for_loops_test.py similarity index 67% rename from tensorflow/contrib/py2tf/converters/for_canonicalization_test.py rename to tensorflow/contrib/py2tf/converters/for_loops_test.py index a6e6350fd45e9c9575af9c12d3d0c4e9b89bee41..70a367d3b517e528b67f260d607431d324d2ab7d 100644 --- a/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/for_loops_test.py @@ -12,25 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for for_canonicalization module.""" +"""Tests for for_loops module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import control_flow from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import for_canonicalization -from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.contrib.py2tf.converters import for_loops from tensorflow.python.platform import test -class TestNamer(control_flow.SymbolNamer): - - def new_symbol(self, name_root, _): - return name_root - - class ControlFlowTest(converter_test_base.TestCase): def test_basic_for(self): @@ -42,13 +34,13 @@ class ControlFlowTest(converter_test_base.TestCase): return s node = self.parse_and_analyze(test_fn, {}) - node = for_canonicalization.transform(node, TestNamer()) - result = compiler.ast_to_object(node) + node = for_loops.transform(node, self.ctx) - l = [1, 2, 3] - self.assertEqual(test_fn(l), result.test_fn(l)) - l = [] - self.assertEqual(test_fn(l), result.test_fn(l)) + with self.compiled(node) as result: + l = [1, 2, 3] + self.assertEqual(test_fn(l), result.test_fn(l)) + l = [] + self.assertEqual(test_fn(l), result.test_fn(l)) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension.py b/tensorflow/contrib/py2tf/converters/list_comprehension.py new file mode 100644 index 0000000000000000000000000000000000000000..e8744831100e4852919b5cd1253b74acea4d790d --- /dev/null +++ b/tensorflow/contrib/py2tf/converters/list_comprehension.py @@ -0,0 +1,80 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Canonicalizing list comprehensions into for and if statements. + +e.g. +result = [x * x for x in xs] + +becomes + +result = [] +for x in xs: + elt = x * x + result.append(elt) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer + + +class ListCompCanonicalizationTransformer(transformer.Base): + """NodeTransformer to canonicalize list comprehensions.""" + + def __init__(self, context): + super(ListCompCanonicalizationTransformer, self).__init__(context) + + def make_update_list_node(self, list_, elt): + return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0] + + def instantiate_list_node(self): + return parser.parse_str('[]').body[0].value + + def visit_Assign(self, node): + if not isinstance(node.value, gast.ListComp): + return node + if len(node.targets) > 1: + raise ValueError('Only support single assignment.') + return self.canonicalize_listcomp(node.targets[0], node.value) + + def canonicalize_listcomp(self, result_node, list_comp_node): + + make_list = templates.replace( + 'list_ = create_list', + list_=result_node, + create_list=self.instantiate_list_node()) + loop_body = self.make_update_list_node(result_node, list_comp_node.elt) + + for gen in reversed(list_comp_node.generators): + for gen_if in reversed(gen.ifs): + loop_body = templates.replace( + 'if test: loop_body', test=gen_if, loop_body=loop_body) + loop_body = templates.replace( + 'for target in iter_: loop_body', + iter_=gen.iter, + target=gen.target, + loop_body=loop_body) + + return make_list + loop_body + + +def transform(node, context): + return ListCompCanonicalizationTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension_test.py b/tensorflow/contrib/py2tf/converters/list_comprehension_test.py new file mode 100644 index 0000000000000000000000000000000000000000..025fac11e41e6771fbb9b80ff3da70dc3ceec73e --- /dev/null +++ b/tensorflow/contrib/py2tf/converters/list_comprehension_test.py @@ -0,0 +1,75 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for list_comprehension module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.py2tf.converters import list_comprehension +from tensorflow.python.platform import test + + +class ListCompTest(converter_test_base.TestCase): + + def test_basic(self): + + def test_fn(l): + s = [e * e for e in l] + return s + + node = self.parse_and_analyze(test_fn, {}) + node = list_comprehension.transform(node, self.ctx) + + with self.compiled(node) as result: + l = [1, 2, 3] + self.assertEqual(test_fn(l), result.test_fn(l)) + l = [] + self.assertEqual(test_fn(l), result.test_fn(l)) + + def test_multiple_generators(self): + + def test_fn(l): + s = [e * e for sublist in l for e in sublist] + return s + + node = self.parse_and_analyze(test_fn, {}) + node = list_comprehension.transform(node, self.ctx) + + with self.compiled(node) as result: + l = [[1], [2], [3]] + self.assertEqual(test_fn(l), result.test_fn(l)) + l = [] + self.assertEqual(test_fn(l), result.test_fn(l)) + + def test_conds(self): + + def test_fn(l): + s = [e * e for e in l if e > 1] + return s + + node = self.parse_and_analyze(test_fn, {}) + node = list_comprehension.transform(node, self.ctx) + + with self.compiled(node) as result: + l = [1, 2, 3] + self.assertEqual(test_fn(l), result.test_fn(l)) + l = [] + self.assertEqual(test_fn(l), result.test_fn(l)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/logical_expressions_test.py b/tensorflow/contrib/py2tf/converters/logical_expressions_test.py index d711065099b24ad814104e6460e6ca551b31b3e6..a28326c517d468230f35e45f0fbfe5257d769895 100644 --- a/tensorflow/contrib/py2tf/converters/logical_expressions_test.py +++ b/tensorflow/contrib/py2tf/converters/logical_expressions_test.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.contrib.py2tf.converters import logical_expressions -from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -34,12 +33,11 @@ class GradientsFunctionTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = logical_expressions.transform(node) - result = compiler.ast_to_object(node) - setattr(result, 'tf', math_ops) - with self.test_session() as sess: - self.assertTrue(sess.run(result.test_fn(1, 1))) - self.assertFalse(sess.run(result.test_fn(1, 2))) + with self.compiled(node, math_ops.equal) as result: + with self.test_session() as sess: + self.assertTrue(sess.run(result.test_fn(1, 1))) + self.assertFalse(sess.run(result.test_fn(1, 2))) def test_bool_ops(self): @@ -48,11 +46,11 @@ class GradientsFunctionTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = logical_expressions.transform(node) - result = compiler.ast_to_object(node) - setattr(result, 'tf', math_ops) - with self.test_session() as sess: - self.assertTrue(sess.run(result.test_fn(True, False, True))) + with self.compiled(node, math_ops.logical_or, + math_ops.logical_and) as result: + with self.test_session() as sess: + self.assertTrue(sess.run(result.test_fn(True, False, True))) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards.py b/tensorflow/contrib/py2tf/converters/side_effect_guards.py index 1f25303fbac1184d016a63d629ba2ecf17d7e426..30976b3ec6db5a6607023ac804d9d54cfb296190 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards.py +++ b/tensorflow/contrib/py2tf/converters/side_effect_guards.py @@ -37,7 +37,11 @@ from __future__ import print_function import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import ast_util +from tensorflow.contrib.py2tf.pyct import qual_names from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -55,53 +59,61 @@ class SymbolNamer(object): raise NotImplementedError() -class SideEffectGuardTransformer(gast.NodeTransformer): +class SideEffectGuardTransformer(transformer.Base): """Adds control dependencies to functions with side effects.""" - def __init__(self, namer): - self.namer = namer - self.indent_next = False - self.next_indent_owner = None + def __init__(self, context): + super(SideEffectGuardTransformer, self).__init__(context) # pylint:disable=invalid-name def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes + alias_map = {} + reindent_requested = False for n in nodes: n = self.visit(n) + # NOTE: the order in which these statements execute is important; in + # particular, watch out for ending up with cycles in the AST. + if alias_map: + n = ast_util.rename_symbols(n, alias_map) if isinstance(n, (list, tuple)): current_dest.extend(n) else: current_dest.append(n) - if self.indent_next: - assert self.next_indent_owner is not None - current_dest.append(self.next_indent_owner) - current_dest = self.next_indent_owner.body - self.next_indent_owner = None - self.indent_next = False - if not current_dest: + if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): + reindent_requested = True + new_dest, new_alias_map = anno.getanno( + n, anno.Basic.INDENT_BLOCK_REMAINDER) + anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) + new_alias_map.update(alias_map) + alias_map = new_alias_map + current_dest = new_dest + if reindent_requested and not current_dest: # TODO(mdan): There may still be something that could be done. raise ValueError('Unable to insert statement into the computation flow: ' - 'it is not followed by any computation that can we can ' - 'condition on the statement.') + 'it is not followed by any computation which ' + 'the statement could gate.') return new_nodes def visit_FunctionDef(self, node): - if anno.hasanno(node, 'skip_processing'): - return node node.body = self._visit_and_reindent(node.body) return node - def _gate_symbols(self, guard_statement, guarded_args): + def visit_With(self, node): + node.body = self._visit_and_reindent(node.body) + return node - def template(args): # pylint:disable=unused-argument - (args,) = (tf.identity(a) for a in (args,)) # pylint:disable=undefined-variable + def visit_If(self, node): + node.body = self._visit_and_reindent(node.body) + node.orelse = self._visit_and_reindent(node.orelse) + return node - guards = templates.replace( - template, args=tuple(gast.Name(a, None, None) for a in guarded_args)) - guard_statement.body.extend(guards) - return guard_statement + def visit_While(self, node): + node.body = self._visit_and_reindent(node.body) + node.orelse = self._visit_and_reindent(node.orelse) + return node def visit_Expr(self, node): self.generic_visit(node) @@ -111,49 +123,68 @@ class SideEffectGuardTransformer(gast.NodeTransformer): # or: # tf.py_func(...) - args_scope = anno.getanno(node.value, 'args_scope') - temp_name = self.namer.new_symbol('temp', args_scope.parent.referenced) - # TODO(mdan): Unsafe reference modification! - args_scope.mark_write(temp_name) - - def template(call, temp_result): - temp_result = call - if temp_result is not None: - if not isinstance(temp_result, (list, tuple)): - temp_result = (temp_result,) - ctx = tf.control_dependencies(temp_result) # pylint:disable=undefined-variable - else: - ctx = contextmanager(lambda: (yield))() # pylint:disable=undefined-variable - with ctx: - # TODO(mdan): Also insert ops to re-fetch if variables are involved. - pass # Will be removed below. - - # TODO(mdan): This is brittle. Reorganize this mechanism. - statements = templates.replace( - template, - call=node.value, - temp_result=gast.Name(temp_name, None, None)) - control_deps_guard = statements[-1] - control_deps_guard.body = [] - # First, attempt to gate future evaluation of args. If that's not # possible, gate all remaining statements (and that may fail too, see # _visit_and_reindent. - guarded_args = tuple( - n for n in args_scope.used if n in args_scope.parent.modified) + args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) + # NOTE: We can't guard object attributes because they may not be writable. + # In addition, avoid renaming well-known names. + # TODO(mdan): Move these names into config. + unguarded_names = (qual_names.QN('self'), qual_names.QN('tf')) + guarded_args = tuple(s for s in args_scope.used + if not s.is_composite() and s not in unguarded_names) + + # TODO(mdan): Include all arguments which depended on guarded_args too. + # For example, the following will still cause a race: + # tf.assign(a, a + 1) + # b = a + 1 + # tf.assign(a, a + 1) # Control deps here should include `b` + # c = b + 1 + # Or maybe we should just raise an "unsafe assign" error? + if guarded_args: - node = tuple(statements[:-1]) + ( - self._gate_symbols(control_deps_guard, guarded_args),) + # The aliases may need new names to avoid incorrectly making them local. + # TODO(mdan): This is brutal. It will even rename modules - any fix? + need_alias = tuple( + s for s in guarded_args if s not in args_scope.parent.modified) + aliased_new_names = tuple( + qual_names.QN( + self.context.namer.new_symbol( + s.ssf(), args_scope.parent.referenced)) for s in need_alias) + alias_map = dict(zip(need_alias, aliased_new_names)) + if len(guarded_args) == 1: + s, = guarded_args + aliased_guarded_args = alias_map.get(s, s) + else: + aliased_guarded_args = gast.Tuple( + [alias_map.get(s, s).ast() for s in guarded_args], None) + + template = """ + with py2tf_utils.control_dependency_on_returns(call): + aliased_guarded_args = py2tf_utils.alias_tensors(guarded_args) + """ + control_deps_guard = templates.replace( + template, + call=node.value, + aliased_guarded_args=aliased_guarded_args, + guarded_args=guarded_args)[-1] else: - node = tuple(statements[:-1]) - # The mechanism will insert the guard statement later. - self.indent_next = True - self.next_indent_owner = control_deps_guard + alias_map = {} + + template = """ + with py2tf_utils.control_dependency_on_returns(call): + pass + """ + control_deps_guard = templates.replace(template, call=node.value)[-1] + control_deps_guard.body = [] + + node = control_deps_guard + anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER, + (node.body, alias_map)) return node # pylint:enable=invalid-name -def transform(node, namer): - transformer = SideEffectGuardTransformer(namer) - return transformer.visit(node) +def transform(node, context): + return SideEffectGuardTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py b/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py index 5c56973dc2ae5d1976a68f040772e856cdaeabf5..463db2e770213ba9636d2537b095a77dece5d8f6 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py @@ -20,41 +20,145 @@ from __future__ import print_function from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.contrib.py2tf.converters import side_effect_guards -from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -class TestNamer(side_effect_guards.SymbolNamer): +class SideEffectGuardsTest(converter_test_base.TestCase): - def new_symbol(self, name_root, _): - return name_root + def test_side_effect_on_return_only_variable(self): + tf = None -class SideEffectGuardsTest(converter_test_base.TestCase): + def test_fn(a): + tf.assign(a, a + 1) + return a - def test_transform(self): + node = self.parse_and_analyze(test_fn, {}) + node = side_effect_guards.transform(node, self.ctx) + + with self.compiled(node, state_ops.assign) as result: + self.assertEqual(len(node.body[0].body), 1) + with self.test_session() as sess: + v = variables.Variable(2) + sess.run(v.initializer) + # NOTE: We don't expect the assignment to execute in this case, because + # variables cannot be reliably guarded. + self.assertEqual(2, sess.run(result.test_fn(v))) + + def test_side_effect_on_used_variable(self): + + tf = None def test_fn(a): - state_ops.assign(a, a + 1) + tf.assign(a, a + 1) + return a + 1 + + node = self.parse_and_analyze(test_fn, {}) + node = side_effect_guards.transform(node, self.ctx) + + with self.compiled(node, state_ops.assign) as result: + self.assertEqual(len(node.body[0].body), 1) + with self.test_session() as sess: + v = variables.Variable(2) + sess.run(v.initializer) + # NOTE: Unlike test_side_effect_on_return_only_variable, the variable + # was used in the local scope and so we could catch the assign's side + # effect. + self.assertEqual(4, sess.run(result.test_fn(v))) + + def test_side_effect_on_tensor(self): + + tf = None + + def test_fn(a): + tf.Assert(a > 0, ['expected in throw']) return a - node = self.parse_and_analyze(test_fn, {'state_ops': state_ops}) - node = side_effect_guards.transform(node, TestNamer()) - result = compiler.ast_to_object(node) - setattr(result, 'state_ops', state_ops) + node = self.parse_and_analyze(test_fn, {}) + node = side_effect_guards.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.Assert) as result: + self.assertEqual(len(node.body[0].body), 1) + with self.test_session() as sess: + # NOTE: In this case we can also capture the side effect because the + # argument is a tensor ans we can wrap it inside an identity. + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + 'expected in throw'): + sess.run(result.test_fn(constant_op.constant(-1))) + + def test_multiline_block(self): + + tf = None + + def test_fn(a): + tf.assign(a, a + 1) + b = a + 1 + tf.assign(a, b + 1) + c = b + 1 + d = c + 1 + return d + + node = self.parse_and_analyze(test_fn, {}) + node = side_effect_guards.transform(node, self.ctx) + + with self.compiled(node, state_ops.assign) as result: + self.assertEqual(len(node.body[0].body), 1) + with self.test_session() as sess: + v = variables.Variable(2) + sess.run(v.initializer) + self.assertEqual(6, sess.run(result.test_fn(v))) + + def test_multiline_nested_block(self): + + tf = None + + def test_fn(a): + with tf.name_scope('foo'): + tf.assign(a, a + 1) + b = a + 1 + c = b + 1 + d = c + 1 + return d + + node = self.parse_and_analyze(test_fn, {}) + node = side_effect_guards.transform(node, self.ctx) + + with self.compiled(node, state_ops.assign, ops.name_scope) as result: + self.assertEqual(len(node.body[0].body[0].body), 1) + with self.test_session() as sess: + v = variables.Variable(2) + sess.run(v.initializer) + self.assertEqual(6, sess.run(result.test_fn(v))) + + def test_multiline_block_unsafe(self): + + tf = None + + def test_fn(a): + tf.assign(a, a + 1) + b = a + 1 + tf.assign(a, a + 1) + c = b + 1 + d = c + 1 + return d - # TODO(mdan): Configure the namespaces instead of doing these hacks. - ops.identity = array_ops.identity - setattr(result, 'tf', ops) + node = self.parse_and_analyze(test_fn, {}) + node = side_effect_guards.transform(node, self.ctx) - with self.test_session() as sess: - v = variables.Variable(2) - sess.run(v.initializer) - self.assertEqual(3, sess.run(result.test_fn(v))) + with self.compiled(node, state_ops.assign) as result: + self.assertEqual(len(node.body[0].body), 1) + with self.test_session() as sess: + v = variables.Variable(2) + sess.run(v.initializer) + # NOTE: This intentionally highlights the flakiness. The test should be + # tightened down once that is solved. + self.assertTrue(sess.run(result.test_fn(v)) in (6, 7)) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/impl/BUILD b/tensorflow/contrib/py2tf/impl/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..90ffabbc9bf4524ec2ebf54b6dd847bd8768a486 --- /dev/null +++ b/tensorflow/contrib/py2tf/impl/BUILD @@ -0,0 +1,67 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "impl", + srcs = [ + "api.py", + "config.py", + "conversion.py", + "naming.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/py2tf/converters", + "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/py2tf/pyct/static_analysis", + "//tensorflow/contrib/py2tf/utils", + "@gast_archive//:gast", + "@six_archive//:six", + ], +) + +py_test( + name = "api_test", + srcs = ["api_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/contrib/py2tf/utils", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "conversion_test", + srcs = ["conversion_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + +py_test( + name = "naming_test", + srcs = ["naming_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/py2tf/api.py b/tensorflow/contrib/py2tf/impl/api.py similarity index 89% rename from tensorflow/contrib/py2tf/api.py rename to tensorflow/contrib/py2tf/impl/api.py index ca1f4e2645ee20fd78c0d837885823d2e199537a..29d2e038a73c8cac89c121ec65c32f0d4f68aff6 100644 --- a/tensorflow/contrib/py2tf/api.py +++ b/tensorflow/contrib/py2tf/impl/api.py @@ -23,10 +23,11 @@ from functools import wraps import gast import six -from tensorflow.contrib.py2tf import config -from tensorflow.contrib.py2tf import conversion +from tensorflow.contrib.py2tf.impl import config +from tensorflow.contrib.py2tf.impl import conversion from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect # TODO(mdan): Properly document the type hints. @@ -83,15 +84,16 @@ def convert_inline(f, *args, **kwargs): return convert(arg_value_hints)(f)(*args, **kwargs) -def convert(recursive=False, arg_types=None): +def convert(recursive=False, verbose=False, arg_types=None): """Decorator that compiles a function to graph mode. - The decorator is dynamic - invoking compilation whenever the decorated function - is called. This means the parameter values are known at compilation. + The decorator is dynamic - invoking compilation whenever the decorated + function is called. This means the parameter values are known at compilation. Args: recursive: Whether to recusrively convert any functions that the decorator function may call. + verbose: Whether to output the compiled code in the logs. arg_types: See to_graph. Returns: @@ -125,6 +127,7 @@ def convert(recursive=False, arg_types=None): wrapped = to_graph( f, recursive=recursive, + verbose=verbose, arg_values=arg_values, arg_types=arg_types, partial_types=partial_types) @@ -140,6 +143,7 @@ def convert(recursive=False, arg_types=None): def to_graph(e, recursive=True, + verbose=False, arg_values=None, arg_types=None, partial_types=None): @@ -155,6 +159,7 @@ def to_graph(e, e: A Python entity. recursive: Whether to recusrively convert any functions that the decorator function may call. + verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function @@ -170,7 +175,8 @@ def to_graph(e, conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, graph_ready, convert_inline), - partial_types=partial_types) + partial_types=partial_types, + api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) @@ -178,14 +184,17 @@ def to_graph(e, module.body.append(parser.parse_str(import_line)) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) - compiled_node = compiler.ast_to_object(module) + compiled_node, compiled_src = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): compiled_node.__dict__.update(six.get_function_globals(e)) - compiled_fn = getattr(compiled_node, name) + + if verbose: + logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) + return compiled_fn @@ -213,7 +222,8 @@ def to_code(e, conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, graph_ready, convert_inline), - partial_types=partial_types) + partial_types=partial_types, + api_module=tf_inspect.getmodule(to_graph)) conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) diff --git a/tensorflow/contrib/py2tf/api_test.py b/tensorflow/contrib/py2tf/impl/api_test.py similarity index 94% rename from tensorflow/contrib/py2tf/api_test.py rename to tensorflow/contrib/py2tf/impl/api_test.py index 2384447708d7e0ab5dbfbeb592a47353f1909f50..02cd8ed2d0ffee8ef2d31ea65902d2b493df9d64 100644 --- a/tensorflow/contrib/py2tf/api_test.py +++ b/tensorflow/contrib/py2tf/impl/api_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import api -from tensorflow.contrib.py2tf import config +from tensorflow.contrib.py2tf.impl import api +from tensorflow.contrib.py2tf.impl import config from tensorflow.contrib.py2tf.pyct import parser from tensorflow.python.framework import constant_op from tensorflow.python.ops import math_ops @@ -32,7 +32,9 @@ class ApiTest(test.TestCase): config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) config.COMPILED_IMPORT_STATEMENTS = ( 'from tensorflow.python.ops ' - 'import control_flow_ops as tf',) + 'import control_flow_ops as tf', + 'from tensorflow.contrib.py2tf import utils as ' + 'py2tf_utils') def test_decorator_recurses(self): @@ -183,7 +185,7 @@ class ApiTest(test.TestCase): compiled_code = api.to_code(test_fn) # Just check for some key words and that it is parseable Python code. - self.assertRegexpMatches(compiled_code, 'tf\\.while_loop') + self.assertRegexpMatches(compiled_code, 'py2tf_utils\\.run_while') self.assertIsNotNone(parser.parse_str(compiled_code)) diff --git a/tensorflow/contrib/py2tf/config.py b/tensorflow/contrib/py2tf/impl/config.py similarity index 78% rename from tensorflow/contrib/py2tf/config.py rename to tensorflow/contrib/py2tf/impl/config.py index 8c502a7a9e546dd9b9b40d7cf6d3c9821038afb3..c90e85c96b690b7781358b173e5d83fe60e29c00 100644 --- a/tensorflow/contrib/py2tf/config.py +++ b/tensorflow/contrib/py2tf/impl/config.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.py2tf import utils + + PYTHON_LITERALS = { 'None': None, 'False': False, @@ -27,12 +30,17 @@ PYTHON_LITERALS = { DEFAULT_UNCOMPILED_MODULES = set(( ('tensorflow',), + (utils.__name__,), )) NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',)) # TODO(mdan): Also allow controlling the generated names (for testability). +# TODO(mdan): Make sure copybara renames the reference below. COMPILED_IMPORT_STATEMENTS = ( - 'from contextlib import contextmanager', + 'from __future__ import print_function', 'import tensorflow as tf', -) + 'from tensorflow.contrib.py2tf.impl import api as ' + 'py2tf_api', + 'from tensorflow.contrib.py2tf import utils as ' + 'py2tf_utils') diff --git a/tensorflow/contrib/py2tf/conversion.py b/tensorflow/contrib/py2tf/impl/conversion.py similarity index 73% rename from tensorflow/contrib/py2tf/conversion.py rename to tensorflow/contrib/py2tf/impl/conversion.py index b484eebbd58b955d1e783359269d16101d83cfd2..7610f0427be45832dcc12e8f32b65292254eadfd 100644 --- a/tensorflow/contrib/py2tf/conversion.py +++ b/tensorflow/contrib/py2tf/impl/conversion.py @@ -21,21 +21,23 @@ from __future__ import print_function import gast import six -from tensorflow.contrib.py2tf import config -from tensorflow.contrib.py2tf import naming -from tensorflow.contrib.py2tf.converters import break_canonicalization +from tensorflow.contrib.py2tf import utils +from tensorflow.contrib.py2tf.converters import asserts +from tensorflow.contrib.py2tf.converters import break_statements from tensorflow.contrib.py2tf.converters import builtin_functions from tensorflow.contrib.py2tf.converters import call_trees -from tensorflow.contrib.py2tf.converters import continue_canonicalization +from tensorflow.contrib.py2tf.converters import continue_statements from tensorflow.contrib.py2tf.converters import control_flow from tensorflow.contrib.py2tf.converters import decorators -from tensorflow.contrib.py2tf.converters import for_canonicalization +from tensorflow.contrib.py2tf.converters import for_loops from tensorflow.contrib.py2tf.converters import logical_expressions -from tensorflow.contrib.py2tf.converters import print_functions from tensorflow.contrib.py2tf.converters import side_effect_guards +from tensorflow.contrib.py2tf.impl import config +from tensorflow.contrib.py2tf.impl import naming from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.py2tf.pyct.static_analysis import activity from tensorflow.contrib.py2tf.pyct.static_analysis import live_values from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.util import tf_inspect @@ -56,16 +58,20 @@ class ConversionMap(object): converted AST name_map: dict[string]: string; maps original entities to the name of their converted counterparts + api_module: A reference to the api module. The reference needs to be passed + to avoid circular dependencies. """ # TODO(mdan): Rename to ConversionContext, and pull in additional flags. - def __init__(self, recursive, nocompile_decorators, partial_types): + def __init__(self, recursive, nocompile_decorators, partial_types, + api_module): self.recursive = recursive self.nocompile_decorators = nocompile_decorators self.partial_types = partial_types if partial_types else () self.dependency_cache = {} self.name_map = {} + self.api_module = api_module def new_namer(self, namespace): return naming.Namer(namespace, self.recursive, self.name_map, @@ -168,10 +174,29 @@ def class_to_graph(c, conversion_map): return node, class_name +def _add_self_references(namespace, api_module): + """Self refs are only required for analysis and are not used directly.""" + # Manually add the utils namespace which may be used from generated code. + if 'py2tf_util' not in namespace: + namespace['py2tf_utils'] = utils + elif namespace['py2tf_utils'] != utils: + raise ValueError( + 'The module name "py2tf_utils" is reserved and may not be used.') + + # We also make reference to the api module for dynamic conversion, but + # to avoid circular references we don't import it here. + if 'py2tf_api' not in namespace: + namespace['py2tf_api'] = api_module + elif namespace['py2tf_api'] != api_module: + raise ValueError( + 'The module name "py2tf_api" is reserved and may not be used.') + + def function_to_graph(f, conversion_map, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" - node = parser.parse_object(f).body[0] + node, source = parser.parse_entity(f) + node = node.body[0] namespace = six.get_function_globals(f) # This is needed for non-global functions. @@ -182,31 +207,35 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, fn = e.cell_contents namespace[fn.__name__] = fn + _add_self_references(namespace, conversion_map.api_module) + namer = conversion_map.new_namer(namespace) ctx = context.EntityContext( namer=namer, - source_code=tf_inspect.getsource(f), - source_file=tf_inspect.getfile(f), + source_code=source, + source_file='', namespace=namespace, arg_values=arg_values, - arg_types=arg_types) + arg_types=arg_types, + recursive=conversion_map.recursive) node = node_to_graph(node, ctx, conversion_map.nocompile_decorators) - # Simulate a rename to ensure the top level is in the name map. This is needed - # for top level functions, and it also helps the consistency verification made - # by update_name_map. - if owner_type is not None: - new_name = namer.compiled_function_name(f.__name__, f, owner_type) - else: - new_name = namer.compiled_function_name(f.__name__, f) + # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py + new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) + if not did_rename: + new_name = f.__name__ + if node.name != f.__name__: + raise NotImplementedError('Strange corner case. Send us offending code!') + node.name = new_name conversion_map.update_name_map(namer) - return node, conversion_map.name_map[f] + return node, new_name def _static_analysis_pass(node, ctx): - node = access.resolve(node) - node = live_values.resolve(node, ctx.namespace, config.PYTHON_LITERALS) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx, None) + node = live_values.resolve(node, ctx, config.PYTHON_LITERALS) node = type_info.resolve(node, ctx) return node @@ -230,10 +259,7 @@ def node_to_graph(node, ctx, nocompile_decorators): # TODO(mdan): Factor out common elements. # These include: - # * keeping track of symbols that have been created - # * marking nodes (e.g. py_func wrappers) to suppress further processing # * code move between blocks - # * insertion of new global references # * visiting blocks in transformers # Certain steps, especially canonicalization, insert new symbols into the @@ -241,29 +267,33 @@ def node_to_graph(node, ctx, nocompile_decorators): # to re-run the analysis. node = _static_analysis_pass(node, ctx) + # Past this point, line numbers are no longer accurate so we ignore the + # source. + # TODO(mdan): Is it feasible to reconstruct intermediate source code? + ctx.source_code = None node = decorators.transform(node, nocompile_decorators) - node = break_canonicalization.transform(node, ctx.namer) + node = break_statements.transform(node, ctx) + node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. - node = continue_canonicalization.transform(node, ctx.namer) + node = continue_statements.transform(node, ctx) ctx.namespace['len'] = len node = _static_analysis_pass(node, ctx) - node = for_canonicalization.transform(node, ctx.namer) - # for_canonicalization may insert new global references. - node = builtin_functions.transform(node) - # builtin_functions may insert new global references. - ctx.namespace['print'] = print + node = for_loops.transform(node, ctx) + # for_loops may insert new global references. + node = builtin_functions.transform(node, ctx) node = _static_analysis_pass(node, ctx) - node = print_functions.transform(node) - node = call_trees.transform(node, ctx.namer, ctx.namespace, - config.DEFAULT_UNCOMPILED_MODULES, + node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES, nocompile_decorators) - node = control_flow.transform(node, ctx.namer) + node = control_flow.transform(node, ctx) + + # control_flow may create new symbols and change scopes. + node = _static_analysis_pass(node, ctx) node = logical_expressions.transform(node) - node = side_effect_guards.transform(node, ctx.namer) + node = side_effect_guards.transform(node, ctx) return node diff --git a/tensorflow/contrib/py2tf/conversion_test.py b/tensorflow/contrib/py2tf/impl/conversion_test.py similarity index 88% rename from tensorflow/contrib/py2tf/conversion_test.py rename to tensorflow/contrib/py2tf/impl/conversion_test.py index 26f915f4f46e54c9648ae6b35415c4e2639af774..75e95ed88883ac1d94ef8bc4edc232f97cd0b75b 100644 --- a/tensorflow/contrib/py2tf/conversion_test.py +++ b/tensorflow/contrib/py2tf/impl/conversion_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf import conversion +from tensorflow.contrib.py2tf.impl import conversion from tensorflow.python.platform import test @@ -28,7 +28,7 @@ class ConversionTest(test.TestCase): def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): - conversion_map = conversion.ConversionMap(True, (), ()) + conversion_map = conversion.ConversionMap(True, (), (), None) conversion.entity_to_graph('dummy', conversion_map, None, None) def test_entity_to_graph_callable(self): @@ -36,7 +36,7 @@ class ConversionTest(test.TestCase): def f(a): return a - conversion_map = conversion.ConversionMap(True, (), ()) + conversion_map = conversion.ConversionMap(True, (), (), None) ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', new_name) @@ -49,7 +49,7 @@ class ConversionTest(test.TestCase): def f(a): return g(a) - conversion_map = conversion.ConversionMap(True, (), ()) + conversion_map = conversion.ConversionMap(True, (), (), None) conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(f in conversion_map.dependency_cache) diff --git a/tensorflow/contrib/py2tf/naming.py b/tensorflow/contrib/py2tf/impl/naming.py similarity index 55% rename from tensorflow/contrib/py2tf/naming.py rename to tensorflow/contrib/py2tf/impl/naming.py index a90758962b83e1616f7d727440eb7481c49343ad..51326091de13715c32d0a79279f1d3274e48ad10 100644 --- a/tensorflow/contrib/py2tf/naming.py +++ b/tensorflow/contrib/py2tf/impl/naming.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.util import tf_inspect +from tensorflow.contrib.py2tf.pyct import qual_names class Namer(object): @@ -45,10 +45,15 @@ class Namer(object): self.generated_names = set() - def compiled_class_name(self, original_name, live_object=None): + def compiled_class_name(self, original_fqn, live_entity=None): """See call_trees.FunctionNamer.compiled_class_name.""" - if live_object is not None and live_object in self.renamed_calls: - return self.renamed_calls[live_object] + if live_entity is not None and live_entity in self.renamed_calls: + return self.renamed_calls[live_entity] + + if isinstance(original_fqn, tuple): + original_name = '__'.join(original_fqn) + else: + original_name = original_fqn new_name_root = 'Tf%s' % original_name new_name = new_name_root @@ -57,49 +62,69 @@ class Namer(object): n += 1 new_name = '%s_%d' % (new_name_root, n) - if live_object is not None: - self.renamed_calls[live_object] = new_name + if live_entity is not None: + self.renamed_calls[live_entity] = new_name self.generated_names.add(new_name) + if live_entity is not None: + self.renamed_calls[live_entity] = new_name return new_name def compiled_function_name(self, - original_name, - live_object=None, + original_fqn, + live_entity=None, owner_type=None): """See call_trees.FunctionNamer.compiled_function_name.""" - if live_object is not None and live_object in self.renamed_calls: - return self.renamed_calls[live_object] if not self.recursive: - new_name = original_name - elif owner_type is None or owner_type in self.partial_types: - # Top level functions: rename - new_name_root = 'tf__%s' % original_name - new_name = new_name_root - n = 0 - while new_name in self.global_namespace: - n += 1 - new_name = '%s_%d' % (new_name_root, n) + return None, False + + if owner_type is not None and owner_type not in self.partial_types: + # Members are not renamed when part of an entire converted class. + return None, False + + if isinstance(original_fqn, tuple): + original_name = '__'.join(original_fqn) else: - if tf_inspect.isclass(owner_type): - # Class members: do not rename (the entire class will be renamed) - new_name = original_name - else: - raise NotImplementedError('Member function "%s" of non-class type: %s' % - (original_name, owner_type)) + original_name = original_fqn - if live_object is not None: - self.renamed_calls[live_object] = new_name + if live_entity is not None and live_entity in self.renamed_calls: + return self.renamed_calls[live_entity], True + + new_name_root = 'tf__%s' % original_name + new_name = new_name_root + n = 0 + while new_name in self.global_namespace: + n += 1 + new_name = '%s_%d' % (new_name_root, n) + + if live_entity is not None: + self.renamed_calls[live_entity] = new_name self.generated_names.add(new_name) - return new_name + + return new_name, True def new_symbol(self, name_root, reserved_locals): """See control_flow.SymbolNamer.new_symbol.""" + # reserved_locals may contain QNs. + all_reserved_locals = set() + for s in reserved_locals: + if isinstance(s, qual_names.QN): + all_reserved_locals.update(s.qn) + elif isinstance(s, str): + all_reserved_locals.add(s) + else: + raise ValueError('Unexpected symbol type "%s"' % type(s)) + + pieces = name_root.split('_') + if pieces[-1].isdigit(): + name_root = '_'.join(pieces[:-1]) + n = int(pieces[-1]) + else: + n = 0 new_name = name_root - n = 0 - while (new_name in self.global_namespace - or new_name in reserved_locals - or new_name in self.generated_names): + + while (new_name in self.global_namespace or + new_name in all_reserved_locals or new_name in self.generated_names): n += 1 new_name = '%s_%d' % (name_root, n) diff --git a/tensorflow/contrib/py2tf/naming_test.py b/tensorflow/contrib/py2tf/impl/naming_test.py similarity index 82% rename from tensorflow/contrib/py2tf/naming_test.py rename to tensorflow/contrib/py2tf/impl/naming_test.py index 7bfc9b8733b6efc3ab440ae5a0614258ae395ad4..beb4e54937bbb91b19157c9b9e3c528353206c62 100644 --- a/tensorflow/contrib/py2tf/naming_test.py +++ b/tensorflow/contrib/py2tf/impl/naming_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import naming +from tensorflow.contrib.py2tf.impl import naming from tensorflow.python.platform import test @@ -29,8 +29,9 @@ class NamerTest(test.TestCase): pass namer = naming.Namer({}, True, None, ()) - self.assertEqual('tf__foo', namer.compiled_function_name('foo')) - self.assertEqual('tf__bar', namer.compiled_function_name('bar', bar)) + self.assertEqual(('tf__foo', True), namer.compiled_function_name('foo')) + self.assertEqual(('tf__bar', True), namer.compiled_function_name( + 'bar', bar)) self.assertEqual({bar: 'tf__bar'}, namer.renamed_calls) self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names) @@ -39,15 +40,18 @@ class NamerTest(test.TestCase): pass namer = naming.Namer({}, True, None, ()) - self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo)) - self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo)) + self.assertEqual(('tf__foo', True), namer.compiled_function_name( + 'foo', foo)) + self.assertEqual(('tf__foo', True), namer.compiled_function_name( + 'foo', foo)) def test_compiled_function_name_avoids_global_conflicts(self): def foo(): pass namer = naming.Namer({'tf__foo': 1}, True, None, ()) - self.assertEqual('tf__foo_1', namer.compiled_function_name('foo', foo)) + self.assertEqual(('tf__foo_1', True), + namer.compiled_function_name('foo', foo)) def test_new_symbol_tracks_names(self): namer = naming.Namer({}, True, None, ()) diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/py2tf/pyct/BUILD index e0331dbc97c688ed34be362426b0b1f0d25931bc..e3c0da4b10f9ffbee1b2a906b64d4762f41d97b4 100644 --- a/tensorflow/contrib/py2tf/pyct/BUILD +++ b/tensorflow/contrib/py2tf/pyct/BUILD @@ -1,5 +1,7 @@ licenses(["notice"]) # Apache 2.0 +exports_files(["LICENSE"]) + load("//tensorflow:tensorflow.bzl", "py_test") filegroup( @@ -19,10 +21,12 @@ py_library( srcs = [ "__init__.py", "anno.py", + "ast_util.py", "compiler.py", "context.py", "parser.py", "pretty_printer.py", + "qual_names.py", "templates.py", "transformer.py", ], @@ -31,6 +35,7 @@ py_library( deps = [ "@astor_archive//:astor", "@gast_archive//:gast", + "@six_archive//:six", "@termcolor_archive//:termcolor", ], ) @@ -38,15 +43,28 @@ py_library( py_test( name = "anno_test", srcs = ["anno_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "ast_util_test", + srcs = ["ast_util_test.py"], + srcs_version = "PY2AND3", deps = [ ":pyct", "//tensorflow/python:client_testlib", + "@gast_archive//:gast", ], ) py_test( name = "compiler_test", srcs = ["compiler_test.py"], + srcs_version = "PY2AND3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -57,6 +75,7 @@ py_test( py_test( name = "parser_test", srcs = ["parser_test.py"], + srcs_version = "PY2AND3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -66,6 +85,17 @@ py_test( py_test( name = "pretty_printer_test", srcs = ["pretty_printer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "qual_names_test", + srcs = ["qual_names_test.py"], + srcs_version = "PY2AND3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -75,6 +105,7 @@ py_test( py_test( name = "templates_test", srcs = ["templates_test.py"], + srcs_version = "PY2AND3", deps = [ ":pyct", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/py2tf/pyct/anno.py b/tensorflow/contrib/py2tf/pyct/anno.py index 889e4ba4ffaed887faffb8736e4a59502da99e81..7a0528b6d0b65b6604930b7a13d8493af9d61f02 100644 --- a/tensorflow/contrib/py2tf/pyct/anno.py +++ b/tensorflow/contrib/py2tf/pyct/anno.py @@ -21,6 +21,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from enum import Enum + + +class NoValue(Enum): + + def __repr__(self): + return self.name + + +class Basic(NoValue): + """Container for annotation keys. + + The enum values are used strictly for documentation purposes. + """ + + QN = 'Qualified name, as it appeared in the code.' + SKIP_PROCESSING = ( + 'This node should be preserved as is and not processed any further.') + INDENT_BLOCK_REMAINDER = ( + 'When a node is annotated with this, the remainder of the block should ' + 'be indented below it. The annotation contains a tuple ' + '(new_body, name_map), where `new_body` is the new indented block and ' + '`name_map` allows renaming symbols.') + def getanno(node, key, field_name='___pyct_anno'): return getattr(node, field_name)[key] @@ -38,3 +62,11 @@ def setanno(node, key, value, field_name='___pyct_anno'): # So that the annotations survive gast_to_ast() and ast_to_gast() if field_name not in node._fields: node._fields += (field_name,) + + +def delanno(node, key, field_name='___pyct_anno'): + annotations = getattr(node, field_name) + del annotations[key] + if not annotations: + delattr(node, field_name) + node._fields = tuple(f for f in node._fields if f != field_name) diff --git a/tensorflow/contrib/py2tf/pyct/anno_test.py b/tensorflow/contrib/py2tf/pyct/anno_test.py index 19e3b4576210c3715620fc7002c91c5130b46ed0..ff40bfe1f50ae731648afdf509c26c3a70d3f6cb 100644 --- a/tensorflow/contrib/py2tf/pyct/anno_test.py +++ b/tensorflow/contrib/py2tf/pyct/anno_test.py @@ -37,6 +37,11 @@ class AnnoTest(test.TestCase): self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(3, anno.getanno(node, 'foo')) + anno.delanno(node, 'foo') + self.assertFalse(anno.hasanno(node, 'foo')) + with self.assertRaises(AttributeError): + anno.getanno(node, 'foo') + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/ast_util.py b/tensorflow/contrib/py2tf/pyct/ast_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f916775b9cf3cec960ec2896c334f1d737862205 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/ast_util.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Copy an AST tree, discarding annotations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ast + +import gast + +from tensorflow.contrib.py2tf.pyct import anno + + +class CleanCopier(gast.NodeVisitor): + """Copy AST nodes. + + The copied nodes will ignore almost all fields that prefixed by '__'. + Exceptions make some annotations. + """ + + # TODO(mdan): Parametrize which annotations get carried over. + + def generic_visit(self, node): + new_fields = {} + for f in node._fields: + if f.startswith('__'): + continue + if not hasattr(node, f): + continue + v = getattr(node, f) + if isinstance(v, list): + v = [self.generic_visit(n) for n in v] + elif isinstance(v, tuple): + v = tuple(self.generic_visit(n) for n in v) + elif isinstance(v, (gast.AST, ast.AST)): + v = self.generic_visit(v) + else: + # Assume everything else is a value type. + pass + new_fields[f] = v + new_node = type(node)(**new_fields) + if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + anno.setanno(new_node, anno.Basic.SKIP_PROCESSING, True) + return new_node + + +def copy_clean(node): + copier = CleanCopier() + if isinstance(node, list): + return [copier.visit(n) for n in node] + elif isinstance(node, tuple): + return tuple(copier.visit(n) for n in node) + else: + return copier.visit(node) + + +class SymbolRenamer(gast.NodeTransformer): + """Transformer that can rename symbols to a simple names.""" + + def __init__(self, name_map): + self.name_map = name_map + + def _process(self, node): + qn = anno.getanno(node, anno.Basic.QN) + if qn in self.name_map: + return gast.Name(str(self.name_map[qn]), node.ctx, None) + return self.generic_visit(node) + + def visit_Name(self, node): + return self._process(node) + + def visit_Attribute(self, node): + return self._process(node) + + +def rename_symbols(node, name_map): + renamer = SymbolRenamer(name_map) + if isinstance(node, list): + return [renamer.visit(n) for n in node] + elif isinstance(node, tuple): + return tuple(renamer.visit(n) for n in node) + return renamer.visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/ast_util_test.py b/tensorflow/contrib/py2tf/pyct/ast_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b00c178168f96e656c57cc75a76e6da8af1d8a --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/ast_util_test.py @@ -0,0 +1,79 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ast_util module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ast + +from tensorflow.contrib.py2tf.pyct import ast_util +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.python.platform import test + + +class AstUtilTest(test.TestCase): + + def test_rename_symbols(self): + node = ast.Tuple([ + ast.Name('a', ast.Load()), + ast.Name('b', ast.Load()), + ast.Attribute(ast.Name('b', None), 'c', ast.Store()), + ast.Attribute( + ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', + None) + ], None) + node = qual_names.resolve(node) + node = ast_util.rename_symbols( + node, + { + qual_names.QN('a'): qual_names.QN('renamed_a'), + qual_names.QN('b.c'): qual_names.QN('renamed_b_c'), + }) + + self.assertEqual(node.elts[0].id, 'renamed_a') + self.assertTrue(isinstance(node.elts[0].ctx, ast.Load)) + self.assertEqual(node.elts[1].id, 'b') + self.assertEqual(node.elts[2].id, 'renamed_b_c') + self.assertTrue(isinstance(node.elts[2].ctx, ast.Store)) + self.assertEqual(node.elts[3].value.id, 'renamed_b_c') + self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load)) + + def test_copy_clean(self): + ret = ast.Return( + ast.BinOp( + op=ast.Add(), + left=ast.Name(id='a', ctx=ast.Load()), + right=ast.Num(1))) + setattr(ret, '__foo', 'bar') + node = ast.FunctionDef( + name='f', + args=ast.arguments( + args=[ast.Name(id='a', ctx=ast.Param())], + vararg=None, + kwarg=None, + defaults=[]), + body=[ret], + decorator_list=[], + returns=None) + new_node = ast_util.copy_clean(node) + self.assertFalse(node is new_node) + self.assertFalse(ret is new_node.body[0]) + self.assertFalse(hasattr(new_node.body[0], '__foo')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/compiler.py b/tensorflow/contrib/py2tf/pyct/compiler.py index b09353cc72bd5f9d02a8973ebe880b92d39ac304..51cf6930e8bcb3728ee55bf5d4781f01a5ef73bd 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler.py +++ b/tensorflow/contrib/py2tf/pyct/compiler.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function # TODO(mdan): Use six for compatibility here. +import atexit import imp import os import tempfile @@ -41,7 +42,8 @@ def ast_to_source(node, indentation): return astor.source_repr.pretty_source(generator.result).lstrip() -def ast_to_object(node, indentation=' '): +def ast_to_object( + node, indentation=' ', source_prefix=None, delete_on_exit=True): """Return the Python objects represented by given AST. Compiling the AST code this way ensures that the source code is readable by @@ -50,6 +52,9 @@ def ast_to_object(node, indentation=' '): Args: node: The code to compile, as an AST object. indentation: The string to use for indentation. + source_prefix: Optional string to print as-is into the source file. + delete_on_exit: Whether to delete the temporary file used for compilation + on exit. Returns: A module object containing the compiled source code. @@ -58,5 +63,10 @@ def ast_to_object(node, indentation=' '): with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: module_name = os.path.basename(f.name[:-3]) + if source_prefix: + f.write(source_prefix) + f.write('\n') f.write(source) - return imp.load_source(module_name, f.name) + if delete_on_exit: + atexit.register(lambda: os.remove(f.name)) + return imp.load_source(module_name, f.name), source diff --git a/tensorflow/contrib/py2tf/pyct/compiler_test.py b/tensorflow/contrib/py2tf/pyct/compiler_test.py index e0cde43566310b99bac5035285154fde906fa127..c1f84238efa7dd6fc0748748a2cb4f074572b4c6 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler_test.py +++ b/tensorflow/contrib/py2tf/pyct/compiler_test.py @@ -41,6 +41,7 @@ class CompilerTest(test.TestCase): targets=[gast.Name('a', gast.Store(), None)], value=gast.Str('c')) ]) + self.assertEqual( textwrap.dedent(""" if 1: @@ -70,15 +71,19 @@ class CompilerTest(test.TestCase): decorator_list=[], returns=None) - mod = compiler.ast_to_object(node) + module, source = compiler.ast_to_object(node) - self.assertEqual(2, mod.f(1)) - with open(mod.__file__, 'r') as temp_output: + expected_source = """ + def f(a): + return a + 1 + """ + self.assertEqual( + textwrap.dedent(expected_source).strip(), + source.strip()) + self.assertEqual(2, module.f(1)) + with open(module.__file__, 'r') as temp_output: self.assertEqual( - textwrap.dedent(""" - def f(a): - return a + 1 - """).strip(), + textwrap.dedent(expected_source).strip(), temp_output.read().strip()) diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/py2tf/pyct/context.py index 73f3613d09d01e9e643cfb8ee3a8e67e5c126455..fef74ebefa290369c7310af6d7e4faeef44d9aee 100644 --- a/tensorflow/contrib/py2tf/pyct/context.py +++ b/tensorflow/contrib/py2tf/pyct/context.py @@ -33,10 +33,11 @@ class EntityContext(object): """ def __init__(self, namer, source_code, source_file, namespace, arg_values, - arg_types): + arg_types, recursive): self.namer = namer self.source_code = source_code self.source_file = source_file self.namespace = namespace self.arg_values = {} if arg_values is None else arg_values self.arg_types = {} if arg_types is None else arg_types + self.recursive = recursive diff --git a/tensorflow/contrib/py2tf/pyct/parser.py b/tensorflow/contrib/py2tf/pyct/parser.py index 3daa69b9ceff714c94c61134f6fb81f9927ea258..dc7df883b349becd860bb0dbceab22cb39c750b5 100644 --- a/tensorflow/contrib/py2tf/pyct/parser.py +++ b/tensorflow/contrib/py2tf/pyct/parser.py @@ -28,11 +28,13 @@ import gast from tensorflow.python.util import tf_inspect -def parse_object(obj): - """Return the AST of given object.""" - return parse_str(tf_inspect.getsource(obj)) +def parse_entity(entity): + """Return the AST of given entity.""" + source = tf_inspect.getsource(entity) + source = textwrap.dedent(source) + return parse_str(source), source def parse_str(src): """Return the AST of given piece of code.""" - return gast.parse(textwrap.dedent(src)) + return gast.parse(src) diff --git a/tensorflow/contrib/py2tf/pyct/parser_test.py b/tensorflow/contrib/py2tf/pyct/parser_test.py index 46f9aa82071efa98518810851b76761ff42751e5..f35dfa04c70dc191078248c32f9a04d28133129a 100644 --- a/tensorflow/contrib/py2tf/pyct/parser_test.py +++ b/tensorflow/contrib/py2tf/pyct/parser_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import textwrap + from tensorflow.contrib.py2tf.pyct import parser from tensorflow.python.platform import test @@ -28,15 +30,16 @@ def f(x): class ParserTest(test.TestCase): - def test_parse_object(self): - mod = parser.parse_object(f) + def test_parse_entity(self): + mod, _ = parser.parse_entity(f) self.assertEqual('f', mod.body[0].name) def test_parse_str(self): - mod = parser.parse_str(""" + mod = parser.parse_str( + textwrap.dedent(""" def f(x): return x + 1 - """) + """)) self.assertEqual('f', mod.body[0].name) diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer.py b/tensorflow/contrib/py2tf/pyct/pretty_printer.py index 5e70c0ed833c10012e6a5b4cb26e9e4198162693..bacc1e4a7774ec5b84495255042392fe089150d5 100644 --- a/tensorflow/contrib/py2tf/pyct/pretty_printer.py +++ b/tensorflow/contrib/py2tf/pyct/pretty_printer.py @@ -25,24 +25,30 @@ import termcolor class PrettyPrinter(gast.NodeVisitor): """Print AST nodes.""" - def __init__(self): + def __init__(self, color): self.indent_lvl = 0 self.result = '' + self.color = color + + def _color(self, string, color, attrs=None): + if self.color: + return termcolor.colored(string, color, attrs=attrs) + return string def _type(self, node): - return termcolor.colored(node.__class__.__name__, None, attrs=['bold']) + return self._color(node.__class__.__name__, None, ['bold']) def _field(self, name): - return termcolor.colored(name, 'blue') + return self._color(name, 'blue') def _value(self, name): - return termcolor.colored(name, 'magenta') + return self._color(name, 'magenta') def _warning(self, name): - return termcolor.colored(name, 'red') + return self._color(name, 'red') def _indent(self): - return termcolor.colored('| ' * self.indent_lvl, None, attrs=['dark']) + return self._color('| ' * self.indent_lvl, None, ['dark']) def _print(self, s): self.result += s @@ -76,6 +82,16 @@ class PrettyPrinter(gast.NodeVisitor): self._print('%s]' % (self._indent())) else: self._print('%s%s=[]' % (self._indent(), self._field(f))) + elif isinstance(v, tuple): + if v: + self._print('%s%s=(' % (self._indent(), self._field(f))) + self.indent_lvl += 1 + for n in v: + self.generic_visit(n) + self.indent_lvl -= 1 + self._print('%s)' % (self._indent())) + else: + self._print('%s%s=()' % (self._indent(), self._field(f))) elif isinstance(v, gast.AST): self.generic_visit(v, f) elif isinstance(v, str): @@ -87,8 +103,8 @@ class PrettyPrinter(gast.NodeVisitor): self.indent_lvl -= 1 -def fmt(node): - printer = PrettyPrinter() +def fmt(node, color=True): + printer = PrettyPrinter(color) if isinstance(node, (list, tuple)): for n in node: printer.visit(n) diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py b/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py index 65e5b1d9191749a0caeeda48df37690564a8fc1e..81e3f47b80b6cb3bb7ba9f4a1787d03df4151a99 100644 --- a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py +++ b/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py @@ -24,10 +24,6 @@ from tensorflow.contrib.py2tf.pyct import pretty_printer from tensorflow.python.platform import test -def f(x): - return x + 1 - - class PrettyPrinterTest(test.TestCase): def test_format(self): diff --git a/tensorflow/contrib/py2tf/pyct/qual_names.py b/tensorflow/contrib/py2tf/pyct/qual_names.py new file mode 100644 index 0000000000000000000000000000000000000000..8717ee6cff198ff31f6cbdb7213e5a8dd3df1149 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/qual_names.py @@ -0,0 +1,104 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for manipulating qualified names. + +A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite +(e.g. 'foo.bar') syntactic symbols. + +This is *not* related to the __qualname__ attribute used by inspect, which +refers to scopes. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import anno + + +class QN(object): + """Represents a qualified name.""" + + def __init__(self, base, attr=None): + if attr: + if not isinstance(base, QN): + raise ValueError('For attribute QNs, base must be a QN.') + self._parent = base + self.qn = base.qn + (attr,) + else: + if isinstance(base, QN): + if base.is_composite(): + self._parent = base.parent + else: + self._parent = None + self.qn = base.qn + else: + self._parent = None + self.qn = tuple(base.split('.')) + + def is_composite(self): + return len(self.qn) > 1 + + @property + def parent(self): + if self._parent is None: + raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0]) + return self._parent + + def __hash__(self): + return hash(self.qn) + + def __eq__(self, other): + return self.qn == other.qn + + def __str__(self): + return '.'.join(self.qn) + + def __repr__(self): + return str(self) + + def ssf(self): + """Simple symbol form.""" + return '_'.join(self.qn) + + def ast(self): + # The caller must adjust the context appropriately. + if self.is_composite(): + return gast.Attribute(self.parent.ast(), self.qn[-1], None) + return gast.Name(self.qn[0], None, None) + + +class QnResolver(gast.NodeTransformer): + """Annotates nodes with QN information. + + Note: Not using NodeAnnos to avoid circular dependencies. + """ + + def visit_Name(self, node): + self.generic_visit(node) + anno.setanno(node, anno.Basic.QN, QN(node.id)) + return node + + def visit_Attribute(self, node): + self.generic_visit(node) + anno.setanno(node, anno.Basic.QN, + QN(anno.getanno(node.value, anno.Basic.QN), node.attr)) + return node + + +def resolve(node): + return QnResolver().visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/qual_names_test.py b/tensorflow/contrib/py2tf/pyct/qual_names_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1b1eee2deca18bb0540c17d6ee85d421602aa2b7 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/qual_names_test.py @@ -0,0 +1,108 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for qual_names module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import textwrap + +from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.python.platform import test + + +class QNTest(test.TestCase): + + def test_basic(self): + a = qual_names.QN('a') + self.assertEqual(a.qn, ('a',)) + self.assertEqual(str(a), 'a') + self.assertEqual(a.ssf(), 'a') + self.assertEqual(a.ast().id, 'a') + self.assertFalse(a.is_composite()) + with self.assertRaises(ValueError): + _ = a.parent + + a_b = qual_names.QN(a, 'b') + self.assertEqual(a_b.qn, ('a', 'b')) + self.assertEqual(str(a_b), 'a.b') + self.assertEqual(a_b.ssf(), 'a_b') + self.assertEqual(a_b.ast().value.id, 'a') + self.assertEqual(a_b.ast().attr, 'b') + self.assertTrue(a_b.is_composite()) + self.assertEqual(a_b.parent.qn, ('a',)) + + a2 = qual_names.QN(a) + self.assertEqual(a2.qn, ('a',)) + with self.assertRaises(ValueError): + _ = a.parent + + a_b2 = qual_names.QN(a_b) + self.assertEqual(a_b2.qn, ('a', 'b')) + self.assertEqual(a_b2.parent.qn, ('a',)) + + self.assertTrue(a2 == a) + self.assertFalse(a2 is a) + + self.assertTrue(a_b.parent == a) + self.assertTrue(a_b2.parent == a) + + self.assertTrue(a_b2 == a_b) + self.assertFalse(a_b2 is a_b) + self.assertFalse(a_b2 == a) + + with self.assertRaises(ValueError): + qual_names.QN('a', 'b') + + def test_hashable(self): + d = {qual_names.QN('a'): 'a', qual_names.QN('b'): 'b'} + + self.assertEqual(d[qual_names.QN('a')], 'a') + self.assertEqual(d[qual_names.QN('b')], 'b') + self.assertTrue(qual_names.QN('c') not in d) + + +class QNResolverTest(test.TestCase): + + def assertQNStringIs(self, node, qn_str): + self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str) + + def test_resolve(self): + samples = """ + a + a.b + (c, d.e) + [f, (g.h.i)] + j(k, l) + """ + nodes = qual_names.resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + + self.assertQNStringIs(nodes[0], 'a') + self.assertQNStringIs(nodes[1], 'a.b') + self.assertQNStringIs(nodes[2].elts[0], 'c') + self.assertQNStringIs(nodes[2].elts[1], 'd.e') + self.assertQNStringIs(nodes[3].elts[0], 'f') + self.assertQNStringIs(nodes[3].elts[1], 'g.h.i') + self.assertQNStringIs(nodes[4].func, 'j') + self.assertQNStringIs(nodes[4].args[0], 'k') + self.assertQNStringIs(nodes[4].args[1], 'l') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD b/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD index abaf9536781efadea61b0da684020baeeed0597d..fbfce18c60cca4b105e7de3c3ea7b9c3438f6b2a 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD @@ -17,7 +17,8 @@ filegroup( py_library( name = "static_analysis", srcs = [ - "access.py", + "activity.py", + "annos.py", "live_values.py", "type_info.py", ], @@ -30,8 +31,9 @@ py_library( ) py_test( - name = "access_test", - srcs = ["access_test.py"], + name = "activity_test", + srcs = ["activity_test.py"], + srcs_version = "PY2AND3", deps = [ ":static_analysis", "//tensorflow/contrib/py2tf/pyct", @@ -43,6 +45,7 @@ py_test( py_test( name = "live_values_test", srcs = ["live_values_test.py"], + srcs_version = "PY2AND3", deps = [ ":static_analysis", "//tensorflow/contrib/py2tf/pyct", @@ -53,6 +56,7 @@ py_test( py_test( name = "type_info_test", srcs = ["type_info_test.py"], + srcs_version = "PY2AND3", deps = [ ":static_analysis", "//tensorflow/contrib/py2tf/pyct", diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py deleted file mode 100644 index 0912ebb4c355c2ae2563e13e36926a4b8e3599a1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for access module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import gast - -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access -from tensorflow.python.platform import test - - -class ScopeTest(test.TestCase): - - def test_basic(self): - scope = access.Scope(None) - self.assertFalse(scope.has('foo')) - - scope.mark_read('foo') - self.assertFalse(scope.has('foo')) - - scope.mark_write('foo') - self.assertTrue(scope.has('foo')) - - scope.mark_read('bar') - self.assertFalse(scope.has('bar')) - - def test_copy(self): - scope = access.Scope(None) - scope.mark_write('foo') - - other = access.Scope(None) - other.copy_from(scope) - - self.assertTrue('foo' in other.created) - - scope.mark_write('bar') - scope.copy_from(other) - - self.assertFalse('bar' in scope.created) - - scope.mark_write('bar') - scope.merge_from(other) - - self.assertTrue('bar' in scope.created) - self.assertFalse('bar' in other.created) - - def test_nesting(self): - scope = access.Scope(None) - scope.mark_write('foo') - scope.mark_read('bar') - - child = access.Scope(scope) - self.assertTrue(child.has('foo')) - self.assertTrue(scope.has('foo')) - - child.mark_write('bar') - self.assertTrue(child.has('bar')) - self.assertFalse(scope.has('bar')) - - def test_referenced(self): - scope = access.Scope(None) - scope.mark_read('a') - - child = access.Scope(scope) - child.mark_read('b') - - child2 = access.Scope(child, isolated=False) - child2.mark_read('c') - - self.assertTrue('c' in child2.referenced) - self.assertTrue('b' in child2.referenced) - self.assertFalse('a' in child2.referenced) - - self.assertTrue('c' in child.referenced) - self.assertTrue('b' in child.referenced) - self.assertFalse('a' in child.referenced) - - -class AccessResolverTest(test.TestCase): - - def test_local_markers(self): - - def test_fn(a): # pylint:disable=unused-argument - b = c # pylint:disable=undefined-variable - while b > 0: - b -= 1 - return b - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - self.assertFalse(anno.getanno(node.body[0].body[0].value, - 'is_local')) # c in b = c - self.assertTrue(anno.getanno(node.body[0].body[1].test.left, - 'is_local')) # b in b > 0 - self.assertTrue(anno.getanno(node.body[0].body[2].value, - 'is_local')) # b in return b - - def assertScopeIs(self, scope, used, modified, created): - self.assertItemsEqual(used, scope.used) - self.assertItemsEqual(modified, scope.modified) - self.assertItemsEqual(created, scope.created) - - def test_print_statement(self): - - def test_fn(a): - b = 0 - c = 1 - print(a, b) - return c - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - print_node = node.body[0].body[2] - if isinstance(print_node, gast.Print): - # Python 2 - print_args_scope = anno.getanno(print_node, 'args_scope') - else: - # Python 3 - assert isinstance(print_node, gast.Expr) - # The call node should be the one being annotated. - print_node = print_node.value - print_args_scope = anno.getanno(print_node, 'args_scope') - # We basically need to detect which variables are captured by the call - # arguments. - self.assertScopeIs(print_args_scope, ('a', 'b'), (), ()) - - def test_call(self): - - def test_fn(a): - b = 0 - c = 1 - foo(a, b) # pylint:disable=undefined-variable - return c - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - call_node = node.body[0].body[2].value - # We basically need to detect which variables are captured by the call - # arguments. - self.assertScopeIs( - anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ()) - - def test_while(self): - - def test_fn(a): - b = a - while b > 0: - c = b - b -= 1 - return b, c - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - while_node = node.body[0].body[1] - self.assertScopeIs( - anno.getanno(while_node, 'body_scope'), ('b',), ('b', 'c'), ('c',)) - self.assertScopeIs( - anno.getanno(while_node, 'body_parent_scope'), ('a', 'b', 'c'), - ('a', 'b', 'c'), ('a', 'b', 'c')) - - def test_for(self): - - def test_fn(a): - b = a - for _ in a: - c = b - b -= 1 - return b, c - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - for_node = node.body[0].body[1] - self.assertScopeIs( - anno.getanno(for_node, 'body_scope'), ('b',), ('b', 'c'), ('c',)) - self.assertScopeIs( - anno.getanno(for_node, 'body_parent_scope'), ('a', 'b', 'c'), - ('a', 'b', 'c', '_'), ('a', 'b', 'c', '_')) - - def test_if(self): - - def test_fn(x): - if x > 0: - x = -x - y = 2 * x - z = -y - else: - x = 2 * x - y = -x - u = -y - return z, u - - node = parser.parse_object(test_fn) - node = access.resolve(node) - - if_node = node.body[0].body[0] - self.assertScopeIs( - anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'), - ('y', 'z')) - # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? - self.assertScopeIs( - anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), - ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) - self.assertScopeIs( - anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'), - ('y', 'u')) - self.assertScopeIs( - anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), - ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py b/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py similarity index 60% rename from tensorflow/contrib/py2tf/pyct/static_analysis/access.py rename to tensorflow/contrib/py2tf/pyct/static_analysis/activity.py index 8f3ac48b68c05256fbac4c4d8d86381755c8027c..1c93e1603113d48176af7a97a0f37321e6f67586 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Access information (reads, writes) resolution.""" +"""Activity analysis.""" from __future__ import absolute_import from __future__ import division @@ -23,6 +23,8 @@ import copy import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Add support for PY3 (e.g. Param vs arg). @@ -53,6 +55,8 @@ class Scope(object): self.modified = set() self.created = set() self.used = set() + self.params = set() + self.returned = set() # TODO(mdan): Rename to `locals` @property @@ -69,61 +73,116 @@ class Scope(object): self.modified = copy.copy(other.modified) self.created = copy.copy(other.created) self.used = copy.copy(other.used) + self.params = copy.copy(other.params) + self.returned = copy.copy(other.returned) def merge_from(self, other): self.modified |= other.modified self.created |= other.created self.used |= other.used + self.params |= other.params + self.returned |= other.returned def has(self, name): - if name in self.modified: + if name in self.modified or name in self.params: return True elif self.parent is not None: return self.parent.has(name) return False + def is_modified_since_entry(self, name): + if name in self.modified: + return True + elif self.parent is not None and not self.isolated: + return self.parent.is_modified_since_entry(name) + return False + + def is_param(self, name): + if name in self.params: + return True + elif self.parent is not None and not self.isolated: + return self.parent.is_param(name) + return False + def mark_read(self, name): self.used.add(name) if self.parent is not None and name not in self.created: self.parent.mark_read(name) + def mark_param(self, name): + self.params.add(name) + + def mark_creation(self, name): + if name.is_composite(): + parent = name.parent + if self.has(parent): + # This is considered mutation of the parent, not creation. + # TODO(mdan): Is that really so? + return + else: + raise ValueError('Unknown symbol "%s".' % parent) + self.created.add(name) + def mark_write(self, name): self.modified.add(name) if self.isolated: - self.created.add(name) + self.mark_creation(name) else: if self.parent is None: - self.created.add(name) + self.mark_creation(name) else: if not self.parent.has(name): - self.created.add(name) + self.mark_creation(name) self.parent.mark_write(name) + def mark_returned(self, name): + self.returned.add(name) + if not self.isolated and self.parent is not None: + self.parent.mark_returned(name) + -class AccessResolver(gast.NodeTransformer): +class ActivityAnalizer(transformer.Base): """Annotates nodes with local scope information. See Scope.""" - def __init__(self): - self.scope = Scope(None) + def __init__(self, context, parent_scope): + super(ActivityAnalizer, self).__init__(context) + self.scope = Scope(parent_scope) + self._in_return_statement = False + + def _track_symbol(self, node): + qn = anno.getanno(node, anno.Basic.QN) - def visit_Name(self, node): - # TODO(mdan): This is insufficient for object fields, e.g. hp.learning_rate. - self.generic_visit(node) if isinstance(node.ctx, gast.Store): - self.scope.mark_write(node.id) + self.scope.mark_write(qn) elif isinstance(node.ctx, gast.Load): - anno.setanno(node, 'is_local', self.scope.has(node.id)) - self.scope.mark_read(node.id) + self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. # TODO(mdan): This bay be incorrect with nested functions. # For nested functions, we'll have to add the notion of hiding args from # the parent scope, not writing to them. - self.scope.mark_write(node.id) + self.scope.mark_creation(qn) + self.scope.mark_param(qn) else: - raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), - node.id)) + raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) + + anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) + anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY, + self.scope.is_modified_since_entry(qn)) + anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn)) + + if self._in_return_statement: + self.scope.mark_returned(qn) + + def visit_Name(self, node): + self.generic_visit(node) + self._track_symbol(node) + return node + + def visit_Attribute(self, node): + self.generic_visit(node) + self._track_symbol(node) return node def visit_Print(self, node): @@ -132,20 +191,20 @@ class AccessResolver(gast.NodeTransformer): self.scope = args_scope for n in node.values: self.visit(n) - anno.setanno(node, 'args_scope', args_scope) + anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) self.scope = current_scope return node def visit_Call(self, node): current_scope = self.scope - args_scope = Scope(current_scope) + args_scope = Scope(current_scope, isolated=False) self.scope = args_scope for n in node.args: self.visit(n) # TODO(mdan): Account starargs, kwargs for n in node.keywords: self.visit(n) - anno.setanno(node, 'args_scope', args_scope) + anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) self.scope = current_scope self.visit(node.func) return node @@ -156,7 +215,7 @@ class AccessResolver(gast.NodeTransformer): self.scope = block_scope for n in block: self.visit(n) - anno.setanno(node, '%s_scope' % scope_name, block_scope) + anno.setanno(node, scope_name, block_scope) self.scope = current_scope return node @@ -168,38 +227,44 @@ class AccessResolver(gast.NodeTransformer): before_parent = Scope(None) before_parent.copy_from(self.scope) after_children = [] - for child, name in children: + for child, scope_name in children: self.scope.copy_from(before_parent) - parent = self._process_block_node(parent, child, name) + parent = self._process_block_node(parent, child, scope_name) after_child = Scope(None) after_child.copy_from(self.scope) after_children.append(after_child) for after_child in after_children: self.scope.merge_from(after_child) - for child, name in children: - # TODO(mdan): We don't need this - we have the parent link from scope. - anno.setanno(parent, '%s_parent_scope' % name, self.scope) return parent def visit_If(self, node): self.visit(node.test) - node = self._process_parallel_blocks( - node, ((node.body, 'body'), (node.orelse, 'orelse'))) + node = self._process_parallel_blocks(node, + ((node.body, NodeAnno.BODY_SCOPE), + (node.orelse, NodeAnno.ORELSE_SCOPE))) return node def visit_For(self, node): self.visit(node.target) self.visit(node.iter) - node = self._process_parallel_blocks( - node, ((node.body, 'body'), (node.orelse, 'orelse'))) + node = self._process_parallel_blocks(node, + ((node.body, NodeAnno.BODY_SCOPE), + (node.orelse, NodeAnno.ORELSE_SCOPE))) return node def visit_While(self, node): self.visit(node.test) - node = self._process_parallel_blocks( - node, ((node.body, 'body'), (node.orelse, 'orelse'))) + node = self._process_parallel_blocks(node, + ((node.body, NodeAnno.BODY_SCOPE), + (node.orelse, NodeAnno.ORELSE_SCOPE))) + return node + + def visit_Return(self, node): + self._in_return_statement = True + node = self.generic_visit(node) + self._in_return_statement = False return node -def resolve(node): - return AccessResolver().visit(node) +def resolve(node, context, parent_scope=None): + return ActivityAnalizer(context, parent_scope).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e1eb954a5efef4d6a00ac492e7c85394d54e28c9 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py @@ -0,0 +1,271 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for activity module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import context +from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.py2tf.pyct.qual_names import QN +from tensorflow.contrib.py2tf.pyct.static_analysis import activity +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.python.platform import test + + +class ScopeTest(test.TestCase): + + def test_basic(self): + scope = activity.Scope(None) + self.assertFalse(scope.has(QN('foo'))) + + scope.mark_read(QN('foo')) + self.assertFalse(scope.has(QN('foo'))) + + scope.mark_write(QN('foo')) + self.assertTrue(scope.has(QN('foo'))) + + scope.mark_read(QN('bar')) + self.assertFalse(scope.has(QN('bar'))) + + def test_copy(self): + scope = activity.Scope(None) + scope.mark_write(QN('foo')) + + other = activity.Scope(None) + other.copy_from(scope) + + self.assertTrue(QN('foo') in other.created) + + scope.mark_write(QN('bar')) + scope.copy_from(other) + + self.assertFalse(QN('bar') in scope.created) + + scope.mark_write(QN('bar')) + scope.merge_from(other) + + self.assertTrue(QN('bar') in scope.created) + self.assertFalse(QN('bar') in other.created) + + def test_nesting(self): + scope = activity.Scope(None) + scope.mark_write(QN('foo')) + scope.mark_read(QN('bar')) + + child = activity.Scope(scope) + self.assertTrue(child.has(QN('foo'))) + self.assertTrue(scope.has(QN('foo'))) + + child.mark_write(QN('bar')) + self.assertTrue(child.has(QN('bar'))) + self.assertFalse(scope.has(QN('bar'))) + + def test_referenced(self): + scope = activity.Scope(None) + scope.mark_read(QN('a')) + + child = activity.Scope(scope) + child.mark_read(QN('b')) + + child2 = activity.Scope(child, isolated=False) + child2.mark_read(QN('c')) + + self.assertTrue(QN('c') in child2.referenced) + self.assertTrue(QN('b') in child2.referenced) + self.assertFalse(QN('a') in child2.referenced) + + self.assertTrue(QN('c') in child.referenced) + self.assertTrue(QN('b') in child.referenced) + self.assertFalse(QN('a') in child.referenced) + + +class ActivityAnalizerTest(test.TestCase): + + def _parse_and_analyze(self, test_fn): + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace={}, + arg_values=None, + arg_types=None, + recursive=True) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + return node + + def test_local_markers(self): + + def test_fn(a): # pylint:disable=unused-argument + b = c # pylint:disable=undefined-variable + while b > 0: + b -= 1 + return b + + node = self._parse_and_analyze(test_fn) + self.assertFalse( + anno.getanno(node.body[0].body[0].value, + NodeAnno.IS_LOCAL)) # c in b = c + self.assertTrue( + anno.getanno(node.body[0].body[1].test.left, + NodeAnno.IS_LOCAL)) # b in b > 0 + self.assertTrue( + anno.getanno(node.body[0].body[2].value, + NodeAnno.IS_LOCAL)) # b in return b + + def assertScopeIs(self, scope, used, modified, created): + self.assertItemsEqual(used, tuple(str(s) for s in scope.used)) + self.assertItemsEqual(modified, tuple(str(s) for s in scope.modified)) + self.assertItemsEqual(created, tuple(str(s) for s in scope.created)) + + def test_print_statement(self): + + def test_fn(a): + b = 0 + c = 1 + print(a, b) + return c + + node = self._parse_and_analyze(test_fn) + print_node = node.body[0].body[2] + if isinstance(print_node, gast.Print): + # Python 2 + print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE) + else: + # Python 3 + assert isinstance(print_node, gast.Expr) + # The call node should be the one being annotated. + print_node = print_node.value + print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE) + # We basically need to detect which variables are captured by the call + # arguments. + self.assertScopeIs(print_args_scope, ('a', 'b'), (), ()) + + def test_call(self): + + def test_fn(a): + b = 0 + c = 1 + foo(a, b) # pylint:disable=undefined-variable + return c + + node = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[2].value + # We basically need to detect which variables are captured by the call + # arguments. + self.assertScopeIs( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ()) + + def test_while(self): + + def test_fn(a): + b = a + while b > 0: + c = b + b -= 1 + return b, c + + node = self._parse_and_analyze(test_fn) + while_node = node.body[0].body[1] + self.assertScopeIs( + anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), + ('c',)) + self.assertScopeIs( + anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), + ('b', 'c'), ('a', 'b', 'c')) + + def test_for(self): + + def test_fn(a): + b = a + for _ in a: + c = b + b -= 1 + return b, c + + node = self._parse_and_analyze(test_fn) + for_node = node.body[0].body[1] + self.assertScopeIs( + anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) + self.assertScopeIs( + anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), + ('b', 'c', '_'), ('a', 'b', 'c', '_')) + + def test_if(self): + + def test_fn(x): + if x > 0: + x = -x + y = 2 * x + z = -y + else: + x = 2 * x + y = -x + u = -y + return z, u + + node = self._parse_and_analyze(test_fn) + if_node = node.body[0].body[0] + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), + ('y', 'z')) + # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'), + ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), + ('x', 'y', 'u'), ('y', 'u')) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), + ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) + + def test_call_with_composite_names(self): + + def foo(*_): + pass + + def test_fn(a): + foo(a.b, a.c) + if a > 0: + a.b = 2 + else: + d = 2 + d.e = a.c + f = d.e + 1 + a.c = f + + node = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[0].value + self.assertScopeIs( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (), + ()) + if_node = node.body[0].body[1] + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ()) + self.assertScopeIs( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), + ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py b/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py new file mode 100644 index 0000000000000000000000000000000000000000..2d8e49442364fdd4a4752c8a83a5f3b76117fe57 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Annotations used by the static analizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from enum import Enum + + +class NoValue(Enum): + + def __repr__(self): + return self.name + + +class NodeAnno(NoValue): + """Additionnal annotations used by the static analyzer. + + These are in addition to the basic annotations declared in anno.py. + """ + + # Symbols + + IS_LOCAL = 'Symbol is local to the function scope being analized.' + IS_PARAM = 'Symbol is a parameter to the function being analized.' + IS_MODIFIED_SINCE_ENTRY = ( + 'Symbol has been explicitly replaced in the current function scope.') + + # Scopes + ARGS_SCOPE = 'The scope for the argument list of a function call.' + BODY_SCOPE = ( + 'The scope for the main body of a statement (True branch for if ' + 'statements, main body for loops).') + ORELSE_SCOPE = ( + 'The scope for the orelse body of a statement (False branch for if ' + 'statements, orelse body for loops).') diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py index 242e544b5286c683ee4aa97bc586751932c73815..9c0a9a9e74eccb3d22840032e8f0c2b81e051e7e 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py @@ -16,7 +16,7 @@ Live values are extracted from the known execution context. -Requires annotations generated by AccessResolver. +Requires activity analysis annotations. """ from __future__ import absolute_import @@ -26,47 +26,56 @@ from __future__ import print_function import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno -class LiveValueResolver(gast.NodeTransformer): +class LiveValueResolver(transformer.Base): """Annotates nodes with live values.""" - def __init__(self, namespace, literals): - """Create a new resolver. - - Args: - namespace: A dict representing the namespace visible to the AST in the - intended execution context. - literals: A dict mapping literal lymbol names to their value. An example - literal is "None". - """ - self.namespace = namespace + def __init__(self, context, literals): + super(LiveValueResolver, self).__init__(context) self.literals = literals def visit_ClassDef(self, node): self.generic_visit(node) - anno.setanno(node, 'live_val', self.namespace[node.name]) + anno.setanno(node, 'live_val', self.context.namespace[node.name]) return node def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): - assert anno.hasanno(node, 'is_local'), node - symbol_is_local = anno.getanno(node, 'is_local') - if not symbol_is_local: + assert anno.hasanno(node, NodeAnno.IS_LOCAL), node + symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL) + assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node + symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY) + assert anno.hasanno(node, NodeAnno.IS_PARAM), node + symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM) + + if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() - elif node.id in self.namespace: - obj = self.namespace[node.id] + elif node.id in self.context.namespace: + obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__name__,)) else: - raise ValueError('Could not find global symbol %s.' % node.id) + pass + # TODO(mdan): Should we raise an error here? + # Can encounter this when: + # * a symbol truly lacks reference + # * a symbol is new, like the new name of a function we just renamed. else: pass # TODO(mdan): Attempt to trace its value through the local chain. # TODO(mdan): Use type annotations as fallback. + + if not symbol_is_modified: + if node.id in self.context.arg_values: + obj = self.context.arg_values[node.id] + anno.setanno(node, 'live_val', obj) + anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node def visit_Attribute(self, node): @@ -79,15 +88,25 @@ class LiveValueResolver(gast.NodeTransformer): node.attr)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) + # TODO(mdan): Investigate the role built-in annotations can play here. + elif anno.hasanno(node.value, 'type'): + parent_type = anno.getanno(node.value, 'type') + if hasattr(parent_type, node.attr): + # This should hold for static members like methods. + # This would not hold for dynamic members like function attributes. + # For the dynamic case, we simply leave the node without an annotation, + # and let downstream consumers figure out what to do. + anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) + anno.setanno(node, 'fqn', + anno.getanno(node.value, 'type_fqn') + (node.attr,)) elif isinstance(node.value, gast.Name): stem_name = node.value # All nonlocal symbols should be fully resolved. - assert anno.hasanno(stem_name, 'is_local'), stem_name - assert anno.getanno(stem_name, 'is_local'), stem_name + assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name # TODO(mdan): Figure out what to do when calling attribute on local object # Maybe just leave as-is? return node -def resolve(node, namespace, literals): - return LiveValueResolver(namespace, literals).visit(node) +def resolve(node, context, literals): + return LiveValueResolver(context, literals).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py index e77497654a0b3096422deef9a3f008eeb6c6be05..9f64689401e3594a77fbdd7b6f02880bd6e90492 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py @@ -19,24 +19,47 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.py2tf.pyct.static_analysis import activity from tensorflow.contrib.py2tf.pyct.static_analysis import live_values +from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class LiveValuesResolverTest(test.TestCase): + def _parse_and_analyze(self, + test_fn, + namespace, + literals=None, + arg_types=None): + literals = literals or {} + arg_types = arg_types or {} + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types, + recursive=True) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + node = live_values.resolve(node, ctx, literals) + node = type_info.resolve(node, ctx) + node = live_values.resolve(node, ctx, literals) + return node + def test_literals(self): def test_fn(): return Foo # pylint: disable=undefined-variable - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {}, {'Foo': 'bar'}) - + node = self._parse_and_analyze(test_fn, {}, {'Foo': 'bar'}) retval_node = node.body[0].body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val')) @@ -48,10 +71,7 @@ class LiveValuesResolverTest(test.TestCase): def test_fn(): return foo() - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'foo': foo}, {}) - + node = self._parse_and_analyze(test_fn, {'foo': foo}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo',), anno.getanno(func_node, 'fqn')) @@ -61,15 +81,29 @@ class LiveValuesResolverTest(test.TestCase): def test_fn(): return constant_op.constant(0) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, {'constant_op': constant_op}, {}) - + node = self._parse_and_analyze(test_fn, {'constant_op': constant_op}) func_node = node.body[0].body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn')) + def test_attributes_with_type_hints(self): + + class TestClass(object): + + def member(self): + pass + + def test_fn(self): + return self.member() + + node = self._parse_and_analyze( + TestClass.test_fn, {'constant_op': constant_op}, + arg_types={'self': (TestClass.__name__, TestClass)}) + func_node = node.body[0].body[0].value.func + self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val')) + self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py index 0042aa90ed218d42aedc720c94d1a478bc9f18f5..8203bda0f9a792a5b24b9abb25d8f39b61625748 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py @@ -36,8 +36,6 @@ class Scope(object): most recently assigned to the symbol. """ - # TODO(mdan): Should rather use a CFG here? - def __init__(self, parent): """Create a new scope. @@ -117,18 +115,34 @@ class TypeInfoResolver(transformer.Base): node.orelse = self._visit_block(node.orelse) return node + def _process_function_arg(self, arg_name): + str_name = str(arg_name) + if self.function_level == 1 and str_name in self.context.arg_types: + # Forge a node to hold the type information, so that method calls on + # it can resolve the type. + type_holder = arg_name.ast() + type_string, type_obj = self.context.arg_types[str_name] + anno.setanno(type_holder, 'type', type_obj) + anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) + self.scope.setval(arg_name, type_holder) + + def visit_arg(self, node): + self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) + return node + def visit_Name(self, node): self.generic_visit(node) + qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Param): - self.scope.setval(node.id, gast.Name(node.id, gast.Load(), None)) - if self.function_level == 1 and node.id in self.context.arg_types: - # Forge a node to hold the type information, so that method calls on - # it can resolve the type. - type_holder = gast.Name(node.id, gast.Load(), None) - type_string, type_obj = self.context.arg_types[node.id] - anno.setanno(type_holder, 'type', type_obj) - anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) - self.scope.setval(node.id, type_holder) + self._process_function_arg(qn) + elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): + # E.g. if we had + # a = b + # then for future references to `a` we should have traced_source = `b` + traced_source = self.scope.getval(qn) + if anno.hasanno(traced_source, 'type'): + anno.setanno(node, 'type', anno.getanno(traced_source, 'type')) + anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn')) return node def _process_variable_assignment(self, source, targets): @@ -147,16 +161,11 @@ class TypeInfoResolver(transformer.Base): for t in targets: if isinstance(t, gast.Tuple): for i, e in enumerate(t.elts): - self.scope.setval(e.id, - gast.Subscript( - source, gast.Index(i), ctx=gast.Store())) - elif isinstance(t, gast.Name): - self.scope.setval(t.id, source) - elif isinstance(t, gast.Attribute): - if not (isinstance(t.value, gast.Name) and t.value.id == 'self'): - raise ValueError( - 'Dont know how to handle assignment to attributes of objects' - ' other than "self": [%s].%s' % (t.value, t.attr)) + self.scope.setval( + anno.getanno(e, anno.Basic.QN), + gast.Subscript(source, gast.Index(i), ctx=gast.Store())) + elif isinstance(t, (gast.Name, gast.Attribute)): + self.scope.setval(anno.getanno(t, anno.Basic.QN), source) else: raise ValueError('Dont know how to handle assignment to %s' % t) @@ -172,38 +181,6 @@ class TypeInfoResolver(transformer.Base): self._process_variable_assignment(node.value, node.targets) return node - def visit_Call(self, node): - target = node.func - if not anno.hasanno(target, 'live_val'): - if not isinstance(target, gast.Attribute): - # Suspecting this pattern would reach here: - # foo = bar - # foo() - raise ValueError('Dont know how to handle dynamic functions.') - if not isinstance(target.value, gast.Name): - # Possible example of this kind: - # foo = module.Foo() - # foo.bar.baz() - # TODO(mdan): This should be doable by using the FQN. - raise ValueError('Dont know how to handle object properties yet.') - # In the example below, object_source is 'tr.train.Optimizer()': - # opt = tf.train.Optimizer() - # opt.foo() - if self.scope.hasval(target.value.id): - object_source = self.scope.getval(target.value.id) - if not anno.hasanno(object_source, 'type'): - raise ValueError('Could not determine type of "%s". Is it dynamic?' % - (target.value.id)) - anno.setanno(target, 'type', anno.getanno(object_source, 'type')) - anno.setanno(target, 'type_fqn', anno.getanno(object_source, - 'type_fqn')) - else: - # TODO(mdan): Figure out what could the user do to get past this. - raise ValueError('No info on "%s". Is it dynamically built?' % - (target.value.id)) - self.generic_visit(node) - return node - def resolve(node, context): return TypeInfoResolver(context).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py index a491f49ca3b87d1340fdd691431e127737abc006..3659f949db9910534870d8dd9e42fd4ee8297253 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py @@ -21,8 +21,8 @@ from __future__ import print_function from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.py2tf.pyct.static_analysis import activity from tensorflow.contrib.py2tf.pyct.static_analysis import live_values from tensorflow.contrib.py2tf.pyct.static_analysis import type_info from tensorflow.python.client import session @@ -57,17 +57,20 @@ class ScopeTest(test.TestCase): class TypeInfoResolverTest(test.TestCase): def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( namer=None, - source_code=None, + source_code=source, source_file=None, namespace=namespace, arg_values=None, - arg_types=arg_types) - node = parser.parse_object(test_fn) - node = access.resolve(node) - node = live_values.resolve(node, namespace, {}) + arg_types=arg_types, + recursive=True) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + node = live_values.resolve(node, ctx, {}) node = type_info.resolve(node, ctx) + node = live_values.resolve(node, ctx, {}) return node def test_constructor_detection(self): @@ -83,16 +86,16 @@ class TypeInfoResolverTest(test.TestCase): self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(call_node, 'type_fqn')) - def test_class_members(self): + def test_class_members_of_detected_constructor(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) opt.minimize(0) node = self._parse_and_analyze(test_fn, {'training': training}) - attr_call_node = node.body[0].body[1].value.func - self.assertEquals((training.__name__, 'GradientDescentOptimizer'), - anno.getanno(attr_call_node, 'type_fqn')) + method_call = node.body[0].body[1].value.func + self.assertEquals(training.GradientDescentOptimizer.minimize, + anno.getanno(method_call, 'live_val')) def test_class_members_in_with_stmt(self): @@ -106,11 +109,11 @@ class TypeInfoResolverTest(test.TestCase): self.assertEquals((session.__name__, 'Session'), anno.getanno(constructor_call, 'type_fqn')) - member_call = node.body[0].body[0].body[0].value.func - self.assertEquals((session.__name__, 'Session'), - anno.getanno(member_call, 'type_fqn')) + method_call = node.body[0].body[0].body[0].value.func + self.assertEquals(session.Session.run, anno.getanno(method_call, + 'live_val')) - def test_constructor_deta_dependent(self): + def test_constructor_data_dependent(self): def test_fn(x): if x > 0: @@ -119,16 +122,18 @@ class TypeInfoResolverTest(test.TestCase): opt = training.GradientDescentOptimizer(0.01) opt.minimize(0) - with self.assertRaises(transformer.PyFlowParseError): - self._parse_and_analyze(test_fn, {'training': training}) + node = self._parse_and_analyze(test_fn, {'training': training}) + method_call = node.body[0].body[1].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) - with self.assertRaises(transformer.PyFlowParseError): - self._parse_and_analyze(test_fn, {'training': training}) + node = self._parse_and_analyze(test_fn, {}) + method_call = node.body[0].body[0].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) def test_parameter_class_members_with_value_hints(self): @@ -138,14 +143,13 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze( test_fn, {'training': training}, arg_types={ - 'opt': (('%s.GradientDescentOptimizer' % training.__name__), - training.GradientDescentOptimizer(0.1)) + 'opt': (training.GradientDescentOptimizer.__name__, + training.GradientDescentOptimizer) }) - attr_call_node = node.body[0].body[0].value.func - self.assertEquals( - tuple(training.__name__.split('.')) + ('GradientDescentOptimizer',), - anno.getanno(attr_call_node, 'type_fqn')) + method_call = node.body[0].body[0].value.func + self.assertEquals(training.GradientDescentOptimizer.minimize, + anno.getanno(method_call, 'live_val')) def test_function_variables(self): @@ -156,8 +160,9 @@ class TypeInfoResolverTest(test.TestCase): foo = bar foo() - with self.assertRaises(transformer.PyFlowParseError): - self._parse_and_analyze(test_fn, {'bar': bar}) + node = self._parse_and_analyze(test_fn, {'bar': bar}) + method_call = node.body[0].body[1].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) def test_nested_members(self): @@ -165,8 +170,9 @@ class TypeInfoResolverTest(test.TestCase): foo = training.GradientDescentOptimizer(0.1) foo.bar.baz() - with self.assertRaises(transformer.PyFlowParseError): - self._parse_and_analyze(test_fn, {'training': training}) + node = self._parse_and_analyze(test_fn, {'training': training}) + method_call = node.body[0].body[1].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py index 4fadc793e6d1dfa8ddabea1d607de68ac6ad9c85..6ee6c0c5ceb70d87779ee313670135cadc5214b5 100644 --- a/tensorflow/contrib/py2tf/pyct/templates.py +++ b/tensorflow/contrib/py2tf/pyct/templates.py @@ -22,11 +22,13 @@ from __future__ import division from __future__ import print_function import ast -import copy +import textwrap import gast +from tensorflow.contrib.py2tf.pyct import ast_util from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct import qual_names class ReplaceTransformer(gast.NodeTransformer): @@ -40,6 +42,7 @@ class ReplaceTransformer(gast.NodeTransformer): that these placeholders will be replaced by. """ self.replacements = replacements + self.in_replacements = False # TODO(mdan): Make a more detailed pass and clean up if needed. @@ -56,61 +59,98 @@ class ReplaceTransformer(gast.NodeTransformer): repl = self.replacements[node.name] if not isinstance(repl, (gast.Name, ast.Name)): raise ValueError( - 'A function name can only be replaced by a Name node. Found: %s', + 'A function name can only be replaced by a Name node. Found: %s' % repl) node.name = repl.id return node - def visit_Name(self, node): - if node.id in self.replacements: - # TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations. - new_nodes = copy.copy(self.replacements[node.id]) - if isinstance(new_nodes, gast.AST): - new_nodes = [new_nodes] - # Preserve the target context. - for n in new_nodes: - if isinstance(n, gast.Tuple): - for e in n.elts: - e.ctx = node.ctx - n.ctx = node.ctx - if len(new_nodes) == 1: - new_nodes, = new_nodes - return new_nodes + def _set_inner_child_context(self, node, ctx): + if isinstance(node, gast.Attribute): + self._set_inner_child_context(node.value, ctx) + node.ctx = gast.Load() + elif isinstance(node, gast.Tuple): + for e in node.elts: + self._set_inner_child_context(e, ctx) + node.ctx = ctx + elif isinstance(node, gast.Name): + node.ctx = ctx + elif isinstance(node, (gast.Str, gast.Num)): + pass else: + raise ValueError('unexpected node type "%s"' % node) + + def visit_Name(self, node): + if node.id not in self.replacements: return node + new_nodes = ast_util.copy_clean(self.replacements[node.id]) + if isinstance(new_nodes, gast.AST): + new_nodes = [new_nodes] + + # Preserve the target context. + for n in new_nodes: + if isinstance(n, gast.Tuple): + for e in n.elts: + self._set_inner_child_context(e, node.ctx) + if isinstance(n, gast.Attribute): + # For attributes, the inner Name node receives the context, while the + # outer ones have it set to Load. + self._set_inner_child_context(n, node.ctx) + else: + n.ctx = node.ctx + + if len(new_nodes) == 1: + new_nodes, = new_nodes + + return new_nodes + + +def _convert_to_ast(n): + """Convert from a known data type to AST.""" + if isinstance(n, str): + # Note: the node will receive the ctx value from the template, see + # ReplaceTransformer.visit_Name. + return gast.Name(id=n, ctx=None, annotation=None) + if isinstance(n, qual_names.QN): + return n.ast() + if isinstance(n, list): + return [_convert_to_ast(e) for e in n] + if isinstance(n, tuple): + return tuple(_convert_to_ast(e) for e in n) + return n + def replace(template, **replacements): """Replace placeholders in a Python template. + AST Name and Tuple nodes always receive the context that inferred from + the template. However, when replacing more complex nodes (that can potentially + contain Name children), then the caller is responsible for setting the + appropriate context. + Args: - template: A function to be used as a template. Any placeholder is expected - to also be a function argument. + template: A string representing Python code. Any symbol name can be used + that appears in the template code can be used as placeholder. **replacements: A mapping from placeholder names to (lists of) AST nodes - that these placeholders will be replaced by. + that these placeholders will be replaced by. String values are also + supported as a shorthand for AST Name nodes with the respective ID. Returns: - body: An AST node or list of AST nodes with the replacements made. If the - template was a function, a list will be returned. If the template was a - node, the same node will be returned. If the template was a string, an - AST node will be returned (a `Module` node in the case of a multi-line - string, an `Expr` node otherwise). + An AST node or list of AST nodes with the replacements made. If the + template was a function, a list will be returned. If the template was a + node, the same node will be returned. If the template was a string, an + AST node will be returned (a `Module` node in the case of a multi-line + string, an `Expr` node otherwise). Raises: - ValueError: If a function is used as a template and an incorrect set of - replacements was passed. + ValueError: if the arguments are incorrect. """ - tree = parser.parse_object(template).body[0] - placeholders = set(arg.id for arg in tree.args.args) - tree.args.args = [] - if tree.args.vararg: - placeholders.add(tree.args.vararg) - tree.args.vararg = None - if set(replacements.keys()) != placeholders: - raise ValueError( - 'too many or few replacements. replacements: %s; placeholders: %s' % - (replacements.keys(), placeholders)) - - # Perform the replacement, stripping the function into which the template was - # wrapped. - return ReplaceTransformer(replacements).visit(tree).body + if not isinstance(template, str): + raise ValueError('Expected string template, got %s' % type(template)) + tree = parser.parse_str(textwrap.dedent(template)) + for k in replacements: + replacements[k] = _convert_to_ast(replacements[k]) + results = ReplaceTransformer(replacements).visit(tree).body + if isinstance(results, list): + return [qual_names.resolve(r) for r in results] + return qual_names.resolve(results) diff --git a/tensorflow/contrib/py2tf/pyct/templates_test.py b/tensorflow/contrib/py2tf/pyct/templates_test.py index 2ad8b9317b67c7ae18a16efac745138e14101e6a..8ccfde8573724741b0bbe4eacb3c54beb381ee7e 100644 --- a/tensorflow/contrib/py2tf/pyct/templates_test.py +++ b/tensorflow/contrib/py2tf/pyct/templates_test.py @@ -27,49 +27,56 @@ from tensorflow.python.platform import test class TemplatesTest(test.TestCase): + def test_replace_tuple(self): + template = """ + def test_fn(a, c): + return b, + """ + + node = templates.replace(template, b=('a', 'c'))[0] + result, _ = compiler.ast_to_object(node) + + self.assertEquals((2, 3), result.test_fn(2, 3)) + def test_replace_variable(self): - def template(a): # pylint:disable=unused-argument - def test_fn(a): # pylint:disable=unused-variable + template = """ + def test_fn(a): a += 1 a = 2 * a + 1 - return b # pylint:disable=undefined-variable + return b + """ - node = templates.replace( - template, a=gast.Name('b', gast.Load(), None))[0] - result = compiler.ast_to_object(node) + node = templates.replace(template, a='b')[0] + result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2)) def test_replace_function_name(self): - def template(fname): # pylint:disable=unused-argument - def fname(a): # pylint:disable=function-redefined + template = """ + def fname(a): a += 1 a = 2 * a + 1 return a + """ - node = templates.replace( - template, fname=gast.Name('test_fn', gast.Load(), None))[0] - result = compiler.ast_to_object(node) + node = templates.replace(template, fname='test_fn')[0] + result, _ = compiler.ast_to_object(node) self.assertEquals(7, result.test_fn(2)) def test_code_block(self): - def template(block): # pylint:disable=unused-argument - def test_fn(a): # pylint:disable=unused-variable - block # pylint:disable=pointless-statement + template = """ + def test_fn(a): + block return a + """ node = templates.replace( template, block=[ - gast.Assign( - [ - gast.Name('a', gast.Store(), None) - ], - gast.BinOp( - gast.Name('a', gast.Load(), None), - gast.Add(), - gast.Num(1))), + gast.Assign([ + gast.Name('a', None, None) + ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ] * 2)[0] - result = compiler.ast_to_object(node) + result, _ = compiler.ast_to_object(node) self.assertEquals(3, result.test_fn(1)) diff --git a/tensorflow/contrib/py2tf/pyct/transformer.py b/tensorflow/contrib/py2tf/pyct/transformer.py index d5aa23eaebbbf7540d52d9fa9cc5292e0f756e6d..877d52af016af720424c8a56257fec9ab64611cb 100644 --- a/tensorflow/contrib/py2tf/pyct/transformer.py +++ b/tensorflow/contrib/py2tf/pyct/transformer.py @@ -18,8 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + import gast +import six +from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import pretty_printer @@ -41,18 +45,25 @@ class Base(gast.NodeTransformer): self.context = context def visit(self, node): + source_code = self.context.source_code + source_file = self.context.source_file try: - source_code = self.context.source_code - source_file = self.context.source_file if source_code and hasattr(node, 'lineno'): self._lineno = node.lineno self._col_offset = node.col_offset + if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + return node return super(Base, self).visit(node) - except ValueError as e: - msg = '%s\nOccurred at node:\n%s' % (str(e), pretty_printer.fmt(node)) + except (ValueError, AttributeError, KeyError, NotImplementedError, + AssertionError) as e: + msg = '%s: %s\nOccurred at node:\n%s' % ( + e.__class__.__name__, str(e), pretty_printer.fmt(node, color=False)) if source_code: - line = self._source.splitlines()[self._lineno - 1] + line = source_code.splitlines()[self._lineno - 1] else: line = '' - raise PyFlowParseError( - msg, (source_file, self._lineno, self._col_offset + 1, line)) + six.reraise(PyFlowParseError, + PyFlowParseError( + msg, + (source_file, self._lineno, self._col_offset + 1, line)), + sys.exc_info()[2]) diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/py2tf/utils/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..c2fdd40707775783140390e4b5c0186c9c3e562e --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/BUILD @@ -0,0 +1,108 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "utils", + srcs = [ + "__init__.py", + "context_managers.py", + "misc.py", + "multiple_dispatch.py", + "printing.py", + "py_func.py", + "tensor_list.py", + "type_check.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/python:script_ops", + "@six_archive//:six", + ], +) + +py_test( + name = "context_managers_test", + srcs = ["context_managers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "misc_test", + srcs = ["misc_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "multiple_dispatch_test", + srcs = ["multiple_dispatch_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "py_func_test", + srcs = ["py_func_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "printing_test", + srcs = ["printing_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "type_check_test", + srcs = ["type_check_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "tensor_list_test", + srcs = ["tensor_list_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + "//tensorflow/python:list_ops", + ], +) diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1b993fd366e1317e5f7e01fe849d86c93b8fc2 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility module that contains APIs usable in the generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns +from tensorflow.contrib.py2tf.utils.misc import alias_tensors +from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond +from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while +from tensorflow.contrib.py2tf.utils.printing import call_print +from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func +from tensorflow.contrib.py2tf.utils.type_check import is_tensor diff --git a/tensorflow/contrib/py2tf/utils/context_managers.py b/tensorflow/contrib/py2tf/utils/context_managers.py new file mode 100644 index 0000000000000000000000000000000000000000..38d9e11fe9069722b9023fee848bf53e1f72de6a --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/context_managers.py @@ -0,0 +1,42 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various context managers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.python.framework import ops + + +def control_dependency_on_returns(return_value): + """Create a TF control dependency on the return values of a function. + + If the function had no return value, a no-op context is returned. + + Args: + return_value: The return value to set as control dependency. + + Returns: + A context manager. + """ + if return_value is None: + return contextlib.contextmanager(lambda: (yield))() + # TODO(mdan): Filter to tensor objects. + if not isinstance(return_value, (list, tuple)): + return_value = (return_value,) + return ops.control_dependencies(return_value) diff --git a/tensorflow/contrib/py2tf/utils/context_managers_test.py b/tensorflow/contrib/py2tf/utils/context_managers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..633ba93540e696889a6b2b71b40b999da39d48ff --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/context_managers_test.py @@ -0,0 +1,42 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for context_managers module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils import context_managers +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + + +class ContextManagersTest(test.TestCase): + + def test_control_dependency_on_returns(self): + # Just dry run them. + with context_managers.control_dependency_on_returns(None): + pass + with context_managers.control_dependency_on_returns( + constant_op.constant(1)): + pass + with context_managers.control_dependency_on_returns( + [constant_op.constant(1), + constant_op.constant(2)]): + pass + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/misc.py b/tensorflow/contrib/py2tf/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..1b06caf0bdeb6f4a079e33f2e887d2dca017adc2 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/misc.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Miscellaneous utilities that don't fit anywhere else.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops + + +def alias_tensors(*args): + """Wrap any Tensor arguments with an identity op. + + Any other argument, including Variables, is returned unchanged. + + Args: + *args: Any arguments. Must contain at least one element. + + Returns: + Same as *args, with Tensor instances replaced as described. + + Raises: + ValueError: If args doesn't meet the requirements. + """ + + def alias_if_tensor(a): + return array_ops.identity(a) if isinstance(a, ops.Tensor) else a + + # TODO(mdan): Recurse into containers? + # TODO(mdan): Anything we can do about variables? Fake a scope reuse? + if len(args) > 1: + return (alias_if_tensor(a) for a in args) + elif len(args) == 1: + return alias_if_tensor(args[0]) + + raise ValueError('at least one argument required') diff --git a/tensorflow/contrib/py2tf/utils/misc_test.py b/tensorflow/contrib/py2tf/utils/misc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcb304c838df69e9e3961907362c7939c065117 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/misc_test.py @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for misc module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils import misc +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class ContextManagersTest(test.TestCase): + + def test_alias_single_tensor(self): + a = constant_op.constant(1) + + new_a = misc.alias_tensors(a) + self.assertFalse(new_a is a) + with self.test_session() as sess: + self.assertEqual(1, sess.run(new_a)) + + def test_alias_tensors(self): + a = constant_op.constant(1) + v = variables.Variable(2) + s = 'a' + l = [1, 2, 3] + + new_a, new_v, new_s, new_l = misc.alias_tensors(a, v, s, l) + + self.assertFalse(new_a is a) + self.assertTrue(new_v is v) + self.assertTrue(new_s is s) + self.assertTrue(new_l is l) + with self.test_session() as sess: + self.assertEqual(1, sess.run(new_a)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/multiple_dispatch.py b/tensorflow/contrib/py2tf/utils/multiple_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..a855fdc075941915035d1e3380846ff912803494 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/multiple_dispatch.py @@ -0,0 +1,90 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for type-dependent behavior used in py2tf-generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.py2tf.utils.type_check import is_tensor +from tensorflow.python.ops import control_flow_ops + + +def run_cond(condition, true_fn, false_fn): + """Type-dependent functional conditional. + + Args: + condition: A Tensor or Python bool. + true_fn: A Python callable implementing the true branch of the conditional. + false_fn: A Python callable implementing the false branch of the + conditional. + + Returns: + result: The result of calling the appropriate branch. If condition is a + Tensor, tf.cond will be used. Otherwise, a standard Python if statement will + be ran. + """ + if is_tensor(condition): + return control_flow_ops.cond(condition, true_fn, false_fn) + else: + return py_cond(condition, true_fn, false_fn) + + +def py_cond(condition, true_fn, false_fn): + if condition: + return true_fn() + else: + return false_fn() + + +def run_while(cond_fn, body_fn, init_args): + """Type-dependent functional while loop. + + Args: + cond_fn: A Python callable implementing the stop conditions of the loop. + body_fn: A Python callable implementing the body of the loop. + init_args: The initial values of the arguments that will be passed to both + cond_fn and body_fn. + + Returns: + result: A list of values with the same shape and type as init_args. If any + of the init_args, or any variables closed-over in cond_fn are Tensors, + tf.while_loop will be used, otherwise a Python while loop will be ran. + + Raises: + ValueError: if init_args is not a tuple or list with one or more elements. + """ + if not isinstance(init_args, (tuple, list)) or not init_args: + raise ValueError( + 'init_args must be a non-empty list or tuple, found %s' % init_args) + + # TODO(alexbw): statically determine all active variables in cond_fn, + # and pass them directly + closure_vars = tuple( + [c.cell_contents for c in six.get_function_closure(cond_fn) or []]) + possibly_tensors = tuple(init_args) + closure_vars + if is_tensor(*possibly_tensors): + return control_flow_ops.while_loop(cond_fn, body_fn, init_args) + else: + return py_while_loop(cond_fn, body_fn, init_args) + + +def py_while_loop(cond_fn, body_fn, init_args): + state = init_args + while cond_fn(*state): + state = body_fn(*state) + return state diff --git a/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py b/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb4d4086b002211eebb86783bb7212c707a1418 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for multiple_dispatch.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.contrib.py2tf.utils import multiple_dispatch +from tensorflow.python.client.session import Session +from tensorflow.python.framework.constant_op import constant +from tensorflow.python.platform import test + + +class MultipleDispatchTest(test.TestCase): + + def test_run_cond_python(self): + true_fn = lambda: 2.0 + false_fn = lambda: 3.0 + self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2.0) + self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3.0) + + def test_run_cond_tf(self): + + true_fn = lambda: constant([2.0]) + false_fn = lambda: constant([3.0]) + with Session() as sess: + out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn) + self.assertEqual(sess.run(out), 2.0) + out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn) + self.assertEqual(sess.run(out), 3.0) + + def test_run_while_python(self): + cond_fn = lambda x, t, s: x > t + body_fn = lambda x, t, s: (x * s, t, s) + + x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 1.0, 0.5]) + self.assertEqual(x, 0.75) + + x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 4.0, 0.5]) + self.assertEqual(x, 3.0) + + def test_run_while_tf(self): + cond_fn = lambda x, t, s: x > t + body_fn = lambda x, t, s: (x * s, t, s) + + with Session() as sess: + x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, + [constant(3.0), 1.0, 0.5]) + self.assertEqual(sess.run(x), 0.75) + + x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, + [constant(3.0), 4.0, 0.5]) + self.assertEqual(sess.run(x), 3.0) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/printing.py b/tensorflow/contrib/py2tf/utils/printing.py new file mode 100644 index 0000000000000000000000000000000000000000..95a62bd80b5f4854e6a062df18d882f7bd495555 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/printing.py @@ -0,0 +1,47 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow printing support utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils import py_func +from tensorflow.python.ops import logging_ops + + +def is_tf_print_compatible(value): + # TODO(mdan): Enable once we can reliably test this. + # This is currently disabled because we can't capture the output of + # op kernels from Python. + del value + return False + + +def call_print(*values): + """Compiled counterpart of the print builtin. + + The function attempts to use tf.Print if all the values are compatible. + Otherwise, it will fall back to py_func. + + Args: + *values: values to print + Returns: + A dummy value indicating the print completed. If tf. + """ + + if all(map(is_tf_print_compatible, values)): + return logging_ops.Print(1, values) + return py_func.wrap_py_func(print, None, values, use_dummy_return=True) diff --git a/tensorflow/contrib/py2tf/utils/printing_test.py b/tensorflow/contrib/py2tf/utils/printing_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2070deb304d8df2433fb9a95ae36d48415578482 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/printing_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for printing module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import six + +from tensorflow.contrib.py2tf.utils import printing +from tensorflow.python.platform import test + + +class ContextManagersTest(test.TestCase): + + def test_call_print_tf(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(printing.call_print('test message', 1)) + self.assertEqual(out_capturer.getvalue(), 'test message 1\n') + finally: + sys.stdout = sys.__stdout__ + + def test_call_print_py_func(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(printing.call_print('test message', [1, 2])) + self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') + finally: + sys.stdout = sys.__stdout__ + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/py_func.py b/tensorflow/contrib/py2tf/utils/py_func.py new file mode 100644 index 0000000000000000000000000000000000000000..838872d092a3ab07e965180eff4fec7ff6c4ccf9 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/py_func.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pyfunc creation utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import script_ops + + +def wrap_py_func(f, return_dtypes, arguments, use_dummy_return=False): + """Helper that wraps a callable to py_func. + + The helper passes tensor arguments through the py_func interface. Non-tensor + arguments are allowed, and will be passed to f directly. Note that non-tensor + arguments are captured by f will not update every time the wrapper is + called (this is consistent with its argument list, which only includes + the tensor arguments). In general, it's safest not to reuse this wrapper. + + Args: + f: Callable + return_dtypes: DType, tuple, list or None, the data type for each of f's + return value. None if f has no return values or use_dummy_return is + True. + arguments: Arguments for f + use_dummy_return: If True, the function will return a dummy value of 1 + and discard its actual return value. + Returns: + The return values of f converted to tensor. + Raises: + ValueError: if the arguments are incorrect. + """ + + if return_dtypes and use_dummy_return: + raise ValueError('if use_dummy_return is True, return_dtypes must be empty') + + n = len(arguments) + arg_is_tensor = tuple(map(tensor_util.is_tensor, arguments)) + index_in_tensor_list = [0] * n + i = 0 + for j in range(n): + index_in_tensor_list[j] = i + if arg_is_tensor[j]: + i += 1 + + def f_wrapper(*tensor_args): + f_args = tuple(tensor_args[index_in_tensor_list[i]] + if arg_is_tensor[i] else arguments[i] for i in range(n)) + retval = f(*f_args) + return 1 if use_dummy_return else retval + + return script_ops.py_func( + f_wrapper, tuple(arguments[i] for i in range(n) if arg_is_tensor[i]), + dtypes.int64 if use_dummy_return else return_dtypes) diff --git a/tensorflow/contrib/py2tf/utils/py_func_test.py b/tensorflow/contrib/py2tf/utils/py_func_test.py new file mode 100644 index 0000000000000000000000000000000000000000..776b5309c6f027bb2008aa83d48e4155e817ed97 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/py_func_test.py @@ -0,0 +1,91 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for wrap_py_func module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils import py_func +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class PyFuncTest(test.TestCase): + + def test_wrap_py_func_simple(self): + + def test_fn(a, b, c): + return a + b + c + + with self.test_session() as sess: + tensor_1 = constant_op.constant(1) + self.assertEqual(3, + sess.run( + py_func.wrap_py_func(test_fn, dtypes.int64, + (1, tensor_1, 1)))) + self.assertEqual(3, + sess.run( + py_func.wrap_py_func(test_fn, dtypes.int64, + (1, 1, 1)))) + self.assertEqual(3, + sess.run( + py_func.wrap_py_func(test_fn, dtypes.int64, + (tensor_1, 1, tensor_1)))) + + def test_wrap_py_func_complex_args(self): + + class TestClass(object): + + def __init__(self): + self.foo = 5 + + def test_fn(a, b): + return a * b.foo + + with self.test_session() as sess: + self.assertEqual(35, + sess.run( + py_func.wrap_py_func(test_fn, dtypes.int64, + (7, TestClass())))) + self.assertEqual( + 35, + sess.run( + py_func.wrap_py_func(test_fn, dtypes.int64, + (constant_op.constant(7), TestClass())))) + + def test_wrap_py_func_dummy_return(self): + + side_counter = [0] + + def test_fn(_): + side_counter[0] += 1 + + with self.test_session() as sess: + self.assertEqual(1, + sess.run( + py_func.wrap_py_func(test_fn, None, (5,), True))) + self.assertEqual([1], side_counter) + self.assertEqual(1, + sess.run( + py_func.wrap_py_func(test_fn, None, + (constant_op.constant(5),), + True))) + self.assertEqual([2], side_counter) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/tensor_list.py b/tensorflow/contrib/py2tf/utils/tensor_list.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ff49e2a0eff384f10903e12212ab929e267804 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/tensor_list.py @@ -0,0 +1,49 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A typed list in Python.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import list_ops + + +class TensorList(object): + """Tensor list wrapper API-compatible with Python built-in list.""" + + def __init__(self, shape, dtype): + self.dtype = dtype + self.shape = shape + self.clear() + + def append(self, value): + self.list_ = list_ops.tensor_list_push_back(self.list_, value) + + def pop(self): + self.list_, value = list_ops.tensor_list_pop_back(self.list_, self.dtype) + return value + + def clear(self): + self.list_ = list_ops.empty_tensor_list(self.shape, self.dtype) + + def count(self): + return list_ops.tensor_list_length(self.list_) + + def __getitem__(self, key): + return list_ops.tensor_list_get_item(self.list_, key, self.dtype) + + def __setitem__(self, key, value): + self.list_ = list_ops.tensor_list_set_item(self.list_, key, value) diff --git a/tensorflow/contrib/py2tf/utils/tensor_list_test.py b/tensorflow/contrib/py2tf/utils/tensor_list_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e554a162674e08da21785dcbe193c54647f128 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/tensor_list_test.py @@ -0,0 +1,89 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for PyFlow list.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils import tensor_list as tl +from tensorflow.python.client.session import Session +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.framework.constant_op import constant +from tensorflow.python.platform import test + + +class TensorListTest(test.TestCase): + + def test_list_append_python(self): + with context.eager_mode(): + a = constant(3.0) + l = tl.TensorList(a.shape, a.dtype) + l.append(a) + self.assertEqual(l.count().numpy(), 1) + l.append(a) + self.assertEqual(l.count().numpy(), 2) + _ = l.pop() + self.assertEqual(l.count().numpy(), 1) + a2 = l.pop() + self.assertEqual(l.count().numpy(), 0) + self.assertEqual(a.numpy(), a2.numpy()) + + def test_list_index_python(self): + with context.eager_mode(): + a = constant(3.0) + b = constant(2.0) + l = tl.TensorList(a.shape, a.dtype) + l.append(a) + self.assertEqual(l[0].numpy(), a.numpy()) + l[0] = ops.convert_to_tensor(b) + self.assertEqual(l[0].numpy(), b.numpy()) + + def test_list_append_tf(self): + a = constant(3.0) + l = tl.TensorList(a.shape, a.dtype) + l.append(a) + c1 = l.count() + l.append(a) + c2 = l.count() + _ = l.pop() + c3 = l.count() + a2 = l.pop() + c4 = l.count() + with Session() as sess: + c1, c2, c3, c4, a, a2 = sess.run([c1, c2, c3, c4, a, a2]) + self.assertEqual(c1, 1) + self.assertEqual(c2, 2) + self.assertEqual(c3, 1) + self.assertEqual(c4, 0) + self.assertEqual(a, a2) + + def test_list_index_tf(self): + a = constant(3.0) + b = constant(2.0) + l = tl.TensorList(a.shape, a.dtype) + l.append(a) + l0 = l[0] + l[0] = b + l1 = l[0] + with self.test_session() as sess: + l0, l1, a, b = sess.run([l0, l1, a, b]) + self.assertEqual(l0, a) + self.assertEqual(l1, b) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/type_check.py b/tensorflow/contrib/py2tf/utils/type_check.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca2dec872c8a9ca7bedaa8603f70e3214a3e24a --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/type_check.py @@ -0,0 +1,33 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities used in py2tf-generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import tensor_util + + +def is_tensor(*args): + """Check if all arguments are tensors. + + Args: + *args: Python objects that may or may not be tensors. + + Returns: + True if all *args are TensorFlow types, False if one or more are not. + """ + return any([tensor_util.is_tensor(a) for a in args]) diff --git a/tensorflow/contrib/py2tf/utils/type_check_test.py b/tensorflow/contrib/py2tf/utils/type_check_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0428e9cccecdc67511e236bc00655a055aea29 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/type_check_test.py @@ -0,0 +1,43 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for type_check.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy + +from tensorflow.contrib.py2tf.utils import type_check +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class TypeCheckTest(test.TestCase): + + def test_checks(self): + self.assertTrue(type_check.is_tensor(constant_op.constant([1, 2, 3]))) + self.assertTrue( + type_check.is_tensor(test_util.variables.Variable([1, 2, 3]))) + self.assertTrue( + type_check.is_tensor( + test_util.array_ops.placeholder(test_util.dtypes.float32))) + self.assertFalse(type_check.is_tensor(3)) + self.assertFalse(type_check.is_tensor(numpy.eye(3))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 3c5b34a0a6adb2f4e340a8e378c1eb51a2e2b534..aec9f47ccb20349c08bbe2fd813ee24a807f9fe3 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -13,6 +13,20 @@ py_library( deps = [], ) +py_test( + name = "common_test", + size = "small", + srcs = ["python/common_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":common", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + ], +) + py_library( name = "graph_matcher", srcs = [ @@ -75,11 +89,18 @@ py_library( ":graph_matcher", ":input_to_ops", "//tensorflow/contrib/graph_editor:graph_editor_py", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", + "//tensorflow/python:ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", ], ) @@ -88,7 +109,6 @@ py_test( srcs = ["python/fold_batch_norms_test.py"], srcs_version = "PY2AND3", deps = [ - ":copy_graph", ":fold_batch_norms", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", @@ -103,31 +123,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:session", - "//tensorflow/python:variables", - ], -) - -py_library( - name = "copy_graph", - srcs = ["python/copy_graph.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:framework_ops", "//tensorflow/python:training", - ], -) - -py_test( - name = "copy_graph_test", - size = "small", - srcs = ["python/copy_graph_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":copy_graph", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", "//tensorflow/python:variables", ], ) @@ -158,7 +154,6 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", "//tensorflow/python:session", "//tensorflow/python:variables", @@ -170,7 +165,7 @@ py_library( srcs = ["python/quantize.py"], srcs_version = "PY2AND3", deps = [ - ":common", + ":graph_matcher", ":input_to_ops", ":quant_ops", "//tensorflow/contrib/graph_editor:graph_editor_py", @@ -217,7 +212,6 @@ py_test( "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:training", ], ) @@ -229,12 +223,9 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":copy_graph", ":fold_batch_norms", ":quantize", - "//tensorflow/python:framework_ops", "//tensorflow/python:util", - "//tensorflow/python:variables", ], ) @@ -247,13 +238,11 @@ py_test( ":quantize_graph", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:init_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py index d0b0674c31239ee903f5ab7ef9ae0262bb20d189..3a1fa61e43986af1a1315d5a9e6f010e802ea157 100644 --- a/tensorflow/contrib/quantize/python/common.py +++ b/tensorflow/contrib/quantize/python/common.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Constants used across this package.""" +"""Common utilities used across this package.""" from __future__ import absolute_import from __future__ import division @@ -21,6 +21,12 @@ from __future__ import print_function import collections import re +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope + # Skip all operations that are backprop related or export summaries. SKIPPED_PREFIXES = ( 'gradients/', 'RMSProp/', 'Adagrad/', 'Const_', 'HistogramSummary', @@ -86,3 +92,31 @@ def _GetOperationByNameDontThrow(graph, name): return graph.get_operation_by_name(name) except KeyError: return None + + +def CreateOrGetQuantizationStep(): + """Returns a Tensor of the number of steps the quantized graph has run. + + Returns: + Quantization step Tensor. + """ + quantization_step_name = 'fake_quantization_step' + quantization_step_tensor_name = quantization_step_name + '/AssignAdd:0' + g = ops.get_default_graph() + try: + return g.get_tensor_by_name(quantization_step_tensor_name) + except KeyError: + # Create in proper graph and base name_scope. + with g.name_scope(None): + quantization_step_tensor = variable_scope.get_variable( + quantization_step_name, + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + with g.name_scope(quantization_step_tensor.op.name + '/'): + # We return the incremented variable tensor. Since this is used in conds + # for quant_delay and freeze_bn_delay, it will run once per graph + # execution. + return state_ops.assign_add(quantization_step_tensor, 1) diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d6237fe5e38d905bf262d7be3746b9ee6046da47 --- /dev/null +++ b/tensorflow/contrib/quantize/python/common_test.py @@ -0,0 +1,59 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for common utilities in this package.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.quantize.python import common +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +class CommonTest(test_util.TensorFlowTestCase): + + def testCreateOrGetQuantizationStep(self): + g = ops.Graph() + with session.Session(graph=g) as sess: + quantization_step_tensor = common.CreateOrGetQuantizationStep() + + # Check that operations are added to the graph. + num_nodes = len(g.get_operations()) + self.assertGreater(num_nodes, 0) + + # Check that getting the quantization step doesn't change the graph. + get_quantization_step_tensor = common.CreateOrGetQuantizationStep() + self.assertEqual(quantization_step_tensor, get_quantization_step_tensor) + self.assertEqual(num_nodes, len(g.get_operations())) + + # Ensure that running the graph increments the quantization step. + sess.run(variables.global_variables_initializer()) + step_val = sess.run(quantization_step_tensor) + self.assertEqual(step_val, 1) + + # Ensure that even running a graph that depends on the quantization step + # multiple times only executes it once. + a = quantization_step_tensor + 1 + b = a + quantization_step_tensor + _, step_val = sess.run([b, quantization_step_tensor]) + self.assertEqual(step_val, 2) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/copy_graph_test.py b/tensorflow/contrib/quantize/python/copy_graph_test.py deleted file mode 100644 index 7ff9ad9f8412d7076bf12d6cf10772244444013f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/quantize/python/copy_graph_test.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for copy_graph.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.quantize.python import copy_graph -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variables -from tensorflow.python.platform import googletest - - -class CopyGraphTest(test_util.TensorFlowTestCase): - - def _CompareNodeInGraph(self, node, graph): - graph_node = graph.get_operation_by_name(node.name) - self.assertEqual(str(node.node_def), str(graph_node.node_def)) - - def testCopyGraph(self): - graph = ops.Graph() - with graph.as_default(): - a = constant_op.constant(1.0) - b = variables.Variable(2.0) - c = a + b - graph_copy = copy_graph.CopyGraph(graph) - # Ensure that the three original nodes are in the new graph. - # import_meta_graph also adds a saver node to the graph which we don't care - # about in this specific use case. - for tensor in [a, b, c]: - self._CompareNodeInGraph(tensor.op, graph_copy) - # Test that the graph collections are the same. - for key in graph.get_all_collection_keys(): - self.assertEqual( - len(graph.get_collection(key)), - len(graph_copy.get_collection(key)), 'Collection %s differs.') - - -if __name__ == '__main__': - googletest.main() diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index aa605e6caadf4d1e69a4a331b1e580797e4fdef8..75d9eb0e58d96e4bb2946684febd250e2e1a6b4a 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -26,6 +26,7 @@ from tensorflow.contrib.quantize.python import input_to_ops from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -33,7 +34,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.util import compat -def FoldBatchNorms(graph): +def FoldBatchNorms(graph, is_training, freeze_batch_norm_delay=None): """Finds batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise @@ -41,15 +42,22 @@ def FoldBatchNorms(graph): Args: graph: Graph to walk and modify. - + is_training: Bool, true if training. + freeze_batch_norm_delay: How many steps to wait before freezing moving mean + and variance and using them for batch normalization. This value is used + only when is_training is True. Raises: ValueError: When batch norm folding fails. """ - _FoldFusedBatchNorms(graph) - _FoldUnfusedBatchNorms(graph) + _FoldFusedBatchNorms( + graph, is_training, freeze_batch_norm_delay=freeze_batch_norm_delay) + _FoldUnfusedBatchNorms( + graph, + is_training=is_training, + freeze_batch_norm_delay=freeze_batch_norm_delay) -def _FoldFusedBatchNorms(graph): +def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): """Finds fused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise @@ -57,6 +65,9 @@ def _FoldFusedBatchNorms(graph): Args: graph: Graph to walk and modify. + is_training: Bool, true if training. + freeze_batch_norm_delay: How many steps to wait before freezing moving mean + and variance and using them for batch normalization. Raises: ValueError: When batch norm folding fails. @@ -67,8 +78,7 @@ def _FoldFusedBatchNorms(graph): # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope # named `scope`. Otherwise, TF creates a unique scope whose name starts with # `scope`. - with graph.as_default(), graph.name_scope(scope + sep), ops.device( - match.bn_op.device): + with graph.as_default(), graph.name_scope(scope + sep): with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep): # new weights = old weights * gamma / sqrt(variance + epsilon) # new biases = -mean * gamma / sqrt(variance + epsilon) + beta @@ -79,9 +89,18 @@ def _FoldFusedBatchNorms(graph): match.mean_tensor * multiplier_tensor, name='bias') + correction_scale, correction_recip, correction_offset = None, None, None + if is_training: + correction_scale, correction_recip, correction_offset = ( + _ComputeBatchNormCorrections( + context='', + match=match, + freeze_batch_norm_delay=freeze_batch_norm_delay, + fused_batch_norm=True)) # The shape of depthwise weights is different, so we need to reshape the # multiplier_tensor to ensure that the scaled_weight_tensor has the # expected shape. + weights = match.weight_tensor if match.layer_op.type == 'DepthwiseConv2dNative': new_shape = [ match.weight_tensor.get_shape().as_list()[2], @@ -90,15 +109,25 @@ def _FoldFusedBatchNorms(graph): multiplier_tensor = array_ops.reshape( multiplier_tensor, new_shape, name='scale_reshape') - # TODO(suharshs): This naming of the following ops needs to carefully - # follow the naming expected by quantize.py. Generalize the quantize code - # to not require these delicate naming conventions. - scaled_weight_tensor = math_ops.multiply( - match.weight_tensor, multiplier_tensor, name='mul_fold') + if correction_scale is not None: + correction_scale = array_ops.reshape( + correction_scale, new_shape, name='correction_reshape') + + if correction_scale is not None: + weights = math_ops.multiply( + correction_scale, weights, name='correction_mult') + scaled_weight_tensor = math_ops.multiply( + weights, multiplier_tensor, name='mul_fold') new_layer_tensor = _CloneWithNewOperands( match.layer_op, match.input_tensor, scaled_weight_tensor) + if correction_recip is not None: + new_layer_tensor = math_ops.multiply( + correction_recip, new_layer_tensor, name='post_conv_mul') + new_layer_tensor = math_ops.add(new_layer_tensor, (correction_offset), + 'correction_add') + bias_add_tensor = math_ops.add( new_layer_tensor, bias_tensor, name='add_fold') @@ -109,46 +138,6 @@ def _FoldFusedBatchNorms(graph): 'Unexpected inputs to op: %s' % match.output_tensor.name) -def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): - """Clones layer_op with input_tensor and weight_tensor as new inputs.""" - new_layer_name = layer_op.name.split('/')[-1] + '_Fold' - if layer_op.type == 'Conv2D': - return nn_ops.conv2d( - input_tensor, - weight_tensor, - strides=layer_op.get_attr('strides'), - padding=layer_op.get_attr('padding'), - use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), - data_format=layer_op.get_attr('data_format'), - name=new_layer_name) - elif layer_op.type == 'MatMul': - return math_ops.matmul( - input_tensor, - weight_tensor, - transpose_a=layer_op.get_attr('transpose_a'), - transpose_b=layer_op.get_attr('transpose_b'), - name=new_layer_name) - elif layer_op.type == 'DepthwiseConv2dNative': - return nn.depthwise_conv2d( - input_tensor, - weight_tensor, - strides=layer_op.get_attr('strides'), - padding=layer_op.get_attr('padding'), - name=new_layer_name) - else: - raise ValueError('Cannot handle operation of type: %s' % layer_op.type) - - -@ops.RegisterGradient('FoldFusedBatchNormGrad') -def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1, - unused_2): - x = op.inputs[0] - n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements() - dmean_dx = grad_mean / n - dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1) - return (dmean_dx + dvar_dx), None, None, None, None - - def _FindFusedBatchNorms(graph): """Finds all ops and tensors related to found FusedBatchNorms. @@ -165,37 +154,59 @@ def _FindFusedBatchNorms(graph): mean_pattern = graph_matcher.OpTypePattern('*') variance_pattern = graph_matcher.OpTypePattern('*') - conv_pattern = graph_matcher.OpTypePattern( - 'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern]) + moving_average_pattern = graph_matcher.OpTypePattern('*') + bn_decay_pattern = graph_matcher.OpTypePattern('*') + layer_pattern = graph_matcher.OpTypePattern( + 'Conv2D|DepthwiseConv2dNative|MatMul', + inputs=[input_pattern, weight_pattern]) # MatMul has a Reshape between it and FusedBatchNorm. - matmul_pattern = graph_matcher.OpTypePattern( - 'MatMul', inputs=[input_pattern, weight_pattern]) matmul_reshape_pattern = graph_matcher.OpTypePattern( - 'Reshape', inputs=[matmul_pattern, + 'Reshape', inputs=[layer_pattern, graph_matcher.OpTypePattern('*')]) - conv_batch_norm_pattern = graph_matcher.OpTypePattern( + batch_norm_pattern = graph_matcher.OpTypePattern( 'FusedBatchNorm', inputs=[ - conv_pattern, gamma_pattern, beta_pattern, mean_pattern, - variance_pattern - ]) - matmul_batch_norm_pattern = graph_matcher.OpTypePattern( - 'FusedBatchNorm', - inputs=[ - matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern, - variance_pattern + graph_matcher.OneofPattern([matmul_reshape_pattern, layer_pattern]), + gamma_pattern, beta_pattern, mean_pattern, variance_pattern ]) matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( - 'Reshape', - inputs=[matmul_batch_norm_pattern, - graph_matcher.OpTypePattern('*')]) + 'Reshape', inputs=[batch_norm_pattern, + graph_matcher.OpTypePattern('*')]) + + bn_matcher = graph_matcher.GraphMatcher( + graph_matcher.OneofPattern( + [matmul_bn_output_reshape_pattern, batch_norm_pattern])) - conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern) - matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern) + moving_average_sub_pattern = graph_matcher.OpTypePattern( + 'Sub', inputs=[moving_average_pattern, batch_norm_pattern]) + moving_average_mul_pattern = graph_matcher.OpTypePattern( + 'Mul', inputs=[moving_average_sub_pattern, bn_decay_pattern]) + + moving_avg_mul_matcher = graph_matcher.GraphMatcher( + moving_average_mul_pattern) + + for match_result in bn_matcher.match_graph(graph): + moving_mean_tensor = None + moving_variance_tensor = None + bn_decay_mean_tensor = None + bn_decay_var_tensor = None + layer_op = match_result.get_op(layer_pattern) + layer_tensor = match_result.get_tensor(layer_pattern) + bn_op = match_result.get_op(batch_norm_pattern) + batch_epsilon_tensor = bn_op.get_attr('epsilon') + + # In the MatMul case, the output of batch norm is reshaped back into a + # 2D tensor, so the output_tensor is the output of the Reshape op. + output_tensor = bn_op.outputs[0] + if layer_op.type == 'MatMul': + output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) + # If the matcher didn't match matmul_bn_output_reshape, there will be + # another match for this 'MatMul' later, so we can skip this one. + if output_reshape_op is None: + continue + output_tensor = output_reshape_op.outputs[0] - def _GetCommonTensors(match_result, bn_op, bn_input_tensor): - """Gets tensors needed for FusedBatchNormMatch from match_result.""" input_tensor = match_result.get_tensor(input_pattern) weight_tensor = match_result.get_tensor(weight_pattern) gamma_tensor = match_result.get_tensor(gamma_pattern) @@ -222,48 +233,30 @@ def _FindFusedBatchNorms(graph): # calculation, the variance is corrected by the term N/N-1 (Bessel's # correction). The variance tensor read from FuseBatchNorm has bessel's # correction applied, so we undo it here. - n = math_ops.cast( - array_ops.size(bn_input_tensor) / array_ops.size(mean_tensor), - dtypes.float32) - variance_tensor = bn_op.outputs[2] * (n - 1) / n + scope, sep, _ = bn_op.name.rpartition('/') + g = ops.get_default_graph() + with g.as_default(), g.name_scope(scope + sep): + n = math_ops.cast( + array_ops.size(layer_tensor) / array_ops.size(mean_tensor), + dtypes.float32) + variance_tensor = math_ops.multiply( + bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction') + # TODO(suharshs): Find a way to get rid of this inner match. + for mul_match_result in moving_avg_mul_matcher.match_graph(graph): + sub_op = mul_match_result.get_op(moving_average_sub_pattern) + if sub_op.inputs[1].name == bn_op.outputs[1].name: + # During training: Batch Mean is bn_op.outputs[1] + moving_mean_tensor = sub_op.inputs[0] + bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern) + if sub_op.inputs[1].name == bn_op.outputs[2].name: + # During training: Batch Var is bn_op.outputs[2] + moving_variance_tensor = sub_op.inputs[0] + bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern) else: mean_tensor = match_result.get_tensor(mean_pattern) variance_tensor = match_result.get_tensor(variance_pattern) - return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor) - - for match_result in conv_matcher.match_graph(graph): - layer_op = match_result.get_op(conv_pattern) - layer_tensor = match_result.get_tensor(conv_pattern) - bn_op = match_result.get_op(conv_batch_norm_pattern) - # In the case of convolution the output_tensor is the output of bn_op. - output_tensor = bn_op.outputs[0] - - (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor) - yield _FusedBatchNormMatch( - layer_op=layer_op, - bn_op=bn_op, - output_tensor=output_tensor, - input_tensor=input_tensor, - weight_tensor=weight_tensor, - gamma_tensor=gamma_tensor, - beta_tensor=beta_tensor, - mean_tensor=mean_tensor, - variance_tensor=variance_tensor) - for match_result in matmul_matcher.match_graph(graph): - layer_op = match_result.get_op(matmul_pattern) - layer_tensor = match_result.get_tensor(matmul_pattern) - bn_op = match_result.get_op(matmul_batch_norm_pattern) - # In the MatMul case, the output of batch norm is reshaped back into a - # 2D tensor, so the output_tensor is the output of the Reshape op. - output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) - output_tensor = output_reshape_op.outputs[0] - - (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor) = _GetCommonTensors(match_result, bn_op, layer_tensor) - yield _FusedBatchNormMatch( + yield _BatchNormMatch( layer_op=layer_op, bn_op=bn_op, output_tensor=output_tensor, @@ -272,63 +265,156 @@ def _FindFusedBatchNorms(graph): gamma_tensor=gamma_tensor, beta_tensor=beta_tensor, mean_tensor=mean_tensor, - variance_tensor=variance_tensor) - - -class _FusedBatchNormMatch(object): - """Contains all information related to a found FusedBatchNorm.""" - - def __init__(self, layer_op, bn_op, output_tensor, input_tensor, - weight_tensor, gamma_tensor, beta_tensor, mean_tensor, - variance_tensor): - self._layer_op = layer_op - self._bn_op = bn_op - self._output_tensor = output_tensor - self._input_tensor = input_tensor - self._weight_tensor = weight_tensor - self._gamma_tensor = gamma_tensor - self._beta_tensor = beta_tensor - self._mean_tensor = mean_tensor - self._variance_tensor = variance_tensor + variance_tensor=variance_tensor, + moving_mean_tensor=moving_mean_tensor, + moving_variance_tensor=moving_variance_tensor, + bn_decay_mean_tensor=bn_decay_mean_tensor, + bn_decay_var_tensor=bn_decay_var_tensor, + batch_epsilon_tensor=batch_epsilon_tensor) + + +def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, + fused_batch_norm): + """Computes batch norm correction params. + + Before batch normalization is frozen: + We use batch statistics for batch norm. + correction_scale = sigma_b/sigma_mv + correction_recip = 1/correction_scale + correction_offset = 0 + + After batch normalization is frozen: + correction_scale = sigma_b/sigma_mv + correction_recip = 1 + correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). + + Batch norm is frozen if global_step > bn_freeze_delay. + The corrections ensure that: + a) The weights are quantized after scaling by gamma/sigma_mv. This enables + smoother training as the scaling on the weights changes slowly, rather than + jump across mini-batches + b) Changing the values of the corrections allows for one to switch between + using batch statistics to using moving mean and average, without requiring + changes to batch_norm - @property - def layer_op(self): - return self._layer_op - - @property - def bn_op(self): - return self._bn_op - @property - def output_tensor(self): - return self._output_tensor + Args: + context: The scope under which we look for batch norm params + match: Object containg required batch norm tensors for correction + computation. + freeze_batch_norm_delay: Delay in steps at which computation switches + from regular batch norm to frozen mean and variance. + fused_batch_norm: Bool, true if fused batch norm is used. - @property - def input_tensor(self): - return self._input_tensor + Returns: + A tuple of correction_scale, correction_recip, correction_offset + """ - @property - def weight_tensor(self): - return self._weight_tensor + g = ops.get_default_graph() + with g.name_scope(context + '/batch_norm_correction'): + recip_sigma_mv = math_ops.rsqrt( + match.moving_variance_tensor + match.batch_epsilon_tensor) + recip_sigma = math_ops.rsqrt( + match.variance_tensor + match.batch_epsilon_tensor) + correction_scale = math_ops.divide( + recip_sigma_mv, recip_sigma, name='scale_compute') + correction_scale = array_ops.identity( + correction_scale, name='correction_scale') + correction_recip = math_ops.reciprocal( + correction_scale, name='reciprocal_compute') + correction_offset = math_ops.multiply( + match.gamma_tensor, + match.mean_tensor * recip_sigma - + match.moving_mean_tensor * recip_sigma_mv, + name='offset_compute') + + if freeze_batch_norm_delay is not None: + use_mv_avg = math_ops.greater_equal( + common.CreateOrGetQuantizationStep(), + freeze_batch_norm_delay, + name='use_moving_average') + else: + use_mv_avg = False + + bn_decay_zero = 0.0 + bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) + bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) + + bn_decay_mean_out = utils.smart_cond( + use_mv_avg, + lambda: bn_decay_zero, + lambda: match.bn_decay_mean_tensor, + name='freeze_moving_mean') + graph_editor.reroute_ts( + [bn_decay_mean_out], [match.bn_decay_mean_tensor], + can_modify=bn_decay_mean_consumers) + + if fused_batch_norm is False: + bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) + bn_decay_var_out = utils.smart_cond( + use_mv_avg, + lambda: bn_decay_zero, + lambda: match.bn_decay_var_tensor, + name='freeze_moving_var') + graph_editor.reroute_ts( + [bn_decay_var_out], [match.bn_decay_var_tensor], + can_modify=bn_decay_var_consumers) + + correction_recip = utils.smart_cond( + use_mv_avg, + lambda: array_ops.ones(correction_scale.shape), + lambda: correction_recip, + name='correction_recip') + + correction_offset = utils.smart_cond( + use_mv_avg, + lambda: correction_offset, + lambda: array_ops.zeros(correction_offset.shape), + name='correction_offset') + return correction_scale, correction_recip, correction_offset - @property - def gamma_tensor(self): - return self._gamma_tensor - @property - def beta_tensor(self): - return self._beta_tensor +def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): + """Clones layer_op with input_tensor and weight_tensor as new inputs.""" + new_layer_name = layer_op.name.split('/')[-1] + '_Fold' + if layer_op.type == 'Conv2D': + return nn_ops.conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), + data_format=layer_op.get_attr('data_format'), + name=new_layer_name) + elif layer_op.type == 'MatMul': + return math_ops.matmul( + input_tensor, + weight_tensor, + transpose_a=layer_op.get_attr('transpose_a'), + transpose_b=layer_op.get_attr('transpose_b'), + name=new_layer_name) + elif layer_op.type == 'DepthwiseConv2dNative': + return nn.depthwise_conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + name=new_layer_name) + else: + raise ValueError('Cannot handle operation of type: %s' % layer_op.type) - @property - def mean_tensor(self): - return self._mean_tensor - @property - def variance_tensor(self): - return self._variance_tensor +@ops.RegisterGradient('FoldFusedBatchNormGrad') +def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1, + unused_2): + x = op.inputs[0] + n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements() + dmean_dx = grad_mean / n + dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1) + return (dmean_dx + dvar_dx), None, None, None, None -def _FoldUnfusedBatchNorms(graph): +def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): """Finds unfused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise @@ -336,6 +422,9 @@ def _FoldUnfusedBatchNorms(graph): Args: graph: Graph to walk and modify. + is_training: Bool, True if training. + freeze_batch_norm_delay: How many steps to wait before freezing moving mean + and variance and using them for batch normalization. Raises: ValueError: When batch norm folding fails. @@ -346,7 +435,12 @@ def _FoldUnfusedBatchNorms(graph): has_scaling = _HasScaling(graph, input_to_ops_map, bn) # The mangling code intimately depends on BatchNorm node's internals. - original_op, folded_op = _CreateFoldedOp(graph, bn, has_scaling=has_scaling) + original_op, folded_op = _CreateFoldedOp( + graph, + bn, + has_scaling=has_scaling, + freeze_batch_norm_delay=freeze_batch_norm_delay, + is_training=is_training) activation = common.GetEndpointActivationOp(graph, bn) if activation: @@ -368,46 +462,84 @@ def _FoldUnfusedBatchNorms(graph): raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) -def _HasScaling(graph, input_to_ops_map, bn): - r"""Checks if batch norm has scaling enabled. - - Difference between batch norm with scaling and without is that with scaling: - - Rsqrt -> mul -> mul_1 - \-> mul_2 - - where - mul multiplies gamma by inverse square root of EMA of batch variance, - mul_1 multiplies output of mul with output from the base operation - (convolution, FC or depthwise convolution), - mul_2 multiplies output of mul with EMA of batch mean, - and without scaling: - - Rsqrt -> mul - \-> mul_1 - - where - mul multiplies the inverse square root of EMA of batch variance with output - from the base operation, - mul_1 multiplies inverse square root of EMA of batch variance with EMA - of batch mean. +def _GetBatchNormParams(graph, context, has_scaling): + """Extracts relevant tensors for folding batch norms. Args: graph: Graph to inspect. - input_to_ops_map: InputToOps object containing mapping from tensor's name - to ops that take it as input. - bn: Batch norm layer prefix string. + context: The scope under which we look for batch norm params + has_scaling: Bool that specifies if scaling is done as part of batch norm. Returns: - A boolean indicating whether this batch norm layer has scaling enabled. + _BatchNormMatch containing all required batch norm parameters. """ - rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt') - rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) - - return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 - - -def _CreateFoldedOp(graph, context, has_scaling): + gamma_tensor = None + batch_mean_tensor = None + batch_variance_tensor = None + moving_mean_tensor = None + moving_variance_tensor = None + batch_epsilon_tensor = None + bn_decay_mean_tensor = None + bn_decay_var_tensor = None + + split_context = context.split('/') + base_context = split_context[-1] + + oplist = graph.get_operations() + op_suffix_gamma = base_context + '/BatchNorm/gamma' + op_suffix_mean = base_context + '/BatchNorm/moments/Squeeze' + op_suffix_variance = base_context + '/BatchNorm/moments/Squeeze_1' + op_suffix_moving_variance = base_context + '/BatchNorm/moving_variance/read' + op_suffix_moving_mean = base_context + '/BatchNorm/moving_mean/read' + op_suffix_epsilon = base_context + '/BatchNorm/batchnorm/add/y' + op_suffix_bn_decay_mean = base_context + '/BatchNorm/AssignMovingAvg/decay' + op_suffix_bn_decay_var = base_context + '/BatchNorm/AssignMovingAvg_1/decay' + + # Parse through list of ops to find relevant ops + for op in oplist: + if op.name.endswith(op_suffix_mean): + # This is an efficient way to check for two things: + # Is batch norm present and is it training mode? + # Batch statistics are computed only during batch norm in training + batch_mean_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_variance): + batch_variance_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_moving_mean): + moving_mean_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_moving_variance): + moving_variance_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_epsilon): + batch_epsilon_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_bn_decay_mean): + bn_decay_mean_tensor = graph.get_tensor_by_name(op.name + ':0') + if op.name.endswith(op_suffix_bn_decay_var): + bn_decay_var_tensor = graph.get_tensor_by_name(op.name + ':0') + if has_scaling: + if op.name.endswith(op_suffix_gamma): + gamma_tensor = graph.get_tensor_by_name(op.name + ':0') + + if not has_scaling: + gamma_tensor = array_ops.ones(batch_mean_tensor.shape) + + return _BatchNormMatch( + layer_op=None, + bn_op=None, + output_tensor=None, + input_tensor=None, + weight_tensor=None, + gamma_tensor=gamma_tensor, + beta_tensor=None, + mean_tensor=batch_mean_tensor, + variance_tensor=batch_variance_tensor, + moving_mean_tensor=moving_mean_tensor, + moving_variance_tensor=moving_variance_tensor, + bn_decay_mean_tensor=bn_decay_mean_tensor, + bn_decay_var_tensor=bn_decay_var_tensor, + batch_epsilon_tensor=batch_epsilon_tensor) + + +def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, + is_training): """Folds in batch norm layer into preceding convolution or FC layer. Creates 3 new nodes, connects their inputs and adds them to the graph: @@ -417,17 +549,20 @@ def _CreateFoldedOp(graph, context, has_scaling): Args: graph: Graph to modify. context: String, batch norm context, i.e. node into which BatchNorm is - nested. + nested. has_scaling: Whether the batch norm has scaling enabled. + freeze_batch_norm_delay: How many steps to wait before freezing moving mean + and variance and using them for batch normalization. + is_training: Bool, true if training. Raises: ValueError: When operation type is not supported, or input and output tensor - shapes mismatch for created operations: mul_fold, add_fold. + shapes mismatch for created operations: mul_fold, add_fold. Returns: A pair of Operations, the first is the original consumer node of the batch - norm (../BatchNorm/batchnorm/add_1), the second is the consumer node of - the folded graph (add_fold). + norm (../BatchNorm/batchnorm/add_1), the second is the consumer node of + the folded graph (add_fold). """ mul_scale_name = 'mul_1' if has_scaling else 'mul' mul_scale = graph.get_operation_by_name(context + @@ -435,19 +570,43 @@ def _CreateFoldedOp(graph, context, has_scaling): mul_scale_name) op_below = mul_scale.inputs[0].op weights = op_below.inputs[1] - + match = _GetBatchNormParams( + graph=graph, context=context, has_scaling=has_scaling) + correction_scale, correction_recip, correction_offset = None, None, None + if is_training: + correction_scale, correction_recip, correction_offset = ( + _ComputeBatchNormCorrections( + context=context, + match=match, + freeze_batch_norm_delay=freeze_batch_norm_delay, + fused_batch_norm=False)) # Special handling for weights of depthwise convolution. if op_below.type == 'DepthwiseConv2dNative': - new_shape = [weights.get_shape().as_list()[2], - weights.get_shape().as_list()[3]] + new_shape = [ + weights.get_shape().as_list()[2], + weights.get_shape().as_list()[3] + ] scale_name = 'mul' if has_scaling else 'Rsqrt' - scale = graph.get_operation_by_name(context + '/BatchNorm/batchnorm/' + - scale_name) + scale = graph.get_operation_by_name( + context + '/BatchNorm/batchnorm/' + scale_name) scale = array_ops.reshape(scale.outputs[0], new_shape, context + '/scale_reshape') - mul_fold = _CloneOp(mul_scale, context + '/mul_fold', - [(0, weights), (1, scale)]) + + if correction_scale is not None: + correction_scale = array_ops.reshape(correction_scale, new_shape, + context + '/correction_reshape') + with ops.device(mul_scale.device): + weights = math_ops.multiply(correction_scale, weights, + context + '/correction_mult') + + mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights), + (1, scale)]) elif op_below.type in ['Conv2D', 'MatMul']: + + if correction_scale is not None: + with ops.device(mul_scale.device): + weights = math_ops.multiply(correction_scale, weights, + context + '/correction_mult') mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)]) else: raise ValueError('Cannot handle operation of type: %s' % op_below.op) @@ -456,10 +615,17 @@ def _CreateFoldedOp(graph, context, has_scaling): conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold', [(1, mul_fold.outputs[0])]) - add_shift = graph.get_operation_by_name(context + - '/BatchNorm/batchnorm/add_1') - add_fold = _CloneOp(add_shift, context + '/add_fold', - [(0, conv_or_fc_folded.outputs[0])]) + add_shift = graph.get_operation_by_name( + context + '/BatchNorm/batchnorm/add_1') + + corrected_output = conv_or_fc_folded.outputs[0] + if correction_offset is not None: + with ops.device(conv_or_fc_folded.device): + corrected_output = math_ops.multiply(correction_recip, corrected_output, + context + '/post_conv_mul') + corrected_output = math_ops.add(corrected_output, (correction_offset), + context + '/correction_add') + add_fold = _CloneOp(add_shift, context + '/add_fold', [(0, corrected_output)]) _AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0]) return add_shift, add_fold @@ -471,7 +637,7 @@ def _CloneOp(op, new_name, new_inputs): op: Operation to modify. new_name: String, a new name to set on cloned op. new_inputs: A list of tuples (idx, tensor), each input with corresponding - index will be replaced by the given Tensor in the cloned op. + index will be replaced by the given Tensor in the cloned op. Returns: Operation, the cloned op. @@ -603,3 +769,121 @@ def _AssertShapesMatch(op_name, in_tensor, out_tensor): if not in_shape.is_compatible_with(out_shape): raise ValueError('%s should not change tensor shape: input %s, ' 'output %s' % (op_name, in_shape, out_shape)) + + +def _HasScaling(graph, input_to_ops_map, bn): + r"""Checks if batch norm has scaling enabled. + + Difference between batch norm with scaling and without is that with scaling: + + Rsqrt -> mul -> mul_1 + \-> mul_2 + + where + mul multiplies gamma by inverse square root of EMA of batch variance, + mul_1 multiplies output of mul with output from the base operation + (convolution, FC or depthwise convolution), + mul_2 multiplies output of mul with EMA of batch mean, + and without scaling: + + Rsqrt -> mul + \-> mul_1 + + where + mul multiplies the inverse square root of EMA of batch variance with output + from the base operation, + mul_1 multiplies inverse square root of EMA of batch variance with EMA + of batch mean. + + Args: + graph: Graph to inspect. + input_to_ops_map: InputToOps object containing mapping from tensor's name + to ops that take it as input. + bn: Batch norm layer prefix string. + + Returns: + A boolean indicating whether this batch norm layer has scaling enabled. + """ + rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt') + rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) + + return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 + + +class _BatchNormMatch(object): + """Contains all information related to a found Fused/UnfusedBatchNorm.""" + + def __init__(self, layer_op, bn_op, output_tensor, input_tensor, + weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor, moving_mean_tensor, moving_variance_tensor, + bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon_tensor): + self._layer_op = layer_op + self._bn_op = bn_op + self._output_tensor = output_tensor + self._input_tensor = input_tensor + self._weight_tensor = weight_tensor + self._gamma_tensor = gamma_tensor + self._beta_tensor = beta_tensor + self._mean_tensor = mean_tensor + self._variance_tensor = variance_tensor + self._moving_mean_tensor = moving_mean_tensor + self._moving_variance_tensor = moving_variance_tensor + self._bn_decay_mean_tensor = bn_decay_mean_tensor + self._bn_decay_var_tensor = bn_decay_var_tensor + self._batch_epsilon_tensor = batch_epsilon_tensor + + @property + def layer_op(self): + return self._layer_op + + @property + def bn_op(self): + return self._bn_op + + @property + def output_tensor(self): + return self._output_tensor + + @property + def input_tensor(self): + return self._input_tensor + + @property + def weight_tensor(self): + return self._weight_tensor + + @property + def gamma_tensor(self): + return self._gamma_tensor + + @property + def beta_tensor(self): + return self._beta_tensor + + @property + def mean_tensor(self): + return self._mean_tensor + + @property + def variance_tensor(self): + return self._variance_tensor + + @property + def moving_mean_tensor(self): + return self._moving_mean_tensor + + @property + def moving_variance_tensor(self): + return self._moving_variance_tensor + + @property + def batch_epsilon_tensor(self): + return self._batch_epsilon_tensor + + @property + def bn_decay_mean_tensor(self): + return self._bn_decay_mean_tensor + + @property + def bn_decay_var_tensor(self): + return self._bn_decay_var_tensor diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index ecf321ff573181c7a2e325770a8dde223bf0c021..c90a18ab0357f1bcbc5d8ccd48edf894d7baf5f9 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers -from tensorflow.contrib.quantize.python import copy_graph from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -34,6 +33,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.training import saver as saver_lib batch_norm = layers.batch_norm conv2d = layers.conv2d @@ -46,26 +46,27 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): def _RunTestOverParameters(self, test_fn): parameters_list = [ - # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm) - (nn_ops.relu6, 'Relu6', False, False, False), - (nn_ops.relu, 'Relu', False, False, False), - (nn_ops.relu6, 'Relu6', True, False, False), - (nn_ops.relu, 'Relu', True, False, False), - (nn_ops.relu6, 'Relu6', False, True, False), - (nn_ops.relu, 'Relu', False, True, False), - (nn_ops.relu6, 'Relu6', True, True, False), - (nn_ops.relu, 'Relu', True, True, False), + # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm, + # freeze_batch_norm_delay) + (nn_ops.relu6, 'Relu6', False, False, False, 100), + (nn_ops.relu, 'Relu', False, False, False, None), + (nn_ops.relu6, 'Relu6', True, False, False, 100), + (nn_ops.relu, 'Relu', True, False, False, None), + (nn_ops.relu6, 'Relu6', False, True, False, 100), + (nn_ops.relu, 'Relu', False, True, False, None), + (nn_ops.relu6, 'Relu6', True, True, False, 100), + (nn_ops.relu, 'Relu', True, True, False, None), # Fused batch norm always has scaling enabled. - (nn_ops.relu6, 'Relu6', False, True, True), - (nn_ops.relu, 'Relu', False, True, True), - (nn_ops.relu6, 'Relu6', True, True, True), - (nn_ops.relu, 'Relu', True, True, True), + (nn_ops.relu6, 'Relu6', False, True, True, None), + (nn_ops.relu, 'Relu', False, True, True, 100), + (nn_ops.relu6, 'Relu6', True, True, True, None), + (nn_ops.relu, 'Relu', True, True, True, 100), ] for params in parameters_list: - test_fn(params[0], params[1], params[2], params[3], params[4]) + test_fn(params[0], params[1], params[2], params[3], params[4], params[5]) def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling, - fused_batch_norm): + fused_batch_norm, freeze_batch_norm_delay): """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. Args: @@ -75,6 +76,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): inputs to just before Relu*. has_scaling: Bool, when true the batch norm has scaling. fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance """ g = ops.Graph() with g.as_default(): @@ -99,12 +102,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) - fold_batch_norms.FoldBatchNorms(g) + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') self._AssertInputOpsAre(folded_mul, [ - scope + '/weights/read', + scope + '/correction_mult', self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold']) @@ -113,12 +117,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertEqual(folded_conv.type, 'Conv2D') self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul']) folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/Conv2D_Fold', + scope + '/correction_add', self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] @@ -128,7 +132,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self._RunTestOverParameters(self._TestFoldConv2d) def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass, - has_scaling, fused_batch_norm): + has_scaling, fused_batch_norm, + freeze_batch_norm_delay): """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. Tests that folding works even with an input shape where some dimensions are @@ -141,6 +146,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): inputs to just before Relu*. has_scaling: Bool, when true the batch norm has scaling. fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance """ g = ops.Graph() with g.as_default(): @@ -164,12 +171,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) - fold_batch_norms.FoldBatchNorms(g) + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') self._AssertInputOpsAre(folded_mul, [ - scope + '/weights/read', + scope + '/correction_mult', self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold']) @@ -177,12 +185,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold') self.assertEqual(folded_conv.type, 'Conv2D') self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul']) folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/Conv2D_Fold', + scope + '/correction_add', self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] @@ -192,7 +200,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass, - has_scaling, fused_batch_norm): + has_scaling, fused_batch_norm, + freeze_batch_norm_delay): """Tests folding cases: inputs -> FC with batch norm -> Relu*. Args: @@ -202,6 +211,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): inputs to just before Relu*. has_scaling: Bool, when true the batch norm has scaling. fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance """ g = ops.Graph() with g.as_default(): @@ -223,12 +234,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) - fold_batch_norms.FoldBatchNorms(g) + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') self._AssertInputOpsAre(folded_mul, [ - scope + '/weights/read', + scope + '/correction_mult', self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) ]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) @@ -237,12 +249,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertEqual(folded_conv.type, 'MatMul') self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul']) folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/MatMul_Fold', + scope + '/correction_add', self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] @@ -252,7 +264,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass, - has_scaling, fused_batch_norm): + has_scaling, fused_batch_norm, + freeze_batch_norm_delay): """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. Args: @@ -262,6 +275,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): inputs to just before Relu*. has_scaling: Bool, when true the batch norm has scaling. fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance """ g = ops.Graph() with g.as_default(): @@ -286,7 +301,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): node = math_ops.add(inputs, node, name='test/Add') relu(node, name='test/' + relu_op_name) - fold_batch_norms.FoldBatchNorms(g) + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) folded_mul = g.get_operation_by_name(scope + '/mul_fold') self.assertEqual(folded_mul.type, 'Mul') @@ -295,8 +311,7 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): else: scale_reshape_op_name = scope + '/scale_reshape' self._AssertInputOpsAre(folded_mul, - [scope + '/depthwise_weights/read', - scale_reshape_op_name]) + [scope + '/correction_mult', scale_reshape_op_name]) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) scale_reshape = g.get_operation_by_name(scale_reshape_op_name) @@ -311,12 +326,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) - self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul']) folded_add = g.get_operation_by_name(scope + '/add_fold') self.assertEqual(folded_add.type, 'Add') self._AssertInputOpsAre(folded_add, [ - scope + '/depthwise_Fold', + scope + '/correction_add', self._BathNormBiasName(scope, fused_batch_norm) ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] @@ -326,7 +341,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) def _TestCompareFoldAndUnfolded(self, relu, relu_op_name, with_bypass, - has_scaling, fused_batch_norm): + has_scaling, fused_batch_norm, + freeze_batch_norm_delay): """Tests that running folded and unfolded BN returns the same results. Args: @@ -336,6 +352,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): inputs to just before Relu*. has_scaling: Bool, when true the batch norm has scaling. fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance """ random_seed.set_random_seed(1234) unfolded_g = ops.Graph() @@ -361,11 +379,12 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): if with_bypass: node = math_ops.add(inputs, node, name='test/Add') relu_node = relu(node, name='test/' + relu_op_name) - - folded_g = copy_graph.CopyGraph(unfolded_g) + folded_g = self._CopyGraph(unfolded_g) with folded_g.as_default(): - fold_batch_norms.FoldBatchNorms(folded_g) - + fold_batch_norms.FoldBatchNorms( + folded_g, + is_training=True, + freeze_batch_norm_delay=freeze_batch_norm_delay) with session.Session(graph=unfolded_g) as sess: sess.run(variables.global_variables_initializer()) grad_node = gradients.gradients(relu_node, inputs) @@ -443,5 +462,15 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): out_op = graph.get_operation_by_name(out_op_name) self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) + def _CopyGraph(self, graph): + """Return a copy of graph.""" + meta_graph = saver_lib.export_meta_graph( + graph=graph, collection_list=graph.get_all_collection_keys()) + graph_copy = ops.Graph() + with graph_copy.as_default(): + _ = saver_lib.import_meta_graph(meta_graph) + return graph_copy + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py index e3581cc55905a0af7d0464bc0ec673d3ed7f0363..b458f039df0523b5b8b07cff7d14643154124b95 100644 --- a/tensorflow/contrib/quantize/python/graph_matcher.py +++ b/tensorflow/contrib/quantize/python/graph_matcher.py @@ -18,8 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc -class OpTypePattern(object): + +class Pattern(object): + """The parent class of all patterns (e.g. OpTypePattern and OneofPattern).""" + + @abc.abstractmethod + def match(self, op, tensor): + """Returns the result of matching op/tensor against this pattern.""" + raise NotImplementedError('Method "match" not implemented.') + + +class OpTypePattern(Pattern): """A tree pattern that matches TF expressions with certain op types.""" def __init__(self, op_type, name=None, inputs=None): @@ -34,7 +45,7 @@ class OpTypePattern(object): similar TF op types. name: Optional string. The name of the pattern that can be looked up in MatchResult. - inputs: Optional list of `OpTypePattern`s or strings that specify the + inputs: Optional list of `Pattern`s or strings that specify the patterns for the inputs of a matching op. If None, this pattern accepts any inputs of a matching op. """ @@ -43,22 +54,51 @@ class OpTypePattern(object): if inputs is None: inputs = [] self._inputs = [ - input_pattern if isinstance(input_pattern, OpTypePattern) else - OpTypePattern(input_pattern) for input_pattern in inputs + input_pattern + if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern) + for input_pattern in inputs ] - @property - def op_type(self): - return self._op_type - - @property - def inputs(self): - return self._inputs - @property def name(self): return self._name + def match(self, op, tensor): + if self._op_type != '*': + if op.type not in self._op_type.split('|'): + return None + + match_result = MatchResult() + match_result.add(self, op, tensor) + + if not self._inputs: + # If pattern.inputs is empty, skips the rest and accepts all the inputs. + return match_result + + if len(op.inputs) != len(self._inputs): + return None + + for input_tensor, input_pattern in zip(op.inputs, self._inputs): + input_match_result = input_pattern.match(input_tensor.op, input_tensor) + if input_match_result is None: + return None + match_result.merge_from(input_match_result) + return match_result + + +class OneofPattern(Pattern): + """Matches one of the given sub-patterns.""" + + def __init__(self, sub_patterns): + self._sub_patterns = sub_patterns + + def match(self, op, tensor): + for sub_pattern in self._sub_patterns: + match_result = sub_pattern.match(op, tensor) + if match_result is not None: + return match_result + return None + class MatchResult(object): r"""Encapsulates the result of a match done by GraphMatcher. @@ -102,16 +142,36 @@ class MatchResult(object): return pattern_or_name if isinstance(pattern_or_name, str): + if pattern_or_name not in self._name_to_pattern: + return None return self._name_to_pattern[pattern_or_name] raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.' % type(pattern_or_name)) + def _get_op_tensor(self, pattern_or_name): + pattern = self._to_pattern(pattern_or_name) + if pattern is None: + return None + + if pattern not in self._pattern_to_op_tensor: + return None + + return self._pattern_to_op_tensor[pattern] + def get_op(self, pattern_or_name): - return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0] + op_tensor = self._get_op_tensor(pattern_or_name) + return op_tensor[0] if op_tensor else None def get_tensor(self, pattern_or_name): - return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1] + op_tensor = self._get_op_tensor(pattern_or_name) + return op_tensor[1] if op_tensor else None + + def merge_from(self, other_match_result): + # pylint: disable=protected-access + self._pattern_to_op_tensor.update(other_match_result._pattern_to_op_tensor) + self._name_to_pattern.update(other_match_result._name_to_pattern) + # pylint: enable=protected-access class GraphMatcher(object): @@ -121,7 +181,7 @@ class GraphMatcher(object): """Initializes a GraphMatcher. Args: - pattern: The `OpTypePattern` against which `GraphMatcher` matches + pattern: The `Pattern` against which `GraphMatcher` matches subgraphs. """ self._pattern = pattern @@ -133,7 +193,7 @@ class GraphMatcher(object): with key `pattern`. Args: - pattern: An `OpTypePattern`. + pattern: An `Pattern`. op: A `tf.Operation` to match against the pattern. tensor: the output `tf.Tensor` of `op` that is used by the matching op of `pattern`'s parent. Can be None if `pattern` is already the root of the @@ -142,20 +202,11 @@ class GraphMatcher(object): Returns: True if an TF expression rooted at `op` matches `pattern`. """ - if pattern.op_type != '*': - if op.type not in pattern.op_type.split('|'): - return False - - self._match_result.add(pattern, op, tensor) - - if not pattern.inputs: - # If pattern.inputs is empty, skips the rest and accepts all the inputs. - return True - - return len(op.inputs) == len(pattern.inputs) and all([ - self._match_pattern(input_pattern, input_tensor.op, input_tensor) - for input_tensor, input_pattern in zip(op.inputs, pattern.inputs) - ]) + match_result = pattern.match(op, tensor) + if match_result is None: + return False + self._match_result.merge_from(match_result) + return True def match_op(self, op): """Matches `op` against `self._pattern`. diff --git a/tensorflow/contrib/quantize/python/graph_matcher_test.py b/tensorflow/contrib/quantize/python/graph_matcher_test.py index e1572865e423e569ee3b280036c0e02b71b70648..6d587572181c125faa02d36fb54933cff24f11c6 100644 --- a/tensorflow/contrib/quantize/python/graph_matcher_test.py +++ b/tensorflow/contrib/quantize/python/graph_matcher_test.py @@ -105,7 +105,7 @@ class GraphMatcherTest(test_util.TensorFlowTestCase): self.assertEqual(match_result.get_op(y1_pattern), y1.op) self.assertEqual(match_result.get_tensor(y1_pattern), y1) - def test_oneof_pattern(self): + def test_oneof_type_pattern(self): # - + # / \ / \ # x y z @@ -125,6 +125,44 @@ class GraphMatcherTest(test_util.TensorFlowTestCase): for match_result in matcher.match_graph(g) ], [plus.op, minus.op]) + def test_oneof_pattern(self): + reshape_pattern = graph_matcher.OpTypePattern('Reshape') + transpose_pattern = graph_matcher.OneofPattern([ + graph_matcher.OpTypePattern( + 'Transpose', + name='transpose', + inputs=[ + graph_matcher.OpTypePattern( + 'Slice', name='slice', inputs=[reshape_pattern, '*', '*']), + '*' + ]), + graph_matcher.OpTypePattern( + 'Transpose', name='transpose', inputs=[reshape_pattern, '*']) + ]) + + matcher = graph_matcher.GraphMatcher(transpose_pattern) + + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=[6]) + reshape = array_ops.reshape(inputs, [2, 3]) + transpose = array_ops.transpose(reshape) + [match_result] = list(matcher.match_graph(g)) + self.assertEqual(match_result.get_tensor(reshape_pattern), reshape) + self.assertEqual(match_result.get_tensor('slice'), None) + self.assertEqual(match_result.get_op('transpose'), transpose.op) + + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=[6]) + reshape = array_ops.reshape(inputs, [2, 3]) + slicing = array_ops.slice(reshape, [0, 0], [-1, -1]) + transpose = array_ops.transpose(slicing) + [match_result] = list(matcher.match_graph(g)) + self.assertEqual(match_result.get_tensor(reshape_pattern), reshape) + self.assertEqual(match_result.get_tensor('slice'), slicing) + self.assertEqual(match_result.get_op('transpose'), transpose.op) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index f80d427ff0a6573ecd6562c443182797b5d22527..0a8e35080cb08f71dc28e33c6138a12656e5a5ea 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -53,7 +53,7 @@ def LastValueQuantize(inputs, init_max=6.0, updates_collection=ops.GraphKeys.UPDATE_OPS, vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - scope=None, + name_prefix='LastValueQuant', reuse=None, is_training=True, num_bits=8, @@ -73,7 +73,7 @@ def LastValueQuantize(inputs, computation. vars_collection: (Optional) collection where to store variables for quantization interval ends. - scope: Optional scope for variable_scope. + name_prefix: name_prefix for created nodes. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. is_training: Whether the op is applied to a training or eval graph. @@ -84,13 +84,13 @@ def LastValueQuantize(inputs, a tensor containing quantized values. """ with variable_scope.variable_scope( - scope, 'LastValueQuantize', values=[inputs], reuse=reuse): + None, default_name=name_prefix, values=[inputs], reuse=reuse): input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: # Only support quantizing 1-, 2- and 4-dimensional tensors. assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' - ' scope: %s' % (input_shape, scope)) + ' scope: %s' % (input_shape, name_prefix)) min_max_shape = [input_shape[-1]] else: min_max_shape = [] @@ -165,7 +165,7 @@ def MovingAvgQuantize(inputs, ema_decay=0.999, updates_collection=ops.GraphKeys.UPDATE_OPS, vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - scope=None, + name_prefix='MovingAvgQuantize', reuse=None, is_training=True, num_bits=8, @@ -186,7 +186,7 @@ def MovingAvgQuantize(inputs, computation. vars_collection: (Optional) collection where to store variables for quantization interval ends. - scope: Optional scope for variable_scope. + name_prefix: name_prefix for created nodes. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. is_training: Whether the op is applied to a training or eval graph. @@ -197,13 +197,13 @@ def MovingAvgQuantize(inputs, a tensor containing quantized values. """ with variable_scope.variable_scope( - scope, 'MovingAvgQuantize', values=[inputs], reuse=reuse): + None, default_name=name_prefix, values=[inputs], reuse=reuse): input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: # Only support quantizing 1-, 2- and 4-dimensional tensors. assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' - ' scope: %s' % (input_shape, scope)) + ' scope: %s' % (input_shape, name_prefix)) min_max_shape = [input_shape[-1]] else: min_max_shape = [] diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 50a2b4c91c9e7a2681f6041646a023a4225fb0c5..7a3f92f503a5d6f2b0fab2a499f8e8758809d0ed 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Logic to update a Tensorflow model graph with quantization operations.""" +"""Logic to update a TensorFlow model graph with quantization operations.""" from __future__ import absolute_import from __future__ import division @@ -21,37 +21,37 @@ from __future__ import print_function import re from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common +from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops from tensorflow.contrib.quantize.python import quant_ops from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.training import training_util -# Operation types used to select operations of interest. +# Quantizable operation types that are supported by the quantization rewrite. _QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'} -# Custom key for storing and retrieving update ops used by quantizing nodes. -_UPDATE_QUANT_OPS = 'update_quant_ops' +# Activations that are supported by the quantization rewrite. +_ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'} + +# Weight types that are supported by the quantization rewrite. +# TODO(suharshs): Add support for ResourceVariable. +_WEIGHT_TYPES = {'Variable', 'VariableV2'} def Quantize(graph, + is_training, weight_bits=8, - weight_narrow_range=False, activation_bits=8, ema_decay=0.999, quant_delay=None, - vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - is_training=True, - quantize_folded_weights_use_ema=False): + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES): """Updates graph with quantization operations. Args: graph: Graph to modify. + is_training: Whether quantizing training graph or eval graph. weight_bits: Number of bits to use for quantizing weights. - weight_narrow_range: Whether to use a more efficient narrow range for - weights quantization. With weight_narrow_range true, the range is - [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1]. activation_bits: Number of bits to use for quantizing activations. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update quantization intervals for quantizing activations (see here about EMA: @@ -61,346 +61,280 @@ def Quantize(graph, training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. - is_training: (Optional) Whether quantizing training graph or eval graph. - quantize_folded_weights_use_ema: (Optional, default False) Whether to - quantize weights after batchnorm-folding with exponential average - quantization. Raises: ValueError: When quantization fails. """ - context = _QuantizeContext(graph, weight_bits, weight_narrow_range, - activation_bits, ema_decay, quant_delay, - vars_collection, is_training, - quantize_folded_weights_use_ema) - - graph_ops = graph.get_operations() - - # Filter out backprop and summary related operations, leave only interesting - # op types. - def _IsInterestingOpWithWeights(op): - return (op.type in _QUANTIZABLE_TYPES and - not op.name.startswith(common.SKIPPED_PREFIXES)) - - for op in (op for op in graph_ops if _IsInterestingOpWithWeights(op)): - if op.name.endswith('/depthwise'): - # Separable convolution may consist of 2 convolution nodes. If so, skip - # .../depthwise and only quantize the top one. - separable_conv = context.GetOperationByNameDontThrow( - op.name[:-len('/depthwise')]) - if separable_conv and separable_conv.type == 'Conv2D': - continue - # Quantize add ops that come after Conv2D or DepthwiseConv2dNative. - if op.type in ['Conv2D', 'DepthwiseConv2dNative']: - add_context_re = re.search(r'^(.*)/[^/]+/', op.name) - if add_context_re is not None: - context.add_contexts.add(add_context_re.group(1)) - if not op.name.endswith('_Fold'): - folded_op = context.GetOperationByNameDontThrow(op.name + '_Fold') - # Do nothing if found, it will be quantized when it is iterated over. - if not folded_op: - context.QuantizeOpWithWeights(op, folded=False) - else: - context.QuantizeOpWithWeights(op, folded=True) - - context.QuantizeAddContexts() - - # Once all quantization ops have been inserted in the graph, collect update - # ops for their variables and modify the TF Slim update barrier (see - # https://www.tensorflow.org/code/tensorflow/contrib/slim/python/slim/learning.py) - # to depend on them. - try: - update_barrier = graph.get_operation_by_name('update_barrier') - except KeyError: - # In evaluation graph, this barrier may not exist. - return None - update_quant_ops = graph.get_collection_ref(_UPDATE_QUANT_OPS) - graph_editor.add_control_inputs(update_barrier, update_quant_ops) - - -class _QuantizeContext(object): - """Context holds references needed for quantization.""" - - def __init__(self, - graph, - weight_bits, - weight_narrow_range, - activation_bits, - ema_decay=0.999, - quant_delay=None, - vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - is_training=True, - quantize_folded_weights_use_ema=False): - """Initializes context to hold references needed for quantization. - - Args: - graph: Graph to modify. - weight_bits: Number of bits to use for quantizing weights. - weight_narrow_range: Whether to use a more efficient narrow range for - weights quantization. With weight_narrow_range true, the range is - [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1]. - activation_bits: Number of bits to use for quantizing activations. - ema_decay: (Optional) Float, EMA decay parameter. - quant_delay: (Optional, default None) Int, count of global steps for which - to delay quantization. This helps weights stabilize at the start of - training. - vars_collection: (Optional) Collection where to store the variables for - quantization interval ends. - is_training: (Optional) Whether quantizing training or eval graph. - quantize_folded_weights_use_ema: (Optional, default False) Whether to - quantize weights after batchnorm-folding with exponential average - quantization. - """ - self.graph = graph - self.weight_bits = weight_bits - self.weight_narrow_range = weight_narrow_range - self.activation_bits = activation_bits - self.ema_decay = ema_decay - self.quant_delay = quant_delay - self.vars_collection = vars_collection - self.is_training = is_training - self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema - self.input_to_ops_map = input_to_ops.InputToOps(graph) - self.add_contexts = set() - - def QuantizeAddContexts(self): - """Quantizes all add ops in self.add_contexts.""" - # Loop through sorted self.add_contexts so that op creation is - # deterministic. This is needed when using multiple worker replicas so that - # the ops can be initialized consistently. - for add_context in sorted(self.add_contexts): - add_op = self.GetOperationByNamesDontThrow([ - add_context + '/Add', add_context + '/add']) - if add_op is not None: - self._InsertQuantOp( - add_context, - add_op, - self.input_to_ops_map.ConsumerOperations(add_op), - name='add_quant', - moving_avg=True, - bits=self.activation_bits, - narrow_range=False) - - def QuantizeOpWithWeights(self, op, folded): - """Quantizes around the specific operation with or without batch norm. - - Args: - op: Operation to quantize. - folded: Operation has been folded and needs special handling if True. - Raises: - ValueError: When quantization fails. - """ - # Op name component before the last slash will be used as context. - context = re.search(r'^(.*)/([^/]+)', op.name).group(1) - - # Quantize weights. - if folded: - producer_op = self.graph.get_operation_by_name(context + '/mul_fold') - else: - try: - input_idx = next(i for i, v in enumerate(op.inputs) - if '/weights/' in v.name or - '/depthwise_weights' in v.name) - except StopIteration: - raise ValueError('No inputs to quantize for op: %s' % op) - producer_op = op.inputs[input_idx].op - - # If batch norm is used, the folded weights depend on the batch std, hence - # it is sensible to use EMA during training to smooth out the noise. This is - # controlled by the flag quantize_folded_weights_use_ema. Its default is - # False for backward compatibility. - # If there is no batch norm, weights do not depend on the batch and using - # the latest value of min and max is more efficient. - weight_use_ema = folded and self.quantize_folded_weights_use_ema - self._InsertQuantOp( - context, - producer_op, [op], - name='weights_quant', - moving_avg=weight_use_ema, - delay_requested=weight_use_ema, - bits=self.weight_bits, - narrow_range=self.weight_narrow_range) - - # Important: do not quantize biases here. During inference they are - # quantized to 32 bits, which is much finer than 8 bit quantization and - # depends on weight and input activation ranges. - - # Find activation and (optionally) Add operations to quantize. - activation_op, add_op, add_context = self._GetReluAndAddOperations(context, - op) - if add_op: - original_context = context - context = add_context - - # Quantize activation outputs. - consumer_ops = self.input_to_ops_map.ConsumerOperations(activation_op) - self._InsertQuantOp( + input_to_ops_map = input_to_ops.InputToOps(graph) + for layer_match in _FindLayersToQuantize(graph): + # Quantize the weights. + context = _GetContextFromOp(layer_match.layer_op) + _InsertQuantOp( context, - activation_op, + 'weights_quant', + layer_match.weight_tensor.op, [layer_match.layer_op], + is_training, + moving_avg=False, + ema_decay=ema_decay, + quant_delay=quant_delay, + narrow_range=True, + vars_collection=vars_collection, + bits=weight_bits) + + # Quantize the activations. + consumer_ops = input_to_ops_map.ConsumerOperations( + layer_match.activation_op) + add_context = context + if layer_match.bypass_op: + add_context = re.search(r'^(.*)/([^/]+)', context).group(1) + _InsertQuantOp( + add_context, + 'act_quant', + layer_match.activation_op, consumer_ops, - name='act_quant', + is_training, moving_avg=True, - init_min=0.0, - bits=self.activation_bits, - narrow_range=False) - - # When a bypass connection was found, also quantize Add op input. - if add_op: - def _QuantizeAddInput(add_input): - if folded: - return add_input.op.name.endswith('/add_fold') - else: - return add_input.op.name.startswith(original_context + '/') - - for add_input in add_op.inputs: - if _QuantizeAddInput(add_input): - self._InsertQuantOp( - original_context, - add_input.op, [add_op], - name='conv_quant', - moving_avg=True, - bits=self.activation_bits, - narrow_range=False) - - def _GetReluAndAddOperations(self, context, op): - """Looks up a Relu* and Add operations in given context. - - Args: - context: Context where to look for operations. - op: Operation to quantize. - - Returns: - A triplet (Operation, Operation, string), the first element is an end - point operation, the second is Add operation (optional), the third element - is string context where the Add operation was found (optional). - - Raises: - ValueError: When operations cannot be found. - """ - activation_op = common.GetEndpointActivationOp(self.graph, context) - if activation_op: - return activation_op, None, None - - if '/' in context: - # If no activation op is there, look for them one level up. - add_context = re.search(r'^(.*)/([^/]+)', context).group(1) - activation_op = common.GetEndpointActivationOp(self.graph, add_context) - if not activation_op: - # Still no Relu, can happen on the top layer, just find the next node up, - # make sure it is BiasAdd. - consumers = [c for outp in op.outputs for c in outp.consumers()] - if len(consumers) != 1 or consumers[0].type != 'BiasAdd': - raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type)) - return consumers[0], None, None - if add_context: - add_op = self.GetOperationByNamesDontThrow([ - add_context + '/Add', add_context + '/add']) - return activation_op, add_op, add_context - else: - raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type)) - - def GetOperationByNameDontThrow(self, name): - """Returns an Operation with the given name. - - Args: - name: Name of Operation to return. - - Returns: - The Operation with the given name. None if the name does not correspond to - any operation in the graph. - """ - try: - return self.graph.get_operation_by_name(name) - except KeyError: - return None - - def GetOperationByNamesDontThrow(self, names): - """Returns an Operation with one of the given names. - - Args: - names: Names of Operation to return. - - Returns: - The Operation with one of the given names. None if none of the names - corresponds to any operation in the graph. - """ - for name in names: - op = self.GetOperationByNameDontThrow(name) - if op is not None: - return op - return None - - def _InsertQuantOp( - self, - context, - producer, - consumers, - name, - moving_avg=True, - init_min=-6.0, - init_max=6.0, - delay_requested=True, - bits=8, - narrow_range=False,): - """Inserts a quant op between a producer op and (multiple) consumer ops. - - Args: - context: Context where producer and consumer operations are nested. - producer: Producer operation of the pairs where quantization will be - inserted. - consumers: Consumer operations of the pairs. - name: Name for the new quantization op within the context. - moving_avg: Specifies whether to use exponential moving average or just - the last value seen. - init_min: Starting minimum value for the new quantization op. - init_max: Starting maximum value for the new quantization op. - delay_requested: If true, implement quantization delay where needed. - False value explicitly disables delay quantization everywhere. - bits: Number of bits to use for quantization, must be between 2 and 8. - narrow_range: Whether to use the narrow quantization range - [1; 2^bits - 1] or wide range [0; 2^bits - 1]. - Raises: - ValueError: When producer operation is not directly connected to the - consumer operation. - """ - scope = context + '/' + name - inputs = producer.outputs[0] - if moving_avg: - quant = (quant_ops.MovingAvgQuantize( - inputs, - init_min=init_min, - init_max=init_max, - ema_decay=self.ema_decay, - is_training=self.is_training, - num_bits=bits, - narrow_range=narrow_range, - updates_collection=_UPDATE_QUANT_OPS, - vars_collection=self.vars_collection, - scope=scope)) - else: - quant = (quant_ops.LastValueQuantize( - inputs, - init_min=init_min, - init_max=init_max, - is_training=self.is_training, - num_bits=bits, - narrow_range=narrow_range, - updates_collection=_UPDATE_QUANT_OPS, - vars_collection=self.vars_collection, - scope=scope)) - - if delay_requested and self.quant_delay and self.quant_delay > 0: - activate_quant = math_ops.greater_equal( - training_util.get_or_create_global_step(), - self.quant_delay, - name=scope + '/activate_quant') - quant = control_flow_ops.cond( - activate_quant, - lambda: quant, - lambda: inputs, - name=scope + '/delayed_quant') - - nodes_modified_count = graph_editor.reroute_ts( - [quant], [inputs], can_modify=consumers) - if nodes_modified_count != len(consumers): - raise ValueError('Some inputs not quantized for ops: [%s]' % - ', '.join([consumer.name for consumer in consumers])) + ema_decay=ema_decay, + quant_delay=quant_delay, + vars_collection=vars_collection, + bits=activation_bits, + init_min=0.0) + + # Quantize the inputs and output to the bypass (if it exists). The input to + # the bypass is the bias add, and the output is the activation. + if layer_match.bypass_op is not None: + _InsertQuantOp( + context, + 'conv_quant', + layer_match.bias_add_op, [layer_match.bypass_op], + is_training, + moving_avg=True, + ema_decay=ema_decay, + quant_delay=quant_delay, + vars_collection=vars_collection, + bits=activation_bits) + _InsertQuantOp( + add_context, + 'add_quant', + layer_match.bypass_op, + input_to_ops_map.ConsumerOperations(layer_match.bypass_op), + is_training, + moving_avg=True, + ema_decay=ema_decay, + quant_delay=quant_delay, + vars_collection=vars_collection, + bits=activation_bits) + + +def _FindLayersToQuantize(graph): + """Matches layers in graph to quantize. + + Args: + graph: Graph to perform match on. + + Yields: + _LayerMatches. + """ + input_pattern = graph_matcher.OpTypePattern('*') + weight_var_pattern = graph_matcher.OpTypePattern('|'.join(_WEIGHT_TYPES)) + weight_pattern = graph_matcher.OpTypePattern( + 'Identity', inputs=[weight_var_pattern]) + + folded_weight_pattern = graph_matcher.OpTypePattern('Mul') + + # The weights inputs to the layer operation can either be from the Variable or + # the folded weight (Mul). + layer_pattern = graph_matcher.OpTypePattern( + '|'.join(_QUANTIZABLE_TYPES), + inputs=[ + input_pattern, + graph_matcher.OneofPattern([weight_pattern, folded_weight_pattern]) + ]) + + folded_bias_mul_pattern = graph_matcher.OpTypePattern( + 'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern]) + post_layer_op_correction_pattern = graph_matcher.OpTypePattern( + 'Add', inputs=[folded_bias_mul_pattern, + graph_matcher.OpTypePattern('*')]) + folded_bias_add_pattern = graph_matcher.OpTypePattern( + 'Add', + inputs=[ + post_layer_op_correction_pattern, + graph_matcher.OpTypePattern('*') + ]) + + bias_add_pattern = graph_matcher.OpTypePattern( + 'Add|BiasAdd', inputs=[layer_pattern, '*']) + + # The bias can come from the bias add or the folded bias add. + bypass_pattern_a = graph_matcher.OpTypePattern( + 'Add', + inputs=[ + graph_matcher.OneofPattern( + [bias_add_pattern, folded_bias_add_pattern]), '*' + ]) + bypass_pattern_b = graph_matcher.OpTypePattern( + 'Add', + inputs=[ + '*', + graph_matcher.OneofPattern( + [bias_add_pattern, folded_bias_add_pattern]) + ]) + + # The input to the activation can come from bias add, fold bias add or the + # bypasses. + activation_pattern = graph_matcher.OpTypePattern( + '|'.join(_ACTIVATION_TYPES), + inputs=[ + graph_matcher.OneofPattern([ + bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a, + bypass_pattern_b + ]) + ]) + + layer_matcher = graph_matcher.GraphMatcher(activation_pattern) + for match_result in layer_matcher.match_graph(graph): + layer_op = match_result.get_op(layer_pattern) + weight_tensor = match_result.get_tensor(weight_pattern) + if weight_tensor is None: + weight_tensor = match_result.get_tensor(folded_weight_pattern) + activation_op = match_result.get_op(activation_pattern) + bias_add_op = match_result.get_op(bias_add_pattern) + if bias_add_op is None: + bias_add_op = match_result.get_op(folded_bias_add_pattern) + bypass_op = match_result.get_op(bypass_pattern_a) + if bypass_op is None: + bypass_op = match_result.get_op(bypass_pattern_b) + yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, + bias_add_op) + + +class _LayerMatch(object): + """Contains all information related to a matched Layer.""" + + def __init__(self, layer_op, weight_tensor, activation_op, bypass_op, + bias_add_op): + self._layer_op = layer_op + self._weight_tensor = weight_tensor + self._activation_op = activation_op + self._bypass_op = bypass_op + self._bias_add_op = bias_add_op + + @property + def layer_op(self): + return self._layer_op + + @property + def weight_tensor(self): + return self._weight_tensor + + @property + def activation_op(self): + return self._activation_op + + @property + def bypass_op(self): + return self._bypass_op + + @property + def bias_add_op(self): + return self._bias_add_op + + +def _InsertQuantOp(context, + name, + producer, + consumers, + is_training, + moving_avg=True, + init_min=-6.0, + init_max=6.0, + bits=8, + ema_decay=0.999, + quant_delay=None, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + narrow_range=False): + """Inserts a quant op between a producer op and (multiple) consumer ops. + + Args: + context: Context w,here producer and consumer operations are nested. + name: Name for the new quantization op within the context. + producer: Producer operation of the pairs where quantization will be + inserted. + consumers: Consumer operations of the pairs. + is_training: Whether quantizing training graph or eval graph. + moving_avg: Specifies whether to use exponential moving average or just + the last value seen. + init_min: Starting minimum value for the new quantization op. + init_max: Starting maximum value for the new quantization op. + bits: Number of bits to use for quantization, must be between 2 and 8. + ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update + quantization intervals for quantizing activations (see here about EMA: + https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). + quant_delay: (Optional, default None) Int, count of global steps for which + to delay quantization. This helps weights stabilize at the start of + training. + vars_collection: (Optional) Collection where to store the variables for + quantization interval ends. + narrow_range: Whether to use the narrow quantization range + [1; 2^bits - 1] or wide range [0; 2^bits - 1]. + Raises: + ValueError: When producer operation is not directly connected to the + consumer operation. + """ + name_prefix = _AddContextToName(context, name) + inputs = producer.outputs[0] + if moving_avg: + quant = ( + quant_ops.MovingAvgQuantize( + inputs, + init_min=init_min, + init_max=init_max, + ema_decay=ema_decay, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) + else: + quant = ( + quant_ops.LastValueQuantize( + inputs, + init_min=init_min, + init_max=init_max, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) + + if quant_delay and quant_delay > 0: + activate_quant = math_ops.greater_equal( + common.CreateOrGetQuantizationStep(), + quant_delay, + name=name_prefix + '/activate_quant') + quant = control_flow_ops.cond( + activate_quant, + lambda: quant, + lambda: inputs, + name=name_prefix + '/delayed_quant') + + nodes_modified_count = graph_editor.reroute_ts( + [quant], [inputs], can_modify=consumers) + if nodes_modified_count != len(consumers): + raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join( + [consumer.name for consumer in consumers])) + + +def _GetContextFromOp(op): + """Gets the root context name from the op name.""" + context_re = re.search(r'^(.*)/([^/]+)', op.name) + if context_re: + return context_re.group(1) + return '' + + +def _AddContextToName(context, name): + """Adds the context to the name if it exists.""" + if not context: + return name + return context + '/' + name diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py index bbd9743d8014ce495a4967e7484981f7e60ae4a3..5a3a74cec4864ad3808d485849334c81f569d300 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph.py +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -18,177 +18,198 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.quantize.python import copy_graph from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.contrib.quantize.python import quantize from tensorflow.python.framework import ops -from tensorflow.python.ops import variables -def _create_graph(input_graph, - is_training, - elements=None, - device_name_or_function=None): - """Returns a transformed training input_graph for simulated quantization. +def _create_graph(input_graph=None, + is_training=True, + weight_bits=8, + activation_bits=8, + quant_delay=None, + freeze_bn_delay=None): + """Rewrites an input_graph in place for simulated quantization. - The forward pass has fake quantization ops inserted to simulate the error - introduced by quantization. + The graph has fake quantization ops inserted to simulate the error + introduced by quantization. Since the graph is transformed in place, + the expected behavior of previously held references to nodes and tensors may + change. Args: - input_graph: The tf.Graph to be transformed. + input_graph: The tf.Graph to be transformed, if None then defaults to the + default graph. is_training: Whether quantizing training or eval graph. - elements: (Optional) List of Tensors and Operations in input_graph whose - corresponding elements in the new graph will be returned. - device_name_or_function: (Optional) The device name or function to use. - - Returns: - g is new tf.Graph that is rewritten for simulated quantization. - l is a list of Tensors/Operations in g corresponding to the provided input - elements, if elements is not None. + weight_bits: Number of bits to use for quantizing weights. + activation_bits: Number of bits to use for quantizing activations. + quant_delay: Number of steps after which weights and activations are + quantized during training. + freeze_bn_delay: Number of steps after which moving mean and variance are + frozen and used instead of batch statistics during training. + freeze_bn_delay should be greater than quant_delay and should correspond + to the number of steps when training has almost converged Raises: ValueError: If elements contains an element that isn't a tf.Tensor or - tf.Operation. + tf.Operation. """ - # TODO(suharshs): Describe the process in more detail in the doc string. - g = copy_graph.CopyGraph(input_graph) - with g.as_default(): - with ops.device(device_name_or_function): - fold_batch_norms.FoldBatchNorms(g) - quantize.Quantize(g, is_training=is_training) - if elements is None: - return g - - return_elements = [] - for element in elements: - if isinstance(element, (ops.Tensor, variables.Variable)): - return_elements.append(g.get_tensor_by_name(element.name)) - elif isinstance(element, ops.Operation): - return_elements.append(g.get_operation_by_name(element.name)) - else: - raise ValueError( - 'elements must consist of Tensor or Operation objects, got: ', - str(element)) - return g, return_elements - - -def create_training_graph(input_graph, - elements=None, - device_name_or_function=None): - """Returns a transformed training input_graph for simulated quantization. - - The forward pass has fake quantization ops inserted to simulate the error - introduced by quantization. + + if input_graph is None: + input_graph = ops.get_default_graph() + with input_graph.as_default(): + fold_batch_norms.FoldBatchNorms( + input_graph, + freeze_batch_norm_delay=freeze_bn_delay, + is_training=is_training) + quantize.Quantize( + input_graph, + is_training, + quant_delay=quant_delay, + weight_bits=weight_bits, + activation_bits=activation_bits) + + +def create_training_graph(input_graph=None, quant_delay=0): + """Rewrites a training input_graph in place for simulated quantization. + + The graph has fake quantization ops inserted to simulate the error + introduced by quantization. Since the graph is transformed in place, + the expected behavior of previously held references to nodes and tensors may + change. + + The default value of quant_delay is suitable for finetuning an already trained + floating point model (recommended). + If one wants to train a quantized model from scratch, quant_delay should be + set to the number of steps it take the floating point model to converge. + Quantization will be activated at this point and effectively finetune the + model. If quant_delay is not provided when training from scratch, training can + often fail. Args: input_graph: The tf.Graph to be transformed. - elements: (Optional) List of Tensors and Operations in input_graph whose - corresponding elements in the new graph will be returned. - device_name_or_function: (Optional) The device name or function to use. - - Returns: - g is new tf.Graph that is rewritten for simulated quantization. - l is a list of Tensors/Operations in g corresponding to the provided input - elements, if elements is not None. + quant_delay: Number of steps after which weights and activations are + quantized during training. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or - tf.Operation. + tf.Operation. """ - return _create_graph( + # TODO(raghuramank) Need to have freeze_bn_delay be a function of batch size + # Currently the values below are hardcoded for mobilenetV1 on imagenet + # Please use the experimental API if you need to tune these values. + if quant_delay == 0: + # Corresponds to case of restoring from a floating point checkpoint + # In this case, we can freeze the moving mean and variance early on and + # switch to using them during training. Therefore, freeze_bn_delay is set to + # 2e5. + freeze_bn_delay = int(2e5) + else: + # If training from scratch, set freeze_bn_delay to 100 epochs after quant + # delay. With a batch size of 64, this corresponds to 20000*100=2M steps. + freeze_bn_delay = quant_delay + int(2e6) + + _create_graph( input_graph=input_graph, is_training=True, - elements=elements, - device_name_or_function=device_name_or_function) + quant_delay=quant_delay, + freeze_bn_delay=freeze_bn_delay) -def create_eval_graph(input_graph, elements=None, device_name_or_function=None): - """Returns a transformed eval input_graph for simulated quantization. +def create_eval_graph(input_graph=None): + """Rewrites an eval input_graph in place for simulated quantization. - The forward pass has fake quantization ops inserted to simulate the error - introduced by quantization. + The graph has fake quantization ops inserted to simulate the error + introduced by quantization. Since the graph is transformed in place, + the expected behavior of previously held references to nodes and tensors may + change. Args: - input_graph: The tf.Graph to be transformed. - elements: (Optional) List of Tensors and Operations in input_graph whose - corresponding elements in the new graph will be returned. - device_name_or_function: (Optional) The device name or function to use. - - Returns: - g is new tf.Graph that is rewritten for simulated quantization. - l is a list of Tensors/Operations in g corresponding to the provided input - elements, if elements is not None. + input_graph: The tf.Graph to be transformed, if None then defaults to the + default graph. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or - tf.Operation. + tf.Operation. """ - return _create_graph( - input_graph=input_graph, - is_training=False, - elements=elements, - device_name_or_function=device_name_or_function) + _create_graph(input_graph=input_graph, is_training=False) -def experimental_create_training_graph(input_graph, - elements=None, - device_name_or_function=None): - """Returns a transformed training input_graph for simulated quantization. +def experimental_create_training_graph(input_graph=None, + weight_bits=8, + activation_bits=8, + quant_delay=0, + freeze_bn_delay=int(2e5)): + """Rewrites a training input_graph in place for simulated quantization. This function has additional experimental options not (yet) available to create_training_graph. The resulting behavior may be undefined. - The forward pass has fake quantization ops inserted to simulate the error - introduced by quantization. - Args: - input_graph: The tf.Graph to be transformed. - elements: (Optional) List of Tensors and Operations in input_graph whose - corresponding elements in the new graph will be returned. - device_name_or_function: (Optional) The device name or function to use. + The graph has fake quantization ops inserted to simulate the error + introduced by quantization. Since the graph is transformed in place, + the expected behavior of previously held references to nodes and tensors may + change. - Returns: - g is new tf.Graph that is rewritten for simulated quantization. - l is a list of Tensors/Operations in g corresponding to the provided input - elements, if elements is not None. + The default value of quant_delay is suitable for finetuning an already trained + floating point model (recommended). + If one wants to train a quantized model from scratch, quant_delay should be + set to the number of steps it take the floating point model to converge. + Quantization will be activated at this point and effectively finetune the + model. If quant_delay is not provided when training from scratch, training can + often fail. + + Args: + input_graph: The tf.Graph to be transformed,if None then defaults to the + default graph. + weight_bits: Number of bits to use for quantizing weights. + activation_bits: Number of bits to use for quantizing activations. + quant_delay: Number of steps after which weights and activations are + quantized during training. + freeze_bn_delay: Number of steps after which moving mean and variance are + frozen and used instead of batch statistics during training. + freeze_bn_delay should be greater than quant_delay and should correspond + to when training has almost converged Raises: ValueError: If elements contains an element that isn't a tf.Tensor or tf.Operation. """ - return _create_graph( + + _create_graph( input_graph=input_graph, is_training=True, - elements=elements, - device_name_or_function=device_name_or_function) + weight_bits=weight_bits, + activation_bits=activation_bits, + quant_delay=quant_delay, + freeze_bn_delay=freeze_bn_delay) -def experimental_create_eval_graph(input_graph, - elements=None, - device_name_or_function=None): - """Returns a transformed eval input_graph for simulated quantization. +def experimental_create_eval_graph(input_graph=None, + weight_bits=8, + activation_bits=8): + """Rewrites an eval input_graph in place for simulated quantization. This function has additional experimental options not (yet) available to create_eval_graph. The resulting behavior may be undefined. - The forward pass has fake quantization ops inserted to simulate the error - introduced by quantization. + + The graph has fake quantization ops inserted to simulate the error + introduced by quantization. Since the graph is transformed in place, + the expected behavior of previously held references to nodes and tensors may + change. Args: - input_graph: The tf.Graph to be transformed. - elements: (Optional) List of Tensors and Operations in input_graph whose - corresponding elements in the new graph will be returned. - device_name_or_function: (Optional) The device name or function to use. + input_graph: The tf.Graph to be transformed, if None then defaults to the + default graph. + weight_bits: Number of bits to use for quantizing weights. + activation_bits: Number of bits to use for quantizing activations. + - Returns: - g is new tf.Graph that is rewritten for simulated quantization. - l is a list of Tensors/Operations in g corresponding to the provided input - elements, if elements is not None. Raises: ValueError: If elements contains an element that isn't a tf.Tensor or - tf.Operation. + tf.Operation. """ - return _create_graph( + _create_graph( input_graph=input_graph, is_training=False, - elements=elements, - device_name_or_function=device_name_or_function) + weight_bits=weight_bits, + activation_bits=activation_bits) diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index 514862a0ab5b796718a04aa65a46e7a7e3b86330..6b9289ef5f4b847172e1f093a1e4b5b2d3bdab57 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -20,13 +20,11 @@ from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import quantize_graph -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -34,7 +32,7 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): # We have a lot of other tests that test the details of the rewrite, here we # just the specific features of the quantize_graph API. - def _RunTestOverParameters(self, test_fn): + def _RunTestOverAllRewrites(self, test_fn): rewrite_fns = [ quantize_graph.create_training_graph, quantize_graph.create_eval_graph, @@ -44,85 +42,189 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): for fn in rewrite_fns: test_fn(fn) - def testReturnedElements(self): - self._RunTestOverParameters(self._TestReturnElements) + def _RunTestOverTrainingRewrites(self, test_fn): + rewrite_fns = [ + quantize_graph.create_training_graph, + quantize_graph.experimental_create_training_graph, + ] + for fn in rewrite_fns: + test_fn(fn) - def _TestReturnElements(self, fn): - graph = ops.Graph() - with graph.as_default(): - a = constant_op.constant(1.0) - b = variables.Variable(2.0) - c = a + b - elements = [a, b, c.op] - q_graph, returned_elements = fn(graph, elements=elements) - # Make sure q_graph is different from graph. - self.assertTrue(graph != q_graph) - # Check that the returned elements are part of the new graph. - for returned_element in returned_elements: - self.assertEqual(q_graph, returned_element.graph) - # Check that the elements match with the one from the input graph. - for element, returned_element in zip(elements, returned_elements): - self.assertEqual(element.name, returned_element.name) - - def testNoReturnElements(self): - self._RunTestOverParameters(self._TestNoReturnElements) - - def _TestNoReturnElements(self, fn): - graph = ops.Graph() - with graph.as_default(): - a = constant_op.constant(1.0) - b = variables.Variable(2.0) - _ = a + b - q_graph = fn(graph) - # Check that quantize_graph didn't return a tuple when elements isn't - # provided. - self.assertTrue(isinstance(q_graph, ops.Graph)) - # Make sure q_graph is different from graph. - self.assertTrue(graph != q_graph) - - def testDeviceName(self): - self._RunTestOverParameters(self._TestDeviceName) - - def _TestDeviceName(self, fn): + def _RunTestOverEvalRewrites(self, test_fn): + rewrite_fns = [ + quantize_graph.create_eval_graph, + quantize_graph.experimental_create_eval_graph, + ] + for fn in rewrite_fns: + test_fn(fn) + + def _RunTestOverExperimentalRewrites(self, test_fn): + rewrite_fns = [ + quantize_graph.experimental_create_training_graph, + quantize_graph.experimental_create_eval_graph, + ] + for fn in rewrite_fns: + test_fn(fn) + + def testRewrite(self): + self._RunTestOverAllRewrites(self._TestRewrite) + + def _TestRewrite(self, rewrite_fn): graph = ops.Graph() with graph.as_default(): - batch_size, height, width, depth = 5, 128, 128, 3 - inputs = array_ops.zeros((batch_size, height, width, depth)) - conv = layers.conv2d( - inputs, - 32, [5, 5], - stride=2, - padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - scope='test') - _ = nn_ops.relu6(conv) - - device_name = '/job:oink/task:0/device:CPU:0' - q_graph = fn(graph, device_name_or_function=device_name) + self._ConvLayer() orig_variable_names = set( [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) - q_variables = q_graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - # Ensure that variables were added. - self.assertTrue(len(orig_variable_names) < len(q_variables)) - # All added variables should have the specified device name. - for var in q_variables: - if var.name not in orig_variable_names: - self.assertEqual(var.device, device_name) - def _WeightInit(self, stddev): - """Returns truncated normal variable initializer. + rewrite_fn(graph) - Function is defined purely to shorten the name so that it stops wrapping. - - Args: - stddev: Standard deviation of normal variable. + q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # Ensure that variables were added. + self.assertTrue(len(orig_variable_names) < len(q_variables)) - Returns: - An initialized that initialzes with a truncated normal variable. - """ - return init_ops.truncated_normal_initializer(stddev=stddev) + def testDefaultGraph(self): + self._RunTestOverAllRewrites(self._TestRewrite) + + def _TestDefaultGraph(self, rewrite_fn): + # Tests that the default graph is correctly used when no args are provided + # to rewrite_fn. + with ops.Graph().as_default() as g: + self._ConvLayer() + orig_variable_names = set( + [v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + rewrite_fn() + + q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # Ensure that variables were added. + self.assertTrue(len(orig_variable_names) < len(q_variables)) + + def testQuantDelay(self): + self._RunTestOverTrainingRewrites(self._TestQuantDelay) + + def _TestQuantDelay(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + quant_delay = 100 + rewrite_fn(quant_delay=quant_delay) + + quant_delay_found = False + for op in g.get_operations(): + # Check to see if the quant_delay is correctly set. + if 'activate_quant' in op.name and op.type == 'Const': + quant_delay_found = True + const_value = str(op.get_attr('value')) + self.assertTrue(('int64_val: %i' % quant_delay) in const_value) + self.assertTrue(quant_delay_found) + + def testWeightBits(self): + self._RunTestOverExperimentalRewrites(self._TestWeightBits) + + def _TestWeightBits(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + weight_bits = 4 + rewrite_fn(weight_bits=weight_bits) + + weights_quant_found = False + for op in g.get_operations(): + # Check to see if FakeQuant operations for weights have the right bits + # set. + if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars': + weights_quant_found = True + self.assertEqual(op.get_attr('num_bits'), weight_bits) + self.assertTrue(weights_quant_found) + + def testActivationBits(self): + self._RunTestOverExperimentalRewrites(self._TestActivationBits) + + def _TestActivationBits(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + activation_bits = 4 + rewrite_fn(activation_bits=activation_bits) + + act_quant_found = False + for op in g.get_operations(): + # Check to see if FakeQuant operations for activations have the right bits + # set. + act_quant_names = ['act_quant', 'conv_quant', 'add_quant'] + if any(s in op.name + for s in act_quant_names) and op.type == 'FakeQuantWithMinMaxVars': + act_quant_found = True + self.assertEqual(op.get_attr('num_bits'), activation_bits) + self.assertTrue(act_quant_found) + + def testTrainingQuantization(self): + self._RunTestOverTrainingRewrites(self._TestTrainingQuantization) + + def _TestTrainingQuantization(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + rewrite_fn() + + # Ensure that FakeQuant and variable update nodes were found. + quant_found = False + assign_min_last_found = False + assign_min_ema_found = False + assign_max_last_found = False + assign_max_ema_found = False + for op in g.get_operations(): + # Check that FakeQuant operations were added. + if op.type == 'FakeQuantWithMinMaxVars': + quant_found = True + # Check that update operations for the added min max variables exist in + # the graph. + if 'AssignMinLast' in op.name: + assign_min_last_found = True + elif 'AssignMinEma' in op.name: + assign_min_ema_found = True + elif 'AssignMaxLast' in op.name: + assign_max_last_found = True + elif 'AssignMaxEma' in op.name: + assign_max_ema_found = True + self.assertTrue(assign_min_last_found) + self.assertTrue(assign_min_ema_found) + self.assertTrue(assign_max_last_found) + self.assertTrue(assign_max_ema_found) + self.assertTrue(quant_found) + + def testEvalQuantization(self): + self._RunTestOverEvalRewrites(self._TestEvalQuantization) + + def _TestEvalQuantization(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + rewrite_fn() + + # Ensure that FakeQuant and variable update nodes were found. + quant_found = False + for op in g.get_operations(): + # Check that FakeQuant operations were added. + if op.type == 'FakeQuantWithMinMaxVars': + quant_found = True + # Check that update operations for the added min max variables don't + # exist in the graph. + update_names = [ + 'AssignMinLast', 'AssignMinEma', 'AssignMaxLast', 'AssignMaxEma' + ] + self.assertFalse(any(s in op.name for s in update_names)) + self.assertTrue(quant_found) + + def _ConvLayer(self): + """Add a basic convolution layer to the default graph.""" + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + weight_init = init_ops.truncated_normal_initializer + conv = layers.conv2d( + inputs, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=weight_init(0.09), + activation_fn=None, + scope='test') + _ = nn_ops.relu6(conv) if __name__ == '__main__': diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index 57dab03f162629f84adf1d15521b05f4014c4a80..639a7454a92aebd7289c59498cebff82cc003f75 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -29,7 +29,6 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -from tensorflow.python.training import training batch_norm = layers.batch_norm conv2d = layers.conv2d @@ -73,8 +72,6 @@ class QuantizeTest(test_util.TensorFlowTestCase): """ graph = ops.Graph() with graph.as_default(): - training.create_global_step(graph) - batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 @@ -91,7 +88,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph, quant_delay=delay) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) @@ -101,7 +98,11 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + '/Conv2D' + if delay and delay > 0: + output_op_name = scope + '/weights_quant/delayed_quant/Switch_1' + else: + output_op_name = scope + '/Conv2D' + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -148,8 +149,6 @@ class QuantizeTest(test_util.TensorFlowTestCase): """ graph = ops.Graph() with graph.as_default(): - training.create_global_step(graph) - batch_size, depth = 5, 256 inputs = array_ops.zeros((batch_size, depth)) out_depth = 256 if with_bypass else 128 @@ -165,7 +164,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph, quant_delay=delay) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + @@ -176,7 +175,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + '/MatMul' + if delay and delay > 0: + output_op_name = scope + '/weights_quant/delayed_quant/Switch_1' + else: + output_op_name = scope + '/MatMul' self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -222,8 +224,6 @@ class QuantizeTest(test_util.TensorFlowTestCase): """ graph = ops.Graph() with graph.as_default(): - training.create_global_step(graph) - batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 @@ -240,7 +240,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph, quant_delay=delay) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + @@ -252,7 +252,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope + '/depthwise_weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + '/depthwise' + if delay and delay > 0: + output_op_name = scope + '/weights_quant/delayed_quant/Switch_1' + else: + output_op_name = scope + '/depthwise' self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -316,40 +319,11 @@ class QuantizeTest(test_util.TensorFlowTestCase): for params in parameters_list: test_fn(params[0], params[1], params[2], params[3], params[4]) - def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm): - """Tests quantization: inputs -> Conv2d with batch norm -> Activation. - - Args: - activation: Callable that returns an Operation, a factory method for the - Activation. - activation_op_name: String, name of the Activation operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Activation. - delay: Int (optional), delay in number of steps until quantization starts. - fused_batch_norm: Bool, when true use FusedBatchNorm. - """ - self._testQuantize_Conv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=True) - self._testQuantize_Conv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=False) - def testQuantize_Conv2dWithBatchNorm(self): self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) - def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm, - use_ema): + def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. Args: @@ -360,12 +334,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. - use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() with graph.as_default(): - training.create_global_step(graph) - batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 @@ -392,25 +363,21 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - fold_batch_norms.FoldBatchNorms(graph) + fold_batch_norms.FoldBatchNorms(graph, is_training=True) - quantize.Quantize( - graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('AssignMinEma' - if use_ema else 'AssignMinLast'), - scope + '/weights_quant/' + ('AssignMaxEma' - if use_ema else 'AssignMaxLast'), - scope + '/mul_fold' + scope + '/weights_quant/' + 'AssignMinLast', + scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if (delay and use_ema) else '/Conv2D_Fold') + if delay else '/Conv2D_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -438,40 +405,11 @@ class QuantizeTest(test_util.TensorFlowTestCase): if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) - def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm): - """Tests quantization: inputs -> FC with batch norm -> Activation. - - Args: - activation: Callable that returns an Operation, a factory method for the - Activation. - activation_op_name: String, name of the Activation operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Activation. - delay: Int (optional), delay in number of steps until quantization starts. - fused_batch_norm: Bool, when true use FusedBatchNorm. - """ - self._testQuantize_FCWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=True) - self._testQuantize_FCWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=False) - def testQuantize_FCWithBatchNorm(self): self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) - def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm, - use_ema): + def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: @@ -482,12 +420,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. - use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() with graph.as_default(): - training.create_global_step(graph) - batch_size, depth = 5, 256 inputs = array_ops.zeros((batch_size, depth)) out_depth = 256 if with_bypass else 128 @@ -511,25 +446,21 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - fold_batch_norms.FoldBatchNorms(graph) + fold_batch_norms.FoldBatchNorms(graph, is_training=True) - quantize.Quantize( - graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('AssignMinEma' - if use_ema else 'AssignMinLast'), - scope + '/weights_quant/' + ('AssignMaxEma' - if use_ema else 'AssignMaxLast'), - scope + '/mul_fold' + scope + '/weights_quant/' + 'AssignMinLast', + scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay and use_ema else '/MatMul_Fold') + if delay else '/MatMul_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -557,42 +488,13 @@ class QuantizeTest(test_util.TensorFlowTestCase): if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) - def _TestQuantize_DepthwiseConv2dWithBatchNorm( - self, activation, activation_op_name, with_bypass, delay, - fused_batch_norm): - """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. - - Args: - activation: Callable that returns an Operation, a factory method for the - Activation. - activation_op_name: String, name of the Activation operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Activation. - delay: Int (optional), delay in number of steps until quantization starts. - fused_batch_norm: Bool, when true use FusedBatchNorm. - """ - self._testQuantize_DepthwiseConv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=True) - self._testQuantize_DepthwiseConv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=False) - def testQuantize_DepthwiseConv2dWithBatchNorm(self): self._RunBatchNormTestOverParameters( self._TestQuantize_DepthwiseConv2dWithBatchNorm) - def _testQuantize_DepthwiseConv2dWithBatchNorm( + def _TestQuantize_DepthwiseConv2dWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, - fused_batch_norm, use_ema): + fused_batch_norm): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: @@ -603,12 +505,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. - use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() with graph.as_default(): - training.create_global_step(graph) - batch_size, height, width, depth = 5, 128, 128, 3 inputs = array_ops.zeros((batch_size, height, width, depth)) stride = 1 if with_bypass else 2 @@ -635,24 +534,20 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - fold_batch_norms.FoldBatchNorms(graph) + fold_batch_norms.FoldBatchNorms(graph, is_training=True) - quantize.Quantize( - graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('AssignMinEma' - if use_ema else 'AssignMinLast'), - scope + '/weights_quant/' + ('AssignMaxEma' - if use_ema else 'AssignMaxLast'), - scope + '/mul_fold' + scope + '/weights_quant/' + 'AssignMinLast', + scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay and use_ema else '/depthwise_Fold') + if delay else '/depthwise_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 1e4dd7cf67dbfbd16386fd740c7dcc83e05ad82a..bb7be0809421b64a019e73f00aac6c58524222e8 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -35,7 +35,15 @@ separable_conv2d = layers.separable_conv2d class QuantizeTest(test_util.TensorFlowTestCase): + def _RunTestOverParameters(self, test_fn): + params = [True, False] + for is_training in params: + test_fn(is_training) + def testInsertQuantOpFailsWhenOpsNotConnected(self): + pass + + def _TestInsertQuantOpFailsWhenOpsNotConnected(self, is_training): graph = ops.Graph() with graph.as_default(): batch_size, height, width, depth = 5, 128, 128, 3 @@ -45,17 +53,18 @@ class QuantizeTest(test_util.TensorFlowTestCase): activation_fn=None, scope='test') relu = nn_ops.relu6(inputs) - context = quantize._QuantizeContext(graph=graph, weight_bits=8, - weight_narrow_range=True, - activation_bits=8) # Inserting a quantization op between two unconnected ops should fail with # ValueError. with self.assertRaises(ValueError) as err: - context._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp') + quantize._InsertQuantOp('test', is_training, conv.op, [relu.op], + 'FailingQuantOp') self.assertEqual( str(err.exception), 'Some inputs not quantized for ops: [Relu6]') def testInsertQuantOpForAddAfterConv2d(self): + self._RunTestOverParameters(self._TestInsertQuantOpForAddAfterConv2d) + + def _TestInsertQuantOpForAddAfterConv2d(self, is_training): graph = ops.Graph() with graph.as_default(): batch_size, height, width, depth = 5, 128, 128, 3 @@ -70,8 +79,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True, - activation_bits=8) + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) quantization_node_name = 'FakeQuantWithMinMaxVars' add_quant = graph.get_operation_by_name('test/add_quant/' + @@ -79,6 +87,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertEqual(add_quant.type, quantization_node_name) def testInsertQuantOpForAddAfterSeparableConv2d(self): + self._RunTestOverParameters( + self._TestInsertQuantOpForAddAfterSeparableConv2d) + + def _TestInsertQuantOpForAddAfterSeparableConv2d(self, is_training): graph = ops.Graph() with graph.as_default(): batch_size, height, width, depth = 5, 128, 128, 3 @@ -94,8 +106,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True, - activation_bits=8) + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) quantization_node_name = 'FakeQuantWithMinMaxVars' add_quant = graph.get_operation_by_name('test/add_quant/' + diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py index 44998b3b6591221fde55d8d2d406d5141b1647f2..bc383a803496380aaba4d0248d2b7f93253b2b50 100644 --- a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py +++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py @@ -35,20 +35,34 @@ _VALID_PADDING = ["VALID", b"VALID"] _SAME_PADDING = ["SAME", b"SAME"] -def _stride_size(node): +def _stride_size(node, name_to_node): """Computes stride size given a TF node. Args: node: Tensorflow node (NodeDef proto). + name_to_node: For MaxPoolV2, mapping from variable name Tensorflow node. Returns: stride_x: Stride size for horizontal direction (integer). stride_y: Stride size for vertical direction (integer). + + Raises: + ValueError: If stride input cannot be found in `name_to_node`. """ - strides_attr = node.attr["strides"] - logging.vlog(4, "strides_attr = %s", strides_attr) - stride_y = strides_attr.list.i[1] - stride_x = strides_attr.list.i[2] + if node.op == "MaxPoolV2": + strides_input_name = node.input[2] + if not strides_input_name.endswith("/strides"): + raise ValueError("Strides name does not end with '/strides'") + strides_node = name_to_node[strides_input_name] + value = strides_node.attr["value"] + t = make_ndarray(value.tensor) + stride_y = t[1] + stride_x = t[2] + else: + strides_attr = node.attr["strides"] + logging.vlog(4, "strides_attr = %s", strides_attr) + stride_y = strides_attr.list.i[1] + stride_x = strides_attr.list.i[2] return stride_x, stride_y @@ -144,11 +158,12 @@ def _padding_size_conv_pool(node, kernel_size, stride, input_resolution=None): return total_padding, padding -def _pool_kernel_size(node): +def _pool_kernel_size(node, name_to_node): """Computes kernel size given a TF pooling node. Args: node: Tensorflow node (NodeDef proto). + name_to_node: For MaxPoolV2, mapping from node name to NodeDef. Returns: kernel_size_x: Kernel size for horizontal direction (integer). @@ -157,13 +172,27 @@ def _pool_kernel_size(node): Raises: ValueError: If pooling is invalid. """ - ksize = node.attr["ksize"] - kernel_size_y = ksize.list.i[1] - kernel_size_x = ksize.list.i[2] - if ksize.list.i[0] != 1: - raise ValueError("pool ksize for first dim is not 1") - if ksize.list.i[3] != 1: - raise ValueError("pool ksize for last dim is not 1") + if node.op == "MaxPoolV2": + ksize_input_name = node.input[1] + if not ksize_input_name.endswith("/ksize"): + raise ValueError("Kernel size name does not end with '/ksize'") + ksize_node = name_to_node[ksize_input_name] + value = ksize_node.attr["value"] + t = make_ndarray(value.tensor) + kernel_size_y = t[1] + kernel_size_x = t[2] + if t[0] != 1: + raise ValueError("pool ksize for first dim is not 1") + if t[3] != 1: + raise ValueError("pool ksize for last dim is not 1") + else: + ksize = node.attr["ksize"] + kernel_size_y = ksize.list.i[1] + kernel_size_x = ksize.list.i[2] + if ksize.list.i[0] != 1: + raise ValueError("pool ksize for first dim is not 1") + if ksize.list.i[3] != 1: + raise ValueError("pool ksize for last dim is not 1") return kernel_size_x, kernel_size_y @@ -243,7 +272,7 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): logging.vlog(3, "node.op = %s", node.op) logging.vlog(4, "node = %s", node) if node.op == "Conv2D" or node.op == "DepthwiseConv2dNative": - stride_x, stride_y = _stride_size(node) + stride_x, stride_y = _stride_size(node, name_to_node) kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_node) # Compute the padding for this node separately for each direction. total_padding_x, padding_x = _padding_size_conv_pool( @@ -260,9 +289,9 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): stride_y = 1 total_padding_x, padding_x, total_padding_y, padding_y = ( _padding_size_pad_layer(node, name_to_node)) - elif node.op == "MaxPool" or node.op == "AvgPool": - stride_x, stride_y = _stride_size(node) - kernel_size_x, kernel_size_y = _pool_kernel_size(node) + elif node.op == "MaxPool" or node.op == "MaxPoolV2" or node.op == "AvgPool": + stride_x, stride_y = _stride_size(node, name_to_node) + kernel_size_x, kernel_size_y = _pool_kernel_size(node, name_to_node) # Compute the padding for this node separately for each direction. total_padding_x, padding_x = _padding_size_conv_pool( node, kernel_size_x, stride_x, input_resolution[1] diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc index c33804906fc21cf2573b79091a76ab1ea86f5966..2def4f3f176b8d4d26c2c94168e9698f14649d94 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #define EIGEN_USE_THREADS -#include #include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h" +#include #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h index 9bb1724a2c0b70ee7ce7238cc179aded95935b26..d8c0a0631d38e55ef9653e0e88e90604ec0f0329 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_ #define TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_ +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #define Sum(a, b) ((a) + (b)) #define Prod(a, b) ((a) * (b)) @@ -58,11 +58,11 @@ inline T negative_infinity() { } // namespace reduce_functions -#define CALL_ALL_REDUCEOPS(func, ...) \ - func(Sum, functor::reduce_functions::zero, ##__VA_ARGS__) \ - func(Prod, functor::reduce_functions::one, ##__VA_ARGS__) \ - func(Max, functor::reduce_functions::negative_infinity, ##__VA_ARGS__) \ - func(Min, functor::reduce_functions::infinity, ##__VA_ARGS__) +#define CALL_ALL_REDUCEOPS(func, ...) \ + func(Sum, functor::reduce_functions::zero, ##__VA_ARGS__) \ + func(Prod, functor::reduce_functions::one, ##__VA_ARGS__) func( \ + Max, functor::reduce_functions::negative_infinity, ##__VA_ARGS__) \ + func(Min, functor::reduce_functions::infinity, ##__VA_ARGS__) #define ReduceSliceFunctorReduceop(reduceop, dummy) \ template \ diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc index 8e6870fadd428ae8a1937a5c0cb43b6763f6be28..9f2be03d718364058da6b63add8752c046798c5b 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc @@ -17,10 +17,10 @@ limitations under the License. #define EIGEN_USE_GPU +#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h" #include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { @@ -34,9 +34,9 @@ namespace functor { __global__ void ReduceSliceDeviceKernel##reduceop( \ Cuda3DLaunchConfig config, Index indices_width, Index bound, \ const T begin, const Index *indices, const T *input, T *out) { \ - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { \ - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { \ - CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { \ + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \ + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \ + CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { \ Index outidx = x * config.virtual_thread_count.y * \ config.virtual_thread_count.z + \ y * config.virtual_thread_count.z + z; \ @@ -68,8 +68,9 @@ namespace functor { if (sizex * sizey * sizez == 0) { \ return; \ } \ - Cuda3DLaunchConfig config = GetCuda3DLaunchConfig(sizex, sizey, sizez, d,\ - ReduceSliceDeviceKernel##reduceop, 0, 0); \ + Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \ + sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop, \ + 0, 0); \ \ ReduceSliceDeviceKernel##reduceop \ <<>>( \ diff --git a/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc index b8b56c0e229563a4e9bc930512c9fe49bd636e31..92879ab5356623dfa82fce8dff8db4d3036ae46c 100644 --- a/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc +++ b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc @@ -87,9 +87,9 @@ and 'indices' is [[0,1] [1,1] [0,2]], -the the output will be [[ 1, 2, 3] - [ 0, 0, 0] - [41,52,63]]. +the output will be [[ 1, 2, 3] + [ 0, 0, 0] + [41,52,63]]. ``` The data must be at least rank 1. The indices must be of shape (?,2) where the @@ -132,9 +132,9 @@ and 'indices' is [[0,1] [1,1] [0,2]], -the the output will be [[ 1, 2, 3] - [ 1, 1, 1] - [40,100,180]]. +the output will be [[ 1, 2, 3] + [ 1, 1, 1] + [40,100,180]]. ``` The data must be at least rank 1. The indices can be of shape (?,2) where the @@ -189,9 +189,9 @@ and 'indices' is [[0,1] [1,1] [0,2]], -the the output will be [[ 1, 20, 3] - [ -BIG_VALUE, -BIG_VALUE, -BIG_VALUE] - [ 400, 20, 60]]. +the output will be [[ 1, 20, 3] + [ -BIG_VALUE, -BIG_VALUE, -BIG_VALUE] + [ 400, 20, 60]]. ``` The data must be at least rank 1. The indices can be of shape (?,2) where the @@ -246,9 +246,9 @@ and 'indices' is [[0,1] [1,1] [0,2]], -the the output will be [[ 1, 20, 3] - [ +BIG_VALUE, +BIG_VALUE, +BIG_VALUE] - [ 1, 5, 3]]. +the output will be [[ 1, 20, 3] + [ +BIG_VALUE, +BIG_VALUE, +BIG_VALUE] + [ 1, 5, 3]]. ``` The data must be at least rank 1. The indices can be of shape (?,2) where the diff --git a/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py b/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py index 60a193db4c7f084d3262a69e2b8c5df66273e138..468886da20021646089bd1d222da1ebd4b5c7822 100644 --- a/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py +++ b/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import numpy as np -import unittest from tensorflow.contrib.reduce_slice_ops.python.ops import reduce_slice_ops from tensorflow.python.framework.test_util import TensorFlowTestCase diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.cc b/tensorflow/contrib/resampler/kernels/resampler_ops.cc index e02c1b6a2bd9daf9e1f81059f7c1f92106cebc8f..63c72836d793a3df4e96a0134f3a1534c288c8c8 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops.cc +++ b/tensorflow/contrib/resampler/kernels/resampler_ops.cc @@ -36,17 +36,12 @@ using GPUDevice = Eigen::GpuDevice; namespace functor { template -struct Resampler2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const CPUDevice& d, - const T* __restrict__ data, - const T* __restrict__ warp, - T* __restrict__ output, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points){ +struct Resampler2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const CPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { const int warp_batch_stride = num_sampling_points * 2; const int data_batch_stride = data_height * data_width * data_channels; const int output_batch_stride = num_sampling_points * data_channels; @@ -59,24 +54,19 @@ struct Resampler2DFunctor{ // The functions take care of performing the relevant pointer // arithmetics abstracting away the low level details in the // main loop over samples. Note that data is stored in NHWC format. - auto set_output = [&](const int sample_id, - const int channel, + auto set_output = [&](const int sample_id, const int channel, const T value) { - output[batch_id * output_batch_stride + - sample_id * data_channels + + output[batch_id * output_batch_stride + sample_id * data_channels + channel] = value; }; - auto get_data_point = [&](const int x, - const int y, - const int chan) { + auto get_data_point = [&](const int x, const int y, const int chan) { const bool point_is_in_range = (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); return point_is_in_range - ? data[batch_id * data_batch_stride + - data_channels * (y * data_width + x) + - chan] - : zero; + ? data[batch_id * data_batch_stride + + data_channels * (y * data_width + x) + chan] + : zero; }; for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) { @@ -89,8 +79,7 @@ struct Resampler2DFunctor{ // The effect is that the sampled signal smoothly goes to 0 outside // the original input domain, rather than presenting a jump // discontinuity at the image boundaries. - if (x > static_cast(-1.0) && - y > static_cast(-1.0) && + if (x > static_cast(-1.0) && y > static_cast(-1.0) && x < static_cast(data_width) && y < static_cast(data_height)) { // Precompute floor (f) and ceil (c) values for x and y. @@ -103,12 +92,10 @@ struct Resampler2DFunctor{ for (int chan = 0; chan < data_channels; ++chan) { const T img_fxfy = dx * dy * get_data_point(fx, fy, chan); - const T img_cxcy = (one - dx) * (one - dy) * - get_data_point(cx, cy, chan); - const T img_fxcy = dx * (one - dy) * - get_data_point(fx, cy, chan); - const T img_cxfy = (one - dx) * dy * - get_data_point(cx, fy, chan); + const T img_cxcy = + (one - dx) * (one - dy) * get_data_point(cx, cy, chan); + const T img_fxcy = dx * (one - dy) * get_data_point(fx, cy, chan); + const T img_cxfy = (one - dx) * dy * get_data_point(cx, fy, chan); set_output(sample_id, chan, img_fxfy + img_cxcy + img_fxcy + img_cxfy); } @@ -125,8 +112,8 @@ struct Resampler2DFunctor{ // estimate of the cost of each work unit is needed to correctly shard the // workload. Shard assumes each cost unit is 1ns, minimum cost per shard // being 10us. - const int64 cost = static_cast(num_sampling_points) * - data_channels * 1000; + const int64 cost = + static_cast(num_sampling_points) * data_channels * 1000; auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost, resample_batches); @@ -138,8 +125,8 @@ struct Resampler2DFunctor{ template class ResamplerOp : public ::tensorflow::OpKernel { public: - explicit ResamplerOp(::tensorflow::OpKernelConstruction* context) : - ::tensorflow::OpKernel(context) {} + explicit ResamplerOp(::tensorflow::OpKernelConstruction* context) + : ::tensorflow::OpKernel(context) {} void Compute(::tensorflow::OpKernelContext* ctx) override { const ::tensorflow::Tensor& data = ctx->input(0); @@ -158,16 +145,17 @@ class ResamplerOp : public ::tensorflow::OpKernel { ::tensorflow::errors::InvalidArgument( "warp should be at least a matrix, got shape ", warp_shape.DebugString())); - OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims()-1) == 2, + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, ::tensorflow::errors::Unimplemented( "Only bilinear interpolation is supported, warping " "coordinates must be 2D; warp shape last entry should be " - "2, but shape vector is: ", warp_shape.DebugString())); + "2, but shape vector is: ", + warp_shape.DebugString())); OP_REQUIRES(ctx, data_shape.dim_size(0) == warp_shape.dim_size(0), ::tensorflow::errors::InvalidArgument( "Batch size of data and warp tensor must be the same, but " - "input shapes are: ", data_shape.DebugString(), ", ", - warp_shape.DebugString())); + "input shapes are: ", + data_shape.DebugString(), ", ", warp_shape.DebugString())); const int batch_size = data_shape.dim_size(0); const int data_height = data_shape.dim_size(1); const int data_width = data_shape.dim_size(2); @@ -180,16 +168,10 @@ class ResamplerOp : public ::tensorflow::OpKernel { // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU. if (num_sampling_points > 0) { - functor::Resampler2DFunctor()(ctx, - ctx->eigen_device(), - data.flat().data(), - warp.flat().data(), - output->flat().data(), - batch_size, - data_height, - data_width, - data_channels, - num_sampling_points); + functor::Resampler2DFunctor()( + ctx, ctx->eigen_device(), data.flat().data(), + warp.flat().data(), output->flat().data(), batch_size, + data_height, data_width, data_channels, num_sampling_points); } } @@ -197,12 +179,9 @@ class ResamplerOp : public ::tensorflow::OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ResamplerOp); }; - -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("Resampler") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Resampler").Device(DEVICE_CPU).TypeConstraint("T"), \ ResamplerOp); TF_CALL_half(REGISTER); @@ -211,40 +190,32 @@ TF_CALL_double(REGISTER); #undef REGISTER #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("Resampler") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T"), \ - ResamplerOp) +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Resampler").Device(DEVICE_GPU).TypeConstraint("T"), \ + ResamplerOp) TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); #undef REGISTER #endif // GOOGLE_CUDA - namespace functor { template -struct ResamplerGrad2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const CPUDevice& d, - const T* __restrict__ data, - const T* __restrict__ warp, - const T* __restrict__ grad_output, - T* __restrict__ grad_data, - T* __restrict__ grad_warp, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points){ +struct ResamplerGrad2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const CPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { // Set gradients to 0, because the kernel incrementally updates the // tensor entries by adding partial contributions. - const int resampler_output_size = batch_size * num_sampling_points * - data_channels; + const int resampler_output_size = + batch_size * num_sampling_points * data_channels; const int grad_warp_size = resampler_output_size / data_channels * 2; - const int grad_data_size = data_height * data_width * data_channels * - batch_size; + const int grad_data_size = + data_height * data_width * data_channels * batch_size; memset(grad_data, 0, sizeof(T) * grad_data_size); memset(grad_warp, 0, sizeof(T) * grad_warp_size); @@ -260,35 +231,29 @@ struct ResamplerGrad2DFunctor{ // The functions take care of performing the relevant pointer // arithmetics abstracting away the low level details in the // main loop over samples. Note that data is stored in NHWC format. - auto get_data_point = [&](const int x, - const int y, - const int chan) { + auto get_data_point = [&](const int x, const int y, const int chan) { const bool point_is_in_range = - (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); + (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); return point_is_in_range - ? data[batch_id * data_batch_stride + - data_channels * (y * data_width + x) + - chan] - : zero; + ? data[batch_id * data_batch_stride + + data_channels * (y * data_width + x) + chan] + : zero; }; auto update_grad_data = [&](const int x, const int y, const int chan, const T value) { const bool point_is_in_range = (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); - if (point_is_in_range){ + if (point_is_in_range) { grad_data[batch_id * data_batch_stride + - data_channels * (y * data_width + x) + - chan] += value; + data_channels * (y * data_width + x) + chan] += value; } }; - auto update_grad_warp = [&](const int sample_id, - const int channel, + auto update_grad_warp = [&](const int sample_id, const int channel, const T value) { - grad_warp[batch_id * warp_batch_stride + - sample_id * 2 + - channel] += value; + grad_warp[batch_id * warp_batch_stride + sample_id * 2 + channel] += + value; }; for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) { @@ -301,8 +266,7 @@ struct ResamplerGrad2DFunctor{ // The effect is that the sampled signal smoothly goes to 0 outside // the original input domain, rather than presenting a jump // discontinuity at the image boundaries. - if (x > static_cast(-1.0) && - y > static_cast(-1.0) && + if (x > static_cast(-1.0) && y > static_cast(-1.0) && x < static_cast(data_width) && y < static_cast(data_height)) { // Precompute floor (f) and ceil (c) values for x and y. @@ -316,27 +280,25 @@ struct ResamplerGrad2DFunctor{ for (int chan = 0; chan < data_channels; ++chan) { const T grad_output_value = grad_output[batch_id * output_batch_stride + - sample_id * data_channels + - chan]; + sample_id * data_channels + chan]; const T img_fxfy = get_data_point(fx, fy, chan); const T img_cxcy = get_data_point(cx, cy, chan); const T img_fxcy = get_data_point(fx, cy, chan); const T img_cxfy = get_data_point(cx, fy, chan); // Update partial gradients wrt relevant warp field entries - update_grad_warp(sample_id, 0, - grad_output_value * - ((one - dy) * (img_cxcy - img_fxcy) + - dy * (img_cxfy - img_fxfy))); + update_grad_warp( + sample_id, 0, + grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) + + dy * (img_cxfy - img_fxfy))); - update_grad_warp(sample_id, 1, - grad_output_value * - ((one - dx) * (img_cxcy - img_cxfy) + - dx * (img_fxcy - img_fxfy))); + update_grad_warp( + sample_id, 1, + grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) + + dx * (img_fxcy - img_fxfy))); // Update partial gradients wrt sampled data - update_grad_data(fx, fy, chan, - grad_output_value * dx * dy); + update_grad_data(fx, fy, chan, grad_output_value * dx * dy); update_grad_data(cx, cy, chan, grad_output_value * (one - dx) * (one - dy)); update_grad_data(fx, cy, chan, @@ -355,8 +317,8 @@ struct ResamplerGrad2DFunctor{ // being 10us. // TODO(fviola): Check out if there is a better way of doing this. auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); - const int64 cost = static_cast(num_sampling_points) * - data_channels * 1000; + const int64 cost = + static_cast(num_sampling_points) * data_channels * 1000; ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost, update_grads_for_batches); } @@ -364,12 +326,11 @@ struct ResamplerGrad2DFunctor{ } // namespace functor - template class ResamplerGradOp : public ::tensorflow::OpKernel { public: - explicit ResamplerGradOp(::tensorflow::OpKernelConstruction* context) : - ::tensorflow::OpKernel(context) {} + explicit ResamplerGradOp(::tensorflow::OpKernelConstruction* context) + : ::tensorflow::OpKernel(context) {} void Compute(::tensorflow::OpKernelContext* ctx) override { const ::tensorflow::Tensor& data = ctx->input(0); @@ -383,7 +344,7 @@ class ResamplerGradOp : public ::tensorflow::OpKernel { "tensor must be a batch of 2d data; data shape should have " "4 entries corresponding to [batch_size, data_height, " "data_width, data_channels], but is: ", - data_shape.DebugString())); + data_shape.DebugString())); const int batch_size = data_shape.dim_size(0); const int data_height = data_shape.dim_size(1); const int data_width = data_shape.dim_size(2); @@ -394,7 +355,7 @@ class ResamplerGradOp : public ::tensorflow::OpKernel { ::tensorflow::errors::InvalidArgument( "warp should be at least a matrix, got shape ", warp_shape.DebugString())); - OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims()-1) == 2, + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, ::tensorflow::errors::Unimplemented( "Only bilinear interpolation is supported, warping " "coordinates must be 2D; warp shape last entry should be " @@ -417,18 +378,11 @@ class ResamplerGradOp : public ::tensorflow::OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(1, warp.shape(), &grad_warp)); // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU. if (num_sampling_points > 0) { - functor::ResamplerGrad2DFunctor()(ctx, - ctx->eigen_device(), - data.flat().data(), - warp.flat().data(), - grad_output.flat().data(), - grad_data->flat().data(), - grad_warp->flat().data(), - batch_size, - data_height, - data_width, - data_channels, - num_sampling_points); + functor::ResamplerGrad2DFunctor()( + ctx, ctx->eigen_device(), data.flat().data(), + warp.flat().data(), grad_output.flat().data(), + grad_data->flat().data(), grad_warp->flat().data(), batch_size, + data_height, data_width, data_channels, num_sampling_points); } } @@ -436,11 +390,9 @@ class ResamplerGradOp : public ::tensorflow::OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ResamplerGradOp); }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("ResamplerGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("ResamplerGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ ResamplerGradOp); TF_CALL_half(REGISTER); @@ -449,11 +401,10 @@ TF_CALL_double(REGISTER); #undef REGISTER #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("ResamplerGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T"), \ - ResamplerGradOp) +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("ResamplerGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + ResamplerGradOp) // Disable half and double precision since atomicAdds are not supported // TF_CALL_half(REGISTER); // TF_CALL_double(REGISTER); diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.h b/tensorflow/contrib/resampler/kernels/resampler_ops.h index 85d3676efac70fe9237d31c2be1fe75e67d70abd..7fe3b9c0df71f51e07d38ea15a672d79fdc70453 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops.h +++ b/tensorflow/contrib/resampler/kernels/resampler_ops.h @@ -29,38 +29,25 @@ namespace functor { // Helper functor for the Resampler Op in 2D template -struct Resampler2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const Device& d, - const T* __restrict__ data, - const T* __restrict__ warp, - T* __restrict__ output, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points); +struct Resampler2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const Device& d, + const T* __restrict__ data, const T* __restrict__ warp, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points); }; - // Helper functor for the Resampler Gradient Op in 2D template -struct ResamplerGrad2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const Device& d, - const T* __restrict__ data, - const T* __restrict__ warp, - const T* __restrict__ grad_output, - T* __restrict__ grad_data, - T* __restrict__ grad_warp, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points); +struct ResamplerGrad2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const Device& d, + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points); }; - } // namespace functor } // namespace tensorflow diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc index 636847a212f27c738032128e3f3f653ec32f851b..3c07051f685c74b6e45fb782c80871f38dffbbf4 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc +++ b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc @@ -31,18 +31,15 @@ using GPUDevice = Eigen::GpuDevice; namespace { -#define GET_DATA_POINT(x, y) \ - data[batch_id * data_batch_stride + \ - data_channels * (y * data_width + x) + \ +#define GET_DATA_POINT(x, y) \ + data[batch_id * data_batch_stride + data_channels * (y * data_width + x) + \ chan] template __global__ void Resampler2DKernel(const T* __restrict__ data, const T* __restrict__ warp, - T* __restrict__ output, - const int batch_size, - const int data_height, - const int data_width, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, const int data_channels, const int num_sampling_points) { const int output_data_size = batch_size * num_sampling_points * data_channels; @@ -75,10 +72,8 @@ __global__ void Resampler2DKernel(const T* __restrict__ data, // The effect is that the sampled signal smoothly goes to 0 outside // the original input domain, rather than presenting a jump // discontinuity at the image boundaries. - if (x > static_cast(-1.0) && - y > static_cast(-1.0) && - x < static_cast(data_width) && - y < static_cast(data_height)) { + if (x > static_cast(-1.0) && y > static_cast(-1.0) && + x < static_cast(data_width) && y < static_cast(data_height)) { // Precompute floor (f) and ceil (c) values for x and y. const int fx = std::floor(static_cast(x)); const int fy = std::floor(static_cast(y)); @@ -87,21 +82,20 @@ __global__ void Resampler2DKernel(const T* __restrict__ data, const T dx = static_cast(cx) - x; const T dy = static_cast(cy) - y; - const T img_fxfy = (fx >= 0 && fy >= 0) - ? dx * dy * GET_DATA_POINT(fx, fy) - : zero; + const T img_fxfy = + (fx >= 0 && fy >= 0) ? dx * dy * GET_DATA_POINT(fx, fy) : zero; const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1) - ? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy) - : zero; + ? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy) + : zero; const T img_fxcy = (fx >= 0 && cy <= data_height - 1) - ? dx * (one - dy) * GET_DATA_POINT(fx, cy) - : zero; + ? dx * (one - dy) * GET_DATA_POINT(fx, cy) + : zero; const T img_cxfy = (cx <= data_width - 1 && fy >= 0) - ? (one - dx) * dy * GET_DATA_POINT(cx, fy) - : zero; + ? (one - dx) * dy * GET_DATA_POINT(cx, fy) + : zero; output[out_index] = img_fxfy + img_cxcy + img_fxcy + img_cxfy; } else { @@ -115,24 +109,20 @@ __global__ void Resampler2DKernel(const T* __restrict__ data, namespace functor { template -struct Resampler2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const GPUDevice& d, - const T* __restrict__ data, - const T* __restrict__ warp, - T* __restrict__ output, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points) { - const int output_data_size = batch_size * num_sampling_points * data_channels; - ::tensorflow::CudaLaunchConfig config = - ::tensorflow::GetCudaLaunchConfig(output_data_size, d); - Resampler2DKernel - <<>>( - data, warp, output, batch_size, data_height, data_width, - data_channels, num_sampling_points); +struct Resampler2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const GPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + T* __restrict__ output, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { + const int output_data_size = + batch_size * num_sampling_points * data_channels; + ::tensorflow::CudaLaunchConfig config = + ::tensorflow::GetCudaLaunchConfig(output_data_size, d); + Resampler2DKernel + <<>>( + data, warp, output, batch_size, data_height, data_width, + data_channels, num_sampling_points); } }; @@ -145,26 +135,20 @@ template struct Resampler2DFunctor; namespace { -#define UPDATE_GRAD_DATA_POINT(x, y, v) \ - atomicAdd(grad_data + (batch_id * data_batch_stride + \ - data_channels * (y * data_width + x) + \ - chan), \ +#define UPDATE_GRAD_DATA_POINT(x, y, v) \ + atomicAdd(grad_data + (batch_id * data_batch_stride + \ + data_channels * (y * data_width + x) + chan), \ v) - template -__global__ void ResamplerGrad2DKernel(const T* __restrict__ data, - const T* __restrict__ warp, - const T* __restrict__ grad_output, - T* __restrict__ grad_data, - T* __restrict__ grad_warp, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points) { - const int resampler_output_size = batch_size * num_sampling_points * - data_channels; +__global__ void ResamplerGrad2DKernel( + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, const int data_height, + const int data_width, const int data_channels, + const int num_sampling_points) { + const int resampler_output_size = + batch_size * num_sampling_points * data_channels; CUDA_1D_KERNEL_LOOP(index, resampler_output_size) { const int out_index = index; @@ -199,10 +183,8 @@ __global__ void ResamplerGrad2DKernel(const T* __restrict__ data, // The effect is that the sampled signal smoothly goes to 0 outside // the original input domain, rather than presenting a jump // discontinuity at the image boundaries. - if (x > static_cast(-1.0) && - y > static_cast(-1.0) && - x < static_cast(data_width) && - y < static_cast(data_height)) { + if (x > static_cast(-1.0) && y > static_cast(-1.0) && + x < static_cast(data_width) && y < static_cast(data_height)) { // Precompute floor (f) and ceil (c) values for x and y. const int fx = std::floor(static_cast(x)); const int fy = std::floor(static_cast(y)); @@ -211,21 +193,17 @@ __global__ void ResamplerGrad2DKernel(const T* __restrict__ data, const T dx = static_cast(cx) - x; const T dy = static_cast(cy) - y; - const T img_fxfy = (fx >= 0 && fy >= 0) - ? GET_DATA_POINT(fx, fy) - : zero; + const T img_fxfy = (fx >= 0 && fy >= 0) ? GET_DATA_POINT(fx, fy) : zero; const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1) - ? GET_DATA_POINT(cx, cy) - : zero; + ? GET_DATA_POINT(cx, cy) + : zero; - const T img_fxcy = (fx >= 0 && cy <= data_height - 1) - ? GET_DATA_POINT(fx, cy) - : zero; + const T img_fxcy = + (fx >= 0 && cy <= data_height - 1) ? GET_DATA_POINT(fx, cy) : zero; - const T img_cxfy = (cx <= data_width - 1 && fy >= 0) - ? GET_DATA_POINT(cx, fy) - : zero; + const T img_cxfy = + (cx <= data_width - 1 && fy >= 0) ? GET_DATA_POINT(cx, fy) : zero; // Update partial gradients wrt relevant warp field entries atomicAdd(grad_warp + warp_id_x, @@ -241,7 +219,7 @@ __global__ void ResamplerGrad2DKernel(const T* __restrict__ data, } if (cx <= data_width - 1 && cy <= data_height - 1) { UPDATE_GRAD_DATA_POINT(cx, cy, - grad_output_value * (one - dx) * (one - dy)); + grad_output_value * (one - dx) * (one - dy)); } if (fx >= 0 && cy <= data_height - 1) { UPDATE_GRAD_DATA_POINT(fx, cy, grad_output_value * dx * (one - dy)); @@ -261,43 +239,37 @@ __global__ void ResamplerGrad2DKernel(const T* __restrict__ data, namespace functor { template -struct ResamplerGrad2DFunctor{ - void operator ()(::tensorflow::OpKernelContext* ctx, - const GPUDevice& d, - const T* __restrict__ data, - const T* __restrict__ warp, - const T* __restrict__ grad_output, - T* __restrict__ grad_data, - T* __restrict__ grad_warp, - const int batch_size, - const int data_height, - const int data_width, - const int data_channels, - const int num_sampling_points) { - // Set gradients to 0, because the kernel incrementally updates the - // tensor entries by adding partial contributions. - const int grad_warp_size = batch_size * num_sampling_points * 2; - const int grad_data_size = batch_size * data_height * data_width * - data_channels; - - ::tensorflow::CudaLaunchConfig config = - ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d); - ::tensorflow::SetZero - <<>>( - grad_warp_size, grad_warp); - - config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d); - ::tensorflow::SetZero - <<>>( - grad_data_size, grad_data); - - const int resampler_output_size = batch_size * num_sampling_points * - data_channels; - config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d); - ResamplerGrad2DKernel - <<>>( - data, warp, grad_output, grad_data, grad_warp, batch_size, - data_height, data_width, data_channels, num_sampling_points); +struct ResamplerGrad2DFunctor { + void operator()(::tensorflow::OpKernelContext* ctx, const GPUDevice& d, + const T* __restrict__ data, const T* __restrict__ warp, + const T* __restrict__ grad_output, T* __restrict__ grad_data, + T* __restrict__ grad_warp, const int batch_size, + const int data_height, const int data_width, + const int data_channels, const int num_sampling_points) { + // Set gradients to 0, because the kernel incrementally updates the + // tensor entries by adding partial contributions. + const int grad_warp_size = batch_size * num_sampling_points * 2; + const int grad_data_size = + batch_size * data_height * data_width * data_channels; + + ::tensorflow::CudaLaunchConfig config = + ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d); + ::tensorflow:: + SetZero<<>>( + grad_warp_size, grad_warp); + + config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d); + ::tensorflow:: + SetZero<<>>( + grad_data_size, grad_data); + + const int resampler_output_size = + batch_size * num_sampling_points * data_channels; + config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d); + ResamplerGrad2DKernel + <<>>( + data, warp, grad_output, grad_data, grad_warp, batch_size, + data_height, data_width, data_channels, num_sampling_points); } }; diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index c568c6760fd67b1902b0c1e6dc1aa439cb63de9b..67f31785b57fddef67733c18c3b744322532c28c 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -18,6 +18,7 @@ See @{$python/contrib.rnn} guide. @@RNNCell +@@LayerRNNCell @@BasicRNNCell @@BasicLSTMCell @@GRUCell @@ -68,6 +69,10 @@ See @{$python/contrib.rnn} guide. @@static_bidirectional_rnn @@stack_bidirectional_dynamic_rnn @@stack_bidirectional_rnn + + +@@transpose_batch_time +@@best_effort_input_batch_size """ from __future__ import absolute_import @@ -85,6 +90,8 @@ from tensorflow.contrib.rnn.python.ops.lstm_ops import * from tensorflow.contrib.rnn.python.ops.rnn import * from tensorflow.contrib.rnn.python.ops.rnn_cell import * +from tensorflow.python.ops.rnn import _best_effort_input_batch_size as best_effort_input_batch_size +from tensorflow.python.ops.rnn import _transpose_batch_time as transpose_batch_time from tensorflow.python.ops.rnn import static_bidirectional_rnn from tensorflow.python.ops.rnn import static_rnn from tensorflow.python.ops.rnn import static_state_saving_rnn diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.cc b/tensorflow/contrib/rnn/kernels/blas_gemm.cc index e62501e9b100484a7be3cc6ae0fc25905c0d0724..03006dab323a7c6dc83d9a17c035ef705f7b0366 100644 --- a/tensorflow/contrib/rnn/kernels/blas_gemm.cc +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.cc @@ -36,11 +36,10 @@ perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory) { namespace functor { template -void TensorCuBlasGemm::operator()(OpKernelContext* ctx, - bool transa, bool transb, uint64 m, - uint64 n, uint64 k, T alpha, const T* a, - int lda, const T* b, int ldb, T beta, T* c, - int ldc) { +void TensorCuBlasGemm::operator()(OpKernelContext* ctx, bool transa, + bool transb, uint64 m, uint64 n, uint64 k, + T alpha, const T* a, int lda, const T* b, + int ldb, T beta, T* c, int ldc) { #if GOOGLE_CUDA perftools::gputools::blas::Transpose trans[] = { perftools::gputools::blas::Transpose::kNoTranspose, diff --git a/tensorflow/contrib/rnn/kernels/gru_ops.cc b/tensorflow/contrib/rnn/kernels/gru_ops.cc index 0796f82b214620dd71d154fb8f8ec953dbcbb9ec..bd3d898fb09da0f490050c85b1e585502d8ecb2c 100644 --- a/tensorflow/contrib/rnn/kernels/gru_ops.cc +++ b/tensorflow/contrib/rnn/kernels/gru_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/contrib/rnn/kernels/gru_ops.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -61,9 +61,9 @@ class GRUCellBlockOp : public OpKernel { h_prev_tensor->dim_size(0), " vs. ", batch_size)); OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("h_prev.dims(1) != cell_size: ", - h_prev_tensor->dim_size(1), " vs. ", - cell_size)); + errors::InvalidArgument( + "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), + " vs. ", cell_size)); // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size] OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size, @@ -82,10 +82,10 @@ class GRUCellBlockOp : public OpKernel { "w_c.dim_size(0) != input_size + cell_size: ", w_c_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES( - ctx, w_c_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("w_c.dim_size(1) != cell_size: ", - w_c_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1), + " vs. ", cell_size)); // Shape of 'b_ru' must be [2*cell_size] OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2, @@ -97,10 +97,10 @@ class GRUCellBlockOp : public OpKernel { errors::InvalidArgument("Rank of b_ru must be 1", b_ru_tensor->dims(), " vs. 1", 1)); // Shape of 'b_c' must be [cell_size] - OP_REQUIRES( - ctx, b_c_tensor->dim_size(0) == cell_size, - errors::InvalidArgument("b_c.dim_size(0) != cell_size: ", - b_c_tensor->dim_size(0), " vs. ", cell_size)); + OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size, + errors::InvalidArgument( + "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0), + " vs. ", cell_size)); OP_REQUIRES(ctx, b_c_tensor->dims() == 1, errors::InvalidArgument("Rank of b_c must be 1", b_c_tensor->dims(), " vs. 1")); @@ -216,9 +216,9 @@ class GRUBlockCellGradOp : public OpKernel { h_prev_tensor->dim_size(0), " vs. ", batch_size)); OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("h_prev.dims(1) != cell_size: ", - h_prev_tensor->dim_size(1), " vs. ", - cell_size)); + errors::InvalidArgument( + "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), + " vs. ", cell_size)); // Shape of 'w_ru' must be [input_size+cell_size, 2*cell_size] OP_REQUIRES(ctx, w_ru_tensor->dim_size(0) == input_size + cell_size, @@ -237,10 +237,10 @@ class GRUBlockCellGradOp : public OpKernel { "w_c.dim_size(0) != input_size + cell_size: ", w_c_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES( - ctx, w_c_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("w_c.dim_size(1) != cell_size: ", - w_c_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, w_c_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "w_c.dim_size(1) != cell_size: ", w_c_tensor->dim_size(1), + " vs. ", cell_size)); // Shape of 'b_ru' must be [2*cell_size] OP_REQUIRES(ctx, b_ru_tensor->dim_size(0) == cell_size * 2, @@ -253,54 +253,54 @@ class GRUBlockCellGradOp : public OpKernel { b_ru_tensor->dims(), " vs. 1")); // Shape of 'b_c' must be [cell_size] - OP_REQUIRES( - ctx, b_c_tensor->dim_size(0) == cell_size, - errors::InvalidArgument("b_c.dim_size(0) != cell_size: ", - b_c_tensor->dim_size(0), " vs. ", cell_size)); + OP_REQUIRES(ctx, b_c_tensor->dim_size(0) == cell_size, + errors::InvalidArgument( + "b_c.dim_size(0) != cell_size: ", b_c_tensor->dim_size(0), + " vs. ", cell_size)); OP_REQUIRES(ctx, b_c_tensor->dims() == 1, errors::InvalidArgument("Rank of b_c must be 1 ", b_c_tensor->dims(), " vs. 1")); // Shape of 'r' must be [batch_size, cell_size] - OP_REQUIRES( - ctx, r_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("r.dims(0) != batch_size: ", - r_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, r_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("r.dims(1) != cell_size: ", - r_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, r_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "r.dims(0) != batch_size: ", r_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, r_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "r.dims(1) != cell_size: ", r_tensor->dim_size(1), " vs. ", + cell_size)); // Shape of 'u' must be [batch_size, cell_size] - OP_REQUIRES( - ctx, u_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("u.dims(0) != batch_size: ", - u_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, u_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("u.dims(1) != cell_size: ", - u_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, u_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "u.dims(0) != batch_size: ", u_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, u_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "u.dims(1) != cell_size: ", u_tensor->dim_size(1), " vs. ", + cell_size)); // Shape of 'c' must be [batch_size, cell_size] - OP_REQUIRES( - ctx, c_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("c.dims(0) != batch_size: ", - c_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, c_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("c.dims(1) != cell_size: ", - c_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, c_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "c.dims(0) != batch_size: ", c_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, c_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "c.dims(1) != cell_size: ", c_tensor->dim_size(1), " vs. ", + cell_size)); // Shape of 'd_h' must be [batch_size, cell_size] - OP_REQUIRES( - ctx, d_h_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("d_h.dims(0) != batch_size: ", - d_h_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, d_h_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("d_h.dims(1) != cell_size: ", - d_h_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, d_h_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "d_h.dims(0) != batch_size: ", d_h_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, d_h_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "d_h.dims(1) != cell_size: ", d_h_tensor->dim_size(1), + " vs. ", cell_size)); // Create output tensors. Tensor* d_x_tensor = nullptr; diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc index 941a457fd3ada312b981fb23c769ff9ecea9ff13..5e7cf0ce84d332bd24088cd78995f7843813328b 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc @@ -281,23 +281,23 @@ class LSTMBlockCellOp : public OpKernel { h_prev_tensor->dim_size(0), " vs. ", batch_size)); OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("h_prev.dims(1) != cell_size: ", - h_prev_tensor->dim_size(1), " vs. ", - cell_size)); + errors::InvalidArgument( + "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), + " vs. ", cell_size)); OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, errors::InvalidArgument( "w.dim_size(0) != input_size + cell_size: ", w_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES( - ctx, w_tensor->dim_size(1) == cell_size * 4, - errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ", - w_tensor->dim_size(1), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, + errors::InvalidArgument( + "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), + " vs. ", cell_size * 4)); - OP_REQUIRES( - ctx, b_tensor->dim_size(0) == cell_size * 4, - errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ", - b_tensor->dim_size(0), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, + errors::InvalidArgument( + "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), + " vs. ", cell_size * 4)); // Allocate our output tensors. Tensor* i_tensor = nullptr; @@ -484,77 +484,77 @@ class LSTMBlockCellGradOp : public OpKernel { h_prev_tensor->dim_size(0), " vs. ", batch_size)); OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("h_prev.dims(1) != cell_size: ", - h_prev_tensor->dim_size(1), " vs. ", - cell_size)); + errors::InvalidArgument( + "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), + " vs. ", cell_size)); OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, errors::InvalidArgument( "w.dim_size(0) != input_size + cell_size: ", w_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES( - ctx, w_tensor->dim_size(1) == cell_size * 4, - errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ", - w_tensor->dim_size(1), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, + errors::InvalidArgument( + "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), + " vs. ", cell_size * 4)); - OP_REQUIRES( - ctx, b_tensor->dim_size(0) == cell_size * 4, - errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ", - b_tensor->dim_size(0), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, + errors::InvalidArgument( + "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), + " vs. ", cell_size * 4)); - OP_REQUIRES( - ctx, i_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("i.dim_size(0) != batch_size: ", - i_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, i_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("i.dim_size(1) != cell_size: ", - i_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, i_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "i.dim_size(0) != batch_size: ", i_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, i_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "i.dim_size(1) != cell_size: ", i_tensor->dim_size(1), + " vs. ", cell_size)); - OP_REQUIRES( - ctx, cs_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("cs.dim_size(0) != batch_size: ", - cs_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, cs_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("cs.dim_size(1) != cell_size: ", - cs_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, cs_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "cs.dim_size(0) != batch_size: ", cs_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, cs_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "cs.dim_size(1) != cell_size: ", cs_tensor->dim_size(1), + " vs. ", cell_size)); - OP_REQUIRES( - ctx, f_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("f.dim_size(0) != batch_size: ", - f_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, f_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("i.dim_size(1) != cell_size: ", - f_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, f_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "f.dim_size(0) != batch_size: ", f_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, f_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "i.dim_size(1) != cell_size: ", f_tensor->dim_size(1), + " vs. ", cell_size)); - OP_REQUIRES( - ctx, o_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("o.dim_size(0) != batch_size: ", - o_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, o_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("o.dim_size(1) != cell_size: ", - o_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, o_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "o.dim_size(0) != batch_size: ", o_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, o_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "o.dim_size(1) != cell_size: ", o_tensor->dim_size(1), + " vs. ", cell_size)); - OP_REQUIRES( - ctx, ci_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("ci.dim_size(0) != batch_size: ", - ci_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, ci_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("ci.dim_size(1) != cell_size: ", - ci_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, ci_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "ci.dim_size(0) != batch_size: ", ci_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, ci_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "ci.dim_size(1) != cell_size: ", ci_tensor->dim_size(1), + " vs. ", cell_size)); - OP_REQUIRES( - ctx, co_tensor->dim_size(0) == batch_size, - errors::InvalidArgument("co.dim_size(0) != batch_size: ", - co_tensor->dim_size(0), " vs. ", batch_size)); - OP_REQUIRES( - ctx, co_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("co.dim_size(1) != cell_size: ", - co_tensor->dim_size(1), " vs. ", cell_size)); + OP_REQUIRES(ctx, co_tensor->dim_size(0) == batch_size, + errors::InvalidArgument( + "co.dim_size(0) != batch_size: ", co_tensor->dim_size(0), + " vs. ", batch_size)); + OP_REQUIRES(ctx, co_tensor->dim_size(1) == cell_size, + errors::InvalidArgument( + "co.dim_size(1) != cell_size: ", co_tensor->dim_size(1), + " vs. ", cell_size)); OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size, errors::InvalidArgument( @@ -860,9 +860,9 @@ class BlockLSTMOp : public OpKernel { h_prev_tensor->dim_size(0), " vs. ", batch_size)); OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, - errors::InvalidArgument("h_prev.dims(1) != cell_size: ", - h_prev_tensor->dim_size(1), " vs. ", - cell_size)); + errors::InvalidArgument( + "h_prev.dims(1) != cell_size: ", h_prev_tensor->dim_size(1), + " vs. ", cell_size)); const Tensor* w_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); @@ -872,46 +872,46 @@ class BlockLSTMOp : public OpKernel { errors::InvalidArgument( "w.dim_size(0) != input_size + cell_size: ", w_tensor->dim_size(0), " vs. ", input_size + cell_size)); - OP_REQUIRES( - ctx, w_tensor->dim_size(1) == cell_size * 4, - errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ", - w_tensor->dim_size(1), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, w_tensor->dim_size(1) == cell_size * 4, + errors::InvalidArgument( + "w.dim_size(1) != cell_size * 4: ", w_tensor->dim_size(1), + " vs. ", cell_size * 4)); const Tensor* wci_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); OP_REQUIRES(ctx, wci_tensor->dims() == 1, errors::InvalidArgument("wci must be 1D")); - OP_REQUIRES( - ctx, wci_tensor->dim_size(0) == cell_size, - errors::InvalidArgument("wci.dim_size(0) != cell_size: ", - wci_tensor->dim_size(0), " vs. ", cell_size)); + OP_REQUIRES(ctx, wci_tensor->dim_size(0) == cell_size, + errors::InvalidArgument( + "wci.dim_size(0) != cell_size: ", wci_tensor->dim_size(0), + " vs. ", cell_size)); const Tensor* wcf_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); OP_REQUIRES(ctx, wcf_tensor->dims() == 1, errors::InvalidArgument("wcf must be 1D")); - OP_REQUIRES( - ctx, wcf_tensor->dim_size(0) == cell_size, - errors::InvalidArgument("wcf.dim_size(0) != cell_size: ", - wcf_tensor->dim_size(0), " vs. ", cell_size)); + OP_REQUIRES(ctx, wcf_tensor->dim_size(0) == cell_size, + errors::InvalidArgument( + "wcf.dim_size(0) != cell_size: ", wcf_tensor->dim_size(0), + " vs. ", cell_size)); const Tensor* wco_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); OP_REQUIRES(ctx, wco_tensor->dims() == 1, errors::InvalidArgument("wco must be 1D")); - OP_REQUIRES( - ctx, wco_tensor->dim_size(0) == cell_size, - errors::InvalidArgument("wco.dim_size(0) != cell_size: ", - wco_tensor->dim_size(0), " vs. ", cell_size)); + OP_REQUIRES(ctx, wco_tensor->dim_size(0) == cell_size, + errors::InvalidArgument( + "wco.dim_size(0) != cell_size: ", wco_tensor->dim_size(0), + " vs. ", cell_size)); const Tensor* b_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); OP_REQUIRES(ctx, b_tensor->dims() == 1, errors::InvalidArgument("b must be 1D")); - OP_REQUIRES( - ctx, b_tensor->dim_size(0) == cell_size * 4, - errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ", - b_tensor->dim_size(0), " vs. ", cell_size * 4)); + OP_REQUIRES(ctx, b_tensor->dim_size(0) == cell_size * 4, + errors::InvalidArgument( + "b.dim_size(0) != cell_size * 4: ", b_tensor->dim_size(0), + " vs. ", cell_size * 4)); TensorShape batch_cell_shape({timelen, batch_size, cell_size}); Tensor* i_out; @@ -1065,9 +1065,9 @@ class BlockLSTMGradOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); const int64 cell_size = w_tensor->dim_size(1) / 4; OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0), - errors::InvalidArgument("w matrix rows don't match: ", - input_size + cell_size, " vs. ", - w_tensor->dim_size(0))); + errors::InvalidArgument( + "w matrix rows don't match: ", input_size + cell_size, + " vs. ", w_tensor->dim_size(0))); const Tensor* wci_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); @@ -1193,7 +1193,6 @@ class BlockLSTMGradOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::v(), batch_cell_shape, &h_grad_tensor)); - const Device& device = ctx->eigen_device(); functor::TensorZero()(device, cs_grad_tensor.flat()); diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h index bc6b85f3f1ab80b5ef5b4a8ba2e5242cf451adbe..d23cedc234b8c0e1a784346f28164ae79b8cbf89 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.h +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h @@ -92,7 +92,6 @@ struct TensorZeroPadding { } }; - struct LSTMBlockCell { LSTMBlockCell(const int batch_size, const int input_size, const int cell_size) : batch_size_(batch_size), diff --git a/tensorflow/contrib/rnn/ops/lstm_ops_test.cc b/tensorflow/contrib/rnn/ops/lstm_ops_test.cc index 544cd163c50062093acf7f5e942f67606936c0e3..68184b643e5e7a04ffecb804703051638514b7b2 100644 --- a/tensorflow/contrib/rnn/ops/lstm_ops_test.cc +++ b/tensorflow/contrib/rnn/ops/lstm_ops_test.cc @@ -149,8 +149,9 @@ TEST_F(LSTMOpsTest, BlockLSTMGrad_ShapeFn) { INFER_ERROR("must be rank 1", op, "?;?;?;?;?;?;?;?;[1,?]" + suffix); // Output with all input knowns makes known rank outputs. - INFER_OK(op, JoinedCopies("?", 18), "[?,?,?];" + JoinedCopies("[?,?]", 3) + - ";" + JoinedCopies("[?]", 4)); + INFER_OK( + op, JoinedCopies("?", 18), + "[?,?,?];" + JoinedCopies("[?,?]", 3) + ";" + JoinedCopies("[?]", 4)); // Output with copies input shapes to output. string input = strings::StrCat("?;[?,?,?];", JoinedCopies("[?,?]", 3), ";", diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index cafeb56ad88ba83fb42faf16db8ee1035da1deac..0e62b315b61cb3ceeb5cfd33bf5102a71abef83b 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -39,9 +39,6 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test -from tensorflow.python.framework import test_util -from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell - # pylint: enable=protected-access Linear = core_rnn_cell._Linear # pylint: disable=invalid-name @@ -84,19 +81,22 @@ class RNNCellTest(test.TestCase): ], [v.name for v in cell.trainable_variables]) self.assertFalse(cell.non_trainable_variables) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) self.assertEqual(res[0].shape, (1, 2)) def testBasicRNNCellNotTrainable(self): with self.test_session() as sess: + def not_trainable_getter(getter, *args, **kwargs): kwargs["trainable"] = False return getter(*args, **kwargs) with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5), + "root", + initializer=init_ops.constant_initializer(0.5), custom_getter=not_trainable_getter): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2]) @@ -108,9 +108,10 @@ class RNNCellTest(test.TestCase): "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME ], [v.name for v in cell.non_trainable_variables]) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) self.assertEqual(res[0].shape, (1, 2)) def testGRUCell(self): @@ -121,9 +122,10 @@ class RNNCellTest(test.TestCase): m = array_ops.zeros([1, 2]) g, _ = rnn_cell_impl.GRUCell(2)(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) # Smoke test self.assertAllClose(res[0], [[0.175991, 0.175991]]) with variable_scope.variable_scope( @@ -133,10 +135,10 @@ class RNNCellTest(test.TestCase): m = array_ops.zeros([1, 2]) g, _ = rnn_cell_impl.GRUCell(2)(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], - {x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) # Smoke test self.assertAllClose(res[0], [[0.156736, 0.156736]]) @@ -148,11 +150,27 @@ class RNNCellTest(test.TestCase): m = array_ops.zeros([1, 2]) g, _ = contrib_rnn_cell.SRUCell(2)(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.509682, 0.509682]]) + + def testSRUCellWithDiffSize(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.SRUCell(2)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) # Smoke test - self.assertAllClose(res[0], [[0.509682, 0.509682]]) + self.assertAllClose(res[0], [[0.55255556, 0.55255556]]) def testBasicLSTMCell(self): for dtype in [dtypes.float16, dtypes.float32]: @@ -164,8 +182,7 @@ class RNNCellTest(test.TestCase): m = array_ops.zeros([1, 8], dtype=dtype) cell = rnn_cell_impl.MultiRNNCell( [ - rnn_cell_impl.BasicLSTMCell( - 2, state_is_tuple=False) + rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) for _ in range(2) ], state_is_tuple=False) @@ -183,22 +200,21 @@ class RNNCellTest(test.TestCase): "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME ] - self.assertEqual( - expected_variable_names, - [v.name for v in cell.trainable_variables]) + self.assertEqual(expected_variable_names, + [v.name for v in cell.trainable_variables]) self.assertFalse(cell.non_trainable_variables) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], - {x.name: np.array([[1., 1.]]), - m.name: 0.1 * np.ones([1, 8])}) + res = sess.run([g, out_m], { + x.name: np.array([[1., 1.]]), + m.name: 0.1 * np.ones([1, 8]) + }) self.assertEqual(len(res), 2) variables = variables_lib.global_variables() self.assertEqual(expected_variable_names, [v.name for v in variables]) # The numbers in results were not calculated, this is just a # smoke test. - self.assertAllClose( - res[0], np.array([[0.240, 0.240]], dtype=np_dtype), 1e-2) + self.assertAllClose(res[0], np.array( + [[0.240, 0.240]], dtype=np_dtype), 1e-2) expected_mem = np.array( [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]], dtype=np_dtype) @@ -208,13 +224,13 @@ class RNNCellTest(test.TestCase): # Test BasicLSTMCell with input_size != num_units. x = array_ops.zeros([1, 3], dtype=dtype) m = array_ops.zeros([1, 4], dtype=dtype) - g, out_m = rnn_cell_impl.BasicLSTMCell( - 2, state_is_tuple=False)(x, m) + g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run( - [g, out_m], - {x.name: np.array([[1., 1., 1.]], dtype=np_dtype), - m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)}) + [g, out_m], { + x.name: np.array([[1., 1., 1.]], dtype=np_dtype), + m.name: 0.1 * np.ones([1, 4], dtype=np_dtype) + }) self.assertEqual(len(res), 2) def testBasicLSTMCellDimension0Error(self): @@ -232,9 +248,11 @@ class RNNCellTest(test.TestCase): g, out_m = rnn_cell_impl.BasicLSTMCell( num_units, state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) - sess.run([g, out_m], - {x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size - 1, state_size])}) + sess.run( + [g, out_m], { + x.name: 1 * np.ones([batch_size, input_size]), + m.name: 0.1 * np.ones([batch_size - 1, state_size]) + }) def testBasicLSTMCellStateSizeError(self): """Tests that state_size must be num_units * 2.""" @@ -251,9 +269,11 @@ class RNNCellTest(test.TestCase): g, out_m = rnn_cell_impl.BasicLSTMCell( num_units, state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) - sess.run([g, out_m], - {x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size, state_size])}) + sess.run( + [g, out_m], { + x.name: 1 * np.ones([batch_size, input_size]), + m.name: 0.1 * np.ones([batch_size, state_size]) + }) def testBasicLSTMCellStateTupleType(self): with self.test_session(): @@ -301,11 +321,12 @@ class RNNCellTest(test.TestCase): state_is_tuple=True) g, (out_m0, out_m1) = cell(x, (m0, m1)) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, out_m0, out_m1], { - x.name: np.array([[1., 1.]]), - m0.name: 0.1 * np.ones([1, 4]), - m1.name: 0.1 * np.ones([1, 4]) - }) + res = sess.run( + [g, out_m0, out_m1], { + x.name: np.array([[1., 1.]]), + m0.name: 0.1 * np.ones([1, 4]), + m1.name: 0.1 * np.ones([1, 4]) + }) self.assertEqual(len(res), 3) # The numbers in results were not calculated, this is just a smoke test. # Note, however, these values should match the original @@ -336,10 +357,11 @@ class RNNCellTest(test.TestCase): state_is_tuple=False) output, state = cell(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([output, state], { - x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]), - m.name: 0.1 * np.ones((batch_size, state_size)) - }) + res = sess.run( + [output, state], { + x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]), + m.name: 0.1 * np.ones((batch_size, state_size)) + }) self.assertEqual(len(res), 2) # The numbers in results were not calculated, this is mostly just a # smoke test. @@ -442,10 +464,10 @@ class RNNCellTest(test.TestCase): rnn_cell_impl.GRUCell(3), num_proj=3) g, new_m = cell(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, new_m], - {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1]])}) + res = sess.run([g, new_m], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]]) + }) self.assertEqual(res[1].shape, (1, 3)) # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) @@ -479,9 +501,11 @@ class RNNCellTest(test.TestCase): base_cell = rnn_cell_impl.GRUCell(3) g, m_new = base_cell(x, m) variable_scope.get_variable_scope().reuse_variables() + def residual_with_slice_fn(inp, out): inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) return inp_sliced + out + g_res, m_new_res = rnn_cell_impl.ResidualWrapper( base_cell, residual_with_slice_fn)(x, m) sess.run([variables_lib.global_variables_initializer()]) @@ -551,10 +575,10 @@ class RNNCellTest(test.TestCase): self.assertEqual(embedding_cell.output_size, 2) g, new_m = embedding_cell(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, new_m], - {x.name: np.array([[1]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g, new_m], { + x.name: np.array([[1]]), + m.name: np.array([[0.1, 0.1]]) + }) self.assertEqual(res[1].shape, (1, 2)) # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.17139, 0.17139]]) @@ -584,8 +608,8 @@ class RNNCellTest(test.TestCase): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 4]) _, ml = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) - for _ in range(2)], state_is_tuple=False)(x, m) + [rnn_cell_impl.GRUCell(2) for _ in range(2)], + state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run(ml, { x.name: np.array([[1., 1.]]), @@ -605,19 +629,20 @@ class RNNCellTest(test.TestCase): # Test incorrectness of state with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) - for _ in range(2)], state_is_tuple=True)(x, m_bad) + [rnn_cell_impl.GRUCell(2) for _ in range(2)], + state_is_tuple=True)(x, m_bad) _, ml = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) - for _ in range(2)], state_is_tuple=True)(x, m_good) + [rnn_cell_impl.GRUCell(2) for _ in range(2)], + state_is_tuple=True)(x, m_good) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run(ml, { - x.name: np.array([[1., 1.]]), - m_good[0].name: np.array([[0.1, 0.1]]), - m_good[1].name: np.array([[0.1, 0.1]]) - }) + res = sess.run( + ml, { + x.name: np.array([[1., 1.]]), + m_good[0].name: np.array([[0.1, 0.1]]), + m_good[1].name: np.array([[0.1, 0.1]]) + }) # The numbers in results were not calculated, this is just a # smoke test. However, these numbers should match those of @@ -628,8 +653,11 @@ class RNNCellTest(test.TestCase): class DropoutWrapperTest(test.TestCase): - def _testDropoutWrapper(self, batch_size=None, time_steps=None, - parallel_iterations=None, **kwargs): + def _testDropoutWrapper(self, + batch_size=None, + time_steps=None, + parallel_iterations=None, + **kwargs): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): @@ -640,14 +668,14 @@ class DropoutWrapperTest(test.TestCase): x = constant_op.constant( [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) m = rnn_cell_impl.LSTMStateTuple( - *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32) - ] * 2) + *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32 + )] * 2) else: x = constant_op.constant( np.random.randn(time_steps, batch_size, 3).astype(np.float32)) m = rnn_cell_impl.LSTMStateTuple(*[ - constant_op.constant( - [[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32) + constant_op. + constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32) ] * 2) outputs, final_state = rnn.dynamic_rnn( cell=rnn_cell_impl.DropoutWrapper( @@ -674,8 +702,8 @@ class DropoutWrapperTest(test.TestCase): res = self._testDropoutWrapper( input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], - [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) + [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], + dtype=np.float32) true_full_final_c = np.array( [[1.949385, 1.949385, 1.949385]], dtype=np.float32) self.assertAllClose(true_full_output, res[0]) @@ -687,8 +715,8 @@ class DropoutWrapperTest(test.TestCase): res = self._testDropoutWrapper( input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], - [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) + [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], + dtype=np.float32) true_full_final_c = np.array( [[1.949385, 1.949385, 1.949385]], dtype=np.float32) self.assertAllClose(true_full_output, res[0]) @@ -703,16 +731,20 @@ class DropoutWrapperTest(test.TestCase): ## consistent across both calls. Otherwise the seed may not end ## up being munged consistently across both graphs. res_standard_1 = self._testDropoutWrapper( - input_keep_prob=keep_some, output_keep_prob=keep_some, - state_keep_prob=keep_some, seed=10, + input_keep_prob=keep_some, + output_keep_prob=keep_some, + state_keep_prob=keep_some, + seed=10, parallel_iterations=1) # Clear away the graph and the test session (which keeps variables around) ops.reset_default_graph() self._ClearCachedSession() random_seed.set_random_seed(2) res_standard_2 = self._testDropoutWrapper( - input_keep_prob=keep_some, output_keep_prob=keep_some, - state_keep_prob=keep_some, seed=10, + input_keep_prob=keep_some, + output_keep_prob=keep_some, + state_keep_prob=keep_some, + seed=10, parallel_iterations=1) self.assertAllClose(res_standard_1[0], res_standard_2[0]) self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c) @@ -722,11 +754,12 @@ class DropoutWrapperTest(test.TestCase): keep_all = variable_scope.get_variable("all", initializer=1.0) keep_none = variable_scope.get_variable("none", initializer=1e-10) res = self._testDropoutWrapper( - input_keep_prob=keep_all, output_keep_prob=keep_none, + input_keep_prob=keep_all, + output_keep_prob=keep_none, state_keep_prob=keep_all) true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], - [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) + [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], + dtype=np.float32) true_full_final_c = np.array( [[1.949385, 1.949385, 1.949385]], dtype=np.float32) self.assertAllClose(np.zeros(res[0].shape), res[0]) @@ -739,13 +772,13 @@ class DropoutWrapperTest(test.TestCase): # Even though we dropout state, by default DropoutWrapper never # drops out the memory ("c") term of an LSTMStateTuple. res = self._testDropoutWrapper( - input_keep_prob=keep_all, output_keep_prob=keep_all, + input_keep_prob=keep_all, + output_keep_prob=keep_all, state_keep_prob=keep_none) - true_c_state = np.array( - [[1.713925, 1.713925, 1.713925]], dtype=np.float32) + true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32) true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], - [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) + [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], + dtype=np.float32) self.assertAllClose(true_full_output[0], res[0][0]) # Second output is modified by zero input state self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4) @@ -758,13 +791,14 @@ class DropoutWrapperTest(test.TestCase): keep_all = variable_scope.get_variable("all", initializer=1.0) keep_none = variable_scope.get_variable("none", initializer=1e-10) true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], - [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) + [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], + dtype=np.float32) true_full_final_c = np.array( [[1.949385, 1.949385, 1.949385]], dtype=np.float32) # All outputs are different because inputs are zeroed out res = self._testDropoutWrapper( - input_keep_prob=keep_none, output_keep_prob=keep_all, + input_keep_prob=keep_none, + output_keep_prob=keep_all, state_keep_prob=keep_all) self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4) self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4) @@ -774,9 +808,13 @@ class DropoutWrapperTest(test.TestCase): keep_some = 0.8 keep_all = variable_scope.get_variable("all", initializer=1.0) res = self._testDropoutWrapper( - input_keep_prob=keep_all, output_keep_prob=keep_some, - state_keep_prob=keep_all, variational_recurrent=True, - input_size=3, batch_size=5, time_steps=7) + input_keep_prob=keep_all, + output_keep_prob=keep_some, + state_keep_prob=keep_all, + variational_recurrent=True, + input_size=3, + batch_size=5, + time_steps=7) # Ensure the same dropout pattern for all time steps output_mask = np.abs(res[0]) > 1e-6 for m in output_mask[1:]: @@ -785,9 +823,13 @@ class DropoutWrapperTest(test.TestCase): def testDropoutWrapperRecurrentStateInputAndOutput(self): keep_some = 0.9 res = self._testDropoutWrapper( - input_keep_prob=keep_some, output_keep_prob=keep_some, - state_keep_prob=keep_some, variational_recurrent=True, - input_size=3, batch_size=5, time_steps=7) + input_keep_prob=keep_some, + output_keep_prob=keep_some, + state_keep_prob=keep_some, + variational_recurrent=True, + input_size=3, + batch_size=5, + time_steps=7) # Smoke test for the state/input masks. output_mask = np.abs(res[0]) > 1e-6 @@ -811,17 +853,27 @@ class DropoutWrapperTest(test.TestCase): random_seed.set_random_seed(2347) np.random.seed(23487) res0 = self._testDropoutWrapper( - input_keep_prob=keep_some, output_keep_prob=keep_some, - state_keep_prob=keep_some, variational_recurrent=True, - input_size=3, batch_size=5, time_steps=7, seed=-234987) + input_keep_prob=keep_some, + output_keep_prob=keep_some, + state_keep_prob=keep_some, + variational_recurrent=True, + input_size=3, + batch_size=5, + time_steps=7, + seed=-234987) ops.reset_default_graph() self._ClearCachedSession() random_seed.set_random_seed(2347) np.random.seed(23487) res1 = self._testDropoutWrapper( - input_keep_prob=keep_some, output_keep_prob=keep_some, - state_keep_prob=keep_some, variational_recurrent=True, - input_size=3, batch_size=5, time_steps=7, seed=-234987) + input_keep_prob=keep_some, + output_keep_prob=keep_some, + state_keep_prob=keep_some, + variational_recurrent=True, + input_size=3, + batch_size=5, + time_steps=7, + seed=-234987) output_mask = np.abs(res0[0]) > 1e-6 for time_step in output_mask: @@ -858,9 +910,10 @@ class SlimRNNCellTest(test.TestCase): g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m) # pylint: enable=protected-access sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) self.assertEqual(res[0].shape, (1, 2)) def testBasicRNNCellMatch(self): diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 0258d7202df20a536ae4240a532249b6b5e7e641..57521c6a9ba0b2d66639017b09c541e270276323 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -45,6 +45,7 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging from tensorflow.python.util import nest + class Plus1RNNCell(rnn_lib.RNNCell): """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" @@ -160,8 +161,7 @@ class RNNTest(test.TestCase): input_size = 5 max_length = 8 # unrolled up to this length inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) @@ -178,10 +178,9 @@ class RNNTest(test.TestCase): self.assertAllClose(v, input_value + 1.0) # Final state - self.assertAllClose( - values[-1], - max_length * np.ones( - (batch_size, input_size), dtype=np.float32)) + self.assertAllClose(values[-1], + max_length * np.ones( + (batch_size, input_size), dtype=np.float32)) def testDropout(self): cell = Plus1RNNCell() @@ -191,8 +190,7 @@ class RNNTest(test.TestCase): input_size = 5 max_length = 8 inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] with variable_scope.variable_scope("share_scope"): outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) @@ -207,8 +205,10 @@ class RNNTest(test.TestCase): with self.test_session(use_gpu=True) as sess: input_value = np.random.randn(batch_size, input_size) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) - full_dropout_values = sess.run(dropped_outputs, - feed_dict={inputs[0]: input_value}) + full_dropout_values = sess.run( + dropped_outputs, feed_dict={ + inputs[0]: input_value + }) for v in values[:-1]: self.assertAllClose(v, input_value + 1.0) @@ -222,8 +222,7 @@ class RNNTest(test.TestCase): input_size = 5 max_length = 8 inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] with variable_scope.variable_scope("drop_scope"): dynamic_outputs, dynamic_state = rnn.static_rnn( @@ -234,12 +233,16 @@ class RNNTest(test.TestCase): input_value = np.random.randn(batch_size, input_size) dynamic_values = sess.run( dynamic_outputs, - feed_dict={inputs[0]: input_value, - sequence_length: [2, 3]}) + feed_dict={ + inputs[0]: input_value, + sequence_length: [2, 3] + }) dynamic_state_value = sess.run( [dynamic_state], - feed_dict={inputs[0]: input_value, - sequence_length: [2, 3]}) + feed_dict={ + inputs[0]: input_value, + sequence_length: [2, 3] + }) # outputs are fully calculated for t = 0, 1 for v in dynamic_values[:2]: @@ -289,8 +292,7 @@ class RNNTest(test.TestCase): input_size = 5 max_length = 8 # unrolled up to this length inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] return rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope=scope) @@ -316,8 +318,7 @@ class LSTMTest(test.TestCase): cell = rnn_cell.LSTMCell( num_units, initializer=initializer, state_is_tuple=False) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) @@ -343,8 +344,7 @@ class LSTMTest(test.TestCase): initializer=initializer, state_is_tuple=False) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) @@ -374,8 +374,7 @@ class LSTMTest(test.TestCase): initializer=initializer, state_is_tuple=False) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] with variable_scope.variable_scope("share_scope"): outputs, state = rnn.static_state_saving_rnn( @@ -388,7 +387,9 @@ class LSTMTest(test.TestCase): input_value = np.random.randn(batch_size, input_size) (last_state_value, saved_state_value) = sess.run( [state, state_saver.saved_state["save_lstm"]], - feed_dict={inputs[0]: input_value}) + feed_dict={ + inputs[0]: input_value + }) self.assertAllEqual(last_state_value, saved_state_value) def testNoProjNoShardingTupleStateSaver(self): @@ -406,8 +407,7 @@ class LSTMTest(test.TestCase): initializer=initializer, state_is_tuple=True) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] with variable_scope.variable_scope("share_scope"): outputs, state = rnn.static_state_saving_rnn( @@ -420,7 +420,9 @@ class LSTMTest(test.TestCase): input_value = np.random.randn(batch_size, input_size) last_and_saved_states = sess.run( state + (state_saver.saved_state["c"], state_saver.saved_state["m"]), - feed_dict={inputs[0]: input_value}) + feed_dict={ + inputs[0]: input_value + }) self.assertEqual(4, len(last_and_saved_states)) self.assertAllEqual(last_and_saved_states[:2], last_and_saved_states[2:]) @@ -432,16 +434,17 @@ class LSTMTest(test.TestCase): with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - state_saver = TestStateSaver(batch_size, { - "c0": num_units, - "m0": num_units, - "c1": num_units + 1, - "m1": num_units + 1, - "c2": num_units + 2, - "m2": num_units + 2, - "c3": num_units + 3, - "m3": num_units + 3 - }) + state_saver = TestStateSaver( + batch_size, { + "c0": num_units, + "m0": num_units, + "c1": num_units + 1, + "m1": num_units + 1, + "c2": num_units + 2, + "m2": num_units + 2, + "c3": num_units + 3, + "m3": num_units + 3 + }) def _cell(i): return rnn_cell.LSTMCell( @@ -459,8 +462,7 @@ class LSTMTest(test.TestCase): self.assertEqual(len(cell.state_size[i]), 2) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] state_names = (("c0", "m0"), ("c1", "m1"), ("c2", "m2"), ("c3", "m3")) @@ -475,10 +477,15 @@ class LSTMTest(test.TestCase): variables_lib.global_variables_initializer().run() input_value = np.random.randn(batch_size, input_size) - last_states = sess.run(list(nest.flatten(state)), - feed_dict={inputs[0]: input_value}) - saved_states = sess.run(list(state_saver.saved_state.values()), - feed_dict={inputs[0]: input_value}) + last_states = sess.run( + list(nest.flatten(state)), feed_dict={ + inputs[0]: input_value + }) + saved_states = sess.run( + list(state_saver.saved_state.values()), + feed_dict={ + inputs[0]: input_value + }) self.assertEqual(8, len(last_states)) self.assertEqual(8, len(saved_states)) flat_state_names = nest.flatten(state_names) @@ -499,8 +506,7 @@ class LSTMTest(test.TestCase): initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) + array_ops.placeholder(dtypes.float32, shape=(None, input_size)) ] cell = rnn_cell.LSTMCell( num_units, @@ -526,8 +532,7 @@ class LSTMTest(test.TestCase): initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) + array_ops.placeholder(dtypes.float32, shape=(None, input_size)) ] cell_notuple = rnn_cell.LSTMCell( num_units, @@ -569,14 +574,20 @@ class LSTMTest(test.TestCase): variables_lib.global_variables_initializer().run() input_value = np.random.randn(batch_size, input_size) - outputs_notuple_v = sess.run(outputs_notuple, - feed_dict={inputs[0]: input_value}) - outputs_tuple_v = sess.run(outputs_tuple, - feed_dict={inputs[0]: input_value}) + outputs_notuple_v = sess.run( + outputs_notuple, feed_dict={ + inputs[0]: input_value + }) + outputs_tuple_v = sess.run( + outputs_tuple, feed_dict={ + inputs[0]: input_value + }) self.assertAllEqual(outputs_notuple_v, outputs_tuple_v) - (state_notuple_v,) = sess.run((state_notuple,), - feed_dict={inputs[0]: input_value}) + (state_notuple_v,) = sess.run( + (state_notuple,), feed_dict={ + inputs[0]: input_value + }) state_tuple_v = sess.run(state_tuple, feed_dict={inputs[0]: input_value}) self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v)) @@ -593,8 +604,7 @@ class LSTMTest(test.TestCase): -0.01, 0.01, seed=self._seed) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) + array_ops.placeholder(dtypes.float32, shape=(None, input_size)) ] cell = rnn_cell.LSTMCell( @@ -625,8 +635,7 @@ class LSTMTest(test.TestCase): with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed) inputs = max_length * [ - array_ops.placeholder( - dtypes.float64, shape=(None, input_size)) + array_ops.placeholder(dtypes.float64, shape=(None, input_size)) ] cell = rnn_cell.LSTMCell( @@ -661,8 +670,7 @@ class LSTMTest(test.TestCase): max_length = 8 with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) + array_ops.placeholder(dtypes.float32, shape=(None, input_size)) ] initializer = init_ops.constant_initializer(0.001) @@ -721,8 +729,7 @@ class LSTMTest(test.TestCase): initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) inputs = max_length * [ - array_ops.placeholder( - dtypes.float64, shape=(None, input_size)) + array_ops.placeholder(dtypes.float64, shape=(None, input_size)) ] cell = rnn_cell.LSTMCell( @@ -743,16 +750,21 @@ class LSTMTest(test.TestCase): self.assertEqual(len(outputs), len(inputs)) - variables_lib.global_variables_initializer().run( - feed_dict={sequence_length: [2, 3]}) + variables_lib.global_variables_initializer().run(feed_dict={ + sequence_length: [2, 3] + }) input_value = np.asarray( np.random.randn(batch_size, input_size), dtype=np.float64) values = sess.run( - outputs, feed_dict={inputs[0]: input_value, - sequence_length: [2, 3]}) + outputs, feed_dict={ + inputs[0]: input_value, + sequence_length: [2, 3] + }) state_value = sess.run( - [state], feed_dict={inputs[0]: input_value, - sequence_length: [2, 3]}) + [state], feed_dict={ + inputs[0]: input_value, + sequence_length: [2, 3] + }) self.assertEqual(values[0].dtype, input_value.dtype) self.assertEqual(state_value[0].dtype, input_value.dtype) @@ -767,8 +779,7 @@ class LSTMTest(test.TestCase): initializer_d = init_ops.random_uniform_initializer( -1, 1, seed=self._seed + 1) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) + array_ops.placeholder(dtypes.float32, shape=(None, input_size)) ] cell = rnn_cell.LSTMCell( num_units, @@ -792,8 +803,10 @@ class LSTMTest(test.TestCase): variables_lib.global_variables_initializer().run() input_value = np.random.randn(batch_size, input_size) - output_values = sess.run(outputs0 + outputs1 + outputs2, - feed_dict={inputs[0]: input_value}) + output_values = sess.run( + outputs0 + outputs1 + outputs2, feed_dict={ + inputs[0]: input_value + }) outputs0_values = output_values[:max_length] outputs1_values = output_values[max_length:2 * max_length] outputs2_values = output_values[2 * max_length:] @@ -814,8 +827,7 @@ class LSTMTest(test.TestCase): with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) + array_ops.placeholder(dtypes.float32, shape=(None, input_size)) ] cell = rnn_cell.LSTMCell( num_units, @@ -833,8 +845,10 @@ class LSTMTest(test.TestCase): variables_lib.global_variables_initializer().run() input_value = np.random.randn(batch_size, input_size) - output_values = sess.run(outputs0 + outputs1, - feed_dict={inputs[0]: input_value}) + output_values = sess.run( + outputs0 + outputs1, feed_dict={ + inputs[0]: input_value + }) outputs0_values = output_values[:max_length] outputs1_values = output_values[max_length:] self.assertEqual(len(outputs0_values), len(outputs1_values)) @@ -861,8 +875,7 @@ class LSTMTest(test.TestCase): -0.01, 0.01, seed=self._seed) if in_graph_mode: inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) + array_ops.placeholder(dtypes.float32, shape=(None, input_size)) ] else: inputs = max_length * [ @@ -939,8 +952,7 @@ class LSTMTest(test.TestCase): -0.01, 0.01, seed=self._seed) if in_graph_mode: inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) + array_ops.placeholder(dtypes.float32, shape=(None, input_size)) ] else: inputs = max_length * [ @@ -1100,8 +1112,8 @@ class LSTMTest(test.TestCase): # Test gradients to inputs and variables w.r.t. outputs & final state static_grad_values = sess.run(static_gradients, feed_dict=feeds) - static_individual_grad_values = sess.run(static_individual_gradients, - feed_dict=feeds) + static_individual_grad_values = sess.run( + static_individual_gradients, feed_dict=feeds) static_individual_var_grad_values = sess.run( static_individual_variable_gradients, feed_dict=feeds) @@ -1148,8 +1160,10 @@ class LSTMTest(test.TestCase): # Generate gradients of several individual outputs w.r.t. inputs dynamic_individual_gradients = nest.flatten([ gradients_impl.gradients(y, [concat_inputs]) - for y in - [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] + for y in [ + split_outputs_dynamic[0], split_outputs_dynamic[-1], + state_dynamic + ] ]) # Generate gradients of individual variables w.r.t. inputs @@ -1159,8 +1173,10 @@ class LSTMTest(test.TestCase): "Count of trainable variables: %d" % len(trainable_variables)) dynamic_individual_variable_gradients = nest.flatten([ gradients_impl.gradients(y, trainable_variables) - for y in - [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] + for y in [ + split_outputs_dynamic[0], split_outputs_dynamic[-1], + state_dynamic + ] ]) # Test forward pass @@ -1170,8 +1186,8 @@ class LSTMTest(test.TestCase): # Test gradients to inputs and variables w.r.t. outputs & final state dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) - dynamic_individual_grad_values = sess.run(dynamic_individual_gradients, - feed_dict=feeds) + dynamic_individual_grad_values = sess.run( + dynamic_individual_gradients, feed_dict=feeds) dynamic_individual_var_grad_values = sess.run( dynamic_individual_variable_gradients, feed_dict=feeds) @@ -1207,8 +1223,8 @@ class LSTMTest(test.TestCase): for i, (a, b) in enumerate( zip(static_individual_var_grad_values, dynamic_individual_var_grad_values)): - tf_logging.info("Comparing individual variable gradients iteration %d" % - i) + tf_logging.info( + "Comparing individual variable gradients iteration %d" % i) self.assertAllEqual(a, b) @test_util.run_in_graph_and_eager_modes() @@ -1223,10 +1239,7 @@ class BidirectionalRNNTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - def _createBidirectionalRNN(self, - use_shape, - use_sequence_length, - scope=None): + def _createBidirectionalRNN(self, use_shape, use_sequence_length, scope=None): num_units = 3 input_size = 5 batch_size = 2 @@ -1270,8 +1283,10 @@ class BidirectionalRNNTest(test.TestCase): # Run with pre-specified sequence length of 2, 3 out, s_fw, s_bw = sess.run( [outputs, state_fw, state_bw], - feed_dict={inputs[0]: input_value, - sequence_length: [2, 3]}) + feed_dict={ + inputs[0]: input_value, + sequence_length: [2, 3] + }) # Since the forward and backward LSTM cells were initialized with the # same parameters, the forward and backward output has to be the same, @@ -1312,8 +1327,10 @@ class BidirectionalRNNTest(test.TestCase): input_value, inputs, outputs, state_fw, state_bw, _ = ( self._createBidirectionalRNN(use_shape, False)) variables_lib.global_variables_initializer().run() - out, s_fw, s_bw = sess.run([outputs, state_fw, state_bw], - feed_dict={inputs[0]: input_value}) + out, s_fw, s_bw = sess.run( + [outputs, state_fw, state_bw], feed_dict={ + inputs[0]: input_value + }) # Since the forward and backward LSTM cells were initialized with the # same parameters, the forward and backward output has to be the same, @@ -1396,13 +1413,11 @@ class BidirectionalRNNTest(test.TestCase): use_time_major, use_sequence_length): with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: input_value, inputs, outputs, state_fw, state_bw, sequence_length = ( - self._createBidirectionalDynamicRNN(use_shape, - use_state_tuple, use_time_major, - use_sequence_length)) + self._createBidirectionalDynamicRNN( + use_shape, use_state_tuple, use_time_major, use_sequence_length)) variables_lib.global_variables_initializer().run() # Run with pre-specified sequence length of 2, 3 - feed_dict = ( - {sequence_length: [2, 3]} if use_sequence_length else {}) + feed_dict = ({sequence_length: [2, 3]} if use_sequence_length else {}) feed_dict.update({inputs[0]: input_value}) if use_state_tuple: out, c_fw, m_fw, c_bw, m_bw = sess.run( @@ -1538,8 +1553,7 @@ class MultiDimensionalLSTMTest(test.TestCase): sequence_length = [4, 6] with self.test_session(graph=ops_lib.Graph()) as sess: inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None,) + input_size) + array_ops.placeholder(dtypes.float32, shape=(None,) + input_size) ] inputs_using_dim = max_length * [ array_ops.placeholder( @@ -1585,14 +1599,22 @@ class MultiDimensionalLSTMTest(test.TestCase): input_total_size = (batch_size,) + input_size input_value = np.random.randn(*input_total_size) - outputs_static_v = sess.run(outputs_static, - feed_dict={inputs[0]: input_value}) - outputs_dynamic_v = sess.run(outputs_dynamic, - feed_dict={inputs[0]: input_value}) - outputs_bid_v = sess.run(outputs_bid, - feed_dict={inputs_using_dim[0]: input_value}) - outputs_sav_v = sess.run(outputs_sav, - feed_dict={inputs_using_dim[0]: input_value}) + outputs_static_v = sess.run( + outputs_static, feed_dict={ + inputs[0]: input_value + }) + outputs_dynamic_v = sess.run( + outputs_dynamic, feed_dict={ + inputs[0]: input_value + }) + outputs_bid_v = sess.run( + outputs_bid, feed_dict={ + inputs_using_dim[0]: input_value + }) + outputs_sav_v = sess.run( + outputs_sav, feed_dict={ + inputs_using_dim[0]: input_value + }) self.assertAllEqual(outputs_static_v, outputs_dynamic_v) self.assertAllEqual(outputs_static_v, outputs_sav_v) @@ -1602,16 +1624,26 @@ class MultiDimensionalLSTMTest(test.TestCase): outputs_bid_array = np.array(outputs_bid_v) self.assertAllEqual(outputs_static_array_double, outputs_bid_array) - state_static_v = sess.run(state_static, - feed_dict={inputs[0]: input_value}) - state_dynamic_v = sess.run(state_dynamic, - feed_dict={inputs[0]: input_value}) - state_bid_fw_v = sess.run(state_fw, - feed_dict={inputs_using_dim[0]: input_value}) - state_bid_bw_v = sess.run(state_bw, - feed_dict={inputs_using_dim[0]: input_value}) - state_sav_v = sess.run(state_sav, - feed_dict={inputs_using_dim[0]: input_value}) + state_static_v = sess.run( + state_static, feed_dict={ + inputs[0]: input_value + }) + state_dynamic_v = sess.run( + state_dynamic, feed_dict={ + inputs[0]: input_value + }) + state_bid_fw_v = sess.run( + state_fw, feed_dict={ + inputs_using_dim[0]: input_value + }) + state_bid_bw_v = sess.run( + state_bw, feed_dict={ + inputs_using_dim[0]: input_value + }) + state_sav_v = sess.run( + state_sav, feed_dict={ + inputs_using_dim[0]: input_value + }) self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v)) self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v)) @@ -1633,16 +1665,17 @@ class NestedLSTMTest(test.TestCase): with self.test_session(graph=ops_lib.Graph()) as sess: state_saver = TestStateSaver(batch_size, state_size) single_input = (array_ops.placeholder( - dtypes.float32, shape=(None, input_size)), array_ops.placeholder( - dtypes.float32, shape=(None, input_size))) + dtypes.float32, shape=(None, input_size)), + array_ops.placeholder( + dtypes.float32, shape=(None, input_size))) inputs = max_length * [single_input] inputs_c = (array_ops.stack([input_[0] for input_ in inputs]), array_ops.stack([input_[1] for input_ in inputs])) - single_input_using_dim = ( - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)), - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size))) + single_input_using_dim = (array_ops.placeholder( + dtypes.float32, shape=(batch_size, input_size)), + array_ops.placeholder( + dtypes.float32, + shape=(batch_size, input_size))) inputs_using_dim = max_length * [single_input_using_dim] # Create a cell for the whole test. This is fine because the cell has no @@ -1688,14 +1721,22 @@ class NestedLSTMTest(test.TestCase): input_total_size = (batch_size, input_size) input_value = (np.random.randn(*input_total_size), np.random.randn(*input_total_size)) - outputs_dynamic_v = sess.run(outputs_dynamic, - feed_dict={single_input: input_value}) - outputs_static_v = sess.run(outputs_static, - feed_dict={single_input: input_value}) - outputs_sav_v = sess.run(outputs_sav, - feed_dict={single_input_using_dim: input_value}) - outputs_bid_v = sess.run(outputs_bid, - feed_dict={single_input_using_dim: input_value}) + outputs_dynamic_v = sess.run( + outputs_dynamic, feed_dict={ + single_input: input_value + }) + outputs_static_v = sess.run( + outputs_static, feed_dict={ + single_input: input_value + }) + outputs_sav_v = sess.run( + outputs_sav, feed_dict={ + single_input_using_dim: input_value + }) + outputs_bid_v = sess.run( + outputs_bid, feed_dict={ + single_input_using_dim: input_value + }) self.assertAllEqual(outputs_static_v, np.transpose(outputs_dynamic_v, (1, 0, 2, 3))) @@ -1706,16 +1747,26 @@ class NestedLSTMTest(test.TestCase): outputs_bid_array = np.array(outputs_bid_v) self.assertAllEqual(outputs_static_array_double, outputs_bid_array) - state_dynamic_v = sess.run(state_dynamic, - feed_dict={single_input: input_value}) - state_static_v = sess.run(state_static, - feed_dict={single_input: input_value}) - state_bid_fw_v = sess.run(state_fw, - feed_dict={single_input_using_dim: input_value}) - state_bid_bw_v = sess.run(state_bw, - feed_dict={single_input_using_dim: input_value}) - state_sav_v = sess.run(state_sav, - feed_dict={single_input_using_dim: input_value}) + state_dynamic_v = sess.run( + state_dynamic, feed_dict={ + single_input: input_value + }) + state_static_v = sess.run( + state_static, feed_dict={ + single_input: input_value + }) + state_bid_fw_v = sess.run( + state_fw, feed_dict={ + single_input_using_dim: input_value + }) + state_bid_bw_v = sess.run( + state_bw, feed_dict={ + single_input_using_dim: input_value + }) + state_sav_v = sess.run( + state_sav, feed_dict={ + single_input_using_dim: input_value + }) self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v)) self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v)) @@ -1764,8 +1815,7 @@ class StateSaverRNNTest(test.TestCase): initializer=initializer, state_is_tuple=False) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] return rnn.static_state_saving_rnn( cell, @@ -1931,8 +1981,10 @@ class RawRNNTest(test.TestCase): (outputs_val, outputs_dynamic_rnn_val, final_state_val, final_state_dynamic_rnn_val) = sess.run( [outputs, outputs_dynamic_rnn, final_state, final_state_dynamic_rnn], - feed_dict={inputs: rand_input, - sequence_length: rand_seq_len}) + feed_dict={ + inputs: rand_input, + sequence_length: rand_seq_len + }) self.assertAllClose(outputs_dynamic_rnn_val, outputs_val) self.assertAllClose(final_state_dynamic_rnn_val, final_state_val) @@ -1945,12 +1997,16 @@ class RawRNNTest(test.TestCase): self.assertEqual(len(gradients), len(gradients_dynamic_rnn)) gradients_val = sess.run( gradients, - feed_dict={inputs: rand_input, - sequence_length: rand_seq_len}) + feed_dict={ + inputs: rand_input, + sequence_length: rand_seq_len + }) gradients_dynamic_rnn_val = sess.run( gradients_dynamic_rnn, - feed_dict={inputs: rand_input, - sequence_length: rand_seq_len}) + feed_dict={ + inputs: rand_input, + sequence_length: rand_seq_len + }) self.assertEqual(len(gradients_val), len(gradients_dynamic_rnn_val)) input_gradients_val = gradients_val[0] input_gradients_dynamic_rnn_val = gradients_dynamic_rnn_val[0] @@ -2067,14 +2123,13 @@ class RawRNNTest(test.TestCase): def loop_fn(time_, cell_output, cell_state, _): if cell_output is None: - emit_output = (array_ops.zeros( - [2, 3], dtype=dtypes.int32), array_ops.zeros( - [unknown_dim], dtype=dtypes.int64)) + emit_output = (array_ops.zeros([2, 3], dtype=dtypes.int32), + array_ops.zeros([unknown_dim], dtype=dtypes.int64)) next_state = cell.zero_state(batch_size, dtypes.float32) else: - emit_output = (array_ops.ones( - [batch_size, 2, 3], dtype=dtypes.int32), array_ops.ones( - [batch_size, unknown_dim], dtype=dtypes.int64)) + emit_output = (array_ops.ones([batch_size, 2, 3], dtype=dtypes.int32), + array_ops.ones( + [batch_size, unknown_dim], dtype=dtypes.int64)) next_state = cell_state elements_finished = array_ops.tile([time_ >= max_time], [batch_size]) finished = math_ops.reduce_all(elements_finished) @@ -2193,8 +2248,8 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): cell = rnn_cell.LSTMCell(num_units, use_peepholes=True) gpu_cell = DeviceWrapperCell(cell, cell_device) - inputs = np.random.randn(batch_size, time_steps, - input_size).astype(np.float32) + inputs = np.random.randn(batch_size, time_steps, input_size).astype( + np.float32) sequence_length = np.random.randint(0, time_steps, size=batch_size) if input_device is not None: @@ -2262,8 +2317,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): gpu_dev = test.gpu_device_name() run_metadata = self._execute_rnn_on( - rnn_device="/cpu:0", cell_device="/cpu:0", - input_device=gpu_dev) + rnn_device="/cpu:0", cell_device="/cpu:0", input_device=gpu_dev) cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): @@ -2278,8 +2332,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): return # Test requires access to a GPU gpu_dev = test.gpu_device_name() - run_metadata = self._execute_rnn_on( - input_device=gpu_dev) + run_metadata = self._execute_rnn_on(input_device=gpu_dev) cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 8a3894ef9d7042e66b52edefdf08b278dcc6c4f4..7b883ebc5d7756f1bdf445f900500a4b89e6cffd 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -1545,97 +1545,6 @@ class BenchmarkLSTMCellXLA(test.Benchmark): ])) -class WeightNormLSTMCellTest(test.TestCase): - """Compared cell output with pre-calculated values.""" - - def _cell_output(self, cell): - """Calculate cell output""" - - with self.test_session() as sess: - init = init_ops.constant_initializer(0.5) - with variable_scope.variable_scope("root", initializer=init): - x = array_ops.zeros([1, 2]) - c0 = array_ops.zeros([1, 2]) - h0 = array_ops.zeros([1, 2]) - - state0 = rnn_cell.LSTMStateTuple(c0, h0) - - xout, sout = cell()(x, state0) - - sess.run([variables.global_variables_initializer()]) - res = sess.run( - [xout, sout], { - x.name: np.array([[1., 1.]]), - c0.name: 0.1 * np.asarray([[0, 1]]), - h0.name: 0.1 * np.asarray([[2, 3]]), - }) - - actual_state_c = res[1].c - actual_state_h = res[1].h - - return actual_state_c, actual_state_h - - def testBasicCell(self): - """Tests cell w/o peepholes and w/o normalisation""" - - def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=False, use_peepholes=False) - - actual_c, actual_h = self._cell_output(cell) - - expected_c = np.array([[0.65937078, 0.74983585]]) - expected_h = np.array([[0.44923624, 0.49362513]]) - - self.assertAllClose(expected_c, actual_c, 1e-5) - self.assertAllClose(expected_h, actual_h, 1e-5) - - def testNonbasicCell(self): - """Tests cell with peepholes and w/o normalisation""" - - def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=False, use_peepholes=True) - - actual_c, actual_h = self._cell_output(cell) - - expected_c = np.array([[0.65937084, 0.7574988]]) - expected_h = np.array([[0.4792085, 0.53470564]]) - - self.assertAllClose(expected_c, actual_c, 1e-5) - self.assertAllClose(expected_h, actual_h, 1e-5) - - def testBasicCellWithNorm(self): - """Tests cell w/o peepholes and with normalisation""" - - def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=True, use_peepholes=False) - - actual_c, actual_h = self._cell_output(cell) - - expected_c = np.array([[0.50125383, 0.58805949]]) - expected_h = np.array([[0.32770363, 0.37397948]]) - - self.assertAllClose(expected_c, actual_c, 1e-5) - self.assertAllClose(expected_h, actual_h, 1e-5) - - def testNonBasicCellWithNorm(self): - """Tests cell with peepholes and with normalisation""" - - def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=True, use_peepholes=True) - - actual_c, actual_h = self._cell_output(cell) - - expected_c = np.array([[0.50125383, 0.59587258]]) - expected_h = np.array([[0.35041603, 0.40873795]]) - - self.assertAllClose(expected_c, actual_c, 1e-5) - self.assertAllClose(expected_h, actual_h, 1e-5) - - class WeightNormLSTMCellTest(test.TestCase): """Compared cell output with pre-calculated values.""" diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py index 4c964ec201f153d6c8293d3bf93bc231ff8f751d..81ca12317be484ba420b7bbfac822e91d6d38bff 100644 --- a/tensorflow/contrib/rnn/python/ops/gru_ops.py +++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py @@ -32,7 +32,7 @@ from tensorflow.python.util.deprecation import deprecated_args _gru_ops_so = loader.load_op_library( resource_loader.get_path_to_datafile("_gru_ops.so")) -LayerRNNCell = rnn_cell_impl._LayerRNNCell # pylint: disable=invalid-name,protected-access +LayerRNNCell = rnn_cell_impl.LayerRNNCell # pylint: disable=invalid-name @ops.RegisterGradient("GRUBlockCell") diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 04f342cd18271425068b2b02c2937236c900c5e2..4eb4fbcd92f0d7cb3bee712862c8950a1971b632 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import resource_loader _lstm_ops_so = loader.load_op_library( resource_loader.get_path_to_datafile("_lstm_ops.so")) -LayerRNNCell = rnn_cell_impl._LayerRNNCell # pylint: disable=invalid-name,protected-access +LayerRNNCell = rnn_cell_impl.LayerRNNCell # pylint: disable=invalid-name # pylint: disable=invalid-name @@ -572,9 +572,8 @@ class LSTMBlockWrapper(base_layer.Layer): def _gather_states(self, data, indices, batch_size): """Produce `out`, s.t. out(i, j) = data(indices(i), i, j).""" - mod_indices = indices * batch_size + math_ops.range(batch_size) - return array_ops.gather( - array_ops.reshape(data, [-1, self.num_units]), mod_indices) + return array_ops.gather_nd( + data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1)) class LSTMBlockFusedCell(LSTMBlockWrapper): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 8adf5dce6ec76d8ac4f182929e0dfc81be946277..a6c2d9cdbb2b6f61d59960f708000e945c6115e9 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -32,12 +32,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl # pylint: disable=unused-import from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables # pylint: disable=unused-import from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import partitioned_variables -from tensorflow.python.ops import nn_impl from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -424,8 +424,9 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): "W_O_diag", shape=[self._num_units], dtype=dtype) # initialize the first freq state to be zero - m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units], - dtype) + m_prev_freq = array_ops.zeros( + [inputs.shape[0].value or inputs.get_shape()[0], self._num_units], + dtype) for fq in range(len(freq_inputs)): c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units], [-1, self._num_units]) @@ -2285,7 +2286,7 @@ class GLSTMCell(rnn_cell_impl.RNNCell): else: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) self._output_size = num_units - self._linear1 = None + self._linear1 = [None] * number_of_groups self._linear2 = None @property @@ -2359,9 +2360,11 @@ class GLSTMCell(rnn_cell_impl.RNNCell): self._group_shape[0]) ], axis=1) - if self._linear1 is None: - self._linear1 = _Linear(x_g_id, 4 * self._group_shape[1], False) - R_k = self._linear1(x_g_id) # pylint: disable=invalid-name + linear = self._linear1[group_id] + if linear is None: + linear = _Linear(x_g_id, 4 * self._group_shape[1], False) + self._linear1[group_id] = linear + R_k = linear(x_g_id) # pylint: disable=invalid-name i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) i_parts.append(i_k) @@ -2680,7 +2683,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell): return m, new_state -class SRUCell(rnn_cell_impl._LayerRNNCell): +class SRUCell(rnn_cell_impl.LayerRNNCell): """SRU, Simple Recurrent Unit Implementation based on @@ -2729,25 +2732,9 @@ class SRUCell(rnn_cell_impl._LayerRNNCell): input_depth = inputs_shape[1].value - # Here the contributor believes that the following constraints - # are implied. The reasoning is explained here with reference to - # the paper https://arxiv.org/pdf/1709.02755.pdf upon which this - # implementation is based. - # In section 2.1 Equation 5, specifically: - # h_t = r_t \odot g(c_t) + (1 - r_t) \odot x_t - # the pointwise operation between r_t and x_t means they have - # the same shape (since we are implementing an RNN cell, braodcasting - # does not happen to input of a single timestep); by the same - # reasons, x_t has the same shape as h_t, essentially mandating that - # input_depth = unit_num. - if input_depth != self._num_units: - raise ValueError("SRU requires input_depth == num_units, got " - "input_depth = %s, num_units = %s" % (input_depth, - self._num_units)) - self._kernel = self.add_variable( rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - shape=[input_depth, 3 * self._num_units]) + shape=[input_depth, 4 * self._num_units]) self._bias = self.add_variable( rnn_cell_impl._BIAS_VARIABLE_NAME, @@ -2760,8 +2747,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell): """Simple recurrent unit (SRU) with num_units cells.""" U = math_ops.matmul(inputs, self._kernel) - x_bar, f_intermediate, r_intermediate = array_ops.split( - value=U, num_or_size_splits=3, axis=1) + x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split( + value=U, num_or_size_splits=4, axis=1) f_r = math_ops.sigmoid( nn_ops.bias_add( @@ -2769,7 +2756,7 @@ class SRUCell(rnn_cell_impl._LayerRNNCell): f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1) c = f * state + (1.0 - f) * x_bar - h = r * self._activation(c) + (1.0 - r) * inputs + h = r * self._activation(c) + (1.0 - r) * x_tx return h, c diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index 64973ccccdc962757a727d7183bd70e94edcfd1b..dfa12e873a6aca806031c48d6f92e0432d0ea6e0 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -80,12 +80,12 @@ class GatherTreeOp : public OpKernel { max_sequence_lengths.shape().DebugString())); Tensor* beams; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams)); - typename TTypes::ConstTensor step_ids_t = step_ids.tensor(); - typename TTypes::ConstTensor parent_ids_t = parent_ids.tensor(); + typename TTypes::ConstTensor step_ids_t(step_ids.tensor()); + typename TTypes::ConstTensor parent_ids_t(parent_ids.tensor()); typename TTypes::ConstVec max_seq_lens_t = max_sequence_lengths.vec(); - typename TTypes::ConstScalar end_token_t = end_token.scalar(); - typename TTypes::Tensor beams_t = beams->tensor(); + typename TTypes::ConstScalar end_token_t(end_token.scalar()); + typename TTypes::Tensor beams_t(beams->tensor()); const T end_token_value = end_token_t(); functor::GatherTree()(ctx, device, step_ids_t, parent_ids_t, max_seq_lens_t, end_token_value, beams_t); diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 95dea312f3a4e77176a4bc4af290ad48c078deda..0a53fd66dbe4d28ea102773b9c5bae50b9d18e9c 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -331,7 +331,7 @@ def _luong_score(query, keys, scale): # batched matmul on: # [batch_size, 1, depth] . [batch_size, depth, max_time] # resulting in an output shape of: - # [batch_time, 1, max_time]. + # [batch_size, 1, max_time]. # we then squeeze out the center singleton dimension. score = math_ops.matmul(query, keys, transpose_b=True) score = array_ops.squeeze(score, [1]) @@ -924,8 +924,7 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, seed=sigmoid_noise_seed) super(LuongMonotonicAttention, self).__init__( - query_layer=layers_core.Dense( - num_units, name="query_layer", use_bias=False, dtype=dtype), + query_layer=None, memory_layer=layers_core.Dense( num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index ef3722ee41bb0b49e5f81d4d6514e2f40d2ad9f1..3245cc5e72154289ea3ba000b9a30586a7ad03a9 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -184,6 +184,7 @@ class TrainingHelper(Helper): """ with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]): inputs = ops.convert_to_tensor(inputs, name="inputs") + self._inputs = inputs if not time_major: inputs = nest.map_structure(_transpose_batch_time, inputs) @@ -200,6 +201,14 @@ class TrainingHelper(Helper): self._batch_size = array_ops.size(sequence_length) + @property + def inputs(self): + return self._inputs + + @property + def sequence_length(self): + return self._sequence_length + @property def batch_size(self): return self._batch_size diff --git a/tensorflow/contrib/session_bundle/bundle_shim.py b/tensorflow/contrib/session_bundle/bundle_shim.py index 3149875e41f6f77b3bcbc0ab1a150cfdc59ad2ba..1db97020a2a81f4d034543e722a6cb7ba823f44a 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.py +++ b/tensorflow/contrib/session_bundle/bundle_shim.py @@ -82,7 +82,8 @@ def _convert_default_signature_to_signature_def(signatures): """ default_signature = signatures.default_signature signature_def = meta_graph_pb2.SignatureDef() - if default_signature.WhichOneof("type") == legacy_constants.REGRESSION_SIGNATURE: + if (default_signature.WhichOneof("type") == + legacy_constants.REGRESSION_SIGNATURE): regression_signature = default_signature.regression_signature signature_def.method_name = signature_constants.REGRESS_METHOD_NAME _add_input_to_signature_def(regression_signature.input.tensor_name, @@ -91,7 +92,8 @@ def _convert_default_signature_to_signature_def(signatures): _add_output_to_signature_def(regression_signature.output.tensor_name, signature_constants.REGRESS_OUTPUTS, signature_def) - elif default_signature.WhichOneof("type") == legacy_constants.CLASSIFICATION_SIGNATURE: + elif (default_signature.WhichOneof("type") == + legacy_constants.CLASSIFICATION_SIGNATURE): classification_signature = default_signature.classification_signature signature_def.method_name = signature_constants.CLASSIFY_METHOD_NAME _add_input_to_signature_def(classification_signature.input.tensor_name, @@ -132,8 +134,9 @@ def _convert_named_signatures_to_signature_def(signatures): signature_constants.PREDICT_OUTPUTS] # TODO(pdudnik): what if there are other signatures? Mimic cr/140900781 once # it is submitted. - if (input_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE or - output_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE): + if (input_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE + or output_signature.WhichOneof("type") != + legacy_constants.GENERIC_SIGNATURE): raise RuntimeError("Named input and output signatures can only be " "up-converted if they are generic signature. " "Input signature type is %s, output signature type is " diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc index 72f32a0f5554e4dd3e7cbf498a57ee6bfba57211..9a1dd9303f43591888dc49984d81c4a0c6af9846 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc @@ -493,17 +493,15 @@ TEST(BundleShimTest, DefaultAndNamedSignatureWithPredict) { ASSERT_FALSE( actual_signature_def_predict->second.inputs().find("foo-input") == actual_signature_def_predict->second.inputs().end()); - EXPECT_EQ("foo-input", - actual_signature_def_predict->second.inputs() - .find("foo-input") - ->second.name()); + EXPECT_EQ("foo-input", actual_signature_def_predict->second.inputs() + .find("foo-input") + ->second.name()); ASSERT_FALSE( actual_signature_def_predict->second.outputs().find("foo-output") == actual_signature_def_predict->second.outputs().end()); - EXPECT_EQ("foo-output", - actual_signature_def_predict->second.outputs() - .find("foo-output") - ->second.name()); + EXPECT_EQ("foo-output", actual_signature_def_predict->second.outputs() + .find("foo-output") + ->second.name()); EXPECT_EQ(kPredictMethodName, actual_signature_def_predict->second.method_name()); } diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py index f6f663aae766b783b85139f57a93e10f553e6bd1..08983337fccc138d40eb959cecc5bf9e47cf6cac 100644 --- a/tensorflow/contrib/session_bundle/exporter.py +++ b/tensorflow/contrib/session_bundle/exporter.py @@ -281,11 +281,12 @@ class Exporter(object): tmp_export_dir = compat.as_text(export_dir) + "-tmp" gfile.MakeDirs(tmp_export_dir) - self._saver.save(sess, - os.path.join( - compat.as_text(tmp_export_dir), - compat.as_text(constants.EXPORT_BASE_NAME)), - meta_graph_suffix=constants.EXPORT_SUFFIX_NAME) + self._saver.save( + sess, + os.path.join( + compat.as_text(tmp_export_dir), + compat.as_text(constants.EXPORT_BASE_NAME)), + meta_graph_suffix=constants.EXPORT_SUFFIX_NAME) # Run the asset callback. if self._assets_callback and self._assets_to_copy: @@ -301,12 +302,12 @@ class Exporter(object): if exports_to_keep: # create a simple parser that pulls the export_version from the directory. def parser(path): - if os.name == 'nt': - match = re.match("^" + export_dir_base.replace('\\','/') + "/(\\d{8})$", - path.path.replace('\\','/')) + if os.name == "nt": + match = re.match( + "^" + export_dir_base.replace("\\", "/") + "/(\\d{8})$", + path.path.replace("\\", "/")) else: - match = re.match("^" + export_dir_base + "/(\\d{8})$", - path.path) + match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) diff --git a/tensorflow/contrib/session_bundle/gc.py b/tensorflow/contrib/session_bundle/gc.py index 249c23c88f3043403e322b73b6c9df97e932a92a..514cc0f652c8d174bdb9bff2b2cf1ea38fdd7b1f 100644 --- a/tensorflow/contrib/session_bundle/gc.py +++ b/tensorflow/contrib/session_bundle/gc.py @@ -70,7 +70,6 @@ import heapq import math import os -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.platform import gfile from tensorflow.python.util.deprecation import deprecated diff --git a/tensorflow/contrib/session_bundle/signature.cc b/tensorflow/contrib/session_bundle/signature.cc index 7133875ad53e77625bbe799f4f886c074a08f1bd..ed70a5b91b231067e8e69951ef7010406e6b22cf 100644 --- a/tensorflow/contrib/session_bundle/signature.cc +++ b/tensorflow/contrib/session_bundle/signature.cc @@ -38,9 +38,9 @@ namespace { Status BatchSizesMatch(const Tensor& input, const Tensor& output) { // Ensure the number of outputs match the number of inputs. if (input.dim_size(0) != output.dim_size(0)) { - return errors::Internal( - strings::StrCat("Input batch size did not match output batch size: ", - input.dim_size(0), " vs. ", output.dim_size(0))); + return errors::Internal(strings::StrCat( + "Input batch size did not match output batch size: ", input.dim_size(0), + " vs. ", output.dim_size(0))); } return Status::OK(); } @@ -100,8 +100,8 @@ Status GetNamedClassificationSignature( const auto& it = signatures.named_signatures().find(name); if (it == signatures.named_signatures().end()) { return errors::NotFound( - strings::StrCat("Missing signature named \"", name, "\" in: ", - DebugStringIfAvailable(signatures))); + strings::StrCat("Missing signature named \"", name, + "\" in: ", DebugStringIfAvailable(signatures))); } if (!it->second.has_classification_signature()) { return errors::FailedPrecondition( @@ -232,8 +232,8 @@ Status GetNamedSignature(const string& name, const auto& it = signatures.named_signatures().find(name); if (it == signatures.named_signatures().end()) { return errors::NotFound( - strings::StrCat("Missing signature named \"", name, "\" in: ", - DebugStringIfAvailable(signatures))); + strings::StrCat("Missing signature named \"", name, + "\" in: ", DebugStringIfAvailable(signatures))); } *signature = it->second; return Status::OK(); diff --git a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py index c04f1cf5bad358a14a1827df05a129339502c86f..e7743bdcba180929007d17bdf3b143c64643aacc 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.signal.python.ops import mfcc_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import spectral_ops_test_util @@ -49,6 +50,14 @@ class MFCCTest(test.TestCase): signal = random_ops.random_normal((2, 3, 5)) mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval() + def test_unknown_shape(self): + """A test that the op runs when shape and rank are unknown.""" + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(use_gpu=True): + signal = array_ops.placeholder_with_default( + random_ops.random_normal((2, 3, 5)), tensor_shape.TensorShape(None)) + self.assertIsNone(signal.shape.ndims) + mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval() if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/ops/mfcc_ops.py b/tensorflow/contrib/signal/python/ops/mfcc_ops.py index 6cef95f742515709f0f41632358c2d8663daed2c..4e842f7f10ae07448cc07e5f636ae80a820e656f 100644 --- a/tensorflow/contrib/signal/python/ops/mfcc_ops.py +++ b/tensorflow/contrib/signal/python/ops/mfcc_ops.py @@ -105,4 +105,4 @@ def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None): num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1] dct2 = spectral_ops.dct(log_mel_spectrograms) - return dct2 * math_ops.rsqrt(num_mel_bins * 2.0) + return dct2 * math_ops.rsqrt(math_ops.to_float(num_mel_bins) * 2.0) diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py index bca2e01d7bbefb18fd69a0eba27e3afb8f636724..a8b5deff6ca3a4a756d31b904e577f08f6155fd7 100644 --- a/tensorflow/contrib/signal/python/ops/spectral_ops.py +++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py @@ -144,7 +144,7 @@ def inverse_stft_window_fn(frame_step, overlaps = -(-frame_length // frame_step) # Ceiling division. denom = array_ops.pad(denom, [(0, overlaps * frame_step - frame_length)]) denom = array_ops.reshape(denom, [overlaps, frame_step]) - denom = math_ops.reduce_sum(denom, 0, keep_dims=True) + denom = math_ops.reduce_sum(denom, 0, keepdims=True) denom = array_ops.tile(denom, [overlaps, 1]) denom = array_ops.reshape(denom, [overlaps * frame_step]) diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md index c7a54cb9a2e9535efbdc179f1463cef379ebb1f9..2d9df8f27ee98431f51fd39c168325b8f625dce9 100644 --- a/tensorflow/contrib/slim/README.md +++ b/tensorflow/contrib/slim/README.md @@ -145,7 +145,7 @@ regular_variables_and_model_variables = slim.get_variables() How does this work? When you create a model variable via TF-Slim's layers or directly via the `slim.model_variable` function, TF-Slim adds the variable to -a the `tf.GraphKeys.MODEL_VARIABLES` collection. What if you have your own +the `tf.GraphKeys.MODEL_VARIABLES` collection. What if you have your own custom layers or variable creation routine but still want TF-Slim to manage or be aware of your model variables? TF-Slim provides a convenience function for adding the model variable to its collection: diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 0544404e9e252cca6d3650b805b91be25d705eea..b3b61e1dfe5671a7fbbee20b0c577ee5fad0fb9b 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -349,7 +349,8 @@ class Image(ItemHandler): shape=None, channels=3, dtype=dtypes.uint8, - repeated=False): + repeated=False, + dct_method=''): """Initializes the image. Args: @@ -368,6 +369,11 @@ class Image(ItemHandler): tf.decode_raw, repeated: if False, decodes a single image. If True, decodes a variable number of image strings from a 1D tensor of strings. + dct_method: An optional string. Defaults to empty string. It only takes + effect when image format is jpeg, used to specify a hint about the + algorithm used for jpeg decompression. Currently valid values + are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for + example, the jpeg library does not have that specific option. """ if not image_key: image_key = 'image/encoded' @@ -381,6 +387,7 @@ class Image(ItemHandler): self._channels = channels self._dtype = dtype self._repeated = repeated + self._dct_method = dct_method def tensors_to_item(self, keys_to_tensors): """See base class.""" @@ -406,9 +413,25 @@ class Image(ItemHandler): A tensor that represents decoded image of self._shape, or (?, ?, self._channels) if self._shape is not specified. """ + def decode_image(): - """Decodes a png or jpg based on the headers.""" - return image_ops.decode_image(image_buffer, self._channels) + """Decodes a image based on the headers.""" + return image_ops.decode_image(image_buffer, channels=self._channels) + + def decode_jpeg(): + """Decodes a jpeg image with specified '_dct_method'.""" + return image_ops.decode_jpeg( + image_buffer, channels=self._channels, dct_method=self._dct_method) + + def check_jpeg(): + """Checks if an image is jpeg.""" + # For jpeg, we directly use image_ops.decode_jpeg rather than decode_image + # in order to feed the jpeg specify parameter 'dct_method'. + return control_flow_ops.cond( + image_ops.is_jpeg(image_buffer), + decode_jpeg, + decode_image, + name='cond_jpeg') def decode_raw(): """Decodes a raw image.""" @@ -420,7 +443,7 @@ class Image(ItemHandler): math_ops.equal(image_format, 'RAW')): decode_raw, } image = control_flow_ops.case( - pred_fn_pairs, default=decode_image, exclusive=True) + pred_fn_pairs, default=check_jpeg, exclusive=True) image.set_shape([None, None, self._channels]) if self._shape is not None: diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index f5a9299d263450ba89617f38bf7a4c5cbc359cb1..c24bd048512daaae116e732ac437f7c9b6f6d7fc 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -42,7 +42,7 @@ from tensorflow.python.platform import flags from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary_iterator -from tensorflow.python.training import input +from tensorflow.python.training import input # pylint: disable=redefined-builtin from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import session_run_hook @@ -236,7 +236,7 @@ class SingleEvaluationTest(test.TestCase): def _prepareCheckpoint(self, checkpoint_path): init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) - saver = saver_lib.Saver() + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) with self.test_session() as sess: sess.run(init_op) saver.save(sess, checkpoint_path) diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 54362c87b561595697ee64b9d5e565fdc3f0bbe0..6a200de1ea172b4ccb38c0f5d889566ccaeef893 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -738,6 +738,7 @@ def train(train_op, if summary_writer is not None: train_step_kwargs['summary_writer'] = sv.summary_writer + total_loss = None should_retry = True while should_retry: try: @@ -770,10 +771,10 @@ def train(train_op, logging.info('Stopping Training.') sv.request_stop() break - except errors.OutOfRangeError: + except errors.OutOfRangeError as e: # OutOfRangeError is thrown when epoch limit per # tf.train.limit_epochs is reached. - logging.info('Caught OutOfRangeError. Stopping Training.') + logging.info('Caught OutOfRangeError. Stopping Training. %s', e) if logdir and sv.is_chief: logging.info('Finished training! Saving model to disk.') sv.saver.save(sess, sv.save_path, global_step=sv.global_step) diff --git a/tensorflow/contrib/slim/python/slim/learning_test.py b/tensorflow/contrib/slim/python/slim/learning_test.py index 4e816f9b11be2986d042f336bdc320ff47d8cc49..831c6e427ae78932bec09cea935f05a87723f1a3 100644 --- a/tensorflow/contrib/slim/python/slim/learning_test.py +++ b/tensorflow/contrib/slim/python/slim/learning_test.py @@ -197,9 +197,7 @@ class MultiplyGradientsTest(test.TestCase): gradient = constant_op.constant(self._grad_vec, dtype=dtypes.float32) variable = variables_lib.Variable(array_ops.zeros_like(gradient)) multiplier_flag = variables_lib.Variable(True) - tensor_multiplier = array_ops.where(multiplier_flag, - self._multiplier, - 1.0) + tensor_multiplier = array_ops.where(multiplier_flag, self._multiplier, 1.0) grad_to_var = (gradient, variable) gradient_multipliers = {variable: tensor_multiplier} @@ -212,11 +210,8 @@ class MultiplyGradientsTest(test.TestCase): sess.run(multiplier_flag.assign(False)) gradient_false_flag = sess.run(grad_to_var[0]) np_testing.assert_almost_equal(gradient_true_flag, - self._multiplied_grad_vec, - 5) - np_testing.assert_almost_equal(gradient_false_flag, - self._grad_vec, - 5) + self._multiplied_grad_vec, 5) + np_testing.assert_almost_equal(gradient_false_flag, self._grad_vec, 5) def LogisticClassifier(inputs): @@ -502,6 +497,7 @@ class TrainTest(test.TestCase): purpose. """ dump_root = tempfile.mkdtemp() + def dumping_wrapper(sess): # pylint: disable=invalid-name return dumping_wrapper_lib.DumpingDebugWrapperSession(sess, dump_root) @@ -519,16 +515,13 @@ class TrainTest(test.TestCase): train_op = learning.create_train_op(total_loss, optimizer) loss = learning.train( - train_op, - None, - number_of_steps=1, - session_wrapper=dumping_wrapper) + train_op, None, number_of_steps=1, session_wrapper=dumping_wrapper) self.assertIsNotNone(loss) run_root = glob.glob(os.path.join(dump_root, 'run_*'))[-1] dump = debug_data.DebugDumpDir(run_root) - self.assertAllEqual( - 0, dump.get_tensors('global_step', 0, 'DebugIdentity')[0]) + self.assertAllEqual(0, + dump.get_tensors('global_step', 0, 'DebugIdentity')[0]) def testTrainWithTrace(self): logdir = os.path.join( @@ -961,8 +954,8 @@ class TrainTest(test.TestCase): self.assertGreater(losses[0], losses[1]) def testTrainWithEpochLimit(self): - logdir = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()), - 'tmp_logs') + logdir = os.path.join( + tempfile.mkdtemp(prefix=self.get_temp_dir()), 'tmp_logs') with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) @@ -982,7 +975,8 @@ class TrainTest(test.TestCase): self.assertIsNotNone(loss) self.assertLess(loss, .015) self.assertTrue(os.path.isfile('{}/model.ckpt-300.index'.format(logdir))) - self.assertTrue(os.path.isfile('{}/model.ckpt-300.data-00000-of-00001'.format(logdir))) + self.assertTrue( + os.path.isfile('{}/model.ckpt-300.data-00000-of-00001'.format(logdir))) if __name__ == '__main__': diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py index 7b609ae96b20a5c3d078777cc8fbb475e5eebb1b..a1282847bef981717d7fdb1474adbbaaae4621c0 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py @@ -47,8 +47,8 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): a_np = np.dot(a_np.T, a_np) # jacobi preconditioner jacobi_np = np.zeros_like(a_np) - jacobi_np[range(a_np.shape[0]), range(a_np.shape[1])] = (1.0 / - a_np.diagonal()) + jacobi_np[range(a_np.shape[0]), range(a_np.shape[1])] = ( + 1.0 / a_np.diagonal()) rhs_np = np.random.uniform( low=-1.0, high=1.0, size=shape_[0]).astype(dtype_) x_np = np.zeros_like(rhs_np) @@ -66,18 +66,30 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): x = array_ops.placeholder(dtype_) jacobi = array_ops.placeholder(dtype_) operator = util.create_operator(a) - preconditioners = [None, util.identity_operator(a), - util.create_operator(jacobi)] + preconditioners = [ + None, util.identity_operator(a), + util.create_operator(jacobi) + ] cg_results = [] for preconditioner in preconditioners: cg_graph = linear_equations.conjugate_gradient( - operator, rhs, preconditioner=preconditioner, - x=x, tol=tol, max_iter=max_iter) + operator, + rhs, + preconditioner=preconditioner, + x=x, + tol=tol, + max_iter=max_iter) if use_static_shape_: cg_val = sess.run(cg_graph) else: - cg_val = sess.run(cg_graph, feed_dict={a: a_np, rhs: rhs_np, x: x_np, - jacobi: jacobi_np}) + cg_val = sess.run( + cg_graph, + feed_dict={ + a: a_np, + rhs: rhs_np, + x: x_np, + jacobi: jacobi_np + }) norm_r0 = np.linalg.norm(rhs_np) norm_r = np.linalg.norm(cg_val.r) self.assertLessEqual(norm_r, tol * norm_r0) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py index 12e94369cbae462c21867657119cd2dd9ee29651..5d7534657bff27f7169e6a97bf4b03d4f6a35bc9 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py @@ -85,9 +85,11 @@ class UtilTest(test.TestCase): op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty]) else: op_shape_val, ax_val, aty_val = sess.run( - [op_shape, ax, aty], feed_dict={a: a_np, - x: x_np, - y: y_np}) + [op_shape, ax, aty], feed_dict={ + a: a_np, + x: x_np, + y: y_np + }) self.assertAllEqual(op_shape_val, [3, 2]) self.assertAllClose(ax_val, x_np) self.assertAllClose(aty_val, y_np) diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py index 4dfaa97ac9834ca3c13a9f8e8d721ddaba33bf7d..d791d467639b572e7831c1d1a582aa15585649b6 100644 --- a/tensorflow/contrib/solvers/python/ops/linear_equations.py +++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import linalg_ops @@ -84,10 +85,9 @@ def conjugate_gradient(operator, cg_state = collections.namedtuple("CGState", ["i", "x", "r", "p", "gamma"]) def stopping_criterion(i, state): - return math_ops.logical_and(i < max_iter, - linalg_ops.norm(state.r) > tol) + return math_ops.logical_and(i < max_iter, linalg_ops.norm(state.r) > tol) - def cg_step(i, state): + def cg_step(i, state): # pylint: disable=missing-docstring z = operator.apply(state.p) alpha = state.gamma / util.dot(state.p, z) x = state.x + alpha * state.p @@ -108,8 +108,7 @@ def conjugate_gradient(operator, rhs = array_ops.expand_dims(rhs, -1) if x is None: x = array_ops.expand_dims( - array_ops.zeros( - n, dtype=rhs.dtype.base_dtype), -1) + array_ops.zeros(n, dtype=rhs.dtype.base_dtype), -1) r0 = rhs else: x = array_ops.expand_dims(x, -1) @@ -119,7 +118,7 @@ def conjugate_gradient(operator, else: p0 = preconditioner.apply(r0) gamma0 = util.dot(r0, p0) - tol = tol * linalg_ops.norm(r0) + tol *= linalg_ops.norm(r0) i = constant_op.constant(0, dtype=dtypes.int32) state = cg_state(i=i, x=x, r=r0, p=p0, gamma=gamma0) _, state = control_flow_ops.while_loop(stopping_criterion, cg_step, diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py index c8b4e472c99e0bf081a7222a7976b1fbbb680825..360e7dbe75f595ff61fb83379089294371203813 100644 --- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py +++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py @@ -105,8 +105,8 @@ class SparsemaxLossTest(test.TestCase): tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu) np_loss = self._np_sparsemax_loss(z, q).astype(dtype) - self.assertAllCloseAccordingToType(np_loss, tf_loss_out, - half_atol=1e-2, half_rtol=5e-3) + self.assertAllCloseAccordingToType( + np_loss, tf_loss_out, half_atol=1e-2, half_rtol=5e-3) self.assertShapeEqual(np_loss, tf_loss_op) def _test_constant_add(self, dtype, random, use_gpu): @@ -116,17 +116,17 @@ class SparsemaxLossTest(test.TestCase): q = np.zeros((test_obs, 10)) q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1 - _, tf_loss_zpc = self._tf_sparsemax_loss( - z + c, q, dtype, use_gpu - ) + _, tf_loss_zpc = self._tf_sparsemax_loss(z + c, q, dtype, use_gpu) - _, tf_loss_z = self._tf_sparsemax_loss( - z, q, dtype, use_gpu - ) + _, tf_loss_z = self._tf_sparsemax_loss(z, q, dtype, use_gpu) - self.assertAllCloseAccordingToType(tf_loss_zpc, tf_loss_z, - float_atol=5e-6, float_rtol=5e-6, - half_atol=1e-2, half_rtol=1e-2) + self.assertAllCloseAccordingToType( + tf_loss_zpc, + tf_loss_z, + float_atol=5e-6, + float_rtol=5e-6, + half_atol=1e-2, + half_rtol=1e-2) def _test_sparsemax_loss_positive(self, dtype, random, use_gpu): """check sparsemax-loss proposition 4""" @@ -170,10 +170,7 @@ class SparsemaxLossTest(test.TestCase): with self.test_session(use_gpu=use_gpu): err = gradient_checker.compute_gradient_error( - logits, z.shape, - loss_op, (test_obs, ), - x_init_value=z, delta=1e-9 - ) + logits, z.shape, loss_op, (test_obs,), x_init_value=z, delta=1e-9) self.assertLess(err, 1e-4) @@ -192,8 +189,8 @@ class SparsemaxLossTest(test.TestCase): tf_grad = loss_grad_op.eval() np_grad = self._np_sparsemax_loss_grad(z, q).astype(dtype) - self.assertAllCloseAccordingToType(np_grad, tf_grad, - half_atol=1e-2, half_rtol=5e-3) + self.assertAllCloseAccordingToType( + np_grad, tf_grad, half_atol=1e-2, half_rtol=5e-3) self.assertShapeEqual(np_grad, loss_grad_op) def _test_dtype(self, dtype): @@ -220,5 +217,6 @@ class SparsemaxLossTest(test.TestCase): def testDouble(self): self._test_dtype('float64') -if __name__ == "__main__": + +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py index 82d36ee9cb21fb822e6df0c3632c49a4fd616825..259e62bd864fba3cc7d9aa387e02c8319438d658 100644 --- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py +++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py @@ -83,8 +83,8 @@ class SparsemaxTest(test.TestCase): tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu) p_sparemax = self._np_sparsemax(z).astype(dtype) - self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out, - half_atol=5e-3) + self.assertAllCloseAccordingToType( + p_sparemax, tf_sparsemax_out, half_atol=5e-3) self.assertShapeEqual(p_sparemax, tf_sparsemax_op) def _test_sparsemax_of_zero(self, dtype, random, use_gpu): @@ -111,9 +111,8 @@ class SparsemaxTest(test.TestCase): p_expected = np.zeros((test_obs, 10), dtype=dtype) p_expected[np.arange(0, test_obs), z_sort_arg[:, 0]] = 1 - tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax( - (1 / epsilon) * z, dtype, use_gpu - ) + tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax((1 / epsilon) * z, + dtype, use_gpu) self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out) self.assertShapeEqual(p_expected, tf_sparsemax_op) @@ -123,16 +122,12 @@ class SparsemaxTest(test.TestCase): z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype) c = random.uniform(low=-3, high=3, size=(test_obs, 1)).astype(dtype) - _, tf_sparsemax_zpc = self._tf_sparsemax( - z + c, dtype, use_gpu - ) + _, tf_sparsemax_zpc = self._tf_sparsemax(z + c, dtype, use_gpu) - _, tf_sparsemax_z = self._tf_sparsemax( - z, dtype, use_gpu - ) + _, tf_sparsemax_z = self._tf_sparsemax(z, dtype, use_gpu) - self.assertAllCloseAccordingToType(tf_sparsemax_zpc, tf_sparsemax_z, - half_atol=5e-3) + self.assertAllCloseAccordingToType( + tf_sparsemax_zpc, tf_sparsemax_z, half_atol=5e-3) def _test_permutation(self, dtype, random, use_gpu): """check sparsemax proposition 3""" @@ -143,12 +138,11 @@ class SparsemaxTest(test.TestCase): per = random.permutation(10) tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax( - z[i, per].reshape(1, -1), dtype, use_gpu - ) + z[i, per].reshape(1, -1), dtype, use_gpu) p_expected = p[i, per].reshape(1, -1) - self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out, - half_atol=5e-3) + self.assertAllCloseAccordingToType( + p_expected, tf_sparsemax_out, half_atol=5e-3) self.assertShapeEqual(p_expected, tf_sparsemax_op) def _test_diffrence(self, dtype, random, use_gpu): @@ -166,18 +160,14 @@ class SparsemaxTest(test.TestCase): continue self.assertTrue( - 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol, - "0 <= %.10f <= %.10f" % ( - p[val, j] - p[val, i], z[val, j] - z[val, i] + etol - ) - ) + 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol, + '0 <= %.10f <= %.10f' % (p[val, j] - p[val, i], + z[val, j] - z[val, i] + etol)) def _test_two_dimentional(self, dtype, random, use_gpu): """check two dimentation sparsemax case""" t = np.linspace(-2, 2, test_obs, dtype=dtype) - z = np.vstack([ - t, np.zeros(test_obs, dtype=dtype) - ]).T + z = np.vstack([t, np.zeros(test_obs, dtype=dtype)]).T tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu) @@ -196,10 +186,7 @@ class SparsemaxTest(test.TestCase): with self.test_session(use_gpu=use_gpu): err = gradient_checker.compute_gradient_error( - logits, z.shape, - sparsemax_op, z.shape, - x_init_value=z, delta=1e-9 - ) + logits, z.shape, sparsemax_op, z.shape, x_init_value=z, delta=1e-9) self.assertLess(err, 1e-4) @@ -248,5 +235,6 @@ class SparsemaxTest(test.TestCase): def testDouble(self): self._test_dtype('float64') -if __name__ == "__main__": + +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py index 73a5cf1e9287ea4e4350d88165744cf12db954bb..890ca20f4cabd65146e803e54e554a5c97e72427 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py @@ -23,7 +23,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.platform import resource_loader __all__ = ["sparsemax"] diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py index ba18f89e16c76a6ef3cb05df0c13f62eace6bbb1..582d1e6136df4d3ad3c8108ae9607d5fef519145 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.util import loader -from tensorflow.python.platform import resource_loader from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/specs/BUILD b/tensorflow/contrib/specs/BUILD index 4b688690aef513dd683817b0b5c2ba4cb50f73d9..084953a0a226cde46ebd9d2031d20cb839180ca8 100644 --- a/tensorflow/contrib/specs/BUILD +++ b/tensorflow/contrib/specs/BUILD @@ -23,7 +23,6 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/ndlstm", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:logging_ops", diff --git a/tensorflow/contrib/specs/README.md b/tensorflow/contrib/specs/README.md index b764e6e714ea907cd4474a07843bda300a8e4c8b..bcf34e601f1ffe3ab7a8c0d2ad573da4c8c977e9 100644 --- a/tensorflow/contrib/specs/README.md +++ b/tensorflow/contrib/specs/README.md @@ -59,17 +59,6 @@ Reshaping: - `Squeeze` = tf.squeeze - `Expand` = tf.expand_dims -Multidimensional LSTM: - -These are intended as alternatives to 2D convolutions. For sequence models, -there will be other modeling primitives. - - - `Lstm2` = Fun(lstm2d.separable_lstm) # 2D-to-2D - - `Lstm2to1` = Fun(lstm2d.reduce_to_sequence) # 2D-to-1D - - `Lstm2to0` = Fun(lstm2d.reduce_to_final) # 2D-to-vector - - `Clstm2(n, m)` is a `Cl(n, [3,3])` followed by `Lstm2(m)` - - `Dws(n)` is a depthwise convolution `Cs(n, [1, 1])` - Other: - `Id` = identity diff --git a/tensorflow/contrib/specs/python/__init__.py b/tensorflow/contrib/specs/python/__init__.py index 52db61e421a52f4106ab1e2a4d7ee5c100b6b4bc..b6cc754023859f8d3668545dd5c2fd1d1581ecf5 100644 --- a/tensorflow/contrib/specs/python/__init__.py +++ b/tensorflow/contrib/specs/python/__init__.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=wildcard-import,g-importing-member +# pylint: disable=wildcard-import,g-importing-member,redefined-builtin from tensorflow.contrib.specs.python.params_ops import * from tensorflow.contrib.specs.python.specs import * from tensorflow.contrib.specs.python.specs_lib import * from tensorflow.contrib.specs.python.specs_ops import * from tensorflow.contrib.specs.python.summaries import * -# pylint: enable=wildcard-import +# pylint: enable=wildcard-import,redefined-builtin diff --git a/tensorflow/contrib/specs/python/specs_ops.py b/tensorflow/contrib/specs/python/specs_ops.py index a6bd4d16c284a8b1a370005a7c55d3b74b4fbf95..49b989b8d0fc83a3793263a2b59a98a8fe292c6a 100644 --- a/tensorflow/contrib/specs/python/specs_ops.py +++ b/tensorflow/contrib/specs/python/specs_ops.py @@ -23,8 +23,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers -from tensorflow.contrib.ndlstm.python import lstm1d -from tensorflow.contrib.ndlstm.python import lstm2d from tensorflow.contrib.specs.python import specs_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops @@ -122,17 +120,6 @@ Sig = Fun(math_ops.sigmoid) Tanh = Fun(math_ops.tanh) Smax = Fun(nn_ops.softmax) -# 2D LSTM - -Lstm2 = Fun(lstm2d.separable_lstm) -Lstm2to1 = Fun(lstm2d.reduce_to_sequence) # 2D to 1D -Lstm2to0 = Fun(lstm2d.reduce_to_final) # 2D to depth-only - - -def Clstm2(n, *args, **kw): - """2D LSTM with 3x3 pre-convolution.""" - return Cl(n, [3, 3]) | Lstm2(*args, **kw) - def Dws(n): """Depth-wise convolution + sigmoid (used after LSTM).""" @@ -143,13 +130,6 @@ def Dwm(n): """Depth-wise convolution + softmax (used after LSTM).""" return Cm(n, [1, 1]) - -# 1D LSTM - -Lstm1 = Fun(lstm1d.ndlstm_base) -Lstm1to0 = Fun(lstm1d.sequence_to_final) # 1D to depth-only -Ssm = Fun(lstm1d.sequence_softmax) - # Sharing of Variables diff --git a/tensorflow/contrib/specs/python/specs_test.py b/tensorflow/contrib/specs/python/specs_test.py index 41782a9fc9ada3d8a1ff30847971aea18f0ca1c7..9a4ad36793542a83105ad0dc1ef7c0624a6c1f99 100644 --- a/tensorflow/contrib/specs/python/specs_test.py +++ b/tensorflow/contrib/specs/python/specs_test.py @@ -149,36 +149,6 @@ class SpecsTest(test.TestCase): self.assertEqual(tuple(result.shape), (10, 20)) self.assertEqual(summaries.tf_spec_structure(spec, inputs), "_ sig sig") - def testLstm2(self): - with self.test_session(): - inputs = constant_op.constant(_rand(1, 64, 64, 5)) - spec = "net = Lstm2(15)" - outputs = specs.create_net(spec, inputs) - self.assertEqual(outputs.get_shape().as_list(), [1, 64, 64, 15]) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (1, 64, 64, 15)) - - def testLstm2to1(self): - with self.test_session(): - inputs = constant_op.constant(_rand(1, 64, 64, 5)) - spec = "net = Lstm2to1(15)" - outputs = specs.create_net(spec, inputs) - self.assertEqual(outputs.get_shape().as_list(), [1, 64, 15]) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (1, 64, 15)) - - def testLstm2to0(self): - with self.test_session(): - inputs = constant_op.constant(_rand(1, 64, 64, 5)) - spec = "net = Lstm2to0(15)" - outputs = specs.create_net(spec, inputs) - self.assertEqual(outputs.get_shape().as_list(), [1, 15]) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (1, 15)) - def testKeywordRestriction(self): with self.test_session(): inputs = constant_op.constant(_rand(10, 20)) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index 7d3b8b7437a9ff5aaa0834db79bca8883cd679c8..2d6d7ea6a3eff2562ba8def4117e3aa6f818b6fd 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -18,6 +18,42 @@ The operations in this package are safe to use with eager execution turned on or off. It has a more flexible API that allows summaries to be written directly from ops to places other than event log files, rather than propagating protos from @{tf.summary.merge_all} to @{tf.summary.FileWriter}. + +To use with eager execution enabled, write your code as follows: + +global_step = tf.train.get_or_create_global_step() +summary_writer = tf.contrib.summary.create_file_writer( + train_dir, flush_millis=10000) +with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): + # model code goes here + # and in it call + tf.contrib.summary.scalar("loss", my_loss) + # In this case every call to tf.contrib.summary.scalar will generate a record + # ... + +To use it with graph execution, write your code as follows: + +global_step = tf.train.get_or_create_global_step() +summary_writer = tf.contrib.summary.create_file_writer( + train_dir, flush_millis=10000) +with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): + # model definition code goes here + # and in it call + tf.contrib.summary.scalar("loss", my_loss) + # In this case every call to tf.contrib.summary.scalar will generate an op, + # note the need to run tf.contrib.summary.all_summary_ops() to make sure these + # ops get executed. + # ... + train_op = .... + +with tf.Session(...) as sess: + tf.global_variables_initializer().run() + tf.contrib.summary.initialize(graph=tf.get_default_graph()) + # ... + while not_done_training: + sess.run([train_op, tf.contrib.summary.all_summary_ops()]) + # ... + """ from __future__ import absolute_import diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index ee661dfdc11451bb72bc2741b0b54ebf5c1e6543..b6249fc92f712b21197c2167fb5d1c4af1f48ca5 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -110,7 +110,7 @@ class SummaryWriter(object): def __init__(self, resource): self._resource = resource - if context.in_eager_mode(): + if context.in_eager_mode() and self._resource is not None: self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="cpu:0") @@ -154,10 +154,12 @@ def initialize( to @{tf.get_default_session}. Raises: - RuntimeError: If in eager mode, or if the current thread has no - default @{tf.contrib.summary.SummaryWriter}. + RuntimeError: If the current thread has no default + @{tf.contrib.summary.SummaryWriter}. ValueError: If session wasn't passed and no default session. """ + if context.in_eager_mode(): + return if context.context().summary_writer_resource is None: raise RuntimeError("No default tf.contrib.summary.SummaryWriter found") if session is None: @@ -202,7 +204,7 @@ def create_file_writer(logdir, if flush_millis is None: flush_millis = constant_op.constant(2 * 60 * 1000) if filename_suffix is None: - filename_suffix = constant_op.constant("") + filename_suffix = constant_op.constant(".v2") return _make_summary_writer( name, gen_summary_ops.create_summary_file_writer, @@ -292,13 +294,9 @@ def all_summary_ops(): Returns: The summary ops. - - Raises: - RuntimeError: If in Eager mode. """ if context.in_eager_mode(): - raise RuntimeError( - "tf.contrib.summary.all_summary_ops is only supported in graph mode.") + return None return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index dfaa4182bb867cc03480320eaf1804da36206655..bb7215f879411e91a1c47b87f5caede63fffea74 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -29,7 +29,6 @@ from tensorflow.core.framework import types_pb2 from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import state_ops @@ -59,12 +58,6 @@ _NUMPY_NUMERIC_TYPES = { class TargetTest(test_util.TensorFlowTestCase): - def testInvalidDirectory(self): - logdir = '/tmp/apath/that/doesnt/exist' - self.assertFalse(gfile.Exists(logdir)) - with self.assertRaises(errors.NotFoundError): - summary_ops.create_file_writer(logdir, max_queue=0, name='t0') - def testShouldRecordSummary(self): self.assertFalse(summary_ops.should_record_summaries()) with summary_ops.always_record_summaries(): diff --git a/tensorflow/contrib/summary/summary_test_internal.py b/tensorflow/contrib/summary/summary_test_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d3384735fb1eb1a048c7aa6da0037ee9fc6936 --- /dev/null +++ b/tensorflow/contrib/summary/summary_test_internal.py @@ -0,0 +1,60 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Internal helpers for tests in this directory.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +import sqlite3 + +from tensorflow.contrib.summary import summary_ops +from tensorflow.python.framework import test_util + + +class SummaryDbTest(test_util.TensorFlowTestCase): + """Helper for summary database testing.""" + + def setUp(self): + super(SummaryDbTest, self).setUp() + self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') + if os.path.exists(self.db_path): + os.unlink(self.db_path) + self.db = sqlite3.connect(self.db_path) + self.create_db_writer = functools.partial( + summary_ops.create_db_writer, + db_uri=self.db_path, + experiment_name='experiment', + run_name='run', + user_name='user') + + def tearDown(self): + self.db.close() + super(SummaryDbTest, self).tearDown() + + +def get_one(db, q, *p): + return db.execute(q, p).fetchone()[0] + + +def get_all(db, q, *p): + return unroll(db.execute(q, p).fetchall()) + + +def unroll(list_of_tuples): + return sum(list_of_tuples, ()) diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py index bda57e6a0ca8e1ddb979a80de276911c7738f0aa..8506c4be9c4ca8305b62da17c7246e6e18313bd3 100644 --- a/tensorflow/contrib/summary/summary_test_util.py +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -21,6 +21,7 @@ from __future__ import print_function import functools import os + import sqlite3 from tensorflow.contrib.summary import summary_ops diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 58a7fa095d8356229fdb5879bea99d316113c828..1e4cc3f0952ef74a1c89b7ed2d8c357fa8847ad5 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -497,6 +497,7 @@ py_library( ":tensor_forest_v4_ops_py", "//tensorflow/contrib/decision_trees/proto:generic_tree_model_py", "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_py", "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index a998ac1e111090a3702c0499a54ef1a5c1b3ac90..4abcc20ed334e706c8ae59e2127dfd6f4e152361 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import layers - +from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib @@ -43,8 +43,8 @@ from tensorflow.python.training import training_util KEYS_NAME = 'keys' LOSS_NAME = 'rf_training_loss' TREE_PATHS_PREDICTION_KEY = 'tree_paths' -VARIANCE_PREDICTION_KEY = 'regression_variance' - +VARIANCE_PREDICTION_KEY = 'prediction_variance' +ALL_SERVING_KEY = 'tensorforest_all' EPSILON = 0.000001 @@ -134,7 +134,8 @@ def get_model_fn(params, trainer_id=0, report_feature_importances=False, local_eval=False, - head_scope=None): + head_scope=None, + include_all_in_serving=False): """Return a model function given a way to construct a graph builder.""" if model_head is None: model_head = get_default_head(params, weights_name) @@ -238,7 +239,13 @@ def get_model_fn(params, model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance - + if include_all_in_serving: + # In order to serve the variance we need to add the prediction dict + # to output_alternatives dict. + if not model_ops.output_alternatives: + model_ops.output_alternatives = {} + model_ops.output_alternatives[ALL_SERVING_KEY] = ( + constants.ProblemType.UNSPECIFIED, model_ops.predictions) return model_ops return _model_fn @@ -293,7 +300,8 @@ class TensorForestEstimator(estimator.Estimator): report_feature_importances=False, local_eval=False, version=None, - head=None): + head=None, + include_all_in_serving=False): """Initializes a TensorForestEstimator instance. Args: @@ -339,6 +347,23 @@ class TensorForestEstimator(estimator.Estimator): version: Unused. head: A heads_lib.Head object that calculates losses and such. If None, one will be automatically created based on params. + include_all_in_serving: if True, allow preparation of the complete + prediction dict including the variance to be exported for serving with + the Servo lib; and it also requires calling export_savedmodel with + default_output_alternative_key=ALL_SERVING_KEY, i.e. + estimator.export_savedmodel(export_dir_base=your_export_dir, + serving_input_fn=your_export_input_fn, + default_output_alternative_key=ALL_SERVING_KEY) + if False, resort to default behavior, i.e. export scores and + probabilities but no variances. In this case + default_output_alternative_key should be None while calling + export_savedmodel(). + Note, that due to backward compatibility we cannot always set + include_all_in_serving to True because in this case calling + export_saved_model() without + default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the + saved_model_export_utils.get_output_alternatives() would raise + ValueError. Returns: A `TensorForestEstimator` instance. @@ -357,7 +382,9 @@ class TensorForestEstimator(estimator.Estimator): num_trainers=num_trainers, trainer_id=trainer_id, report_feature_importances=report_feature_importances, - local_eval=local_eval), + local_eval=local_eval, + include_all_in_serving=include_all_in_serving, + ), model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc index 76cfb4c9ca02269f9fee61c767acc6cb4a0b4ca7..cf0db788a419f64ed891df8aa097fa8826f6de91 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/hard_routing_function_op.cc @@ -99,18 +99,17 @@ class HardRoutingFunction : public OpKernel { const Tensor& tree_biases_tensor = context->input(2); if (input_data.shape().dim_size(0) > 0) { - OP_REQUIRES(context, input_data.shape().dims() == 2, - errors::InvalidArgument( - "input_data should be two-dimensional")); + OP_REQUIRES( + context, input_data.shape().dims() == 2, + errors::InvalidArgument("input_data should be two-dimensional")); } // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; - const int32 num_data = static_cast( - input_data.shape().dim_size(0)); - const int32 num_features = static_cast( - input_data.shape().dim_size(1)); + const int32 num_data = static_cast(input_data.shape().dim_size(0)); + const int32 num_features = + static_cast(input_data.shape().dim_size(1)); Tensor* output_probability = nullptr; TensorShape output_probability_shape; @@ -125,9 +124,8 @@ class HardRoutingFunction : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, output_probability_shape, &output_probability)); - OP_REQUIRES_OK(context, - context->allocate_output(1, output_path_shape, - &output_path)); + OP_REQUIRES_OK( + context, context->allocate_output(1, output_path_shape, &output_path)); auto out_probability = output_probability->tensor(); auto out_path = output_path->tensor(); @@ -144,12 +142,11 @@ class HardRoutingFunction : public OpKernel { out_probability(i, 0) = 1.0; out_path(i, 0) = 0; for (int j = 0; j < tree_depth_ - 1; j++) { - float left_prob = LeftProbability(point, - tree_parameters_tensor.Slice(j, j+1), - tree_biases(j), - num_features); + float left_prob = + LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1), + tree_biases(j), num_features); - int32 left_child = 2*node + 1; + int32 left_child = 2 * node + 1; int32 right_child = left_child + 1; float dot_product = 0.0; diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc index 28f50f1a32eb1827a242d527cd42c58487877959..f64155fa55af22d57c6619d8a39da0455dc0de65 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_gradient_op.cc @@ -85,12 +85,9 @@ REGISTER_OP("KFeatureGradient") class KFeatureGradient : public OpKernel { public: - explicit KFeatureGradient(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("layer_num", - &layer_num_)); - OP_REQUIRES_OK(context, context->GetAttr("random_seed", - &random_seed_)); + explicit KFeatureGradient(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_)); + OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_)); } void Compute(OpKernelContext* context) override { @@ -101,14 +98,14 @@ class KFeatureGradient : public OpKernel { const Tensor& routing_tensor = context->input(3); // Extract dimensions from input tensors. - const int32 num_data = static_cast( - input_data_tensor.shape().dim_size(0)); - const int32 num_features = static_cast( - input_data_tensor.shape().dim_size(1)); - const int32 num_nodes = static_cast( - tree_parameters_tensor.shape().dim_size(0)); - const int32 num_features_per_node = static_cast( - tree_parameters_tensor.shape().dim_size(1)); + const int32 num_data = + static_cast(input_data_tensor.shape().dim_size(0)); + const int32 num_features = + static_cast(input_data_tensor.shape().dim_size(1)); + const int32 num_nodes = + static_cast(tree_parameters_tensor.shape().dim_size(0)); + const int32 num_features_per_node = + static_cast(tree_parameters_tensor.shape().dim_size(1)); // Construct output tensors. Tensor* out_routes = nullptr; @@ -127,12 +124,12 @@ class KFeatureGradient : public OpKernel { out_weights_shape.AddDim(num_nodes); out_weights_shape.AddDim(num_features_per_node); - OP_REQUIRES_OK(context, context->allocate_output( - 0, out_routes_shape, &out_routes)); - OP_REQUIRES_OK(context, context->allocate_output( - 1, out_data_shape, &out_data)); - OP_REQUIRES_OK(context, context->allocate_output( - 2, out_weights_shape, &out_weights)); + OP_REQUIRES_OK(context, + context->allocate_output(0, out_routes_shape, &out_routes)); + OP_REQUIRES_OK(context, + context->allocate_output(1, out_data_shape, &out_data)); + OP_REQUIRES_OK( + context, context->allocate_output(2, out_weights_shape, &out_weights)); tensorforest::Initialize(*out_data, 0.0f); @@ -148,18 +145,13 @@ class KFeatureGradient : public OpKernel { std::vector feature_set; for (int i = 0; i < num_data; i++) { - const Tensor point = input_data_tensor.Slice(i, i+1); + const Tensor point = input_data_tensor.Slice(i, i + 1); feature_set.clear(); // Traverse the tree from the bottom up. for (int j = num_nodes - 1; j >= 0; j--) { - tensorforest::GetFeatureSet( - layer_num_, - j, - random_seed_, - num_features, - num_features_per_node, - &feature_set); + tensorforest::GetFeatureSet(layer_num_, j, random_seed_, num_features, + num_features_per_node, &feature_set); // Compute routing gradient. // j is a leaf node. @@ -170,12 +162,8 @@ class KFeatureGradient : public OpKernel { int32 right_child = left_child + 1; float left_prob = LeftProbabilityK( - point, - feature_set, - tree_parameters_tensor.Slice(j, j+1), - tree_biases(j), - num_features, - num_features_per_node); + point, feature_set, tree_parameters_tensor.Slice(j, j + 1), + tree_biases(j), num_features, num_features_per_node); float right_prob = 1.0f - left_prob; diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc index 9bc42eb61fae013de3e4ea73aaf371cdaa4ccf9a..e7cafb144da84865ad2b4ea0c33866ddb89119a5 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/k_feature_routing_function_op.cc @@ -43,7 +43,6 @@ using shape_inference::ShapeHandle; using tensorforest::CheckTensorBounds; using tensorforest::LeftProbabilityK; - // The term 'routing function' is synonymous with 'the probability // that an instance is routed to each leaf node.' It is defined in // 'Deep Neural Decision Forests' by Kontschieder et al. @@ -96,10 +95,8 @@ class KFeatureRoutingFunction : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("max_nodes", &max_nodes_)); OP_REQUIRES_OK(context, context->GetAttr("num_features_per_node", &num_features_per_node_)); - OP_REQUIRES_OK(context, context->GetAttr("layer_num", - &layer_num_)); - OP_REQUIRES_OK(context, context->GetAttr("random_seed", - &random_seed_)); + OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_)); + OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_)); } void Compute(OpKernelContext* context) override { @@ -108,27 +105,25 @@ class KFeatureRoutingFunction : public OpKernel { const Tensor& tree_biases_tensor = context->input(2); if (input_data.shape().dim_size(0) > 0) { - OP_REQUIRES(context, input_data.shape().dims() == 2, - errors::InvalidArgument( - "input_data should be two-dimensional")); + OP_REQUIRES( + context, input_data.shape().dims() == 2, + errors::InvalidArgument("input_data should be two-dimensional")); } // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; - const int32 num_data = static_cast( - input_data.shape().dim_size(0)); - const int32 num_features = static_cast( - input_data.shape().dim_size(1)); + const int32 num_data = static_cast(input_data.shape().dim_size(0)); + const int32 num_features = + static_cast(input_data.shape().dim_size(1)); Tensor* output_probabilities = nullptr; TensorShape output_shape; output_shape.AddDim(num_data); output_shape.AddDim(max_nodes_); - OP_REQUIRES_OK(context, - context->allocate_output(0, output_shape, - &output_probabilities)); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, + &output_probabilities)); auto out_probs = output_probabilities->tensor(); const auto tree_biases = tree_biases_tensor.tensor(); @@ -136,30 +131,22 @@ class KFeatureRoutingFunction : public OpKernel { // Iteratively compute the probability of reaching each leaf. std::vector feature_set; for (int i = 0; i < num_data; i++) { - const Tensor point = input_data.Slice(i, i+1); + const Tensor point = input_data.Slice(i, i + 1); out_probs(i, 0) = 1.0f; for (int j = 0; j < max_nodes_ / 2; j++) { feature_set.clear(); - tensorforest::GetFeatureSet( - layer_num_, - i, - random_seed_, - num_features, - num_features_per_node_, - &feature_set); - - int32 left_child = 2*j + 1; + tensorforest::GetFeatureSet(layer_num_, i, random_seed_, num_features, + num_features_per_node_, &feature_set); + + int32 left_child = 2 * j + 1; int32 right_child = left_child + 1; float prob = out_probs(i, j); - float left_prob = LeftProbabilityK(point, - feature_set, - tree_parameters_tensor.Slice(j, j+1), - tree_biases(j), - num_features, - num_features_per_node_); + float left_prob = LeftProbabilityK( + point, feature_set, tree_parameters_tensor.Slice(j, j + 1), + tree_biases(j), num_features, num_features_per_node_); out_probs(i, left_child) = prob * left_prob; out_probs(i, right_child) = prob * (1.0f - left_prob); diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc index 4027e732b3f52585c2149c3cdc71535664f04ed4..0c2eaabe8f3e1e1377a8d5c5308aaec00030a20f 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/routing_function_op.cc @@ -90,46 +90,43 @@ class RoutingFunction : public OpKernel { const Tensor& tree_biases_tensor = context->input(2); if (input_data.shape().dim_size(0) > 0) { - OP_REQUIRES(context, input_data.shape().dims() == 2, - errors::InvalidArgument( - "input_data should be two-dimensional")); + OP_REQUIRES( + context, input_data.shape().dims() == 2, + errors::InvalidArgument("input_data should be two-dimensional")); } // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; - const int32 num_data = static_cast( - input_data.shape().dim_size(0)); - const int32 num_features = static_cast( - input_data.shape().dim_size(1)); + const int32 num_data = static_cast(input_data.shape().dim_size(0)); + const int32 num_features = + static_cast(input_data.shape().dim_size(1)); Tensor* output_probabilities = nullptr; TensorShape output_shape; output_shape.AddDim(num_data); output_shape.AddDim(max_nodes_); - OP_REQUIRES_OK(context, - context->allocate_output(0, output_shape, - &output_probabilities)); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, + &output_probabilities)); auto out_probs = output_probabilities->tensor(); const auto tree_biases = tree_biases_tensor.tensor(); // Iteratively compute the probability of reaching each leaf. for (int i = 0; i < num_data; i++) { - const Tensor point = input_data.Slice(i, i+1); + const Tensor point = input_data.Slice(i, i + 1); out_probs(i, 0) = 1.0; for (int j = 0; j < max_nodes_ / 2; j++) { - int32 left_child = 2*j + 1; + int32 left_child = 2 * j + 1; int32 right_child = left_child + 1; float prob = out_probs(i, j); - float left_prob = LeftProbability(point, - tree_parameters_tensor.Slice(j, j+1), - tree_biases(j), - num_features); + float left_prob = + LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1), + tree_biases(j), num_features); out_probs(i, left_child) = prob * left_prob; out_probs(i, right_child) = prob * (1.0 - left_prob); diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc index 66aa293dc1cb93b82f06d838ad7b0f9c09761585..c9df09bfda44e665ed013da383e1e9a2c665c454 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_function_op.cc @@ -96,10 +96,9 @@ class StochasticHardRoutingFunction : public OpKernel { explicit StochasticHardRoutingFunction(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("tree_depth", &tree_depth_)); - OP_REQUIRES_OK(context, context->GetAttr("random_seed", - &random_seed_)); + OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_)); single_rand_ = std::unique_ptr( - new random::PhiloxRandom(random_seed_)); + new random::PhiloxRandom(random_seed_)); rng_ = std::unique_ptr( new random::SimplePhilox(single_rand_.get())); } @@ -111,20 +110,19 @@ class StochasticHardRoutingFunction : public OpKernel { const Tensor& tree_biases_tensor = context->input(2); if (input_data.shape().dim_size(0) > 0) { - OP_REQUIRES(context, input_data.shape().dims() == 2, - errors::InvalidArgument( - "input_data should be two-dimensional")); + OP_REQUIRES( + context, input_data.shape().dims() == 2, + errors::InvalidArgument("input_data should be two-dimensional")); } // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; - const int32 num_data = static_cast( - input_data.shape().dim_size(0)); - const int32 num_features = static_cast( - input_data.shape().dim_size(1)); - const int32 num_nodes = static_cast( - tree_parameters_tensor.shape().dim_size(0)); + const int32 num_data = static_cast(input_data.shape().dim_size(0)); + const int32 num_features = + static_cast(input_data.shape().dim_size(1)); + const int32 num_nodes = + static_cast(tree_parameters_tensor.shape().dim_size(0)); Tensor* output_probability = nullptr; TensorShape output_probability_shape; @@ -139,9 +137,8 @@ class StochasticHardRoutingFunction : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, output_probability_shape, &output_probability)); - OP_REQUIRES_OK(context, - context->allocate_output(1, output_path_shape, - &output_path)); + OP_REQUIRES_OK( + context, context->allocate_output(1, output_path_shape, &output_path)); auto out_probability = output_probability->tensor(); auto out_path = output_path->tensor(); @@ -150,19 +147,18 @@ class StochasticHardRoutingFunction : public OpKernel { // Stochastically traverse the tree to a leaf. for (int i = 0; i < num_data; i++) { - const Tensor point = input_data.Slice(i, i+1); + const Tensor point = input_data.Slice(i, i + 1); int32 node = 0; out_probability(i, 0) = 1.0; out_path(i, 0) = 0; for (int j = 0; j < tree_depth_ - 1; j++) { - int32 left_child = 2*node + 1; + int32 left_child = 2 * node + 1; int32 right_child = left_child + 1; - float left_prob = LeftProbability(point, - tree_parameters_tensor.Slice(j, j+1), - tree_biases(j), - num_features); + float left_prob = + LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1), + tree_biases(j), num_features); if (left_prob < rng_->RandFloat()) { CHECK_LT(i, num_data); diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc index 0b5afe464f4b9608af0feca584aaa799f5980f46..b0d8b832b5437db7a4b3026e80ae99d0391d7f7a 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/stochastic_hard_routing_gradient_op.cc @@ -149,14 +149,14 @@ class StochasticHardRoutingGradient : public OpKernel { TensorShape output_bias_shape; output_bias_shape.AddDim(num_data); - OP_REQUIRES_OK(context, context->allocate_output( - 0, output_routing_shape, &output_routing)); - OP_REQUIRES_OK(context, context->allocate_output( - 1, output_data_shape, &output_data)); - OP_REQUIRES_OK(context, context->allocate_output( - 2, output_parameters_shape, &output_parameters)); - OP_REQUIRES_OK(context, context->allocate_output( - 3, output_bias_shape, &output_bias)); + OP_REQUIRES_OK(context, context->allocate_output(0, output_routing_shape, + &output_routing)); + OP_REQUIRES_OK( + context, context->allocate_output(1, output_data_shape, &output_data)); + OP_REQUIRES_OK(context, context->allocate_output(2, output_parameters_shape, + &output_parameters)); + OP_REQUIRES_OK( + context, context->allocate_output(3, output_bias_shape, &output_bias)); tensorforest::Initialize(*output_routing, 0.0); tensorforest::Initialize(*output_data, 0.0); @@ -178,7 +178,7 @@ class StochasticHardRoutingGradient : public OpKernel { const Tensor point = input_data.Slice(i, i + 1); // Traverses the tree from the bottom up. - for (int j = tree_depth_-1; j > -1; j--) { + for (int j = tree_depth_ - 1; j > -1; j--) { int32 node = path(i, j); CHECK_LT(node, num_nodes); diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc index cacad03e274c3279eb3706e71e1bcdf8433ca1ef..25825a78a1498490009fe4ff6bbfc67493727037 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc @@ -64,8 +64,7 @@ REGISTER_OP("UnpackPath") class UnpackPath : public OpKernel { public: - explicit UnpackPath(OpKernelConstruction* context) - : OpKernel(context) {} + explicit UnpackPath(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { VLOG(1) << "unpack start"; @@ -73,8 +72,8 @@ class UnpackPath : public OpKernel { const Tensor& path_values_tensor = context->input(1); const int32 num_data = static_cast(path_tensor.shape().dim_size(0)); - const int32 tree_depth = static_cast( - path_tensor.shape().dim_size(1)); + const int32 tree_depth = + static_cast(path_tensor.shape().dim_size(1)); const int32 num_nodes = MathUtil::IPow(2, tree_depth) - 1; @@ -107,7 +106,6 @@ class UnpackPath : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("UnpackPath").Device(DEVICE_CPU), - UnpackPath); +REGISTER_KERNEL_BUILDER(Name("UnpackPath").Device(DEVICE_CPU), UnpackPath); } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc index c091a73c4e48a47bdccea3ec99371faab9c586c2..34388fe1aab72895a805141ec66a71ecf0f42ba4 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc @@ -25,9 +25,7 @@ namespace tensorforest { using tensorflow::Tensor; -float LeftProbability(const Tensor& point, - const Tensor& weight, - float bias, +float LeftProbability(const Tensor& point, const Tensor& weight, float bias, int num_features) { const auto p = point.unaligned_flat(); const auto w = weight.unaligned_flat(); @@ -41,11 +39,8 @@ float LeftProbability(const Tensor& point, return 1.0 / (1.0 + exp(-dot_product + bias)); } -float LeftProbabilityK(const Tensor& point, - std::vector feature_set, - const Tensor& weight, - float bias, - int num_features, +float LeftProbabilityK(const Tensor& point, std::vector feature_set, + const Tensor& weight, float bias, int num_features, int k) { const auto p = point.unaligned_flat(); const auto w = weight.unaligned_flat(); diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h index c5902184f95ea8f97be4a10d1101a38333359d44..69a0143a4e319157a4526ca80fbb3f6472902b31 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h @@ -24,16 +24,11 @@ namespace tensorflow { namespace tensorforest { // Returns the probability that the point falls to the left. -float LeftProbability(const Tensor& point, - const Tensor& weight, - float bias, +float LeftProbability(const Tensor& point, const Tensor& weight, float bias, int num_features); -float LeftProbabilityK(const Tensor& point, - std::vector feature_set, - const Tensor& weight, - float bias, - int num_features, +float LeftProbabilityK(const Tensor& point, std::vector feature_set, + const Tensor& weight, float bias, int num_features, int k); // Returns a random set of num_features_to_pick features in the @@ -49,5 +44,3 @@ void GetFeatureSet(int32 tree_num, int32 node_num, int32 random_seed, } // namespace tensorflow #endif // LEARNING_LIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_ - - diff --git a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc index 47b49a379c4b7a17d35b52c1403f67c2f07aeeaf..b21a9179777c21f65435e136aa6082e27fb3b78c 100644 --- a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc +++ b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc @@ -30,15 +30,13 @@ namespace tensorflow { using tensorforest::CheckTensorBounds; - float Convert(const string& in) { const std::size_t intval = std::hash()(in); return static_cast(intval); } - -void Evaluate(const Tensor& input_data, Tensor output_data, - int32 start, int32 end) { +void Evaluate(const Tensor& input_data, Tensor output_data, int32 start, + int32 end) { auto out_data = output_data.unaligned_flat(); const auto in_data = input_data.unaligned_flat(); @@ -59,9 +57,8 @@ class ReinterpretStringToFloat : public OpKernel { if (!CheckTensorBounds(context, input_data)) return; Tensor* output_data = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, input_data.shape(), - &output_data)); + OP_REQUIRES_OK( + context, context->allocate_output(0, input_data.shape(), &output_data)); // Evaluate input data in parallel. const int32 num_data = static_cast(input_data.NumElements()); @@ -73,8 +70,8 @@ class ReinterpretStringToFloat : public OpKernel { auto work = [&input_data, output_data, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - Evaluate(input_data, *output_data, - static_cast(start), static_cast(end)); + Evaluate(input_data, *output_data, static_cast(start), + static_cast(end)); }; Shard(num_threads, worker_threads->workers, num_data, 100, work); } diff --git a/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc b/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc index dd2a98b08cdb486c98c161390a3a1f81d31e1f4b..60740c2be3703141805c7eae0ac384edf934ab3d 100644 --- a/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc +++ b/tensorflow/contrib/tensor_forest/kernels/scatter_add_ndim_op.cc @@ -22,7 +22,6 @@ #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/platform/logging.h" - namespace tensorflow { using tensorforest::CheckTensorBounds; @@ -38,20 +37,19 @@ class ScatterAddNdim : public OpKernel { if (indices_tensor.shape().dim_size(0) > 0) { OP_REQUIRES(context, indices_tensor.shape().dims() == 2, - errors::InvalidArgument( - "indices should be two-dimensional")); + errors::InvalidArgument("indices should be two-dimensional")); const int32 delta_dims = deltas_tensor.shape().dims(); OP_REQUIRES( context, indices_tensor.shape().dim_size(1) + delta_dims == - input_tensor.shape().dims() + 1, + input_tensor.shape().dims() + 1, errors::InvalidArgument( "Number of indices dimensions should be the same as input " "rank.")); OP_REQUIRES( context, indices_tensor.shape().dim_size(0) == - deltas_tensor.shape().dim_size(0), + deltas_tensor.shape().dim_size(0), errors::InvalidArgument( "Number of updates should be same as number of indices.")); } else { @@ -68,8 +66,8 @@ class ScatterAddNdim : public OpKernel { const auto indices = indices_tensor.tensor(); const auto deltas = deltas_tensor.unaligned_flat(); - const int32 num_dims = static_cast( - indices_tensor.shape().dim_size(1)); + const int32 num_dims = + static_cast(indices_tensor.shape().dim_size(1)); // Figure out if indices don't specify a complete position in the // input tensor. @@ -80,10 +78,9 @@ class ScatterAddNdim : public OpKernel { // Calculate index multipliers. std::vector multipliers; - OP_REQUIRES( - context, input.size() < std::numeric_limits::max(), - errors::InvalidArgument( - "Input must contain less than 2^31 total elements")); + OP_REQUIRES(context, input.size() < std::numeric_limits::max(), + errors::InvalidArgument( + "Input must contain less than 2^31 total elements")); int32 last_size = static_cast(input.size()); for (int32 j = 0; j < num_dims; j++) { diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc b/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc index 94e12cea5a072f0746e642196d55f3a3b13a16c3..44997ec5d6d5fdb9aab52ab7a50f46a731bfda66 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc @@ -65,8 +65,8 @@ void GetTwoBest(int max, const std::function& score_fn, float ClassificationSplitScore( const Eigen::Tensor& splits, - const Eigen::Tensor& rights, - int32 num_classes, int i) { + const Eigen::Tensor& rights, int32 num_classes, + int i) { Eigen::array offsets; // Class counts are stored with the total in [0], so the length of each // count vector is num_classes + 1. @@ -74,7 +74,7 @@ float ClassificationSplitScore( Eigen::array extents; extents[0] = num_classes; return WeightedGiniImpurity(splits.slice(offsets, extents)) + - WeightedGiniImpurity(rights.slice(offsets, extents)); + WeightedGiniImpurity(rights.slice(offsets, extents)); } void GetTwoBestClassification(const Tensor& total_counts, @@ -90,29 +90,28 @@ void GetTwoBestClassification(const Tensor& total_counts, // in seg faults, so we have to go with flat views of these tensors. However, // it is still pretty efficient because we put off evaluation until the // score is actually returned. - const auto tc = total_counts.Slice( - accumulator, accumulator + 1).unaligned_flat(); + const auto tc = + total_counts.Slice(accumulator, accumulator + 1).unaligned_flat(); // TODO(gilberth): See if we can delay evaluation here by templating the // arguments to ClassificationSplitScore. - const Eigen::Tensor splits = split_counts.Slice( - accumulator, accumulator + 1).unaligned_flat(); + const Eigen::Tensor splits = + split_counts.Slice(accumulator, accumulator + 1).unaligned_flat(); Eigen::array bcast; bcast[0] = num_splits; const Eigen::Tensor rights = tc.broadcast(bcast) - splits; - std::function score_fn = std::bind( - ClassificationSplitScore, splits, rights, num_classes, - std::placeholders::_1); + std::function score_fn = + std::bind(ClassificationSplitScore, splits, rights, num_classes, + std::placeholders::_1); GetTwoBest(num_splits, score_fn, best_score, best_index, second_best_score, second_best_index); } -int32 BestFeatureClassification( - const Tensor& total_counts, const Tensor& split_counts, - int32 accumulator) { +int32 BestFeatureClassification(const Tensor& total_counts, + const Tensor& split_counts, int32 accumulator) { float best_score; float second_best_score; int best_feature_index; @@ -130,8 +129,7 @@ float RegressionSplitScore( const Eigen::Tensor& splits_square, const Eigen::Tensor& right_sums, const Eigen::Tensor& right_squares, - int32 accumulator, - int32 num_regression_dims, int i) { + int32 accumulator, int32 num_regression_dims, int i) { Eigen::array offsets = {i * num_regression_dims + 1}; Eigen::array extents = {num_regression_dims - 1}; float left_count = splits_count_accessor(accumulator, i, 0); @@ -141,15 +139,15 @@ float RegressionSplitScore( // Guard against divide-by-zero. if (left_count > 0) { - score += WeightedVariance( - splits_sum.slice(offsets, extents), - splits_square.slice(offsets, extents), left_count); + score += + WeightedVariance(splits_sum.slice(offsets, extents), + splits_square.slice(offsets, extents), left_count); } if (right_count > 0) { - score += WeightedVariance(right_sums.slice(offsets, extents), - right_squares.slice(offsets, extents), - right_count); + score += + WeightedVariance(right_sums.slice(offsets, extents), + right_squares.slice(offsets, extents), right_count); } return score; } @@ -159,20 +157,20 @@ void GetTwoBestRegression(const Tensor& total_sums, const Tensor& total_squares, int32 accumulator, float* best_score, int* best_index, float* second_best_score, int* second_best_index) { const int32 num_splits = static_cast(split_sums.shape().dim_size(1)); - const int32 num_regression_dims = static_cast( - split_sums.shape().dim_size(2)); + const int32 num_regression_dims = + static_cast(split_sums.shape().dim_size(2)); // Ideally, Eigen::Tensor::chip would be best to use here but it results // in seg faults, so we have to go with flat views of these tensors. However, // it is still pretty efficient because we put off evaluation until the // score is actually returned. - const auto tc_sum = total_sums.Slice( - accumulator, accumulator + 1).unaligned_flat(); - const auto tc_square = total_squares.Slice( - accumulator, accumulator + 1).unaligned_flat(); - const auto splits_sum = split_sums.Slice( - accumulator, accumulator + 1).unaligned_flat(); - const auto splits_square = split_squares.Slice( - accumulator, accumulator + 1).unaligned_flat(); + const auto tc_sum = + total_sums.Slice(accumulator, accumulator + 1).unaligned_flat(); + const auto tc_square = + total_squares.Slice(accumulator, accumulator + 1).unaligned_flat(); + const auto splits_sum = + split_sums.Slice(accumulator, accumulator + 1).unaligned_flat(); + const auto splits_square = + split_squares.Slice(accumulator, accumulator + 1).unaligned_flat(); // Eigen is infuriating to work with, usually resulting in all kinds of // unhelpful compiler errors when trying something that seems sane. This // helps us do a simple thing like access the first element (the counts) @@ -193,10 +191,10 @@ void GetTwoBestRegression(const Tensor& total_sums, const Tensor& total_squares, best_score, best_index, second_best_score, second_best_index); } -int32 BestFeatureRegression( - const Tensor& total_sums, const Tensor& total_squares, - const Tensor& split_sums, const Tensor& split_squares, - int32 accumulator) { +int32 BestFeatureRegression(const Tensor& total_sums, + const Tensor& total_squares, + const Tensor& split_sums, + const Tensor& split_squares, int32 accumulator) { float best_score; float second_best_score; int best_feature_index; @@ -207,10 +205,11 @@ int32 BestFeatureRegression( return best_feature_index; } -bool BestSplitDominatesRegression( - const Tensor& total_sums, const Tensor& total_squares, - const Tensor& split_sums, const Tensor& split_squares, - int32 accumulator) { +bool BestSplitDominatesRegression(const Tensor& total_sums, + const Tensor& total_squares, + const Tensor& split_sums, + const Tensor& split_squares, + int32 accumulator) { // TODO(thomaswc): Implement this, probably as part of v3. return false; } @@ -599,7 +598,6 @@ bool Decide(float value, float bias, DataColumnTypes type) { } } - void GetParentWeightedMean(float leaf_sum, const float* leaf_data, float parent_sum, const float* parent_data, float valid_leaf_threshold, int num_outputs, diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h index dad9df4898844eaa17bdfe5b4b298a95377fd12e..edbac6700677633cbd4d41f7040b4859ca599c4a 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h @@ -45,13 +45,10 @@ const int32 LEAF_NODE = -1; const int32 FREE_NODE = -2; // Used to indicate column types, e.g. categorical vs. float -enum DataColumnTypes { - kDataFloat = 0, - kDataCategorical = 1 -}; +enum DataColumnTypes { kDataFloat = 0, kDataCategorical = 1 }; // Calculates the sum of a tensor. -template +template T Sum(Tensor counts) { Eigen::Tensor count_sum = counts.unaligned_flat().sum(); @@ -97,7 +94,7 @@ float WeightedGiniImpurity(const T& counts) { return RawWeightedGiniImpurity(smoothed); } -template +template float WeightedVariance(const T1& sums, const T2& squares, float count) { const auto e_x = sums / count; const auto e_x2 = squares / count; @@ -120,10 +117,11 @@ int32 BestFeatureRegression(const Tensor& total_sums, // Returns true if the best split's variance is sufficiently smaller than // that of the next best split. -bool BestSplitDominatesRegression( - const Tensor& total_sums, const Tensor& total_squares, - const Tensor& split_sums, const Tensor& split_squares, - int32 accumulator); +bool BestSplitDominatesRegression(const Tensor& total_sums, + const Tensor& total_squares, + const Tensor& split_sums, + const Tensor& split_squares, + int32 accumulator); // Performs booststrap_samples bootstrap samples of the best split's class // counts and the second best splits's class counts, and returns true if at @@ -178,10 +176,8 @@ bool DecideNode(const GetFeatureFnType& get_dense, // isn't present in sparse_input_indices. sparse_input_indices is assumed // to be sorted. template -float FindSparseValue( - const T1& sparse_input_indices, - const T2& sparse_input_values, - int32 i, int32 j) { +float FindSparseValue(const T1& sparse_input_indices, + const T2& sparse_input_values, int32 i, int32 j) { int32 low = 0; int32 high = sparse_input_values.dimension(0); while (low < high) { @@ -273,7 +269,6 @@ int32 GetNumSparseFeatures(const T1& indices, int32 input_index, // categorical data, it is value != bias. bool Decide(float value, float bias, DataColumnTypes type = kDataFloat); - // Returns true if all the splits are initialized. Since they get initialized // in order, we can simply infer this from the last split. // This should only be called for a single allocator's candidate features diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc b/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc index 7485a695dfba93fd3f57c19096b205b10e2fa8b5..08553545502c21eb8f2d68bfd342f8ba7c081adb 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils_test.cc @@ -44,11 +44,13 @@ TEST(TestWeightedVariance, Basic) { Tensor squares = test::AsTensor({29, 12}, {2}); EXPECT_FLOAT_EQ(WeightedVariance(sums.unaligned_flat(), - squares.unaligned_flat(), 3), 2.0); + squares.unaligned_flat(), 3), + 2.0); Tensor zero = test::AsTensor({0}, {1}); EXPECT_FLOAT_EQ(WeightedVariance(zero.unaligned_flat(), - zero.unaligned_flat(), 1), 0); + zero.unaligned_flat(), 1), + 0); } TEST(TestInitialize, Basic) { @@ -94,17 +96,16 @@ TEST(BestFeatureClassification, Basic) { const int32 num_accumulators = 4; const int32 num_splits = 3; const int32 num_classes = 4; - Tensor totals = test::AsTensor({1, 5, 6, 7, - 0, 0, 0, 0, - 30, 10, 10, 10, // this one - -1, -1, -1, -1}, - {num_accumulators, num_classes}); - Tensor splits = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 30, 10, 10, 10, 10, 0, 0, 10, 19, 5, 6, 8, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, - {num_accumulators, num_splits, num_classes}); + Tensor totals = test::AsTensor( + {1, 5, 6, 7, 0, 0, 0, 0, 30, 10, 10, 10, // this one + -1, -1, -1, -1}, + {num_accumulators, num_classes}); + Tensor splits = + test::AsTensor({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 10, + 10, 10, 10, 0, 0, 10, 19, 5, 6, 8, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {num_accumulators, num_splits, num_classes}); EXPECT_EQ(BestFeatureClassification(totals, splits, 2), 1); } @@ -114,17 +115,16 @@ TEST(BestFeatureClassification, NoWinner) { const int32 num_splits = 3; const int32 num_classes = 4; // When counts are all the same, the most reasonable thing to do is pick 0. - Tensor totals = test::AsTensor({1, 5, 6, 7, - 0, 0, 0, 0, - 18, 6, 6, 6, // this one - -1, -1, -1, -1}, - {num_accumulators, num_classes}); - Tensor splits = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 9, 3, 3, 3, 9, 3, 3, 3, 9, 3, 3, 3, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, - {num_accumulators, num_splits, num_classes}); + Tensor totals = + test::AsTensor({1, 5, 6, 7, 0, 0, 0, 0, 18, 6, 6, 6, // this one + -1, -1, -1, -1}, + {num_accumulators, num_classes}); + Tensor splits = + test::AsTensor({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 3, + 3, 3, 9, 3, 3, 3, 9, 3, 3, 3, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {num_accumulators, num_splits, num_classes}); EXPECT_EQ(BestFeatureClassification(totals, splits, 2), 0); } @@ -133,36 +133,34 @@ TEST(BestFeatureRegression, Basic) { const int32 num_accumulators = 4; const int32 num_splits = 3; const int32 num_classes = 4; - Tensor total_sums = test::AsTensor( - {1, 5, 6, 7, - 0, 0, 0, 0, - 10, 8, 6, 9, // this one - -1, -1, -1, -1}, - {num_accumulators, num_classes}); + Tensor total_sums = + test::AsTensor({1, 5, 6, 7, 0, 0, 0, 0, 10, 8, 6, 9, // this one + -1, -1, -1, -1}, + {num_accumulators, num_classes}); Tensor total_squares = test::AsTensor( - {1, 5, 6, 7, - 0, 0, 0, 0, - 100, 50, 40, 45, // this one + {1, 5, 6, 7, 0, 0, 0, 0, 100, 50, 40, 45, // this one -1, -1, -1, -1}, {num_accumulators, num_classes}); - Tensor split_sums = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 10, 8, 6, 9, 9, 8, 5, 9, 0, 0, 0, 0, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, - {num_accumulators, num_splits, num_classes}); + Tensor split_sums = + test::AsTensor({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 8, + 6, 9, 9, 8, 5, 9, 0, 0, 0, 0, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {num_accumulators, num_splits, num_classes}); // lower the variance by lowering one of the squares just a little. - Tensor split_squares = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 100, 50, 40, 45, 100, 50, 40, 43, 0, 0, 0, 0, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, - {num_accumulators, num_splits, num_classes}); + Tensor split_squares = + test::AsTensor( + {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 100, 50, 40, 45, 100, 50, 40, 43, 0, 0, 0, 0, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {num_accumulators, num_splits, num_classes}); EXPECT_EQ(BestFeatureRegression(total_sums, total_squares, split_sums, - split_squares, 2), 1); + split_squares, 2), + 1); } TEST(BestFeatureRegression, NoWinner) { @@ -170,37 +168,33 @@ TEST(BestFeatureRegression, NoWinner) { const int32 num_splits = 3; const int32 num_classes = 4; // when counts are all the same, the most reasonable thing to do is pick 0. - Tensor total_sums = test::AsTensor( - {1, 5, 6, 7, - 0, 0, 0, 0, - 10, 8, 6, 9, // this one - -1, -1, -1, -1}, - {num_accumulators, num_classes}); + Tensor total_sums = + test::AsTensor({1, 5, 6, 7, 0, 0, 0, 0, 10, 8, 6, 9, // this one + -1, -1, -1, -1}, + {num_accumulators, num_classes}); Tensor total_squares = test::AsTensor( - {1, 5, 6, 7, - 0, 0, 0, 0, - 100, 50, 40, 45, // this one + {1, 5, 6, 7, 0, 0, 0, 0, 100, 50, 40, 45, // this one -1, -1, -1, -1}, {num_accumulators, num_classes}); - Tensor split_sums = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 10, 8, 6, 9, 10, 8, 6, 9, 10, 8, 6, 9, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, - {num_accumulators, num_splits, num_classes}); + Tensor split_sums = + test::AsTensor({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 8, + 6, 9, 10, 8, 6, 9, 10, 8, 6, 9, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {num_accumulators, num_splits, num_classes}); Tensor split_squares = test::AsTensor( - {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 100, 50, 40, 45, 100, 50, 40, 45, 100, 50, 40, 45, // this one - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, + {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 100, 50, 40, 45, 100, 50, 40, 45, 100, 50, 40, 45, // this one + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, {num_accumulators, num_splits, num_classes}); EXPECT_EQ(BestFeatureRegression(total_sums, total_squares, split_sums, - split_squares, 2), 0); + split_squares, 2), + 0); } } // namespace tensorforest } // namespace tensorflow - diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc index 81e2a1b2a1b720574210e376fa786923367794a6..f4a7058ddb8bfdd6393a9369006aabc29d058d3b 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc @@ -14,8 +14,8 @@ // ============================================================================= #include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" namespace tensorflow { @@ -58,8 +58,7 @@ CandidateGraphRunner::CandidateGraphRunner( // Features don't change, store them in a tensor. const auto& oblique = split.inequality_left_child_test().oblique(); const int32 feat_size = oblique.features_size(); - features_.reset( - new Tensor(tensorflow::DT_INT32, TensorShape({feat_size}))); + features_.reset(new Tensor(tensorflow::DT_INT32, TensorShape({feat_size}))); auto feat = features_->flat(); int i = 0; for (const auto& id : oblique.features()) { @@ -67,10 +66,10 @@ CandidateGraphRunner::CandidateGraphRunner( } } -void CandidateGraphRunner::RunOp( - const string& name, const TensorNameValueList& inputs, - const std::vector& output_tensor_names, - std::vector* outputs) { +void CandidateGraphRunner::RunOp(const string& name, + const TensorNameValueList& inputs, + const std::vector& output_tensor_names, + std::vector* outputs) { std::vector op_name; if (name != kNoOp) { op_name.push_back(name); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h index cced26b9036ba8ba6c5994b7483261a062f80588..328af28725af016e90b30ae2d303ffba15c81c1f 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h @@ -26,7 +26,6 @@ namespace tensorflow { namespace tensorforest { - // Keep a tree ensemble in memory for efficient evaluation and mutation. class DecisionTreeResource : public ResourceBase { public: @@ -35,15 +34,12 @@ class DecisionTreeResource : public ResourceBase { string DebugString() override { return strings::StrCat("DecisionTree[size=", - decision_tree_->decision_tree().nodes_size(), - "]"); + decision_tree_->decision_tree().nodes_size(), "]"); } void MaybeInitialize(); - const decision_trees::Model& decision_tree() const { - return *decision_tree_; - } + const decision_trees::Model& decision_tree() const { return *decision_tree_; } decision_trees::Model* mutable_decision_tree() { return decision_tree_.get(); @@ -59,9 +55,7 @@ class DecisionTreeResource : public ResourceBase { // Resets the resource and frees the proto. // Caller needs to hold the mutex lock while calling this. - void Reset() { - decision_tree_.reset(new decision_trees::Model()); - } + void Reset() { decision_tree_.reset(new decision_trees::Model()); } mutex* get_mutex() { return &mu_; } @@ -84,7 +78,6 @@ class DecisionTreeResource : public ResourceBase { std::vector> node_evaluators_; }; - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h index 85ce7b825b11983307370bb3ac30eeec9b6b2c99..bf2b2aaa3c8f433ab4fc145217857112f7a0a579 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h @@ -22,7 +22,6 @@ namespace tensorflow { namespace tensorforest { - // Base class for evaluators of decision nodes that effectively copy proto // contents into C++ structures for faster execution. class DecisionNodeEvaluator { diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc index 5c49b87443e7b1f4ef532256ae2efdc9fa985d8a..af5cf72a3c0bea0eef45c3446acf52ff389c6751 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc @@ -20,11 +20,11 @@ namespace tensorflow { namespace { +using tensorflow::decision_trees::InequalityTest; +using tensorflow::decision_trees::MatchingValuesTest; using tensorflow::tensorforest::InequalityDecisionNodeEvaluator; using tensorflow::tensorforest::MatchingValuesDecisionNodeEvaluator; using tensorflow::tensorforest::ObliqueInequalityDecisionNodeEvaluator; -using tensorflow::decision_trees::InequalityTest; -using tensorflow::decision_trees::MatchingValuesTest; TEST(InequalityDecisionNodeEvaluatorTest, TestLessOrEqual) { InequalityTest test; @@ -124,4 +124,3 @@ TEST(ObliqueDecisionNodeEvaluatorTest, Basic) { } // namespace } // namespace tensorflow - diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h index 0d6712e9e552d7045eb198f7e65d04eb42eff920..eea0be27caf0a022ba7acaacd359c75a2df4eedb 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h @@ -40,9 +40,7 @@ class FertileStatsResource : public ResourceBase { model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_); } - string DebugString() override { - return "FertileStats"; - } + string DebugString() override { return "FertileStats"; } void ExtractFromProto(const FertileStats& stats); @@ -50,8 +48,7 @@ class FertileStatsResource : public ResourceBase { // Resets the resource and frees the proto. // Caller needs to hold the mutex lock while calling this. - void Reset() { - } + void Reset() {} // Reset the stats for a node, but leave the leaf_stats intact. void ResetSplitStats(int32 node_id, int32 depth) { @@ -84,7 +81,6 @@ class FertileStatsResource : public ResourceBase { // was found. bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth); - private: mutex mu_; std::shared_ptr model_op_; @@ -94,7 +90,6 @@ class FertileStatsResource : public ResourceBase { void AllocateNode(int32 node_id, int32 depth); }; - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc index 3ce630e3a9691b87ad291a9f29616f741953dd84..da600d34eacdf27514709240723e5bb730cfe7f0 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc @@ -20,7 +20,6 @@ #include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" #include "tensorflow/core/lib/random/distribution_sampler.h" - namespace tensorflow { namespace tensorforest { @@ -454,14 +453,14 @@ void DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const { class_stats->add_value()->set_float_value(total_counts_[i]); } - for (int split_num = 0; split_num < num_splits(); ++split_num) { + for (int split_num = 0; split_num < num_splits(); ++split_num) { auto* cand = slot->add_candidates(); *cand->mutable_split() = splits_[split_num]; auto* left_stats = cand->mutable_left_stats() ->mutable_classification() ->mutable_dense_counts(); for (int i = 0; i < num_outputs_; ++i) { - left_stats->add_value()->set_float_value(left_count(split_num, i)); + left_stats->add_value()->set_float_value(left_count(split_num, i)); } } } @@ -546,7 +545,7 @@ void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const { (*class_stats)[entry.first] = val; } - for (int split_num = 0; split_num < num_splits(); ++split_num) { + for (int split_num = 0; split_num < num_splits(); ++split_num) { auto* cand = slot->add_candidates(); *cand->mutable_split() = splits_[split_num]; auto* left_stats = cand->mutable_left_stats() @@ -561,8 +560,8 @@ void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const { } } -float SparseClassificationGrowStats::GiniScore( - int split, float* left_sum, float* right_sum) const { +float SparseClassificationGrowStats::GiniScore(int split, float* left_sum, + float* right_sum) const { float left_square = 0, right_square = 0; *left_sum = 0; *right_sum = 0; @@ -844,12 +843,11 @@ void LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const { total_squares->add_value()->set_float_value(total_sum_squares_[i]); } - for (int split_num = 0; split_num < num_splits(); ++split_num) { + for (int split_num = 0; split_num < num_splits(); ++split_num) { auto* cand = slot->add_candidates(); *cand->mutable_split() = splits_[split_num]; - auto* sums = cand->mutable_left_stats() - ->mutable_regression() - ->mutable_mean_output(); + auto* sums = + cand->mutable_left_stats()->mutable_regression()->mutable_mean_output(); auto* squares = cand->mutable_left_stats() ->mutable_regression() ->mutable_mean_output_squares(); @@ -891,20 +889,17 @@ float LeastSquaresRegressionGrowStats::SplitVariance(int split) const { float total_variance = 0; for (int i = 0; i < params_.num_outputs(); ++i) { // Left side - const float le_x = - left_sum(split, i) / left_counts_[split]; + const float le_x = left_sum(split, i) / left_counts_[split]; - const float le_x2 = - left_square(split, i) / left_counts_[split]; + const float le_x2 = left_square(split, i) / left_counts_[split]; total_variance += le_x2 - le_x * le_x; // Right side const float re_x = (total_sum_[i] - left_sum(split, i)) / (weight_sum_ - left_counts_[split]); - const float re_x2 = - (total_sum_squares_[i] - left_square(split, i)) / - (weight_sum_ - left_counts_[split]); + const float re_x2 = (total_sum_squares_[i] - left_square(split, i)) / + (weight_sum_ - left_counts_[split]); total_variance += re_x2 - re_x * re_x; } return total_variance; @@ -937,8 +932,7 @@ bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const { left->set_weight_sum(left_counts_[best_index]); auto* left_output_sum = left_reg_stats->mutable_mean_output(); for (int i = 0; i < num_outputs; ++i) { - left_output_sum->add_value()->set_float_value( - left_sum(best_index, i)); + left_output_sum->add_value()->set_float_value(left_sum(best_index, i)); } // Right @@ -947,8 +941,8 @@ bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const { right->set_weight_sum(weight_sum_ - left_counts_[best_index]); auto* right_output_sum = right_reg_stats->mutable_mean_output(); for (int i = 0; i < num_outputs; ++i) { - right_output_sum->add_value()->set_float_value( - total_sum_[i] - left_sum(best_index, i)); + right_output_sum->add_value()->set_float_value(total_sum_[i] - + left_sum(best_index, i)); } return true; } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h index f938d08c84d72b4c5a71e8f7fb0f639aa70e3e49..dc3e9fe79d32a19930d500b62b520eddb4b41aa8 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h @@ -73,21 +73,15 @@ class GrowStats { const InputTarget* target, int example) {} void RemoveSplit(int split_num); - int num_splits() const { - return splits_.size(); - } + int num_splits() const { return splits_.size(); } - float weight_sum() const { - return weight_sum_; - } + float weight_sum() const { return weight_sum_; } virtual bool IsInitialized() const { return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_; } - int32 depth() const { - return depth_; - } + int32 depth() const { return depth_; } protected: GrowStats(const TensorForestParams& params, int32 depth); @@ -206,8 +200,8 @@ class ClassificationStats : public GrowStats { virtual float left_count(int split, int class_num) const = 0; virtual float right_count(int split, int class_num) const = 0; - virtual void ClassificationAddLeftExample( - int split, int64 int_label, float weight) = 0; + virtual void ClassificationAddLeftExample(int split, int64 int_label, + float weight) = 0; virtual void ClassificationAddRightExample(int split, int64 int_label, float weight) { // Does nothing by default, but sub-classes can override. @@ -316,7 +310,7 @@ class DenseClassificationGrowStats : public ClassificationStats { void PackToProto(FertileSlot* slot) const override; void InitLeafClassStats(int best_split_index, LeafStat* left_stats, - LeafStat* right_stats) const; + LeafStat* right_stats) const override; protected: void ClassificationAddSplitStats() override { @@ -375,15 +369,13 @@ class SparseClassificationGrowStats : public ClassificationStats { SparseClassificationGrowStats(const TensorForestParams& params, int32 depth) : ClassificationStats(params, depth) {} - void Initialize() override { - Clear(); - } + void Initialize() override { Clear(); } void ExtractFromProto(const FertileSlot& slot) override; void PackToProto(FertileSlot* slot) const override; void InitLeafClassStats(int best_split_index, LeafStat* left_stats, - LeafStat* right_stats) const; + LeafStat* right_stats) const override; protected: void ClassificationAddSplitStats() override { @@ -476,7 +468,7 @@ class FixedSizeSparseClassificationGrowStats : public ClassificationStats { void PackToProto(FertileSlot* slot) const override; void InitLeafClassStats(int best_split_index, LeafStat* left_stats, - LeafStat* right_stats) const; + LeafStat* right_stats) const override; protected: void ClassificationAddSplitStats() override { @@ -562,9 +554,9 @@ class LeastSquaresRegressionGrowStats : public GrowStats { } void RemoveSplitStats(int split_num) override { left_sums_.erase(left_sums_.begin() + num_outputs_ * split_num, - left_sums_.begin() + num_outputs_ * (split_num + 1)); + left_sums_.begin() + num_outputs_ * (split_num + 1)); left_squares_.erase(left_squares_.begin() + num_outputs_ * split_num, - left_squares_.begin() + num_outputs_ * (split_num + 1)); + left_squares_.begin() + num_outputs_ * (split_num + 1)); left_counts_.erase(left_counts_.begin() + split_num, left_counts_.begin() + (split_num + 1)); } @@ -605,7 +597,6 @@ class LeastSquaresRegressionGrowStats : public GrowStats { std::vector left_counts_; }; - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc index ceb58d2ead5c2f148c96d9cb9532a73688593d33..26e989928e00de1b2ae1646abf216adfbec2be4f 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc @@ -24,21 +24,21 @@ namespace tensorflow { namespace { -using tensorflow::tensorforest::GrowStats; -using tensorflow::tensorforest::TestableInputTarget; -using tensorflow::tensorforest::FertileSlot; +using tensorflow::decision_trees::BinaryNode; +using tensorflow::decision_trees::FeatureId; +using tensorflow::decision_trees::InequalityTest; using tensorflow::tensorforest::DenseClassificationGrowStats; -using tensorflow::tensorforest::SparseClassificationGrowStats; +using tensorflow::tensorforest::FertileSlot; using tensorflow::tensorforest::FixedSizeClassStats; using tensorflow::tensorforest::FixedSizeSparseClassificationGrowStats; +using tensorflow::tensorforest::GrowStats; using tensorflow::tensorforest::LeastSquaresRegressionGrowStats; -using tensorflow::tensorforest::TensorForestParams; +using tensorflow::tensorforest::SparseClassificationGrowStats; using tensorflow::tensorforest::SPLIT_FINISH_BASIC; using tensorflow::tensorforest::SPLIT_FINISH_DOMINATE_HOEFFDING; using tensorflow::tensorforest::SPLIT_PRUNE_HOEFFDING; -using tensorflow::decision_trees::BinaryNode; -using tensorflow::decision_trees::InequalityTest; -using tensorflow::decision_trees::FeatureId; +using tensorflow::tensorforest::TensorForestParams; +using tensorflow::tensorforest::TestableInputTarget; BinaryNode MakeSplit(const string& feat, float val) { BinaryNode split; @@ -52,8 +52,7 @@ BinaryNode MakeSplit(const string& feat, float val) { return split; } -void RunBatch(GrowStats* stats, - const TestableInputTarget* target) { +void RunBatch(GrowStats* stats, const TestableInputTarget* target) { std::unique_ptr dataset( new tensorflow::tensorforest::TestableDataSet( {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, 2)); @@ -102,18 +101,10 @@ class TestableRunningStats : public DenseClassificationGrowStats { TestableRunningStats(const TensorForestParams& params, int32 depth) : DenseClassificationGrowStats(params, depth) {} - float test_left_sum(int split) { - return get_left_gini()->sum(split); - } - float test_left_square(int split) { - return get_left_gini()->square(split); - } - float test_right_sum(int split) { - return get_right_gini()->sum(split); - } - float test_right_square(int split) { - return get_right_gini()->square(split); - } + float test_left_sum(int split) { return get_left_gini()->sum(split); } + float test_left_square(int split) { return get_left_gini()->square(split); } + float test_right_sum(int split) { return get_right_gini()->sum(split); } + float test_right_square(int split) { return get_right_gini()->square(split); } }; TEST(GrowStatsDenseClassificationTest, BasicRunningStats) { @@ -166,9 +157,7 @@ class TestableFinishEarly : public DenseClassificationGrowStats { int num_times_called_; protected: - void CheckFinishEarlyHoeffding() override { - ++num_times_called_; - } + void CheckFinishEarlyHoeffding() override { ++num_times_called_; } }; TEST(GrowStatsDenseClassificationTest, TestFinishEarly) { @@ -212,7 +201,6 @@ TEST(GrowStatsDenseClassificationTest, TestFinishEarly) { ASSERT_EQ(stat->num_times_called_, 9); } - TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) { TensorForestParams params; params.set_num_outputs(2); @@ -224,7 +212,8 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) { finish->set_type(SPLIT_FINISH_BASIC); finish->mutable_check_every_steps()->set_constant_value(100); params.mutable_pruning_type()->set_type(SPLIT_PRUNE_HOEFFDING); - params.mutable_pruning_type()->mutable_prune_every_samples() + params.mutable_pruning_type() + ->mutable_prune_every_samples() ->set_constant_value(1); // On each iteration, we add two examples, one of class 0 and one @@ -234,8 +223,8 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) { std::vector weights = {1, 1}; TestableInputTarget target(labels, weights, 1); std::unique_ptr dataset( - new tensorflow::tensorforest::TestableDataSet( - {-1.0, -1.0, 1.0, -1.0}, 2)); + new tensorflow::tensorforest::TestableDataSet({-1.0, -1.0, 1.0, -1.0}, + 2)); DenseClassificationGrowStats stats(params, 1); stats.Initialize(); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc index 14cb19d36f33e478728aba3e28b7bea11b691d34..d43884481afbbbc988d6eb80e01e49663df6914b 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc @@ -21,8 +21,6 @@ namespace tensorflow { namespace tensorforest { namespace { -const int32 SPARSE_DEFAULT = 0; - bool DecideInequalityTest(const decision_trees::InequalityTest& test, float value) { float bias = test.threshold().float_value(); @@ -111,10 +109,10 @@ void TensorDataSet::set_input_tensors(const Tensor& dense, dense_data_.reset(new DenseStorageType(dense.tensor())); } if (sparse_indices.shape().dims() == 2) { - sparse_indices_.reset(new SparseIndicesStorageType( - sparse_indices.tensor())); - sparse_values_.reset(new SparseValuesStorageType( - sparse_values.tensor())); + sparse_indices_.reset( + new SparseIndicesStorageType(sparse_indices.tensor())); + sparse_values_.reset( + new SparseValuesStorageType(sparse_values.tensor())); sparse_batch_size_ = sparse_shape.tensor()(0); } original_dense_tensor_ = dense; diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h index eafad6b591672f67ae816405ff603f9aaba30a1b..c544a8c75e9bfe8fe6bbea8913e7be17d868bfef 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h @@ -93,9 +93,7 @@ class TensorDataSet { // an int32 you can avoid the atoi32. virtual float GetExampleValue(int example, int32 feature_id) const; - int num_features() { - return available_features_.size(); - } + int num_features() { return available_features_.size(); } const Tensor& original_tensor() const { return original_dense_tensor_; } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h index 44ec09c50ef3d092bd1bf7f051f492e1fffdd05b..d4402b6055a36d38042a0e6cfa07b532ec11c093 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h @@ -79,9 +79,7 @@ class TensorInputTarget : public StoredInputTarget { return (*target_)(example_index * num_targets_ + target_index); } - const Tensor& original_tensor() const { - return original_tensor_; - } + const Tensor& original_tensor() const { return original_tensor_; } protected: Tensor original_tensor_; diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc index d43c068e462ff78b114fb29bd8cf0ee0c6080fcd..83614a25314117ef9ba29b4dcf6ebee8f7f3e226 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc @@ -160,6 +160,5 @@ void RegressionLeafModelOperator::ExportModel( } } - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc index ffd92c01f9a59719e6bb2458c2f28253c364a2e8..ab4191809b6a7400114acf85991c74acfac55505 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc @@ -26,19 +26,19 @@ namespace { using tensorflow::decision_trees::Leaf; using tensorflow::tensorforest::DenseClassificationLeafModelOperator; using tensorflow::tensorforest::LeafModelOperator; -using tensorflow::tensorforest::SparseClassificationLeafModelOperator; -using tensorflow::tensorforest::SparseOrDenseClassificationLeafModelOperator; using tensorflow::tensorforest::LeafStat; using tensorflow::tensorforest::RegressionLeafModelOperator; -using tensorflow::tensorforest::TestableInputTarget; +using tensorflow::tensorforest::SparseClassificationLeafModelOperator; +using tensorflow::tensorforest::SparseOrDenseClassificationLeafModelOperator; using tensorflow::tensorforest::TensorForestParams; +using tensorflow::tensorforest::TestableInputTarget; const int32 kNumClasses = 3; constexpr char kRegressionStatProto[] = - "weight_sum: 3 " - "regression { " - "mean_output { " + "weight_sum: 3 " + "regression { " + "mean_output { " "value { " " float_value: 27 " "} " @@ -48,8 +48,8 @@ constexpr char kRegressionStatProto[] = "value { " " float_value: 10 " "} " - "} " - "mean_output_squares { " + "} " + "mean_output_squares { " "value {" " float_value: 245" "}" @@ -59,8 +59,8 @@ constexpr char kRegressionStatProto[] = "value {" " float_value: 46" "}" - "}" -"}"; + "}" + "}"; void TestClassificationNormalUse(const std::unique_ptr& op) { Leaf l; @@ -83,7 +83,6 @@ void TestClassificationNormalUse(const std::unique_ptr& op) { EXPECT_FLOAT_EQ(op->GetOutputValue(l, 1), 3.4); } - TEST(DenseLeafModelOperatorsTest, NormalUse) { TensorForestParams params; params.set_num_outputs(kNumClasses); @@ -182,7 +181,7 @@ TEST(SparseLeafModelOperatorsTest, InitWithExisting) { std::unique_ptr leaf(new Leaf); - op->ExportModel( *stat, leaf.get()); + op->ExportModel(*stat, leaf.get()); // Make sure it was initialized correctly. EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 0), 1.1); @@ -194,7 +193,6 @@ TEST(SparseLeafModelOperatorsTest, InitWithExisting) { EXPECT_EQ(leaf->sparse_vector().sparse_value().size(), kNumClasses); } - TEST(RegressionLeafModelOperatorsTest, NormalUse) { TensorForestParams params; params.set_num_outputs(kNumClasses); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params.h b/tensorflow/contrib/tensor_forest/kernels/v4/params.h index b0ed949424756cc498d4b7ad1fb1867fff11b265..7583e3d0402a3a1d07f3696727b285747dc887de 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/params.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/params.h @@ -24,7 +24,6 @@ namespace tensorforest { // Return the value of the given depth-dependent parameter given a leaf's depth. float ResolveParam(const DepthDependentParam& param, int32 depth); - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc index 801881af1368dc33f00b356d12bea07ae3161ef6..4010a71006d58df0bec6d3686a9c47433b46fdd4 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc @@ -71,5 +71,3 @@ TEST(ParamsTest, TestThreshold) { } } // namespace - - diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc index cdb1d80a4bbd47d1481ecde2348bef500bd125f1..b7b60d0ab8c2670cec8b029d1f42c5edd3690afe 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc @@ -52,8 +52,8 @@ std::unique_ptr SplitCollectionOperator::CreateGrowStats( new SparseClassificationGrowStats(params_, depth)); case STATS_LEAST_SQUARES_REGRESSION: - return std::unique_ptr(new LeastSquaresRegressionGrowStats( - params_, depth)); + return std::unique_ptr( + new LeastSquaresRegressionGrowStats(params_, depth)); case STATS_FIXED_SIZE_SPARSE_GINI: return std::unique_ptr( @@ -136,8 +136,7 @@ void SplitCollectionOperator::CreateAndInitializeCandidateWithExample( stats_.at(node_id)->AddSplit(split, input_data, target, example); } -bool SplitCollectionOperator::BestSplit(int32 node_id, - SplitCandidate* best, +bool SplitCollectionOperator::BestSplit(int32 node_id, SplitCandidate* best, int32* depth) const { auto* slot = stats_.at(node_id).get(); *depth = slot->depth(); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h index ad52f89faddb15be77644b5dc374aca73c46b149..c606ff98c67f411a5817f0282238fdaf3be03642 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h @@ -71,9 +71,7 @@ class SplitCollectionOperator { } // Perform any necessary cleanup for any tracked state for the slot. - virtual void ClearSlot(int32 node_id) { - stats_.erase(node_id); - } + virtual void ClearSlot(int32 node_id) { stats_.erase(node_id); } // Return true if slot is fully initialized. virtual bool IsInitialized(int32 node_id) const; diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc index 0bec198e97e8215d2cfdb9ada5355dd5b0d2d97b..c749fbe69e17769c2f2b69bcf541eb0eb8b9e7e8 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc @@ -32,9 +32,9 @@ namespace tensorforest { // smoothed_sum = stats.sum() + #_classes float GiniImpurity(const LeafStat& stats, int32 num_classes) { const float smoothed_sum = num_classes + stats.weight_sum(); - return 1.0 - ( - (stats.classification().gini().square() - + 2 * stats.weight_sum() + num_classes) / (smoothed_sum * smoothed_sum)); + return 1.0 - ((stats.classification().gini().square() + + 2 * stats.weight_sum() + num_classes) / + (smoothed_sum * smoothed_sum)); } float WeightedGiniImpurity(const LeafStat& stats, int32 num_classes) { @@ -46,21 +46,20 @@ void UpdateGini(LeafStat* stats, float old_val, float weight) { // Equivalent to stats->square() - old_val * old_val + new_val * new_val, // (for new_val = old_val + weight), but more numerically stable. stats->mutable_classification()->mutable_gini()->set_square( - stats->classification().gini().square() - + weight * weight + 2 * old_val * weight); + stats->classification().gini().square() + weight * weight + + 2 * old_val * weight); } - float Variance(const LeafStat& stats, int output) { if (stats.weight_sum() == 0) { return 0; } const float e_x = - stats.regression().mean_output().value(output).float_value() - / stats.weight_sum(); + stats.regression().mean_output().value(output).float_value() / + stats.weight_sum(); const auto e_x2 = - stats.regression().mean_output_squares().value(output).float_value() - / stats.weight_sum(); + stats.regression().mean_output_squares().value(output).float_value() / + stats.weight_sum(); return e_x2 - e_x * e_x; } @@ -75,8 +74,7 @@ float TotalVariance(const LeafStat& stats) { float SmoothedGini(float sum, float square, int num_classes) { // See comments for GiniImpurity above. const float smoothed_sum = num_classes + sum; - return 1.0 - - (square + 2 * sum + num_classes) / (smoothed_sum * smoothed_sum); + return 1.0 - (square + 2 * sum + num_classes) / (smoothed_sum * smoothed_sum); } float WeightedSmoothedGini(float sum, float square, int num_classes) { diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h index 289c81e9d51dbc5d2023f7eabce8c2089748645d..38deb3e3cd816aae5fe66f26cd4b934316d03ce4 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h @@ -27,9 +27,7 @@ class TestableInputTarget : public StoredInputTarget> { : StoredInputTarget(new std::vector(t), new std::vector(w), num_t) {} - int NumItems() const { - return target_->size(); - } + int NumItems() const { return target_->size(); } int32 GetTargetAsClassIndex(int example_index, int target_index) const override { @@ -51,7 +49,6 @@ class TestableInputTarget : public StoredInputTarget> { } }; - class TestableDataSet : public TensorDataSet { public: TestableDataSet(const std::vector& data, int num_features) diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD index 6ff5a9e2b18ead9ea9f77f796b91b05d9b895489..4175d8adb58a85728519042a9870e8c4590232ba 100644 --- a/tensorflow/contrib/tensorboard/db/BUILD +++ b/tensorflow/contrib/tensorboard/db/BUILD @@ -40,7 +40,6 @@ cc_library( hdrs = ["summary_db_writer.h"], copts = tf_copts(), deps = [ - ":schema", ":summary_converter", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc index d891e86e53f4d760bfaea0e67601cfda037a4564..85b3e7231bcb433e9510522597c03c5f764f06cf 100644 --- a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc @@ -42,14 +42,14 @@ class SummaryFileWriter : public SummaryWriterInterface { if (is_dir.code() != tensorflow::error::NOT_FOUND) { return is_dir; } - TF_RETURN_IF_ERROR(env_->CreateDir(logdir)); + TF_RETURN_IF_ERROR(env_->RecursivelyCreateDir(logdir)); } mutex_lock ml(mu_); events_writer_ = tensorflow::MakeUnique(io::JoinPath(logdir, "events")); - if (!events_writer_->InitWithSuffix(filename_suffix)) { - return errors::Unknown("Could not initialize events writer."); - } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + events_writer_->InitWithSuffix(filename_suffix), + "Could not initialize events writer."); last_flush_ = env_->NowMicros(); is_initialized_ = true; return Status::OK(); @@ -151,9 +151,8 @@ class SummaryFileWriter : public SummaryWriterInterface { events_writer_->WriteEvent(*e); } queue_.clear(); - if (!events_writer_->Flush()) { - return errors::InvalidArgument("Could not flush events file."); - } + TF_RETURN_WITH_CONTEXT_IF_ERROR(events_writer_->Flush(), + "Could not flush events file."); last_flush_ = env_->NowMicros(); return Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..cf67c27b70f1a8c761b71074d3eb5cd962a68488 --- /dev/null +++ b/tensorflow/contrib/tensorrt/BUILD @@ -0,0 +1,246 @@ +# Description: +# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow +# and provide TensorRT operators and converter package. +# APIs are meant to change over time. + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_copts", + "tf_cuda_library", + "tf_custom_op_library", + "tf_custom_op_library_additional_deps", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", +) +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load( + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", +) + +tf_cuda_cc_test( + name = "tensorrt_test_cc", + size = "small", + srcs = ["tensorrt_test.cc"], + tags = [ + "manual", + "notap", + ], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_custom_op_library( + name = "python/ops/_trt_engine_op.so", + srcs = ["ops/trt_engine_op.cc"], + deps = [ + ":trt_engine_op_kernel", + ":trt_shape_function", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_cuda_library( + name = "trt_shape_function", + srcs = ["shape_fn/trt_shfn.cc"], + hdrs = ["shape_fn/trt_shfn.h"], + visibility = ["//visibility:public"], + deps = [ + ":trt_logging", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), +) + +cc_library( + name = "trt_engine_op_kernel", + srcs = ["kernels/trt_engine_op.cc"], + hdrs = ["kernels/trt_engine_op.h"], + copts = tf_copts(), + deps = [ + ":trt_logging", + "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:stream_executor_headers_lib", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), + alwayslink = 1, +) + +tf_gen_op_libs( + op_lib_names = ["trt_engine_op"], + deps = if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_cuda_library( + name = "trt_logging", + srcs = ["log/trt_logger.cc"], + hdrs = ["log/trt_logger.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_gen_op_wrapper_py( + name = "trt_engine_op", + deps = [ + ":trt_engine_op_op_lib", + ":trt_logging", + ":trt_shape_function", + ], +) + +tf_custom_op_py_library( + name = "trt_engine_op_loader", + srcs = ["python/ops/trt_engine_op.py"], + dso = [ + ":python/ops/_trt_engine_op.so", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:resources", + ], +) + +py_library( + name = "init_py", + srcs = [ + "__init__.py", + "python/__init__.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":trt_convert_py", + ":trt_ops_py", + ], +) + +py_library( + name = "trt_ops_py", + srcs_version = "PY2AND3", + deps = [ + ":trt_engine_op", + ":trt_engine_op_loader", + ], +) + +py_library( + name = "trt_convert_py", + srcs = ["python/trt_convert.py"], + srcs_version = "PY2AND3", + deps = [ + ":wrap_conversion", + ], +) + +tf_py_wrap_cc( + name = "wrap_conversion", + srcs = ["trt_conversion.i"], + copts = tf_copts(), + deps = [ + ":trt_conversion", + "//tensorflow/core:framework_lite", + "//util/python:python_headers", + ], +) + +# Library for the node-level conversion portion of TensorRT operation creation +tf_cuda_library( + name = "trt_conversion", + srcs = [ + "convert/convert_graph.cc", + "convert/convert_nodes.cc", + ], + hdrs = [ + "convert/convert_graph.h", + "convert/convert_nodes.h", + ], + deps = [ + ":segment", + ":trt_logging", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:devices", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/optimizers:constant_folding", + "//tensorflow/core/grappler/optimizers:layout_optimizer", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), +) + +# Library for the segmenting portion of TensorRT operation creation +cc_library( + name = "segment", + srcs = ["segment/segment.cc"], + hdrs = [ + "segment/segment.h", + "segment/union_find.h", + ], + linkstatic = 1, + deps = [ + "//tensorflow/core:graph", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "@protobuf_archive//:protobuf_headers", + ], +) + +tf_cc_test( + name = "segment_test", + size = "small", + srcs = ["segment/segment_test.cc"], + deps = [ + ":segment", + "//tensorflow/c:c_api", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dfcce0fd00eedf3341850bbc23927dc3b2e2d2aa --- /dev/null +++ b/tensorflow/contrib/tensorrt/README.md @@ -0,0 +1,40 @@ +Using TensorRT in TensorFlow +============================ + +This module provides necessary bindings and introduces TRT_engine_op +operator that wraps a subgraph in TensorRT. + +Compilation +----------- + +In order to compile the module, you need to have a local TensorRT +installation (libnvinfer.so and respective include files). During the +configuration step, TensorRT should be enabled and installation path +should be set. If installed through package managers (deb,rpm), +configure script should find the necessary components from the system +automatically. If installed from tar packages, user has to set path to +location where the library is installed during configuration. + + +``` +bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_package +bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ +``` + +After the installation of tensorflow package, TensorRT transformation +will be available. An example use is shown below. + +```python +import tensorflow as tf +import tensorflow.contrib.tensorrt as trt +#... create and train or load model +gdef = sess.graph.as_graph_def() +trt_gdef = trt.create_inference_graph( + gdef, #original graph_def + ["output"], #name of output node(s) + max_batch_size, #maximum batch size to run the inference + max_workspace_size_bytes) # max memory for TensorRT to use +tf.reset_default_graph() +tf.import_graph_def(graph_def=trt_gdef) +#...... run inference +``` diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd551d70b4385b14b84b7b98a6d16b0c03733d38 --- /dev/null +++ b/tensorflow/contrib/tensorrt/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Exposes the python wrapper for TensorRT graph transforms.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.tensorrt.python import * +# pylint: enable=unused-import,wildcard-import diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc new file mode 100644 index 0000000000000000000000000000000000000000..899448004f917b36b35fb871a66a9d857736a338 --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -0,0 +1,273 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/devices.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/layout_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +namespace convert { +namespace { + +static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) { + // LINT.IfChange + // TODO(jie): Segmentation shouldn't associated with op name. + // Split it into a registration for each kernel. + static const std::set candidate_ops = { + "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu", + "Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean" + }; + // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) + return candidate_ops.count(node_def.op()); +} + +void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, + const std::set& subgraph_node_ids, + tensorflow::EdgeSet* incoming_edges) { + for (int node_id : subgraph_node_ids) { + const tensorflow::Node* node = graph.FindNodeId(node_id); + for (const tensorflow::Edge* edge : node->in_edges()) { + if (!subgraph_node_ids.count(edge->src()->id()) && + !edge->src()->IsSource()) { + incoming_edges->insert(edge); + } + } + } +} + +void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, + const std::set& subgraph_node_ids, + tensorflow::EdgeSet* outgoing_edges) { + for (int node_id : subgraph_node_ids) { + const tensorflow::Node* node = graph.FindNodeId(node_id); + for (const tensorflow::Edge* edge : node->out_edges()) { + if (!subgraph_node_ids.count(edge->dst()->id()) && + !edge->dst()->IsSink()) { + outgoing_edges->insert(edge); + } + } + } +} + +std::pair ParseTensorName(string name, int default_idx = 0) { + int idx = default_idx; + size_t sep = name.find_last_of(':'); + if (sep != string::npos) { + name = name.substr(0, sep); + idx = std::stoi(name.substr(sep + 1)); + } + return std::make_pair(name, idx); +} + +std::unordered_map> BuildTensorNameMap( + const std::vector& tensor_names) { + std::unordered_map> result; + for (string const& tensor_name : tensor_names) { + string node_name; + int index; + std::tie(node_name, index) = ParseTensorName(tensor_name); + result[node_name].push_back(index); + } + return result; +} + +tensorflow::Status ConvertSubGraphToTensorRT( + const std::vector& output_names, + const std::set& subgraph_node_ids, + size_t max_batch_size, // Max batch size that engine will be created for + // Max amount of memory that engine will be allowed to consume, in bytes + size_t max_workspace_size_bytes, + const tensorflow::grappler::GraphProperties& graph_properties, + tensorflow::Graph* graph) { + tensorflow::EdgeSet subgraph_incoming_edges; + GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges); + + std::vector> subgraph_inputs; + + // Collect inputs by looking for incoming edges + for (const tensorflow::Edge* edge : subgraph_incoming_edges) { + subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); + } + std::set> subgraph_outputs_set; + // Collect outputs referenced from output_names + auto output_name_to_index_map = BuildTensorNameMap(output_names); + for (int node_id : subgraph_node_ids) { + tensorflow::Node* node = graph->FindNodeId(node_id); + if (output_name_to_index_map.count(node->name())) { + for (int index : output_name_to_index_map.at(node->name())) { + subgraph_outputs_set.insert({node_id, index}); + } + } + } + // Collect outputs referenced from outgoing edges + tensorflow::EdgeSet subgraph_outgoing_edges; + GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges); + for (const tensorflow::Edge* edge : subgraph_outgoing_edges) { + subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); + } + // Impose an ordering on the outputs + std::vector> subgraph_outputs( + subgraph_outputs_set.begin(), subgraph_outputs_set.end()); + // Build TensorRT node and add it to the graph + tensorflow::NodeDef trt_node_def; + TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef( + *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs, + max_batch_size, max_workspace_size_bytes, graph_properties, + &trt_node_def)); + tensorflow::Status status; + tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status); + TF_RETURN_IF_ERROR(status); + + // Re-map outgoing edges to use the new TRT node instead of the orig subgraph + std::map, int> subgraph_edge_to_output_map; + for (size_t i = 0; i < subgraph_outputs.size(); ++i) { + subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i}); + } + TF_RETURN_IF_ERROR(status); + for (const tensorflow::Edge* edge : subgraph_outgoing_edges) { + std::pair old_src = {edge->src()->id(), edge->src_output()}; + int new_src_output = subgraph_edge_to_output_map.at(old_src); + TF_RETURN_IF_ERROR(graph->UpdateEdge(trt_node, new_src_output, edge->dst(), + edge->dst_input())); + } + // Remove the original subgraph + for (int node_id : subgraph_node_ids) { + tensorflow::Node* node = graph->FindNodeId(node_id); + // Don't remove the input placeholders + if (node->type_string() == "Placeholder") { + continue; + } + graph->RemoveNode(node); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status BuildNodeMap( + const tensorflow::Graph& graph, + std::unordered_map* node_map) { + for (auto* node : graph.op_nodes()) { + if (!node_map->insert({node->name(), node}).second) { + return tensorflow::errors::AlreadyExists( + "Node name is not unique in graph: " + node->name()); + } + } + return tensorflow::Status::OK(); +} + +} // namespace + +tensorflow::Status ConvertGraphDefToTensorRT( + const tensorflow::GraphDef& graph_def, + const std::vector& output_names, size_t max_batch_size, + size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def) { + // Optimization pass + tensorflow::grappler::GrapplerItem item; + item.fetch = output_names; + tensorflow::GraphDef gdef; + + // Layout optimization + item.graph = graph_def; + tensorflow::grappler::LayoutOptimizer optimizer; + tensorflow::grappler::Cluster* cluster; + + // Virtual cluster + tensorflow::DeviceProperties device_properties; + device_properties.set_type("GPU"); + device_properties.mutable_environment()->insert({"architecture", "6"}); + cluster = + new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}}); + + TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef)); + + // Constant folding + item.graph = gdef; + tensorflow::grappler::ConstantFolding fold(nullptr); + TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef)); + + // AJ refactoring shape inference through grappler/GraphProperties. + tensorflow::grappler::GraphProperties static_graph_properties(item); + TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(false)); + + // Build full graph + tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), + gdef.library()); + tensorflow::Graph graph(flib); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( + tensorflow::GraphConstructorOptions(), gdef, &graph)); + + // Segment the graph into subgraphs that can be converted to TensorRT + tensorflow::tensorrt::segment::SegmentOptions segment_options; + + // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT) + for (auto node : output_names) { + segment_options.exclude_node_list.insert(node); + } + + // TODO(sami): this should be passed as a knob!!!! + segment_options.minimum_segment_size = 2; + tensorflow::tensorrt::segment::SegmentNodesVector segments; + TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( + gdef, IsTensorRTCandidate, segment_options, &segments)); + if (segments.size() > 1) { + VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size(); + } + std::unordered_map node_map; + TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); + for (const std::set& subgraph_node_names : segments) { + std::set subgraph_node_ids; + for (const string& node_name : subgraph_node_names) { + subgraph_node_ids.insert(node_map.at(node_name)->id()); + } + TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT( + output_names, subgraph_node_ids, max_batch_size, + max_workspace_size_bytes, static_graph_properties, &graph)); + } + graph.ToGraphDef(new_graph_def); + return tensorflow::Status::OK(); +} + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..154ad3f2e8fb0ae702448097fbdece510df30223 --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +// max_batch_size: maximum batch size which can be used for inference for +// optimization targets inference run with max batch size. +// max_workspace_size_bytes: The upper bound of memory allowence for +// engine building. +tensorflow::Status ConvertGraphDefToTensorRT( + const tensorflow::GraphDef& graph_def, + const std::vector& output_names, size_t max_batch_size, + size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def); + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ee717dd7fb1eff4a11fb104cf5806ec8ab853d2 --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -0,0 +1,1601 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tensor_coding.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorrt/include/NvInfer.h" + +// Check if the types are equal. Cast to int first so that failure log message +// would work! +#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +namespace { + +inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, + nvinfer1::DataType* trt_dtype) { + switch (tf_dtype) { + case tensorflow::DataType::DT_FLOAT: + *trt_dtype = nvinfer1::DataType::kFLOAT; + break; + case tensorflow::DataType::DT_INT8: + *trt_dtype = nvinfer1::DataType::kINT8; + break; + case tensorflow::DataType::DT_HALF: + *trt_dtype = nvinfer1::DataType::kHALF; + break; + default: + return tensorflow::errors::InvalidArgument("Unsupported data type"); + } + return tensorflow::Status::OK(); +} + +inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) { + nvinfer1::Dims dims; + dims.nbDims = tensor.dims(); + for (int i = 0; i < dims.nbDims; i++) { + dims.d[i] = tensor.dim_size(i); + } + return dims; +} + +inline int64_t GetShapeSize(nvinfer1::Dims shape) { + // Returns total number of elements in shape + int64_t count = 1; + for (int d = 0; d < shape.nbDims; ++d) { + count *= shape.d[d]; + } + return count; +} + +static std::vector> CreateSamePadding( + const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel, + const std::vector& input_dims) { + std::vector> padding(input_dims.size()); + CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+? + + for (size_t i = 0; i < input_dims.size(); ++i) { + // Formula to calculate the padding + int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] - + input_dims[i]; + p = (p > 0) ? p : 0; + + // Right precedence padding, like in TensorFlow + int left = p / 2; + int right = p - left; + + VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right + << "paras: " << input_dims[i] << ", " << stride.d[i] << ", " + << "kernel: " << kernel.d[i]; + padding[i] = {left, right}; + } + return padding; +} + +class TRT_ShapedWeights { + public: + TRT_ShapedWeights(tensorflow::DataType type, const void* values, + nvinfer1::Dims shape) + : shape_(shape), type_(type), values_(values), empty_weight_flag_(false) { + // Note: this->shape.type[] is not used + } + + explicit TRT_ShapedWeights(tensorflow::DataType type) + : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {} + + TRT_ShapedWeights(const TRT_ShapedWeights& rhs) + : shape_(rhs.shape_), + type_(rhs.type_), + values_(rhs.values_), + empty_weight_flag_(rhs.empty_weight_flag_) {} + + int64_t count() const { + int64_t c = 1; + for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i]; + return c; + } + + nvinfer1::Weights GetWeightsForTRT() const { + nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT); + TF_CHECK_OK(ConvertDType(type_, &trt_type)); + if (empty_weight_flag_) return nvinfer1::Weights{trt_type, nullptr, 0}; + + // Note: this->shape.type[] is not used + return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)}; + } + + const void* GetValues() const { return values_; } + + void SetValues(const void* values) { values_ = values; } + + size_t size_bytes() const { + int type_size = tensorflow::DataTypeSize(this->type_); + return this->count() * type_size; + } + + // Default converter + operator nvinfer1::Weights() const { return GetWeightsForTRT(); } + + nvinfer1::Dims shape_; + tensorflow::DataType type_; + + private: + const void* values_; + bool empty_weight_flag_; +}; + +class TRT_TensorOrWeights { + public: + explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor) + : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {} + explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights) + : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {} + TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) + : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {} + ~TRT_TensorOrWeights() {} + + bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; } + bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; } + + nvinfer1::ITensor* tensor() { + CHECK_EQ(is_tensor(), true); + return tensor_; + } + const nvinfer1::ITensor* tensor() const { + CHECK_EQ(is_tensor(), true); + return tensor_; + } + TRT_ShapedWeights& weights() { + CHECK_EQ(is_weights(), true); + return weights_; + } + const TRT_ShapedWeights& weights() const { + CHECK_EQ(is_weights(), true); + return weights_; + } + nvinfer1::Dims shape() const { + if (is_tensor()) { + return tensor()->getDimensions(); + } else { + return weights().shape_; + } + } + + private: + nvinfer1::ITensor* tensor_; + TRT_ShapedWeights weights_; + enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } variant_; +}; + +class TFAttrs { + public: + explicit TFAttrs(const tensorflow::NodeDef& tf_node) { + for (const auto& attr : tf_node.attr()) { + attrs_.insert({attr.first, &attr.second}); + } + } + bool count(string key) const { return attrs_.count(key); } + tensorflow::AttrValue const* at(string key) const { + if (!attrs_.count(key)) { + LOG(FATAL) << "Attribute not found: " << key; + } + return attrs_.at(key); + } + template + T get(string key) const; + template + T get(string key, const T& default_value) const { + return attrs_.count(key) ? this->get(key) : default_value; + } + + private: + typedef std::map AttrMap; + AttrMap attrs_; +}; + +template <> +string TFAttrs::get(string key) const { + return this->at(key)->s(); +} + +template <> +std::vector TFAttrs::get>(string key) const { + auto attr = this->at(key)->list().i(); + return std::vector(attr.begin(), attr.end()); +} + +template <> +nvinfer1::Dims TFAttrs::get(string key) const { + auto values = this->get>(key); + nvinfer1::Dims dims; + dims.nbDims = values.size(); + std::copy(values.begin(), values.end(), dims.d); + // Note: No dimension type information is included + return dims; +} + +template <> +nvinfer1::DataType TFAttrs::get(string key) const { + nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); + TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); + return trt_dtype; +} + +template <> +tensorflow::DataType TFAttrs::get(string key) const { + return this->at(key)->type(); +} + +template +void Reorder4(nvinfer1::DimsNCHW shape, const T* idata, + nvinfer1::DimsNCHW istrides, T* odata, + nvinfer1::DimsNCHW ostrides) { + for (int n = 0; n < shape.n(); ++n) { + for (int c = 0; c < shape.c(); ++c) { + for (int h = 0; h < shape.h(); ++h) { + for (int w = 0; w < shape.w(); ++w) { + odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() + + w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() + + h * istrides.h() + w * istrides.w()]; + } + } + } + } +} + +void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, + TRT_ShapedWeights* oweights) { + CHECK_EQ(iweights.type_, oweights->type_); + CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); + int r = iweights.shape_.d[0]; + int s = iweights.shape_.d[1]; + int c = iweights.shape_.d[2]; + int k = iweights.shape_.d[3]; + oweights->shape_.d[0] = k; + oweights->shape_.d[1] = c; + oweights->shape_.d[2] = r; + oweights->shape_.d[3] = s; + nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; + nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; + switch (iweights.type_) { + case tensorflow::DataType::DT_FLOAT: + Reorder4({k, c, r, s}, static_cast(iweights.GetValues()), + istrides, + static_cast(const_cast(oweights->GetValues())), + ostrides); + break; + default: + LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!"; + } +} + +struct InferDeleter { + template + void operator()(T* obj) const { + if (obj) { + obj->destroy(); + } + } +}; + +template +inline std::shared_ptr infer_object(T* obj) { + return std::shared_ptr(obj, InferDeleter()); +} + +// Logger for GIE info/warning/errors +class Converter; + +using OpConverter = + std::function const&, + std::vector*)>; + +class Converter { + std::unordered_map trt_tensors_; + std::unordered_map op_registry_; + nvinfer1::INetworkDefinition* trt_network_; + std::list> temp_bufs_; + + void register_op_converters(); + + std::vector get_inputs( + const tensorflow::NodeDef& node_def) { + std::vector inputs; + for (const auto& input_name : node_def.input()) { + VLOG(2) << "Retrieve input: " << input_name; + inputs.push_back(trt_tensors_.at(input_name)); + } + return inputs; + } + + public: + explicit Converter(nvinfer1::INetworkDefinition* trt_network) + : trt_network_(trt_network) { + this->register_op_converters(); + } + + TRT_ShapedWeights get_temp_weights(tensorflow::DataType type, + nvinfer1::Dims shape) { + TRT_ShapedWeights weights(type, nullptr, shape); + // TODO(jie): check weights size_bytes. 0 means type error + temp_bufs_.push_back(std::vector(weights.size_bytes())); + weights.SetValues(temp_bufs_.back().data()); + return weights; + } + + TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) { + return this->get_temp_weights(weights.type_, weights.shape_); + } + + tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) { + std::vector inputs = this->get_inputs(node_def); + string op = node_def.op(); + if (!op_registry_.count(op)) { + return tensorflow::errors::Unimplemented( + "No converter registered for op: " + op); + } + OpConverter op_converter = op_registry_.at(op); + std::vector outputs; + TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); + for (size_t i = 0; i < outputs.size(); ++i) { + TRT_TensorOrWeights output = outputs.at(i); + // TODO(jie): tf protobuf seems to be omitting the :0 suffix + string output_name = node_def.name(); + if (i != 0) output_name = output_name + ":" + std::to_string(i); + if (output.is_tensor()) { + output.tensor()->setName(output_name.c_str()); + } + VLOG(2) << "Write out tensor: " << output_name; + if (!trt_tensors_.insert({output_name, output}).second) { + return tensorflow::errors::AlreadyExists( + "Output tensor already exists for op: " + op); + } + } + return tensorflow::Status::OK(); + } + + nvinfer1::INetworkDefinition* network() { return trt_network_; } + + TRT_TensorOrWeights get_tensor(string name) { + if (!trt_tensors_.count(name)) { + return TRT_TensorOrWeights(nullptr); + } + return trt_tensors_.at(name); + } + + bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) { + return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second; + } + + nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor, + std::vector order) { + auto dims = input_tensor->getDimensions(); + + // TODO(jie): change the return to status and properly exit + if (order.size() - 1 != size_t(dims.nbDims)) + LOG(ERROR) << "Dimension does not match, fail gracefully"; + + nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); + nvinfer1::Permutation permutation; + for (int32_t i = 0; i < dims.nbDims; ++i) { + permutation.order[i] = order[i + 1] - 1; + } + layer->setFirstTranspose(permutation); + + nvinfer1::Dims reshape_dims; + reshape_dims.nbDims = dims.nbDims; + for (int32_t i = 0; i < reshape_dims.nbDims; ++i) { + reshape_dims.d[i] = 0; + reshape_dims.type[i] = dims.type[i]; + } + layer->setReshapeDimensions(reshape_dims); + return layer->getOutput(0); + } +}; + +// **************************************************************************** +// Constant folding functions +// TODO(jie): once optimizer kicks in, we should have done constant folding +// there. +//*****************************************************************************/ +struct LambdaFactory { + enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB }; + OP_CATEGORY op; + + template + std::function unary() { + switch (op) { + case OP_CATEGORY::RSQRT: { + VLOG(2) << "RSQRT GETS DONE"; + return [](T t) -> T { return 1.0 / std::sqrt(t); }; + } + case OP_CATEGORY::NEG: + return [](T t) -> T { return -t; }; + default: + VLOG(2) << "Not supported op for unary: " << static_cast(op); + return nullptr; + } + } + + template + std::function binary() { + switch (op) { + case OP_CATEGORY::ADD: + return [](T l, T r) -> T { return l + r; }; + case OP_CATEGORY::SUB: + return [](T l, T r) -> T { return l - r; }; + case OP_CATEGORY::MUL: + return [](T l, T r) -> T { return l * r; }; + default: + LOG(WARNING) << "Not supported op for binary: " << static_cast(op); + } + return [](T l, T r) -> T { + LOG(FATAL) << "Unsupported op type "; + return l; + }; + } + + template + std::function broadcast_r(T val) { + VLOG(2) << "LAMBDA VAL : " << val; + switch (op) { + case OP_CATEGORY::ADD: + return [val](T l) -> T { + VLOG(2) << "LAMBDA VAL : " << val; + return l + val; + }; + // Return [val](T l)-> T {return l+val;}; + case OP_CATEGORY::SUB: + return [val](T l) -> T { + VLOG(2) << "LAMBDA VAL : " << val; + return l - val; + }; + case OP_CATEGORY::MUL: + return [val](T l) -> T { + VLOG(2) << "LAMBDA VAL : " << val; + return l * val; + }; + default: + LOG(WARNING) << "Not supported op for binary: " << static_cast(op); + } + return [val](T l) -> T { + LOG(FATAL) << "Unsupported op type "; + return l; + }; + } + + template + std::function broadcast_l(T val) { + VLOG(2) << "LAMBDA VAL : " << val; + switch (op) { + case OP_CATEGORY::ADD: + return [val](T l) -> T { + VLOG(2) << "LAMBDA VAL : " << val; + return val + l; + }; + case OP_CATEGORY::SUB: + return [val](T l) -> T { + VLOG(2) << "LAMBDA VAL : " << val; + return val - l; + }; + case OP_CATEGORY::MUL: + return [val](T l) -> T { + VLOG(2) << "LAMBDA VAL : " << val; + return val * l; + }; + default: + LOG(ERROR) << "Not supported op for binary: " << static_cast(op); + } + return [val](T l) -> T { + LOG(FATAL) << "Unsupported op type "; + return l; + }; + } +}; + +tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, + TRT_ShapedWeights* oweights, + LambdaFactory unary_op) { + CHECK_EQ(iweights.type_, oweights->type_); + switch (iweights.type_) { + case tensorflow::DataType::DT_FLOAT: { + auto inp = static_cast(iweights.GetValues()); + auto oup = static_cast(const_cast(oweights->GetValues())); + std::transform(inp, inp + iweights.count(), oup, unary_op.unary()); + break; + } + default: + return tensorflow::errors::Unimplemented( + "Data type not supported: " + + tensorflow::DataTypeString(iweights.type_)); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, + const TRT_ShapedWeights& iweights_r, + TRT_ShapedWeights* oweights, + LambdaFactory binary_op) { + // Assume iweights_l.type == iweight_r.type + CHECK_EQ(iweights_l.type_, oweights->type_); + CHECK_EQ(iweights_r.type_, oweights->type_); + VLOG(2) << "SANITY CHECK!"; + + switch (iweights_l.type_) { + case tensorflow::DataType::DT_FLOAT: { + auto inp_l = static_cast(iweights_l.GetValues()); + auto inp_r = static_cast(iweights_r.GetValues()); + auto oup = static_cast(const_cast(oweights->GetValues())); + + if (iweights_l.count() != iweights_r.count()) { + // We only supports broadcast of RankZero + if (iweights_l.count() == 1) { + VLOG(2) << "I bet it is not working!" << (*inp_l); + std::transform(inp_r, inp_r + iweights_r.count(), oup, + binary_op.broadcast_l(*inp_l)); + } else if (iweights_r.count() == 1) { + VLOG(2) << "I bet it is not working!" << (*inp_r); + std::transform(inp_l, inp_l + iweights_l.count(), oup, + binary_op.broadcast_r(*inp_r)); + } else { + return tensorflow::errors::Unimplemented( + "Binary op with non-rankZero broadcast not supported"); + } + } else { + std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup, + binary_op.binary()); + } + break; + } + default: + return tensorflow::errors::Unimplemented( + "Data type not supported: " + + tensorflow::DataTypeString(iweights_l.type_)); + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status ConstantFoldUnary( + Converter& ctx, const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + TRT_ShapedWeights weights_input = inputs.at(0).weights(); + + // Allocate output weights + TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input); + + // FIXME assume type matches input weights + // Get trt type & shape + // Maybe this part has to be moved into the block of rsqrt later + // Check type consistency + CHECK_EQ(weights_input.type_, + TFAttrs(node_def).get("T")); + + // Maybe I should do a switch + LambdaFactory unary_op; + if (node_def.op() == "Rsqrt") { + // Compute rsqrt + unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT; + auto ret = UnaryCompute(weights_input, &weights_output, unary_op); + // PAss the output + if (ret == tensorflow::Status::OK()) { + outputs->push_back(TRT_TensorOrWeights(weights_output)); + } + return ret; + } else { + return tensorflow::errors::Unimplemented("Binary op not supported: " + + node_def.op()); + } +} + +// TODO(jie,ben) broadcast is needed yet not implemented +// Let's get the simple stuff working first. Maybe we should fall bakc to TF +// approach for constant folding +tensorflow::Status ConstantFoldBinary( + Converter& ctx, const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + TRT_ShapedWeights weights_input_l = inputs.at(0).weights(); + TRT_ShapedWeights weights_input_r = inputs.at(1).weights(); + + // Check type consistency + CHECK_EQ(weights_input_l.type_, weights_input_r.type_); + + if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims) + return tensorflow::errors::Unimplemented( + "Binary op implicit broadcast not supported: " + node_def.op()); + + // TODO(jie): constant fold should really fall back to TF. + int nb_dims = weights_input_l.shape_.nbDims; + nvinfer1::Dims output_shape; + output_shape.nbDims = nb_dims; + VLOG(2) << "nb_dims: " << nb_dims + << ", the other: " << weights_input_r.shape_.nbDims; + for (int i = 0; i < nb_dims; i++) { + if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) { + output_shape.d[i] = weights_input_l.shape_.d[i]; + } else if (weights_input_l.shape_.d[i] == 1 || + weights_input_r.shape_.d[i] == 1) { + output_shape.d[i] = + std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]); + } else { + return tensorflow::errors::Unimplemented( + "Binary op with incompatible shape at, " + node_def.op()); + } + VLOG(2) << "left: " << weights_input_l.shape_.d[i] + << "right: " << weights_input_r.shape_.d[i] + << "output: " << output_shape.d[i]; + } + + // FIXME assume type matches input weights + // Get trt type & shape + TFAttrs attrs(node_def); + // Maybe this part has to be moved into the block of rsqrt later + tensorflow::DataType dtype = attrs.get("T"); + + // Allocate output weights + TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape); + + // Maybe I should do a switch + LambdaFactory binary_op; + if (node_def.op() == "Sub") { + binary_op.op = LambdaFactory::OP_CATEGORY::SUB; + } else if (node_def.op() == "Mul") { + binary_op.op = LambdaFactory::OP_CATEGORY::MUL; + } else if (node_def.op() == "Add") { + binary_op.op = LambdaFactory::OP_CATEGORY::ADD; + } else { + return tensorflow::errors::Unimplemented("Binary op not supported: " + + node_def.op()); + } + auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output, + binary_op); + + // Pass the output + if (ret == tensorflow::Status::OK()) { + outputs->push_back(TRT_TensorOrWeights(weights_output)); + } + + return ret; +} + +// TODO(jie): broadcast is needed yet not implemented. +// Only implemented channel wise for the time being +tensorflow::Status BinaryTensorOpWeight( + Converter& ctx, const tensorflow::NodeDef& node_def, + const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights, + std::vector* outputs) { + // FIXME assume type matches input weights + // Get trt type & shape + // Maybe this part has to be moved into the block of rsqrt later + + // Check type consistency + auto dtype = TFAttrs(node_def).get("T"); + CHECK_EQ_TYPE(tensor->getType(), dtype); // Cast to int for error messages + nvinfer1::DataType ttype; + TF_CHECK_OK(ConvertDType(weights.type_, &ttype)); + CHECK_EQ_TYPE(ttype, dtype); // Cast to int for error message + + // Check scale mode + auto dims_w = weights.shape_; + auto dims_t = tensor->getDimensions(); + + // Default to channel-wise + auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; + + if (weights.count() == 1) { + VLOG(2) << "UNIFORM"; + scale_mode = nvinfer1::ScaleMode::kUNIFORM; + } else { + // No broadcasting on Batch dimension; + assert(dims_w.d[0] == 1); + + // Broadcasting on Channel dimension only allowed in kUNIFORM + assert(dims_w.d[1] == dims_t.d[0]); + assert(dims_w.nbDims == dims_t.nbDims); + + // Default is element; + for (int i = 2; i < dims_w.nbDims; i++) { + if (dims_w.d[i] != dims_t.d[i - 1]) { + scale_mode = nvinfer1::ScaleMode::kCHANNEL; + break; + } + } + if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) { + scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; + for (int i = 2; i < dims_w.nbDims; i++) { + if (dims_w.d[i] != 1) + return tensorflow::errors::InvalidArgument( + "Weight shape not compatible at, " + node_def.name()); + } + } + } + + // Prepare weights + TRT_ShapedWeights shift_weights(weights.type_); + TRT_ShapedWeights scale_weights(weights.type_); + TRT_ShapedWeights power_weights(weights.type_); + + // Maybe I should do a switch + if (node_def.op() == "Sub") { + TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights); + LambdaFactory unary_op; + unary_op.op = LambdaFactory::OP_CATEGORY::NEG; + TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); + shift_weights = neg_weights; + } else if (node_def.op() == "Mul") { + scale_weights = weights; + } else if (node_def.op() == "Add") { + shift_weights = weights; + } else { + return tensorflow::errors::Unimplemented("Binary op not supported: " + + node_def.op()); + } + + nvinfer1::IScaleLayer* layer = ctx.network()->addScale( + *const_cast(tensor), scale_mode, shift_weights, + scale_weights, power_weights); + + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + // Pass the output + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status BinaryTensorOpTensor( + Converter& ctx, const tensorflow::NodeDef& node_def, + const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r, + std::vector* outputs) { + static const std::unordered_map ops{ + {"Add", nvinfer1::ElementWiseOperation::kSUM}, + {"Mul", nvinfer1::ElementWiseOperation::kPROD}, + // {"max", nvinfer1::ElementWiseOperation::kMAX}, + // {"min", nvinfer1::ElementWiseOperation::kMIN}, + {"Sub", nvinfer1::ElementWiseOperation::kSUB}, + {"Div", nvinfer1::ElementWiseOperation::kDIV}, + }; + + // FIXME assume type matches input weights + // Get trt type & shape + TFAttrs attrs(node_def); + // Maybe this part has to be moved into the block of rsqrt later + nvinfer1::DataType dtype = attrs.get("T"); + + // Check type consistency + CHECK_EQ_TYPE(tensor_l->getType(), dtype); + CHECK_EQ_TYPE(tensor_r->getType(), dtype); + auto op_pair = ops.find(node_def.op()); + if (op_pair == ops.end()) + return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + + " not supported at: " + + node_def.name()); + + nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( + *const_cast(tensor_l), + *const_cast(tensor_r), op_pair->second); + + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + // Pass the output + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertPlaceholder( + Converter& ctx, const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + VLOG(2) << "Placeholder should have been replace already"; + return tensorflow::errors::Unimplemented(", cannot convert Placeholder op"); + // OK this make sense since we are supposed to replace it with input + TFAttrs attrs(node_def); + nvinfer1::DataType dtype = attrs.get("dtype"); + nvinfer1::Dims dims = attrs.get("shape"); + + dims.nbDims--; + for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1]; + + nvinfer1::ITensor* output = + ctx.network()->addInput(node_def.name().c_str(), dtype, dims); + if (!output) { + return tensorflow::errors::InvalidArgument("Failed to create Input layer"); + } + outputs->push_back(TRT_TensorOrWeights(output)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertConv2D(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + // TODO(jie): handle NHWC/NCHW transpose; + TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck); + ReorderRSCKToKCRS(weights_rsck, &weights); + TRT_ShapedWeights biases(weights.type_); + int noutput = weights.shape_.d[0]; + nvinfer1::DimsHW kernel_size; + kernel_size.h() = weights.shape_.d[2]; + kernel_size.w() = weights.shape_.d[3]; + TFAttrs attrs(node_def); + + int h_index = 2; + int w_index = 3; + auto data_format = attrs.get("data_format"); + if (data_format == "NHWC") { + tensor = ctx.TransposeTensor(const_cast(tensor), + {0, 3, 1, 2}); + h_index = 1; + w_index = 2; + // TODO(jie): transpose it + } + + // TODO(jie): stride. (NHWC/NCHW) + auto tf_stride = attrs.get>("strides"); + nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + + auto tensor_dim = tensor->getDimensions(); + std::vector> padding; + // TODO(jie): padding. + if (attrs.get("padding") == "SAME") { + // This is NCHW tensor with no batch dimension. + // 1 -> h + // 2 -> w + padding = CreateSamePadding( + stride, kernel_size, + {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); + } else { + padding = {{0, 0}, {0, 0}}; + } + + if (padding[0].first != padding[0].second || + padding[1].first != padding[1].second) { + // TODO(jie): handle asymmetric padding + VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second + << padding[1].first << padding[1].second; + + auto dim_before = tensor->getDimensions(); + VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1] + << dim_before.d[2] << ", " << dim_before.d[3]; + auto pad_layer = ctx.network()->addPadding( + *const_cast(tensor), + nvinfer1::DimsHW(padding[0].first, padding[1].first), + nvinfer1::DimsHW(padding[0].second, padding[1].second)); + padding = {{0, 0}, {0, 0}}; + tensor = pad_layer->getOutput(0); + auto dim_after = tensor->getDimensions(); + VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1] + << dim_after.d[2] << ", " << dim_after.d[3]; + } + + nvinfer1::IConvolutionLayer* layer = + ctx.network()->addConvolution(*const_cast(tensor), + noutput, kernel_size, weights, biases); + + layer->setStride(stride); + layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + auto dim_after = output_tensor->getDimensions(); + VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1] + << dim_after.d[2] << ", " << dim_after.d[3]; + + if (data_format == "NHWC") { + // TODO(jie): transpose it back! + output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); + } else { + VLOG(2) << "NCHW !!!!"; + } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertPool(Converter& ctx, + const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + TFAttrs attrs(node_def); + + int h_index = 2; + int w_index = 3; + auto data_format = attrs.get("data_format"); + if (data_format == "NHWC") { + h_index = 1; + w_index = 2; + tensor = ctx.TransposeTensor(const_cast(tensor), + {0, 3, 1, 2}); + } else { + VLOG(2) << "NCHW !!!!"; + } + nvinfer1::PoolingType type; + // TODO(jie): support other pooling type + if (node_def.op() == "MaxPool") + type = nvinfer1::PoolingType::kMAX; + else + return tensorflow::errors::Unimplemented("Only supports Max pool"); + + // TODO(jie): NCHW + auto tf_stride = attrs.get>("strides"); + nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + + auto tf_kernel = attrs.get>("ksize"); + nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); + + auto tensor_dim = tensor->getDimensions(); + std::vector> padding; + // TODO(jie): padding. + if (attrs.get("padding") == "SAME") { + // This is NCHW tensor with no batch dimension. + // 1 -> h + // 2 -> w + padding = CreateSamePadding( + stride, ksize, + {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); + } else if (attrs.get("padding") == "VALID") { + // No padding for valid padding here + VLOG(2) << "No padding added for VALID padding in pool" << node_def.name(); + padding = {{0, 0}, {0, 0}}; + } else { + return tensorflow::errors::Unimplemented( + "Current MaxPool cannot support padding other than SAME"); + } + + if (padding[0].first != padding[0].second || + padding[1].first != padding[1].second) { + // TODO(jie): handle asymmetric padding + VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second + << padding[1].first << padding[1].second; + auto pad_layer = ctx.network()->addPadding( + *const_cast(tensor), + nvinfer1::DimsHW(padding[0].first, padding[1].first), + nvinfer1::DimsHW(padding[0].second, padding[1].second)); + padding = {{0, 0}, {0, 0}}; + tensor = pad_layer->getOutput(0); + } + + nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling( + *const_cast(tensor), type, ksize); + + layer->setStride(stride); + layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + if (data_format == "NHWC") { + // TODO(jie): transpose it back! + output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); + } else { + VLOG(2) << "NCHW !!!!"; + } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertActivation( + Converter& ctx, const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + nvinfer1::IActivationLayer* layer = ctx.network()->addActivation( + *const_cast(tensor), nvinfer1::ActivationType::kRELU); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertScale(Converter& ctx, + const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) + return tensorflow::errors::Unimplemented( + "Only supports tensor op weight for now, at " + node_def.name()); + // Implement tensor binaryOp weight [channel wise] for now; + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + + // TODO(jie): handle NHWC/NCHW transpose; + TRT_ShapedWeights weights = inputs.at(1).weights(); + TRT_ShapedWeights empty_weights(weights.type_); + + TFAttrs attrs(node_def); + + // Transpose NHWC + auto data_format = attrs.get("data_format"); + if (data_format == "NHWC") { + tensor = ctx.TransposeTensor(const_cast(tensor), + {0, 3, 1, 2}); + // TODO(jie): transpose it + } else { + VLOG(2) << "NCHW !!!!"; + } + nvinfer1::IScaleLayer* layer = ctx.network()->addScale( + *const_cast(tensor), nvinfer1::ScaleMode::kCHANNEL, + weights, empty_weights, empty_weights); + + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + if (data_format == "NHWC") { + // TODO(jie): transpose it back! + output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); + } else { + VLOG(2) << "NCHW !!!!"; + } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertConst(Converter& ctx, + const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + const auto& weights_tensor = node_def.attr().at("value").tensor(); + + // Get trt type & shape + TFAttrs attrs(node_def); + const tensorflow::DataType dtype = attrs.get("dtype"); + + // Create shaped weights as output + tensorflow::Tensor tensor; + if (!tensor.FromProto(weights_tensor)) + return tensorflow::errors::Internal("Cannot parse weight tensor proto: " + + node_def.name()); + + TRT_ShapedWeights weights(dtype); + if (!weights_tensor.float_val().empty()) { + VLOG(2) << "SCALAR!!!" << node_def.name(); + nvinfer1::Dims scalar_shape; + if (tensor.dims() > 0) { + VLOG(2) << "Dimensions: " << tensor.dims(); + weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(), + GetTensorShape(tensor)); + } else { + VLOG(2) << "Dimensions: " << tensor.dims(); + scalar_shape.nbDims = 1; + scalar_shape.d[0] = 1; + scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; + for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) { + scalar_shape.d[i] = 0; + scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; + } + weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(), + scalar_shape); + } + } else if (!weights_tensor.tensor_content().empty()) { + VLOG(2) << "TENSOR!!!" << node_def.name(); + const auto& content = weights_tensor.tensor_content(); + + weights = ctx.get_temp_weights(dtype, GetTensorShape(tensor)); + if (content.size() > 0) { + const int dtype_size = tensorflow::DataTypeSize(dtype); + CHECK_EQ(0, content.size() % dtype_size) + << "Tensor content size (" << content.size() + << ") is not a multiple of " << dtype_size; + port::CopyToArray( + content, static_cast(const_cast(weights.GetValues()))); + } + } else { + return tensorflow::errors::Unimplemented( + "Not supported constant type, at " + node_def.name()); + } + // Pass the output + outputs->push_back(TRT_TensorOrWeights(weights)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertIdentity( + Converter& ctx, const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + outputs->push_back(inputs.at(0)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertBinary(Converter& ctx, + const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + if (inputs.size() != 2) + return tensorflow::errors::FailedPrecondition( + "Binary ops require two tensor input, at " + node_def.name()); + + if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) + return ConstantFoldBinary(ctx, node_def, inputs, outputs); + + if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) + return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(), + inputs.at(1).weights(), outputs); + + if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) + return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(), + inputs.at(0).weights(), outputs); + + if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) + return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(), + inputs.at(1).tensor(), outputs); + + return tensorflow::errors::Unknown("Binary op input error, at " + + node_def.name()); +} + +tensorflow::Status ConvertUnary(Converter& ctx, + const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + if (inputs.size() != 1) + return tensorflow::errors::FailedPrecondition( + "Unary ops require single tensor input, at " + node_def.name()); + + if (inputs.at(0).is_weights()) + return ConstantFoldUnary(ctx, node_def, inputs, outputs); + else if (inputs.at(0).is_tensor()) + return tensorflow::errors::Unimplemented( + "Unary op for tensor not supported, at " + node_def.name()); + + return tensorflow::errors::Unknown("Binary op input error, at " + + node_def.name()); +} + +tensorflow::Status ConvertReduce(Converter& ctx, + const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) + return tensorflow::errors::InvalidArgument( + "Input expects tensor and weights, at" + node_def.name()); + + // Implement tensor binaryOp weight [channel wise] for now; + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + auto dims = tensor->getDimensions(); + // Restore implicit batch dimension + int nb_dims = dims.nbDims + 1; + + TRT_ShapedWeights index_list = inputs.at(1).weights(); + + TFAttrs attrs(node_def); + // TODO(jie): handle data type. + // Index type here is done through TF type, so I can leverage their + // EnumToDataType for my cast + auto index_type = attrs.get("Tidx"); + + // Only expect to handle INT32 as attributes for now + if (index_type != tensorflow::DataType::DT_INT32) + return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); + auto index_list_data = + static_cast(const_cast(index_list.GetValues())); + + // Hack warning: have to fall back to pool layer since reduce is not in public + // TRT yet. + if (nb_dims != 4) + return tensorflow::errors::InvalidArgument( + "TRT only support reduce on 4 dimensional tensors, at" + + node_def.name()); + if (index_list.count() > 2) + return tensorflow::errors::InvalidArgument( + "TRT cannot support reduce on more than 2 dimensions, at" + + node_def.name()); + + std::set idx_set; + // We cannot operate on Channel. permutation flag used to transpose tensor + int permuted_index = -1; + for (int i = 0; i < index_list.count(); i++) { + if (index_list_data[i] == 0) + return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" + + node_def.name()); + if (index_list_data[i] == 1) permuted_index = 1; + idx_set.emplace(index_list_data[i]); + } + + std::vector permutation_order(nb_dims); + nvinfer1::DimsHW pool_kernel; + if (permuted_index == 1) { + for (int i = 2; i < nb_dims; i++) { + if (idx_set.count(i)) { + permuted_index = i; + break; + } + } + for (int i = 0; i < nb_dims; i++) permutation_order[i] = i; + + permutation_order[permuted_index] = 1; + permutation_order[1] = permuted_index; + + // Apply permutation before extracting dimension for pool_kernel + tensor = ctx.TransposeTensor(const_cast(tensor), + permutation_order); + } + + // Apply permutation before extracting dimension for pool_kernel + pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1; + pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1; + + nvinfer1::ITensor* output_tensor; + + if (node_def.op() == "Mean") { + nvinfer1::IPoolingLayer* layer = + ctx.network()->addPooling(*const_cast(tensor), + nvinfer1::PoolingType::kAVERAGE, pool_kernel); + output_tensor = layer->getOutput(0); + } else { + return tensorflow::errors::Unimplemented( + "Op not supported " + node_def.op() + " , at " + node_def.name()); + } + if (permuted_index != -1) { + // Apply permutation before extracting dimension for pool_kernel + output_tensor = ctx.TransposeTensor( + const_cast(output_tensor), permutation_order); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertPad(Converter& ctx, + const tensorflow::NodeDef& node_def, + std::vector const& inputs, + std::vector* outputs) { + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) + return tensorflow::errors::InvalidArgument( + "Input expects tensor and weights, at" + node_def.name()); + + // Implement tensor binaryOp weight [channel wise] for now; + nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + auto dims = tensor->getDimensions(); + // Restore implicit batch dimension + int nb_dims = dims.nbDims + 1; + + TRT_ShapedWeights pads = inputs.at(1).weights(); + + TFAttrs attrs(node_def); + // Padding type here is done through TF type + // so I can leverage their EnumToDataType for my cast + auto padding_type = attrs.get("Tpaddings"); + // TODO(jie): handle data type conversion for TRT? + + if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) + return tensorflow::errors::InvalidArgument( + "Pad only supports explicit padding on 4 dimensional tensor, at " + + node_def.name()); + + // Only expect to handle INT32 as attributes for now + if (padding_type != tensorflow::DataType::DT_INT32) + return tensorflow::errors::Unimplemented( + "Tpaddings supports only DT_INT32"); + auto pad_data = static_cast(const_cast(pads.GetValues())); + + std::vector pad_index; + for (int i = 0; i < nb_dims; i++) { + if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) + pad_index.push_back(i); + } + + // No padding at all, we should exit + if (pad_index.size() == 0) { + outputs->push_back(inputs.at(0)); + return tensorflow::Status::OK(); + } + + // Only supports padding on less than 2 axis GIE-2579 + if (pad_index.size() > 2) + return tensorflow::errors::InvalidArgument( + "Padding layer does not support padding on > 2"); + + // Padding on batch dimension is not supported + if (pad_index[0] == 0) + return tensorflow::errors::InvalidArgument( + "Padding layer does not support padding on batch dimension"); + + // Not doing the legit thing here. ignoring padding on dim 1 and 3; + // TODO(jie): implement pad as uff parser + if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) + return tensorflow::errors::Unimplemented( + "Padding layer does not support padding on dimension 1 and 3 yet"); + + bool legit_pad = true; + nvinfer1::DimsHW pre_padding(0, 0); + nvinfer1::DimsHW post_padding(0, 0); + + std::vector permuted_pad_index(pad_index); + if (pad_index[0] == 1) { + legit_pad = false; + tensor = ctx.TransposeTensor(const_cast(tensor), + {0, 3, 2, 1}); + permuted_pad_index[0] = 3; + } + + for (size_t i = 0; i < pad_index.size(); i++) { + int index = pad_index[i]; + if (permuted_pad_index[i] == 2) { + pre_padding.h() = pad_data[index * 2]; + post_padding.h() = pad_data[index * 2 + 1]; + } else if (permuted_pad_index[i] == 3) { + pre_padding.w() = pad_data[index * 2]; + post_padding.w() = pad_data[index * 2 + 1]; + } + } + + nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding( + *const_cast(tensor), pre_padding, post_padding); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + if (!legit_pad) + output_tensor = ctx.TransposeTensor( + const_cast(output_tensor), {0, 3, 2, 1}); + + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + +void Converter::register_op_converters() { + // vgg_16 slim implementation + op_registry_["Placeholder"] = ConvertPlaceholder; + op_registry_["Conv2D"] = ConvertConv2D; + op_registry_["Relu"] = ConvertActivation; + op_registry_["MaxPool"] = ConvertPool; + // This could be really handled as ConvertBinary + op_registry_["BiasAdd"] = ConvertScale; + op_registry_["Const"] = ConvertConst; + // op_registry_["MatMul"] = ConvertFullyConnected; // Not used in vgg + // TODO(ben,jie): this is a temp hack. + op_registry_["Identity"] = ConvertIdentity; // Identity should be removed + // op_registry_["AvgPool"] = ConvertPool; + + // resnet_50_v1 slim implementation + op_registry_["Add"] = ConvertBinary; + op_registry_["Mul"] = ConvertBinary; + op_registry_["Sub"] = ConvertBinary; + op_registry_["Rsqrt"] = ConvertUnary; + op_registry_["Mean"] = ConvertReduce; + op_registry_["Pad"] = ConvertPad; + // TODO(ben,jie): Add more ops +} + +} // namespace + +tensorflow::Status ConvertSubGraphToTensorRTNodeDef( + const tensorflow::Graph& graph, const std::set& subgraph_node_ids, + const std::vector>& input_inds, + const std::vector>& output_inds, size_t max_batch_size, + size_t max_workspace_size_bytes, + const tensorflow::grappler::GraphProperties& graph_properties, + tensorflow::NodeDef* trt_node) { + // Visit nodes in reverse topological order and construct the TRT network. + + // Toposort + std::vector order_vec; + tensorflow::GetPostOrder(graph, &order_vec); + // Select just the subgraph + std::list order; + for (tensorflow::Node* node : order_vec) { + if (subgraph_node_ids.count(node->id())) { + // We want topological order to contstruct the + // network layer by layer + order.push_front(node); + } + } + // Topological order is needed to build TRT network + + tensorflow::tensorrt::Logger trt_logger; + + auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger)); + if (!trt_builder) { + return tensorflow::errors::Internal( + "Failed to create TensorRT builder object"); + } + + auto trt_network = infer_object(trt_builder->createNetwork()); + if (!trt_network) { + return tensorflow::errors::Internal( + "Failed to create TensorRT network object"); + } + + // Build the network + Converter converter(trt_network.get()); + + std::vector input_names; + std::vector input_dtypes; + for (std::pair const& input : input_inds) { + int node_id = input.first; + int output_idx = input.second; + tensorflow::Node* node = graph.FindNodeId(node_id); + auto node_name = node->name(); + input_names.push_back(node_name); // Insert original node name without port + // TODO(jie): alternative :) + if (!graph_properties.HasOutputProperties(node_name)) + return tensorflow::errors::Internal("Failed to find input node: " + + node_name); + + auto op_info_vec = graph_properties.GetOutputProperties(node_name); + if (static_cast(op_info_vec.size()) < output_idx) + return tensorflow::errors::Internal( + "Accessing output index of: " + std::to_string(output_idx) + + ", at node: " + node_name + " with output entry from shape_map: " + + std::to_string(op_info_vec.size())); + + auto op_info = op_info_vec.at(output_idx); + + tensorflow::DataType tf_dtype = op_info.dtype(); + input_dtypes.push_back(tf_dtype); + + nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); + TF_CHECK_OK(ConvertDType(tf_dtype, &dtype)); + + VLOG(2) << "Accessing output index of: " << std::to_string(output_idx) + << ", at node: " << node_name + << " with output entry from shape_map: " + << std::to_string(op_info_vec.size()); + + // TODO(ben,jie): update TRT input format/dimension + nvinfer1::DimsCHW input_dim_psuedo_chw; + for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1; + + for (int i = 1; i < op_info.shape().dim_size(); i++) { + VLOG(2) << "dimension: " << i + << " , size: " << op_info.shape().dim(i).size(); + input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size(); + } + + // TODO(ben,jie): proper way to restore input tensor name? + auto input_tensor_name = node_name; + if (output_idx != 0) + input_tensor_name = node_name + ":" + std::to_string(output_idx); + + nvinfer1::ITensor* input_tensor = converter.network()->addInput( + input_tensor_name.c_str(), dtype, input_dim_psuedo_chw); + + if (!input_tensor) + return tensorflow::errors::InvalidArgument( + "Failed to create Input layer"); + VLOG(2) << "Input tensor name :" << input_tensor_name; + + if (!converter.insert_input_tensor(input_tensor_name, input_tensor)) + return tensorflow::errors::AlreadyExists( + "Output tensor already exists for op: " + input_tensor_name); + } + + VLOG(2) << "Finished sorting"; + + for (const tensorflow::Node* node : order) { + const tensorflow::NodeDef& node_def = node->def(); + VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op(); + TF_RETURN_IF_ERROR(converter.convert_node(node_def)); + } + + VLOG(2) << "Finished conversion"; + + // Gather output metadata + std::vector output_names; + std::vector output_dtypes; + for (std::pair const& output : output_inds) { + int node_id = output.first; + int output_idx = output.second; + tensorflow::Node* node = graph.FindNodeId(node_id); + string op_name = node->name(); + string tensor_name = op_name; + if (output_idx != 0) + tensor_name = tensor_name + ":" + std::to_string(output_idx); + VLOG(2) << "Output tensor name: " << tensor_name; + output_names.push_back(tensor_name); + auto tensor_or_weights = converter.get_tensor(tensor_name); + if (!tensor_or_weights.is_tensor()) { + return tensorflow::errors::InvalidArgument( + "Output node is weights not tensor"); + } + nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); + if (!tensor) { + return tensorflow::errors::NotFound("Output tensor not found: " + + tensor_name); + } + converter.network()->markOutput(*tensor); + tensorflow::DataType tf_dtype = node->output_type(output_idx); + output_dtypes.push_back(tf_dtype); + nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; + TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); + tensor->setType(trt_dtype); + } + + VLOG(2) << "Finished output"; + // TODO(jie): static_id is not thread safe. + static int static_id = 0; + + // Build the engine + trt_builder->setMaxBatchSize(max_batch_size); + trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes); + VLOG(0) << "Starting build engine " << static_id; + // TODO(ben,jie): half2 and int8 mode support + string engine_plan_string; + { + auto trt_engine = + infer_object(trt_builder->buildCudaEngine(*converter.network())); + VLOG(0) << "Built network"; + auto engine_plan = infer_object(trt_engine->serialize()); + VLOG(0) << "Serialized engine"; + const char* engine_plan_data = + static_cast(engine_plan->data()); + engine_plan_string = + string(engine_plan_data, engine_plan_data + engine_plan->size()); + } + + VLOG(0) << "Finished engine"; + + // Build the TRT op + // TODO(sami,ben,jie): proper naming! + tensorflow::NodeDefBuilder op_builder( + tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp"); + std::vector income_edges; + for (size_t i = 0; i < input_names.size(); ++i) { + int output_idx = input_inds.at(i).second; + // We wired up the input here already, it is redundant to do it again in + // ConvertSubGraphToTensorRT(convert_graph.cc) + auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut( + input_names.at(i), output_idx, input_dtypes.at(i)); + income_edges.push_back(incoming_edge); + } + tensorflow::gtl::ArraySlice input_list( + income_edges); + op_builder.Input(input_list); + + VLOG(0) << "Finished op preparation"; + + auto status = op_builder.Attr("serialized_engine", engine_plan_string) + .Attr("input_nodes", input_names) + .Attr("output_nodes", output_names) + .Attr("OutT", output_dtypes) + .Finalize(trt_node); + + VLOG(0) << status.ToString() << " finished op building"; + + return tensorflow::Status::OK(); +} + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h new file mode 100644 index 0000000000000000000000000000000000000000..2e7fd19566e1ed3719b932c7443a9c3f652b2d3e --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/lib/core/status.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +tensorflow::Status ConvertSubGraphToTensorRTNodeDef( + const tensorflow::Graph& graph, const std::set& subgraph_node_ids, + const std::vector>& + input_inds, // {node_id, output_idx} + const std::vector>& + output_inds, // {node_id, output_idx} + size_t max_batch_size, size_t max_workspace_size_bytes, + const tensorflow::grappler::GraphProperties& graph_prop, + tensorflow::NodeDef* trt_node); + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8efdf63ebebc4d7a199c60635ca64348d2b30505 --- /dev/null +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" + +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" + +namespace tensorflow { +namespace tensorrt { +static ::tensorflow::tensorrt::Logger logger; + +TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { + // read serialized_engine + string serialized_engine; + OP_REQUIRES_OK(context, + context->GetAttr("serialized_engine", &serialized_engine)); + + // register input output node name in trt_sub_graph + OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_)); + OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_)); + + // TODO(samikama) runtime should be taken from a resourcemanager as well. + // Only engine should be in the op and context and runtime should be taken + // from resourcemanager + nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); + trt_engine_ptr_.reset(infer->deserializeCudaEngine( + serialized_engine.c_str(), serialized_engine.size(), nullptr)); + + trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); + // Runtime is safe to delete after engine creation + infer->destroy(); +} + +void TRTEngineOp::Compute(OpKernelContext* context) { + int num_binding = context->num_inputs() + context->num_outputs(); + std::vector buffers(num_binding); + + size_t binding_index; + int num_batch = 0; + bool valid = true; + for (int i = 0; i < context->num_inputs(); i++) { + // Grab the input tensor + binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str()); + + const Tensor& input_tensor = context->input(i); + const TensorShape& input_shape = input_tensor.shape(); + if (i == 0) { + num_batch = input_shape.dim_size(0); + } else if (num_batch != input_shape.dim_size(0)) { + valid = false; + break; + } + switch (trt_engine_ptr_->getBindingDataType(binding_index)) { + case nvinfer1::DataType::kFLOAT: + buffers[binding_index] = (void*)(input_tensor.flat().data()); + break; + case nvinfer1::DataType::kHALF: + LOG(FATAL) << "half size is not supported yet!"; + break; + case nvinfer1::DataType::kINT8: + LOG(FATAL) << "int8 is not supported yet!"; + break; + } + } + + // Might want a different way to inform the user of batch size inconsistency + if (!valid) LOG(WARNING) << "input data inconsistent batch size"; + + for (int i = 0; i < static_cast(output_nodes_.size()); i++) { + // This is bad that we have to reallocate output buffer every run. + // Create an output tensor + binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str()); + Tensor* output_tensor = nullptr; + + TensorShape output_shape; + if (binding_index != -1) { + auto dims = trt_engine_ptr_->getBindingDimensions(binding_index); + std::vector trt_shape(dims.nbDims + 1); + trt_shape[0] = num_batch; + for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; + OP_REQUIRES_OK(context, + TensorShapeUtils::MakeShape( + trt_shape.data(), trt_shape.size(), &output_shape)); + } else { + LOG(FATAL) << "output node not found, at " << output_nodes_[i]; + break; + } + + OP_REQUIRES_OK(context, + context->allocate_output(i, output_shape, &output_tensor)); + switch (trt_engine_ptr_->getBindingDataType(binding_index)) { + case nvinfer1::DataType::kFLOAT: + buffers[binding_index] = + reinterpret_cast(output_tensor->flat().data()); + break; + case nvinfer1::DataType::kHALF: + LOG(FATAL) << "half size is not supported yet!"; + break; + case nvinfer1::DataType::kINT8: + LOG(FATAL) << "int8 is not supported yet!"; + break; + } + } + // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(context->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + + // execution handled by TF since we are getting stream from TF. + // it is safe for CPU pointer array (buffers) to go out of scope after enqueue + trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr); +} + +REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0964b4b18a781143fdd7884a2904321b9d14e354 --- /dev/null +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ + +#include +#include +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +class Logger; + +class TRTEngineOp : public OpKernel { + public: + explicit TRTEngineOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* context) override; + + private: + template + struct Destroyer { + void operator()(T* d) { d->destroy(); } + }; + + template + using destroyed_ptr = std::unique_ptr>; + destroyed_ptr trt_engine_ptr_; + // TODO(samikama): context should go to a resource manager! + destroyed_ptr trt_execution_context_ptr_; + + std::vector input_nodes_; + std::vector output_nodes_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/contrib/tensorrt/log/trt_logger.cc new file mode 100644 index 0000000000000000000000000000000000000000..7add8cb8b3d2a04206ee4174e79a1a4b86e05f30 --- /dev/null +++ b/tensorflow/contrib/tensorrt/log/trt_logger.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace tensorrt { + +// Use TF logging for TensorRT informations +void Logger::log(Severity severity, const char* msg) { + // Suppress info-level messages + switch (severity) { + case Severity::kINFO: { // Mark TRT info messages as debug! + VLOG(2) << msg; + break; + } + case Severity::kWARNING: { + LOG(WARNING) << msg; + break; + } + case Severity::kERROR: { + LOG(ERROR) << msg; + break; + } + case Severity::kINTERNAL_ERROR: { + LOG(FATAL) << msg; + break; + } + // This is useless for now. But would catch it in future if enum changes. It + // is always good to have default case! + default: { + LOG(FATAL) << name_ << "Got unknown severity level from TRT " << msg; + break; + } + } +} +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h new file mode 100644 index 0000000000000000000000000000000000000000..d71f66b933a8068a6276a7e070755e0075543bb5 --- /dev/null +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ + +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +// Logger for GIE info/warning/errors +class Logger : public nvinfer1::ILogger { + private: + void log(nvinfer1::ILogger::Severity severity, const char* msg) override; + + string name_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..079d73f7bec3f9a9740e455b31a259cec287f849 --- /dev/null +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +namespace shape_inference { +extern Status TRTEngineOpShapeInference(InferenceContext* c); +} + +REGISTER_OP("TRTEngineOp") + .Attr("serialized_engine: string") + .Attr("input_nodes: list(string)") + .Attr("output_nodes: list(string)") + .Attr("InT: list({float32})") + .Attr("OutT: list({float32})") + .Input("in_tensor: InT") + .Output("out_tensor: OutT") + .SetShapeFn(shape_inference::TRTEngineOpShapeInference); + +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e050a768ce97af1fc1d2c85cb52640b4c6a6a97 --- /dev/null +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Exposes the python wrapper for TensorRT graph transforms.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long +from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph +# pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py new file mode 100644 index 0000000000000000000000000000000000000000..31a313182be9a2fca7457a539670dbc911ccabb1 --- /dev/null +++ b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py @@ -0,0 +1,34 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Exposes the Python wrapper of TRTEngineOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import platform + +if platform.system() != "Windows": + # pylint: disable=wildcard-import,unused-import,g-import-not-at-top + from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import * + + from tensorflow.contrib.util import loader + from tensorflow.python.platform import resource_loader + # pylint: enable=wildcard-import,unused-import,g-import-not-at-top + + _trt_engine_op = loader.load_op_library( + resource_loader.get_path_to_datafile("_trt_engine_op.so")) +else: + raise RuntimeError("Windows platforms are not supported") diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..9454862f857ab743712ce409ff007de55e72a68e --- /dev/null +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -0,0 +1,103 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Exposes the Python wrapper conversion to trt_graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long +import six as _six +from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import errors +from tensorflow.python.framework import errors_impl as _impl +from tensorflow.python.framework import ops + + +# TODO(skama): get outputs from session when implemented as c++ +# optimization pass +def create_inference_graph(input_graph_def, + outputs, + max_batch_size=1, + max_workspace_size_bytes=2 << 20): + """Python wrapper for the TRT transormation. + + + Args: + input_graph_def: GraphDef object containing a model to be transformed. + outputs: List of tensors or node names for the model outputs. + max_batch_size: max size for the input batch + max_workspace_size_bytes: parameter to control memory allocation (in Bytes) + + Returns: + New GraphDef with TRTEngineOps placed in graph replacing subgraphs. + + Raises: + RuntimeError: if the returned status message is malformed. + """ + + def py2bytes(inp): + return inp + + def py3bytes(inp): + return inp.encode("utf-8", errors="surrogateescape") + + def py2string(inp): + return inp + + def py3string(inp): + return inp.decode("utf-8") + + if _six.PY2: + to_bytes = py2bytes + to_string = py2string + else: + to_bytes = py3bytes + to_string = py3string + + out_names = [] + for i in outputs: + if isinstance(i, ops.Tensor): + out_names.append(to_bytes(i.name)) + else: + out_names.append(to_bytes(i)) + + input_graph_def_str = input_graph_def.SerializeToString() + + # TODO(sami): Fix this when we can return status from C++ library + # There is a problem with the TF internal library setup that doesn't + # allow us to return a status object from C++. Thus we return a + # pair or strings where first one is encoded status and the second + # one is the transformed graphs protobuf string. + out = trt_convert(input_graph_def_str, out_names, max_batch_size, + max_workspace_size_bytes) + status = to_string(out[0]) + output_graph_def_string = out[1] + del input_graph_def_str # Save some memory + if len(status) < 2: + raise _impl.UnknownError(None, None, status) + if status[:2] != "OK": + msg = status.split(";") + if len(msg) == 1: + raise RuntimeError("Status message is malformed {}".format(status)) + # pylint: disable=protected-access + raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), + int(msg[0])) + # pylint: enable=protected-access + output_graph_def = graph_pb2.GraphDef() + output_graph_def.ParseFromString(output_graph_def_string) + del output_graph_def_string # Save some memory + return output_graph_def diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc new file mode 100644 index 0000000000000000000000000000000000000000..6193f0b0a13f6985d5fc8dd4c6fc09b15f72f139 --- /dev/null +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -0,0 +1,253 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/segment/segment.h" + +#include +#include +#include + +#include "tensorflow/contrib/tensorrt/segment/union_find.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tensorrt { +namespace segment { + +namespace { + +bool CanContractEdge(const tensorflow::Edge* edge, + const tensorflow::Graph& graph) { + const tensorflow::Node* src = edge->src(); + const tensorflow::Node* dst = edge->dst(); + + // Can't contract edge if doing so would cause a cycle in the + // graph. So, if there is a directed path from 'src' to 'dst', other + // than 'edge' (or any other direct edge from 'src' to 'dst'), then + // combining 'src' and 'dst' will cause a cycle along that path. + // + // In practice, to avoid modifying the graph and to take advantage + // of existing graph functions, we perform an equivalent. + // 1. Get all nodes incoming to 'dst', excluding 'src' + // 2. Reverse DFS from those nodes + // 3. If reverse DFS reaches 'src' then we have a cycle + std::vector dfs_start_nodes; + for (tensorflow::Node* node : dst->in_nodes()) { + if (node != src) { + dfs_start_nodes.push_back(node); + } + } + + bool is_cycle = false; + if (!dfs_start_nodes.empty()) { + tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {}, + [&is_cycle, src](tensorflow::Node* node) { + if (node == src) { + is_cycle = true; + } + }); + } + + return !is_cycle; +} + +void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph, + std::vector* remove_edges) { + // Transfer all inputs and outputs of 'dst' to 'src' except edges + // connecting the two. + tensorflow::Node* src = edge->src(); + tensorflow::Node* dst = edge->dst(); + + // We can use '0' for input/output index because we don't need them + // to be accurate for the way we are using the graph. + std::vector in_edges(dst->in_edges().begin(), + dst->in_edges().end()); + for (const tensorflow::Edge* in_edge : in_edges) { + if (in_edge->src() != src) { + tensorflow::Edge* e = const_cast(in_edge); + if (e->src() == graph->source_node()) { + graph->AddEdge(e->src(), e->src_output(), src, + tensorflow::Graph::kControlSlot); + } else { + graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */); + } + } + } + + std::vector out_edges(dst->out_edges().begin(), + dst->out_edges().end()); + for (const tensorflow::Edge* out_edge : out_edges) { + tensorflow::Edge* e = const_cast(out_edge); + if (e->dst() == graph->sink_node()) { + graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(), + e->dst_input()); + } else { + graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input()); + } + } + + // Return the edges that must be removed to disconnect 'dst' from + // the graph. We don't actually remove 'dst' since the caller holds + // references to all the nodes. + for (const auto& in_edge : dst->in_edges()) { + remove_edges->push_back(in_edge); + } + for (const auto& out_edge : dst->out_edges()) { + remove_edges->push_back(out_edge); + } +} + +} // namespace + +tensorflow::Status SegmentGraph( + const tensorflow::GraphDef& gdef, + const std::function& candidate_fn, + const SegmentOptions& options, SegmentNodesVector* segments) { + // Create a Graph representation of the GraphDef. + tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), + gdef.library()); + tensorflow::Graph graph(flib); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( + tensorflow::GraphConstructorOptions(), gdef, &graph)); + + // tensorflow::DumpGraph("Pre-Segment", &graph); + + // Use a union-find to collect the nodes that belong to the same + // segment. A node value of nullptr indicates that the node is not a + // candidate for TRT. + std::vector> node_segments; + for (int i = 0; i < graph.num_node_ids(); ++i) { + tensorflow::Node* node = graph.FindNodeId(i); + if (options.exclude_node_list.count(node->name()) != 0 || + !candidate_fn(node->def())) { + node = nullptr; + } + node_segments.emplace_back(node); + } + + // The segmentation algorithm below visits nodes in reverse + // topological order and attempts to merge nodes along output + // edges. That means that subgraphs grow from the output-side of the + // network towards the inputs. In general this is not guaranteed to + // produce a globally optimal segmentation. In the future if we have + // a measure of how beneficial it is to include a given node in a + // TRT subgraph then we can revisit this algorithm to take advantage + // of that information. + std::vector order; + tensorflow::GetPostOrder(graph, &order); + + for (const tensorflow::Node* node : order) { + // All output nodes of 'node' have been visited... + VLOG(2) << "Trying node " << node->name(); + + // 'node' must be a TRT candidate... + if (node_segments[node->id()].Value() == nullptr) { + VLOG(2) << "... not a TRT candidate"; + continue; + } + + // Contract output edges to combine 'node' with output + // nodes. Iterate since combining two nodes may unblock other + // combining. + while (true) { + std::set contract_edges; + for (const tensorflow::Edge* out_edge : node->out_edges()) { + VLOG(2) << "... out node " << out_edge->dst()->name(); + + // Out node must be TRT candidate... + if (node_segments[out_edge->dst()->id()].Value() == nullptr) { + VLOG(2) << "... ... not a TRT candidate"; + continue; + } + + if (CanContractEdge(out_edge, graph)) { + VLOG(2) << "... ... can contract"; + contract_edges.insert(out_edge); + } else { + VLOG(2) << "... ... cannot contract, would form cycle"; + } + } + + if (contract_edges.empty()) { + break; + } + + // Contract edges and collect the adjacent nodes into the same + // segment/subgraph. + while (!contract_edges.empty()) { + const tensorflow::Edge* contract_edge = *contract_edges.begin(); + const tensorflow::Node* src = contract_edge->src(); + const tensorflow::Node* dst = contract_edge->dst(); + + VLOG(2) << "Merge " << src->name() << " <- " << dst->name(); + node_segments[src->id()].Merge(&node_segments[dst->id()]); + + // Contracting the edge leaves disconnected graph edges. + // Remove these from the graph and from 'contract_edges' so we + // don't visit them again. + tensorflow::Edge* e = const_cast(contract_edge); + std::vector remove_edges; + ContractEdge(e, &graph, &remove_edges); + + for (const tensorflow::Edge* r : remove_edges) { + contract_edges.erase(r); + graph.RemoveEdge(r); + } + } + } + } + + // Collect the segments/subgraphs. Each subgraph is represented by a + // set of the names of the nodes in that subgraph. + std::unordered_map> sg_map; + for (auto& u : node_segments) { + if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) { + sg_map[u.ParentValue()->name()].insert(u.Value()->name()); + } + } + + // Convert the segments into the expected return format + for (const auto& itr : sg_map) { + const auto& segment_node_names = itr.second; + if (VLOG_IS_ON(1)) { + string s; + for (const auto& name : segment_node_names) { + s += " " + name; + } + VLOG(1) << "Segment " << segments->size() << ":" << s; + } + + // Don't use small segments. + if (static_cast(segment_node_names.size()) < + options.minimum_segment_size) { + VLOG(1) << "Segment " << segments->size() << " has only " + << segment_node_names.size() << " nodes, dropping"; + continue; + } + + segments->emplace_back(segment_node_names); + } + + return tensorflow::Status::OK(); +} + +} // namespace segment +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h new file mode 100644 index 0000000000000000000000000000000000000000..ee6e2b3ed26cd1fabc0e952d882d549046cd9a30 --- /dev/null +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tensorrt { +namespace segment { + +using SegmentNodesVector = std::vector>; + +struct SegmentOptions { + // Segment must contain at least this many nodes. + int minimum_segment_size = 2; + std::set exclude_node_list; +}; + +// Get the subgraphs of a graph that can be handled by TensorRT. +// +// @param gdef The GraphDef describing the network +// @param candidate_fn A function that returns true for a NodeDef if +// that node can be handled by TensorRT. +// @param segments Returns the TensorRT segments/subgraphs. Each entry +// in the vector describes a subgraph by giving a set of the names of +// all the NodeDefs in that subgraph. +// @return the status. +tensorflow::Status SegmentGraph( + const tensorflow::GraphDef& gdef, + const std::function& candidate_fn, + const SegmentOptions& options, SegmentNodesVector* segments); + +} // namespace segment +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..74cbc5f2b376b76324eed06d251767da6f928e3e --- /dev/null +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -0,0 +1,367 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tensorrt { +namespace segment { +namespace test { + +class SegmentTest : public ::testing::Test { + public: + bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); + + TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name); + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name); + + std::function MakeCandidateFn( + const std::set& node_names); + + protected: + void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, + TF_Operation** op); + void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name, TF_Operation** op, bool check); + + SegmentOptions default_options_; +}; + +bool SegmentTest::GetGraphDef(TF_Graph* graph, + tensorflow::GraphDef* graph_def) { + TF_Status* s = TF_NewStatus(); + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphToGraphDef(graph, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length); + TF_DeleteBuffer(buffer); + TF_DeleteStatus(s); + return ret; +} + +std::function SegmentTest::MakeCandidateFn( + const std::set& node_names) { + return [node_names](const NodeDef& node) -> bool { + return node_names.find(node.name()) != node_names.end(); + }; +} + +void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s, + const char* name, TF_Operation** op) { + TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); + TF_SetAttrType(desc, "dtype", TF_INT32); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s, + const char* name) { + TF_Operation* op; + PlaceholderHelper(graph, s, name, &op); + return op; +} + +void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name, TF_Operation** op, + bool check) { + TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); + TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; + TF_AddInputList(desc, add_inputs, 2); + *op = TF_FinishOperation(desc, s); + if (check) { + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); + } +} + +TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r, + TF_Graph* graph, TF_Status* s, + const char* name) { + TF_Operation* op; + AddHelper(l, r, graph, s, name, &op, true); + return op; +} + +TEST_F(SegmentTest, Empty) { + TF_Graph* graph = TF_NewGraph(); + + GraphDef graph_def; + ASSERT_TRUE(GetGraphDef(graph, &graph_def)); + + SegmentNodesVector segments; + ASSERT_EQ( + SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments), + tensorflow::Status::OK()); + + // Expect no segments/subgraphs. + EXPECT_TRUE(segments.empty()); + TF_DeleteGraph(graph); +} + +TEST_F(SegmentTest, Simple) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // feed + // // || + // add0 add1 + // | | / + // | add2 + // | / || + // add3 add4 + // | / + // + // + TF_Operation* feed = Placeholder(graph, s, "feed"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); + + TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); + TF_Operation* add4 = Add(add2, add2, graph, s, "add4"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); + + GraphDef graph_def; + ASSERT_TRUE(GetGraphDef(graph, &graph_def)); + + SegmentNodesVector segments; + ASSERT_EQ( + SegmentGraph(graph_def, + MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}), + default_options_, &segments), + tensorflow::Status::OK()); + + // Expect all Add operations to be collapsed into a single segment + ASSERT_EQ(segments.size(), 1); + std::vector expected{"add0", "add1", "add2", "add3", "add4"}; + for (const auto& ex : expected) { + EXPECT_TRUE(segments[0].find(ex) != segments[0].end()) + << "Missing expected node " << ex; + } + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST_F(SegmentTest, AvoidCycle) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // add2 is not a TRT candidate so add0/add3 cannot be formed as a + // subgraph + // + // feed + // // || + // add0 add1 + // | | / + // | add2 + // | / || + // add3 add4 + // | / + // + // + TF_Operation* feed = Placeholder(graph, s, "feed"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); + + TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); + TF_Operation* add4 = Add(add2, add2, graph, s, "add4"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); + + GraphDef graph_def; + ASSERT_TRUE(GetGraphDef(graph, &graph_def)); + + SegmentNodesVector segments; + ASSERT_EQ( + SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}), + default_options_, &segments), + tensorflow::Status::OK()); + + // Expect no subgraphs + EXPECT_EQ(segments.size(), 0); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST_F(SegmentTest, Multiple) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // add5 is not a TRT candidate so two subgraphs should be formed + // + // feed + // // || || + // add0 add1 add7 + // | | / / || + // | add2-----add5 add8 + // | / | | | | + // add3 add4 add6 + // | | / + // + // + TF_Operation* feed = Placeholder(graph, s, "feed"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); + + TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add7 = Add(feed, feed, graph, s, "add7"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add5 = Add(add2, add7, graph, s, "add5"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add8 = Add(add7, add7, graph, s, "add8"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); + TF_Operation* add4 = Add(add2, add5, graph, s, "add4"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); + TF_Operation* add6 = Add(add5, add8, graph, s, "add6"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add6"), string(TF_OperationName(add6))); + + GraphDef graph_def; + ASSERT_TRUE(GetGraphDef(graph, &graph_def)); + + SegmentNodesVector segments; + ASSERT_EQ(SegmentGraph(graph_def, + MakeCandidateFn({"add0", "add1", "add2", "add3", + "add4", "add6", "add7", "add8"}), + default_options_, &segments), + tensorflow::Status::OK()); + + // Expect two subgraphs + EXPECT_EQ(segments.size(), 2); + + std::vector expected0{"add0", "add1", "add2", "add3"}; + for (const auto& ex : expected0) { + EXPECT_TRUE(segments[0].find(ex) != segments[0].end()) + << "Missing expected node " << ex; + } + + std::vector expected1{"add6", "add8"}; + for (const auto& ex : expected1) { + EXPECT_TRUE(segments[1].find(ex) != segments[1].end()) + << "Missing expected node " << ex; + } + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST_F(SegmentTest, BigIfElse) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // add2 is not a TRT candidate + // + // feed + // || + // add0 + // // || + // add1 add4 + // || || + // add2 add5 + // || || + // add3 add6 + // || // + // add7 + // || + // + // + TF_Operation* feed = Placeholder(graph, s, "feed"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); + + TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add1 = Add(add0, add0, graph, s, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add2 = Add(add1, add1, graph, s, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add3 = Add(add2, add2, graph, s, "add3"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add4 = Add(add0, add0, graph, s, "add4"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add5 = Add(add4, add4, graph, s, "add5"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add6 = Add(add5, add5, graph, s, "add6"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Operation* add7 = Add(add3, add6, graph, s, "add7"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(string("add7"), string(TF_OperationName(add7))); + + GraphDef graph_def; + ASSERT_TRUE(GetGraphDef(graph, &graph_def)); + + SegmentNodesVector segments; + ASSERT_EQ(SegmentGraph(graph_def, + MakeCandidateFn({"add0", "add1", "add3", "add4", + "add5", "add6", "add7"}), + default_options_, &segments), + tensorflow::Status::OK()); + + // Expect 2 subgraphs + EXPECT_EQ(segments.size(), 2); + + std::vector expected0{"add3", "add4", "add5", "add6", "add7"}; + for (const auto& ex : expected0) { + EXPECT_TRUE(segments[0].find(ex) != segments[0].end()) + << "Missing expected node " << ex; + } + + std::vector expected1{"add0", "add1"}; + for (const auto& ex : expected1) { + EXPECT_TRUE(segments[1].find(ex) != segments[1].end()) + << "Missing expected node " << ex; + } + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +} // namespace test +} // namespace segment +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/segment/union_find.h b/tensorflow/contrib/tensorrt/segment/union_find.h new file mode 100644 index 0000000000000000000000000000000000000000..1c64ebbb0ae532a4776ab8963515d19fd3b23b4c --- /dev/null +++ b/tensorflow/contrib/tensorrt/segment/union_find.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ + +namespace tensorflow { +namespace tensorrt { +namespace segment { + +// Union-Find data structure. +// Each cluster has an associated value; when merging clusters we can control +// which value becomes the representative of the merged clusters. Values must be +// copyable. +template +class UnionFind { + public: + UnionFind() : size_(1), parent_(nullptr) {} + explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {} + + // Returns the number of elements in a cluster. + int Size() { return FindRoot()->size_; } + + // Merges this cluster with 'other'. This cluster's value becomes + // the value of the merged cluster; the value of 'other' is ignored. + void Merge(UnionFind* other); + + // Each cluster has an associated value. Retrieves the value associated + // with this cluster. + T& ParentValue() { return FindRoot()->value_; } + + // Get the original value of this node. + T& Value() { return value_; } + + private: + // Finds the root element of the cluster. Performs path compression. + UnionFind* FindRoot(); + + int size_; + UnionFind* parent_; + T value_; +}; + +template +void UnionFind::Merge(UnionFind* other) { + UnionFind* a = FindRoot(); + UnionFind* b = other->FindRoot(); + if (a == b) return; + + b->parent_ = a; + a->size_ += b->size_; +} + +template +UnionFind* UnionFind::FindRoot() { + if (!parent_) return this; + // Path compression: update intermediate nodes to point to the root of the + // equivalence class. + parent_ = parent_->FindRoot(); + return parent_; +} + +} // namespace segment +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b475177bc670ddae2b26b6a494f758eba20b2c3 --- /dev/null +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" + +#include +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace shape_inference { + +tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { + tensorflow::tensorrt::Logger logger; + string serialized_engine; + TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine)); + nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); + nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( + serialized_engine.c_str(), serialized_engine.size(), nullptr); + + int num_batch = -1; + std::vector<::tensorflow::DataType> input_type; + TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type)); + for (size_t i = 0; i < context->num_inputs(); i++) { + // Check if input shape is legit + auto input_shape = context->input(i); + for (int j = 0; j < context->Rank(input_shape); j++) { + auto dim_handler = context->Dim(input_shape, j); + if (j == 0) { + if (i == 0) { + num_batch = context->Value(dim_handler); + } else if (num_batch != context->Value(dim_handler)) { + // TODO(jie): TensorRT engine requires consistent batch between inputs + // tensors. Segmenter should be aware of this. + LOG(FATAL) << "TensorRT engine requires consistent batch size"; + } + } + } + } + + // Arrange input here + std::vector input_nodes; + TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes)); + + // Arrange output here + std::vector output_nodes; + TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes)); + for (size_t i = 0; i < output_nodes.size(); i++) { + int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str()); + ShapeHandle output_shape; + std::vector dim_vec; + dim_vec.emplace_back(context->MakeDim(num_batch)); + if (binding_index != -1) { + auto dims = trt_engine->getBindingDimensions(binding_index); + for (int j = 0; j < dims.nbDims; j++) { + dim_vec.emplace_back(context->MakeDim(dims.d[j])); + } + } else { + LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i]; + } + output_shape = context->MakeShape(dim_vec); + context->set_output(i, output_shape); + } + + return Status::OK(); +} + +} // namespace shape_inference +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h new file mode 100644 index 0000000000000000000000000000000000000000..4b50f66699f0965639e22169ee7d71e860314bf0 --- /dev/null +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_ + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace shape_inference { +Status TRTEngineOpShapeInference(InferenceContext* c); +} // namespace shape_inference +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_ diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e11522ea5bda7f5a303d6ea332148dbd7b17f162 --- /dev/null +++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "cuda/include/cuda.h" +#include "cuda/include/cuda_runtime_api.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace { + +class Logger : public nvinfer1::ILogger { + public: + void log(nvinfer1::ILogger::Severity severity, const char* msg) override { + switch (severity) { + case Severity::kINFO: + LOG(INFO) << msg; + break; + case Severity::kWARNING: + LOG(WARNING) << msg; + break; + case Severity::kINTERNAL_ERROR: + case Severity::kERROR: + LOG(ERROR) << msg; + break; + default: + break; + } + } +}; + +class ScopedWeights { + public: + ScopedWeights(float value) : value_(value) { + w.type = nvinfer1::DataType::kFLOAT; + w.values = &value_; + w.count = 1; + } + const nvinfer1::Weights& get() { return w; } + + private: + float value_; + nvinfer1::Weights w; +}; + +const char* kInputTensor = "input"; +const char* kOutputTensor = "output"; + +// Creates a network to compute y=2x+3. +nvinfer1::IHostMemory* CreateNetwork() { + Logger logger; + nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger); + ScopedWeights weights(2.0); + ScopedWeights bias(3.0); + + nvinfer1::INetworkDefinition* network = builder->createNetwork(); + // Add the input. + auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT, + nvinfer1::DimsCHW{1, 1, 1}); + EXPECT_NE(input, nullptr); + // Add the hidden layer. + auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get()); + EXPECT_NE(layer, nullptr); + // Mark the output. + auto output = layer->getOutput(0); + output->setName(kOutputTensor); + network->markOutput(*output); + // Build the engine + builder->setMaxBatchSize(1); + builder->setMaxWorkspaceSize(1 << 10); + auto engine = builder->buildCudaEngine(*network); + EXPECT_NE(engine, nullptr); + // Serialize the engine to create a model, then close everything. + nvinfer1::IHostMemory* model = engine->serialize(); + network->destroy(); + engine->destroy(); + builder->destroy(); + return model; +} + +// Executes the network. +void Execute(nvinfer1::IExecutionContext& context, const float* input, + float* output) { + const nvinfer1::ICudaEngine& engine = context.getEngine(); + + // We have two bindings: input and output. + ASSERT_EQ(engine.getNbBindings(), 2); + const int input_index = engine.getBindingIndex(kInputTensor); + const int output_index = engine.getBindingIndex(kOutputTensor); + + // Create GPU buffers and a stream + void* buffers[2]; + ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float))); + ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float))); + cudaStream_t stream; + ASSERT_EQ(0, cudaStreamCreate(&stream)); + + // Copy the input to the GPU, execute the network, and copy the output back. + // + // Note that since the host buffer was not created as pinned memory, these + // async copies are turned into sync copies. So the following synchronization + // could be removed. + ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float), + cudaMemcpyHostToDevice, stream)); + context.enqueue(1, buffers, stream, nullptr); + ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float), + cudaMemcpyDeviceToHost, stream)); + cudaStreamSynchronize(stream); + + // Release the stream and the buffers + cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaFree(buffers[input_index])); + ASSERT_EQ(0, cudaFree(buffers[output_index])); +} + +TEST(TensorrtTest, BasicFunctions) { + // Create the network model. + nvinfer1::IHostMemory* model = CreateNetwork(); + // Use the model to create an engine and then an execution context. + Logger logger; + nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger); + nvinfer1::ICudaEngine* engine = + runtime->deserializeCudaEngine(model->data(), model->size(), nullptr); + model->destroy(); + nvinfer1::IExecutionContext* context = engine->createExecutionContext(); + + // Execute the network. + float input = 1234; + float output; + Execute(*context, &input, &output); + EXPECT_EQ(output, input * 2 + 3); + + // Destroy the engine. + context->destroy(); + engine->destroy(); + runtime->destroy(); +} + +} // namespace +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py new file mode 100644 index 0000000000000000000000000000000000000000..18dba94acb3724cb2b5a1c53227bcf08bf9f8fcc --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -0,0 +1,88 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +# normally we should do import tensorflow as tf and then +# tf.placeholder, tf.constant, tf.nn.conv2d etc but +# it looks like internal builds don't like it so +# importing every module individually + +from tensorflow.contrib import tensorrt as trt +from tensorflow.core.protobuf import config_pb2 as cpb2 +from tensorflow.python.client import session as csess +from tensorflow.python.framework import constant_op as cop +from tensorflow.python.framework import dtypes as dtypes +from tensorflow.python.framework import importer as importer +from tensorflow.python.framework import ops as ops +from tensorflow.python.ops import array_ops as aops +from tensorflow.python.ops import nn as nn +from tensorflow.python.ops import nn_ops as nn_ops + + +def get_simple_graph_def(): + """Create a simple graph and return its graph_def""" + g = ops.Graph() + with g.as_default(): + a = aops.placeholder( + dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") + e = cop.constant( + [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], + name="weights", + dtype=dtypes.float32) + conv = nn.conv2d( + input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") + b = cop.constant( + [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32) + t = nn.bias_add(conv, b, name="biasAdd") + relu = nn.relu(t, "relu") + idty = aops.identity(relu, "ID") + v = nn_ops.max_pool( + idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + aops.squeeze(v, name="output") + return g.as_graph_def() + + +def run_graph(gdef, dumm_inp): + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + ops.reset_default_graph() + g = ops.Graph() + with g.as_default(): + inp, out = importer.import_graph_def( + graph_def=gdef, return_elements=["input", "output"]) + inp = inp.outputs[0] + out = out.outputs[0] + with csess.Session( + config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + val = sess.run(out, {inp: dumm_inp}) + return val + + +if "__main__" in __name__: + inp_dims = (100, 24, 24, 2) + dummy_input = np.random.random_sample(inp_dims) + gdef = get_simple_graph_def() + # Get optimized graph + trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0]) + o1 = run_graph(gdef, dummy_input) + o2 = run_graph(trt_graph, dummy_input) + o3 = run_graph(trt_graph, dummy_input) + assert np.array_equal(o1, o2) + assert np.array_equal(o3, o2) # sanity check + print("Pass") diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i new file mode 100644 index 0000000000000000000000000000000000000000..d679945d569c1784448b6cb09c2f431b9cda56d7 --- /dev/null +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* Wrap trt_conversion */ +%{ +#define SWIG_FILE_WITH_INIT +%} +%include "std_pair.i" +%include "tensorflow/python/platform/base.i" + +%{ +PyObject* pair_helper(std::pair* in) { + PyObject *first(nullptr), *second(nullptr), *tuple(nullptr); + first = PyBytes_FromStringAndSize(in->first.data(), in->first.length()); + if (!first) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "Pair conversion first argument failed"); + } + return NULL; + } + second = PyBytes_FromStringAndSize(in->second.data(), in->second.length()); + if (!second) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, + "Pair conversion second argument failed"); + } + return NULL; + } + tuple = Py_BuildValue("(OO)", first, second); + if (!tuple) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, + "Tuple creation from pair failed!"); + } + return NULL; + } + return tuple; +} +%} +%typemap(out) std::pair { + PyObject *tuple = pair_helper(&$1); + if (!tuple) SWIG_fail; + $result = tuple; +} +%{ +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/stat_summarizer.h" +#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +%} + +%ignoreall +%unignore tensorflow; +%unignore trt_convert; + +%{ +std::pair trt_convert( + string graph_def_string, // The serialized GraphDef string. + std::vector output_names, + size_t max_batch_size, + size_t max_workspace_size_bytes + // Unfortunately we can't use TF_Status here since it + // is in c/c_api and brings in a lot of other libraries + // which in turn declare ops. These ops are included + // statically in our library and cause an abort when + // module is loaded due to double registration + // until Tensorflow properly exposes these headers + // we have to work around this by returning a string + // and converting it to exception on python side. + //,TF_Status* out_status) { +) { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + string out_status; + + tensorflow::GraphDef graph_def; + if (!graph_def.ParseFromString(graph_def_string)) { + out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; + return std::pair{out_status, ""}; + } + + if (!output_names.size()) { + out_status = "InvalidArgument;Size of the output_names vector is 0"; + return std::pair{out_status, ""}; + // return ""; + } + tensorflow::GraphDef outGraph; + tensorflow::Status conversion_status = + tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( + graph_def, output_names, max_batch_size, max_workspace_size_bytes, + &outGraph); + if (!conversion_status.ok()) { + auto retCode = (int)conversion_status.code(); + char buff[2000]; + snprintf(buff, 2000, "%d;%s", retCode, + conversion_status.error_message().c_str()); + out_status = buff; + return std::pair{out_status, ""}; + } + string result; + if (!outGraph.SerializeToString(&result)) { + out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; + return std::pair{out_status, ""}; + } + out_status = "OK;All good!"; + return std::pair{out_status, result}; +#else + // Returns FAILED_PRECONDITION. + return std::pair{"9;TensorRT is not enabled!", ""}; +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +} +%} + +std::pair trt_convert(string graph_def_string, + std::vector output_names, + size_t max_batch_size, + size_t max_workspace_size_bytes); + + +%unignoreall diff --git a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv index 02a60d1cf61765c7c916803fe918d8b7b186405e..b49a0662c29b1d810f4be31ca1f318f0571f533e 100644 --- a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv +++ b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv @@ -1,100 +1,100 @@ -0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867 -1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303 -2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864 -3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426 -4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223 -5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842 -6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606 -7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347 -8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951 -9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228 -10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897 -11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634 -12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594 -13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394 -14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609 -15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449 -16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251 -17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382 -18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767 -19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713 -20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251 -21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811 -22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681 -23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735 -24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436 -25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899 -26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814 -27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727 -28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582 -29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555 -30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696 -31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548 -32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627 -33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104 -34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156 -35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459 -36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576 -37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584 -38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577 -39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467 -40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566 -41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909 -42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021 -43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831 -44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905 -45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271 -46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094 -47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554 -48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769 -49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606 -50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629 -51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199 -52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961 -53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122 -54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454 -55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301 -56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182 -57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365 -58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011 -59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449 -60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229 -61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259 -62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272 -63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989 -64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496 -65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376 -66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206 -67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502 -68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219 -69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125 -70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514 -71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166 -72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832 -73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913 -74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188 -75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388 -76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136 -77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766 -78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959 -79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083 -80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483 -81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656 -82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107 -83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991 -84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527 -85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649 -86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788 -87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289 -88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298 -89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873 -90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669 -91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462 -92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232 -93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225 -94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288 -95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086 -96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161 -97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227 -98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937 -99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724 +0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0. +1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0. +2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0. +3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0. +4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0. +5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0. +6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0. +7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0. +8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0. +9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0. +10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0. +11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0. +12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0. +13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0. +14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0. +15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0. +16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0. +17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0. +18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0. +19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0. +20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0. +21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0. +22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0. +23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0. +24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0. +25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0. +26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0. +27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0. +28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0. +29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0. +30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0. +31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0. +32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0. +33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0. +34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0. +35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0. +36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0. +37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0. +38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0. +39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0. +40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0. +41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0. +42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0. +43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0. +44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0. +45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0. +46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0. +47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0. +48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0. +49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0. +50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0. +51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0. +52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0. +53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0. +54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0. +55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0. +56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0. +57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0. +58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0. +59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0. +60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0. +61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0. +62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0. +63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0. +64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0. +65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0. +66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0. +67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0. +68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0. +69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0. +70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0. +71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0. +72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0. +73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0. +74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0. +75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0. +76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0. +77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0. +78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0. +79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0. +80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0. +81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0. +82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0. +83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0. +84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0. +85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0. +86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0. +87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0. +88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0. +89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0. +90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0. +91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0. +92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0. +93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0. +94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0. +95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0. +96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0. +97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0. +98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0. +99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0. diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py index c7193cef6915f9d0caf5b52fc084129cbc736994..f37cafcc502dc9415db0829b9b067b862f87dca7 100644 --- a/tensorflow/contrib/timeseries/examples/lstm.py +++ b/tensorflow/contrib/timeseries/examples/lstm.py @@ -18,13 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools from os import path +import tempfile import numpy import tensorflow as tf from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators from tensorflow.contrib.timeseries.python.timeseries import model as ts_model +from tensorflow.contrib.timeseries.python.timeseries import state_management try: import matplotlib # pylint: disable=g-import-not-at-top @@ -45,7 +48,8 @@ _DATA_FILE = path.join(_MODULE_PATH, "data/multivariate_periods.csv") class _LSTMModel(ts_model.SequentialTimeSeriesModel): """A time series model-building example using an RNNCell.""" - def __init__(self, num_units, num_features, dtype=tf.float32): + def __init__(self, num_units, num_features, exogenous_feature_columns=None, + dtype=tf.float32): """Initialize/configure the model object. Note that we do not start graph building here. Rather, this object is a @@ -55,6 +59,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): num_units: The number of units in the model's LSTMCell. num_features: The dimensionality of the time series (features per timestep). + exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn + objects representing features which are inputs to the model but are + not predicted by it. These must then be present for training, + evaluation, and prediction. dtype: The floating point data type to use. """ super(_LSTMModel, self).__init__( @@ -62,6 +70,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): train_output_names=["mean"], predict_output_names=["mean"], num_features=num_features, + exogenous_feature_columns=exogenous_feature_columns, dtype=dtype) self._num_units = num_units # Filled in by initialize_graph() @@ -69,7 +78,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): self._lstm_cell_run = None self._predict_from_lstm_output = None - def initialize_graph(self, input_statistics): + def initialize_graph(self, input_statistics=None): """Save templates for components, which can then be used repeatedly. This method is called every time a new graph is created. It's safe to start @@ -80,18 +89,19 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): input_statistics: A math_utils.InputStatistics object. """ super(_LSTMModel, self).initialize_graph(input_statistics=input_statistics) - self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units) - # Create templates so we don't have to worry about variable reuse. - self._lstm_cell_run = tf.make_template( - name_="lstm_cell", - func_=self._lstm_cell, - create_scope_now_=True) - # Transforms LSTM output into mean predictions. - self._predict_from_lstm_output = tf.make_template( - name_="predict_from_lstm_output", - func_= - lambda inputs: tf.layers.dense(inputs=inputs, units=self.num_features), - create_scope_now_=True) + with tf.variable_scope("", use_resource=True): + # Use ResourceVariables to avoid race conditions. + self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units) + # Create templates so we don't have to worry about variable reuse. + self._lstm_cell_run = tf.make_template( + name_="lstm_cell", + func_=self._lstm_cell, + create_scope_now_=True) + # Transforms LSTM output into mean predictions. + self._predict_from_lstm_output = tf.make_template( + name_="predict_from_lstm_output", + func_=functools.partial(tf.layers.dense, units=self.num_features), + create_scope_now_=True) def get_start_state(self): """Return initial state for the time series model.""" @@ -100,6 +110,8 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): tf.zeros([], dtype=tf.int64), # The previous observation or prediction. tf.zeros([self.num_features], dtype=self.dtype), + # The most recently seen exogenous features. + tf.zeros(self._get_exogenous_embedding_shape(), dtype=self.dtype), # The state of the RNNCell (batch dimension removed since this parent # class will broadcast). [tf.squeeze(state_element, axis=0) @@ -127,7 +139,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): loss (note that we could also return other measures of goodness of fit, although only "loss" will be optimized). """ - state_from_time, prediction, lstm_state = state + state_from_time, prediction, exogenous, lstm_state = state with tf.control_dependencies( [tf.assert_equal(current_times, state_from_time)]): # Subtract the mean and divide by the variance of the series. Slightly @@ -139,16 +151,22 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): (prediction - transformed_values) ** 2, axis=-1) # Keep track of the new observation in model state. It won't be run # through the LSTM until the next _imputation_step. - new_state_tuple = (current_times, transformed_values, lstm_state) + new_state_tuple = (current_times, transformed_values, + exogenous, lstm_state) return (new_state_tuple, predictions) def _prediction_step(self, current_times, state): """Advance the RNN state using a previous observation or prediction.""" - _, previous_observation_or_prediction, lstm_state = state + _, previous_observation_or_prediction, exogenous, lstm_state = state + # Update LSTM state based on the most recent exogenous and endogenous + # features. + inputs = tf.concat([previous_observation_or_prediction, exogenous], + axis=-1) lstm_output, new_lstm_state = self._lstm_cell_run( - inputs=previous_observation_or_prediction, state=lstm_state) + inputs=inputs, state=lstm_state) next_prediction = self._predict_from_lstm_output(lstm_output) - new_state_tuple = (current_times, next_prediction, new_lstm_state) + new_state_tuple = (current_times, next_prediction, + exogenous, new_lstm_state) return new_state_tuple, {"mean": self._scale_back_data(next_prediction)} def _imputation_step(self, current_times, state): @@ -160,36 +178,75 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): def _exogenous_input_step( self, current_times, current_exogenous_regressors, state): - """Update model state based on exogenous regressors.""" - raise NotImplementedError( - "Exogenous inputs are not implemented for this example.") + """Save exogenous regressors in model state for use in _prediction_step.""" + state_from_time, prediction, _, lstm_state = state + return (state_from_time, prediction, + current_exogenous_regressors, lstm_state) def train_and_predict( - csv_file_name=_DATA_FILE, training_steps=200, estimator_config=None): + csv_file_name=_DATA_FILE, training_steps=200, estimator_config=None, + export_directory=None): """Train and predict using a custom time series model.""" # Construct an Estimator from our LSTM model. + exogenous_feature_columns = [ + # Exogenous features are not part of the loss, but can inform + # predictions. In this example the features have no extra information, but + # are included as an API example. + tf.contrib.layers.real_valued_column( + "2d_exogenous_feature", dimension=2)] estimator = ts_estimators.TimeSeriesRegressor( - model=_LSTMModel(num_features=5, num_units=128), - optimizer=tf.train.AdamOptimizer(0.001), config=estimator_config) + model=_LSTMModel(num_features=5, num_units=128, + exogenous_feature_columns=exogenous_feature_columns), + optimizer=tf.train.AdamOptimizer(0.001), config=estimator_config, + # Set state to be saved across windows. + state_manager=state_management.ChainingStateManager()) reader = tf.contrib.timeseries.CSVReader( csv_file_name, column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,) - + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5)) + + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5 + + ("2d_exogenous_feature",) * 2)) train_input_fn = tf.contrib.timeseries.RandomWindowInputFn( reader, batch_size=4, window_size=32) estimator.train(input_fn=train_input_fn, steps=training_steps) evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader) evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1) # Predict starting after the evaluation + predict_exogenous_features = { + "2d_exogenous_feature": numpy.concatenate( + [numpy.ones([1, 100, 1]), numpy.zeros([1, 100, 1])], + axis=-1)} (predictions,) = tuple(estimator.predict( input_fn=tf.contrib.timeseries.predict_continuation_input_fn( - evaluation, steps=100))) + evaluation, steps=100, + exogenous_features=predict_exogenous_features))) times = evaluation["times"][0] observed = evaluation["observed"][0, :, :] predicted_mean = numpy.squeeze(numpy.concatenate( [evaluation["mean"][0], predictions["mean"]], axis=0)) all_times = numpy.concatenate([times, predictions["times"]], axis=0) + + # Export the model in SavedModel format. + if export_directory is None: + export_directory = tempfile.mkdtemp() + input_receiver_fn = estimator.build_raw_serving_input_receiver_fn() + export_location = estimator.export_savedmodel( + export_directory, input_receiver_fn) + # Predict using the SavedModel + with tf.Graph().as_default(): + with tf.Session() as session: + signatures = tf.saved_model.loader.load( + session, [tf.saved_model.tag_constants.SERVING], export_location) + saved_model_output = ( + tf.contrib.timeseries.saved_model_utils.predict_continuation( + continue_from=evaluation, signatures=signatures, + session=session, steps=100, + exogenous_features=predict_exogenous_features)) + # The exported model gives the same results as the Estimator.predict() + # call above. + numpy.testing.assert_allclose( + predictions["mean"], + numpy.squeeze(saved_model_output["mean"], axis=0)) return times, observed, all_times, predicted_mean diff --git a/tensorflow/contrib/timeseries/examples/lstm_test.py b/tensorflow/contrib/timeseries/examples/lstm_test.py index 3cace567266d497b12d836f44a335bbe5d916949..ca56e38ca079f71b38cf29605a295a50929945e8 100644 --- a/tensorflow/contrib/timeseries/examples/lstm_test.py +++ b/tensorflow/contrib/timeseries/examples/lstm_test.py @@ -36,7 +36,8 @@ class LSTMExampleTest(test.TestCase): def test_periodicity_learned(self): (observed_times, observed_values, all_times, predicted_values) = lstm.train_and_predict( - training_steps=100, estimator_config=_SeedRunConfig()) + training_steps=100, estimator_config=_SeedRunConfig(), + export_directory=self.get_temp_dir()) self.assertAllEqual([100], observed_times.shape) self.assertAllEqual([100, 5], observed_values.shape) self.assertAllEqual([200], all_times.shape) diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 3738dfa154d4f39b9562446972443ed88f3fbe8b..f8355f366fe8e191ab570fd271bbe4a8bf71c73d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.layers.python.layers import feature_column + from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import feature_keys from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib @@ -72,15 +74,14 @@ class TimeSeriesRegressor(estimator_lib.Estimator): # tf.Example containing all features (times, values, any exogenous features) # and serialized model state (possibly also as a tf.Example). def build_raw_serving_input_receiver_fn(self, - exogenous_features=None, default_batch_size=None, default_series_length=None): """Build an input_receiver_fn for export_savedmodel which accepts arrays. + Automatically creates placeholders for exogenous `FeatureColumn`s passed to + the model. + Args: - exogenous_features: A dictionary mapping feature keys to exogenous - features (either Numpy arrays or Tensors). Used to determine the shapes - of placeholders for these features. default_batch_size: If specified, must be a scalar integer. Sets the batch size in the static shape information of all feature Tensors, which means only this batch size will be accepted by the exported model. If None @@ -94,9 +95,6 @@ class TimeSeriesRegressor(estimator_lib.Estimator): An input_receiver_fn which may be passed to the Estimator's export_savedmodel. """ - if exogenous_features is None: - exogenous_features = {} - def _serving_input_receiver_fn(): """A receiver function to be passed to export_savedmodel.""" placeholders = {} @@ -119,14 +117,22 @@ class TimeSeriesRegressor(estimator_lib.Estimator): dtype=self._model.dtype), shape=(default_batch_size, default_series_length, self._model.num_features))) - for feature_key, feature_value in exogenous_features.items(): - value_tensor = ops.convert_to_tensor(feature_value) - value_tensor.get_shape().with_rank_at_least(2) - feature_shape = value_tensor.get_shape().as_list() - feature_shape[0] = default_batch_size - feature_shape[1] = default_series_length + with ops.Graph().as_default(): + # Default placeholders have only an unknown batch dimension. Make them + # in a separate graph, then splice in the series length to the shapes + # and re-create them in the outer graph. + exogenous_feature_shapes = { + key: (value.get_shape(), value.dtype) for key, value + in feature_column.make_place_holder_tensors_for_base_features( + self._model.exogenous_feature_columns).items()} + for feature_key, (batch_only_feature_shape, value_dtype) in ( + exogenous_feature_shapes.items()): + batch_only_feature_shape = batch_only_feature_shape.with_rank_at_least( + 1).as_list() + feature_shape = ([default_batch_size, default_series_length] + + batch_only_feature_shape[1:]) placeholders[feature_key] = array_ops.placeholder( - dtype=value_tensor.dtype, name=feature_key, shape=feature_shape) + dtype=value_dtype, name=feature_key, shape=feature_shape) # Models may not know the shape of their state without creating some # variables/ops. Avoid polluting the default graph by making a new one. We # use only static metadata from the returned Tensors. diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py index d4ee59036624cffb216709e096981d362670e416..04225333b9377447f46d32663df76aece97a51e7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py +++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py @@ -500,6 +500,41 @@ class CSVReader(ReaderBaseTimeSeriesParser): return features +class TFExampleReader(ReaderBaseTimeSeriesParser): + """Reads and parses `tf.Example`s from a TFRecords file.""" + + def __init__(self, + filenames, + features): + """Configure `tf.Example` parsing. + + Args: + filenames: A filename or list of filenames to read the time series + from. Each line must have columns corresponding to `column_names`. + features: A dictionary mapping from feature keys to `tf.FixedLenFeature` + objects. Must include `TrainEvalFeatures.TIMES` (scalar integer) and + `TrainEvalFeatures.VALUES` (floating point vector) features. + Raises: + ValueError: If required times/values features are not present. + """ + if feature_keys.TrainEvalFeatures.TIMES not in features: + raise ValueError("'{}' is a required column.".format( + feature_keys.TrainEvalFeatures.TIMES)) + if feature_keys.TrainEvalFeatures.VALUES not in features: + raise ValueError("'{}' is a required column.".format( + feature_keys.TrainEvalFeatures.VALUES)) + self._features = features + super(TFExampleReader, self).__init__(filenames=filenames) + + def _get_reader(self): + return io_ops.TFRecordReader() + + def _process_records(self, examples): + """Parse `tf.Example`s into `Tensors`.""" + return parsing_ops.parse_example( + serialized=examples, features=self._features) + + class TimeSeriesInputFn(object): """Base for classes which create batches of windows from a time series.""" diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py index ed78a835a4d451e9e7d18bb833d8ebed6c05a195..703537abf0fe3985aaf0434cc633cb410dd6bd4c 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py @@ -27,7 +27,11 @@ from tensorflow.contrib.timeseries.python.timeseries import input_pipeline from tensorflow.contrib.timeseries.python.timeseries import test_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures +from tensorflow.core.example import example_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.lib.io import tf_record +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import coordinator as coordinator_lib @@ -52,6 +56,21 @@ def _make_csv_time_series(num_features, num_samples, test_tmpdir): return filename +def _make_tfexample_series(num_features, num_samples, test_tmpdir): + _, data_file = tempfile.mkstemp(dir=test_tmpdir) + with tf_record.TFRecordWriter(data_file) as writer: + for i in range(num_samples): + example = example_pb2.Example() + times = example.features.feature[TrainEvalFeatures.TIMES] + times.int64_list.value.append(i) + values = example.features.feature[TrainEvalFeatures.VALUES] + values.float_list.value.extend( + [float(i) * 2. + feature_number + for feature_number in range(num_features)]) + writer.write(example.SerializeToString()) + return data_file + + def _make_numpy_time_series(num_features, num_samples): times = numpy.arange(num_samples) values = times[:, None] * 2. + numpy.arange(num_features)[None, :] @@ -107,6 +126,19 @@ class RandomWindowInputFnTests(test.TestCase): time_series_reader = input_pipeline.CSVReader([filename]) self._test_out_of_order(time_series_reader, discard_out_of_order=False) + def test_tfexample_sort_out_of_order(self): + filename = _make_tfexample_series( + num_features=1, num_samples=50, + test_tmpdir=self.get_temp_dir()) + time_series_reader = input_pipeline.TFExampleReader( + [filename], + features={ + TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature( + shape=[], dtype=dtypes.int64), + TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature( + shape=[1], dtype=dtypes.float32)}) + self._test_out_of_order(time_series_reader, discard_out_of_order=False) + def test_numpy_sort_out_of_order(self): data = _make_numpy_time_series(num_features=1, num_samples=50) time_series_reader = input_pipeline.NumpyReader(data) @@ -183,6 +215,20 @@ class RandomWindowInputFnTests(test.TestCase): self._test_multivariate(time_series_reader=time_series_reader, num_features=2) + def test_tfexample_multivariate(self): + filename = _make_tfexample_series( + num_features=2, num_samples=50, + test_tmpdir=self.get_temp_dir()) + time_series_reader = input_pipeline.TFExampleReader( + [filename], + features={ + TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature( + shape=[], dtype=dtypes.int64), + TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature( + shape=[2], dtype=dtypes.float32)}) + self._test_multivariate(time_series_reader=time_series_reader, + num_features=2) + def test_numpy_multivariate(self): data = _make_numpy_time_series(num_features=3, num_samples=50) time_series_reader = input_pipeline.NumpyReader(data) @@ -248,6 +294,20 @@ class WholeDatasetInputFnTests(test.TestCase): self._whole_dataset_input_fn_test_template( time_series_reader=time_series_reader, num_features=1, num_samples=50) + def test_tfexample(self): + filename = _make_tfexample_series( + num_features=4, num_samples=100, + test_tmpdir=self.get_temp_dir()) + time_series_reader = input_pipeline.TFExampleReader( + [filename], + features={ + TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature( + shape=[], dtype=dtypes.int64), + TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature( + shape=[4], dtype=dtypes.float32)}) + self._whole_dataset_input_fn_test_template( + time_series_reader=time_series_reader, num_features=4, num_samples=100) + def test_numpy(self): data = _make_numpy_time_series(num_features=4, num_samples=100) time_series_reader = input_pipeline.NumpyReader(data) diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index b32b5c5494ae14187954b900119678a5b53a3602..bac7d1ebf59b28d4688a3d1a69ecdc1fc12248e0 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -22,6 +22,7 @@ import abc import collections from tensorflow.contrib import layers +from tensorflow.contrib.layers import feature_column from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures @@ -83,6 +84,11 @@ class TimeSeriesModel(object): self._stats_means = None self._stats_sigmas = None + @property + def exogenous_feature_columns(self): + """`FeatureColumn` objects for features which are not predicted.""" + return self._exogenous_feature_columns + # TODO(allenl): Move more of the generic machinery for generating and # predicting into TimeSeriesModel, and possibly share it between generate() # and predict() @@ -250,6 +256,23 @@ class TimeSeriesModel(object): """ pass + def _get_exogenous_embedding_shape(self): + """Computes the shape of the vector returned by _process_exogenous_features. + + Returns: + The shape as a list. Does not include a batch dimension. + """ + if not self._exogenous_feature_columns: + return (0,) + with ops.Graph().as_default(): + placeholder_features = ( + feature_column.make_place_holder_tensors_for_base_features( + self._exogenous_feature_columns)) + embedded = layers.input_from_feature_columns( + columns_to_tensors=placeholder_features, + feature_columns=self._exogenous_feature_columns) + return embedded.get_shape().as_list()[1:] + def _process_exogenous_features(self, times, features): """Create a single vector from exogenous features. diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py index 5980fc5d5deccc151b01c72fa19b734a7c485bdc..1fb4a3c121c8d7c1daf8fc4a3f59a8b8de38bf8f 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py @@ -187,9 +187,7 @@ class StateSpaceEquivalenceTests(test.TestCase): estimator.train(combined_input_fn, steps=1) export_location = estimator.export_savedmodel( self.get_temp_dir(), - estimator.build_raw_serving_input_receiver_fn( - exogenous_features={ - "exogenous": numpy.zeros((0, 0), dtype=numpy.float32)})) + estimator.build_raw_serving_input_receiver_fn()) with ops.Graph().as_default() as graph: random_model.initialize_graph() with self.test_session(graph=graph) as session: @@ -209,7 +207,7 @@ class StateSpaceEquivalenceTests(test.TestCase): features={ feature_keys.FilteringFeatures.TIMES: [1, 2], feature_keys.FilteringFeatures.VALUES: [1., 2.], - "exogenous": [-1., -2.]}) + "exogenous": [[-1.], [-2.]]}) second_split_filtering = saved_model_utils.filter_continuation( continue_from=first_split_filtering, signatures=signatures, @@ -217,7 +215,7 @@ class StateSpaceEquivalenceTests(test.TestCase): features={ feature_keys.FilteringFeatures.TIMES: [3, 4], feature_keys.FilteringFeatures.VALUES: [3., 4.], - "exogenous": [-3., -4.] + "exogenous": [[-3.], [-4.]] }) combined_filtering = saved_model_utils.filter_continuation( continue_from={ @@ -227,7 +225,7 @@ class StateSpaceEquivalenceTests(test.TestCase): features={ feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4], feature_keys.FilteringFeatures.VALUES: [1., 2., 3., 4.], - "exogenous": [-1., -2., -3., -4.] + "exogenous": [[-1.], [-2.], [-3.], [-4.]] }) split_predict = saved_model_utils.predict_continuation( continue_from=second_split_filtering, @@ -235,14 +233,14 @@ class StateSpaceEquivalenceTests(test.TestCase): session=session, steps=1, exogenous_features={ - "exogenous": [[-5.]]}) + "exogenous": [[[-5.]]]}) combined_predict = saved_model_utils.predict_continuation( continue_from=combined_filtering, signatures=signatures, session=session, steps=1, exogenous_features={ - "exogenous": [[-5.]]}) + "exogenous": [[[-5.]]]}) for state_key, combined_state_value in combined_filtering.items(): if state_key == feature_keys.FilteringResults.TIMES: continue diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 0199313bc8d0214a547498b97e9a1d83ee37b708..c48e84ddfaac8ac9c07e061847315eab3fd72152 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -36,13 +36,16 @@ py_library( name = "tpu_estimator", srcs = [ "python/tpu/tpu_config.py", + "python/tpu/tpu_context.py", "python/tpu/tpu_estimator.py", + "python/tpu/tpu_system_metadata.py", "python/tpu/util.py", ], srcs_version = "PY2AND3", deps = [ ":tpu_lib", ":tpu_py", + "//tensorflow/contrib/summary:summary_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", diff --git a/tensorflow/contrib/tpu/ops/infeed_ops.cc b/tensorflow/contrib/tpu/ops/infeed_ops.cc index 849c4a1102787870b372c35740cf0fe271efa078..efc546f9a6077de9cac5a5acefa3fc7206547fc6 100644 --- a/tensorflow/contrib/tpu/ops/infeed_ops.cc +++ b/tensorflow/contrib/tpu/ops/infeed_ops.cc @@ -41,6 +41,7 @@ REGISTER_OP("InfeedEnqueue") .Attr("dtype: type") .Attr("shape: shape = {}") .Attr("device_ordinal: int = -1") + .SetShapeFn(shape_inference::NoOutputs) .SetIsStateful() .Doc(R"doc( An op which feeds a single Tensor value into the computation. @@ -58,6 +59,7 @@ REGISTER_OP("InfeedEnqueueTuple") .Attr("dtypes: list(type)") .Attr("shapes: list(shape)") .Attr("device_ordinal: int = -1") + .SetShapeFn(shape_inference::NoOutputs) .SetIsStateful() .Doc(R"doc( An op which feeds multiple Tensor values into the computation as an XLA tuple. diff --git a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc index 28417b89e0d4e0c5b2ca4f4794d29ab8a31049d7..f8de8baa65339383c7f92284ee274a434f12f8c2 100644 --- a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc @@ -212,4 +212,20 @@ An op that shuts down a running distributed TPU system. The Op returns an error if no system is running. )doc"); -} // namespace tensorflow +REGISTER_OP("SessionStatus") + .Input("fetch_start_timestamp: double") + .Output("status: string") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Not for public usage. + +Returns messages from the current session as a serialized SessionStatusProto. + +This includes the current state of the compiler, along with any critical +logging or warning messages. + +fetch_start_timestamp: any messages earlier than this will be excluded from the +returned proto. +)doc"); + +} // end namespace tensorflow diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 1cded9f8cf01b931d1d535a54effd54459dd8e9a..b1ef9fde37fe0647965f0818895be37d2d56d207 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/contrib/tpu/profiler/version.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -47,22 +48,40 @@ string GetCurrentTimeStampAsString() { return s; } -ProfileResponse Profile(const string& service_addr, int duration_ms) { +Status ValidateHostPortPair(const string& host_port) { + uint32 port; + std::vector parts = str_util::Split(host_port, ':'); + // Must be host:port, port must be a number, host must not contain a '/', + // host also must not be empty. + if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) || + parts[0].find("/") != string::npos || parts[0].empty()) { + return errors::InvalidArgument("Could not interpret \"", host_port, + "\" as a host-port pair."); + } + return Status::OK(); +} + +ProfileResponse Profile(const string& service_addr, int duration_ms, + const ProfileOptions& opts) { ProfileRequest request; request.set_duration_ms(duration_ms); request.set_max_events(kMaxEvents); request.add_tools("input_pipeline"); request.add_tools("overview_page"); + *request.mutable_opts() = opts; std::cout << "Limiting the number of trace events to " << kMaxEvents << std::endl; ::grpc::ClientContext context; ::grpc::ChannelArguments channel_args; // TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available. + // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their + // `ValidateHostPortPair` checks for empty host string case. channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits::max()); std::unique_ptr stub = TPUProfiler::NewStub(::grpc::CreateCustomChannel( - service_addr, ::grpc::InsecureChannelCredentials(), channel_args)); + "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), + channel_args)); ProfileResponse response; TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response))); return response; @@ -76,13 +95,21 @@ int main(int argc, char** argv) { tensorflow::string FLAGS_service_addr; tensorflow::string FLAGS_logdir; int FLAGS_duration_ms = 2000; + int FLAGS_num_tracing_attempts = 3; + bool FLAGS_include_dataset_ops = true; std::vector flag_list = { tensorflow::Flag("service_addr", &FLAGS_service_addr, "Address of TPU profiler service e.g. localhost:8466"), tensorflow::Flag("logdir", &FLAGS_logdir, - "Path of TensorBoard log directory e.g. /tmp/tb_log"), + "Path of TensorBoard log directory e.g. /tmp/tb_log, " + "gs://tb_bucket"), tensorflow::Flag("duration_ms", &FLAGS_duration_ms, "Duration of tracing in ms. Default is 2000ms."), + tensorflow::Flag("num_tracing_attempts", &FLAGS_num_tracing_attempts, + "Automatically retry N times when no trace event " + "is collected. Default is 3."), + tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops, + "Set to false to profile longer TPU device traces."), }; std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION @@ -91,14 +118,46 @@ int main(int argc, char** argv) { tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) { - std::printf("%s", usage.c_str()); + std::cout << usage.c_str() << std::endl; + return 2; + } + tensorflow::Status status = + tensorflow::tpu::ValidateHostPortPair(FLAGS_service_addr); + if (!status.ok()) { + std::cout << status.error_message() << std::endl; + std::cout << usage.c_str() << std::endl; return 2; } tensorflow::port::InitMain(argv[0], &argc, &argv); - int duration_ms = FLAGS_duration_ms; - tensorflow::ProfileResponse response = - tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms); + // Sets the minimum duration_ms and tracing attempts to one. + int duration_ms = std::max(FLAGS_duration_ms, 1); + int remaining_attempts = std::max(FLAGS_num_tracing_attempts, 1); + tensorflow::ProfileOptions opts; + opts.set_include_dataset_ops(FLAGS_include_dataset_ops); + tensorflow::ProfileResponse response; + + while (true) { + std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " + << "Remaining attempt(s): " << remaining_attempts-- << std::endl; + response = tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms, opts); + if (remaining_attempts <= 0 || !response.encoded_trace().empty()) break; + std::cout << "No trace event is collected. Automatically retrying." + << std::endl + << std::endl; + } + + if (response.encoded_trace().empty()) { + std::cout << "No trace event is collected after " + << FLAGS_num_tracing_attempts << " attempt(s). " + << "Perhaps, you want to try again (with more attempts?)." + << std::endl + << "Tip: increase number of attempts with --num_tracing_attempts." + << std::endl; + // Don't dump profile data if no trace is collected. + return 0; + } + // Use the current timestamp as the run name. tensorflow::string run = tensorflow::tpu::GetCurrentTimeStampAsString(); TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile( diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index b842951eb2c22792a22d9a16c022d3122391f4e8..ebd6185faad28ae7a22eb33f6b358eb2344c9c22 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -151,10 +151,7 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run, TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(profile_run_dir)); // Ignore computation_graph for now. - const bool empty_trace = response.encoded_trace().empty(); - if (empty_trace) { - *os << "No trace event is collected." << std::endl; - } else { + if (!response.encoded_trace().empty()) { LOG(INFO) << "Converting trace events to TraceViewer JSON."; TF_RETURN_IF_ERROR( DumpTraceToLogDirectory(profile_run_dir, response.encoded_trace(), os)); @@ -165,11 +162,9 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run, TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir, response.op_profile(), os)); } - if (!empty_trace && !response.tool_data().empty()) { - for (const auto& tool_data : response.tool_data()) { - TF_RETURN_IF_ERROR( - DumpToolDataToLogDirectory(profile_run_dir, tool_data, os)); - } + for (const auto& tool_data : response.tool_data()) { + TF_RETURN_IF_ERROR( + DumpToolDataToLogDirectory(profile_run_dir, tool_data, os)); } return Status::OK(); diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h index 25b958bcfeab7e0cfd9c180b8af4057e9bdfc73b..29ef977bacfd61e163be49558c5b94277ed479c1 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h @@ -27,7 +27,10 @@ namespace tpu { // The following tools are supported: // - Trace viewer // - Op profile -// - HLO computation graph +// - Input pipeline analyzer +// - Overview page +// Note: this function creates a directory even when all fields in +// ProfileResponse are unset/empty. Status WriteTensorboardTPUProfile(const string& logdir, const string& run, const ProfileResponse& response, std::ostream* os); diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 846db1332991e8c84f51dc7e6bcc3592a955991e..a730d6142d890cc41f72176cf617ac0b0434192c 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl import flags import os import subprocess @@ -24,13 +25,36 @@ import sys import tensorflow as tf -tf.flags.DEFINE_string('service_addr', '', - 'Address of TPU profiler service e.g. localhost:8466') -tf.flags.DEFINE_string('logdir', '', - 'Path of TensorBoard log directory e.g. /tmp/tb_log') -tf.flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.') +# Cloud TPU Cluster Resolvers +flags.DEFINE_string( + 'gcp_project', None, + 'Project name for the Cloud TPU-enabled project. If not specified, we ' + 'will attempt to automatically detect the GCE project from metadata.') +flags.DEFINE_string( + 'tpu_zone', + None, + help='GCE zone where the Cloud TPU is located in. If not specified, we ' + 'will attempt to automatically detect the GCE project from metadata.') +flags.DEFINE_string('tpu_name', None, + 'Name of the Cloud TPU for Cluster Resolvers. You must ' + 'specify either this flag or --master.') -FLAGS = tf.flags.FLAGS +# Tool specific parameters +flags.DEFINE_string( + 'service_addr', None, 'Address of TPU profiler service e.g. ' + 'localhost:8466, you must specify either this flag or --tpu_name.') +flags.DEFINE_string('logdir', None, + 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' + 'gs://tb_bucket') +flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.') +flags.DEFINE_integer('num_tracing_attempts', 3, + 'Automatically retry N times when no trace ' + 'event is collected.') +flags.DEFINE_boolean('include_dataset_ops', True, + 'Set to false to profile longer TPU ' + 'device traces.') + +FLAGS = flags.FLAGS EXECUTABLE = 'data/capture_tpu_profile' @@ -39,14 +63,35 @@ def run_main(): def main(unused_argv=None): - if not FLAGS.service_addr or not FLAGS.logdir: - sys.exit('service_addr and logdir must be provided.') + tf.logging.set_verbosity(tf.logging.INFO) + + if FLAGS.service_addr is None and FLAGS.tpu_name is None: + sys.exit('You must specify either --service_addr or --tpu_name.') + + if FLAGS.service_addr is not None: + if FLAGS.tpu_name is not None: + tf.logging.warn('Both --service_addr and --tpu_name are set. Ignoring ' + '--tpu_name and using --service_addr.') + service_addr = FLAGS.service_addr + else: + tpu_cluster_resolver = ( + tf.contrib.cluster_resolver.TPUClusterResolver( + tpu_names=[FLAGS.tpu_name], + zone=FLAGS.tpu_zone, + project=FLAGS.gcp_project)) + service_addr = tpu_cluster_resolver.get_master() + service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466') + + if not FLAGS.logdir: + sys.exit('logdir must be provided.') executable_path = os.path.join(os.path.dirname(__file__), EXECUTABLE) logdir = os.path.expandvars(os.path.expanduser(FLAGS.logdir)) cmd = [executable_path] - cmd.append('--logdir='+logdir) - cmd.append('--service_addr='+FLAGS.service_addr) - cmd.append('--duration_ms='+str(FLAGS.duration_ms)) + cmd.append('--logdir=' + logdir) + cmd.append('--service_addr=' + service_addr) + cmd.append('--duration_ms=' + str(FLAGS.duration_ms)) + cmd.append('--num_tracing_attempts=' + str(FLAGS.num_tracing_attempts)) + cmd.append('--include_dataset_ops=' + str(FLAGS.include_dataset_ops).lower()) subprocess.call(cmd) diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index 92196638318f4a551619d04ba730ac66a58d596e..8d99835b64152629c66607e6792495eb36319eb8 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -20,7 +20,7 @@ from __future__ import print_function from setuptools import setup -_VERSION = '1.4.3-a2' +_VERSION = '1.6.0-rc1' CONSOLE_SCRIPTS = [ 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', @@ -47,20 +47,16 @@ setup( # 4 - Beta # 5 - Production/Stable 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', - 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Scientific/Engineering :: Artificial Intelligence', @@ -69,4 +65,5 @@ setup( 'Topic :: Software Development :: Libraries :: Python Modules', ], license='Apache 2.0', - keywords='tensorflow performance tpu',) + keywords='tensorflow performance tpu', +) diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto index 5440bbbfdd75207bd209c19d5cc42dc69504d39b..2094294baad63ae73712c8648b588accd4551ef8 100644 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -61,6 +61,11 @@ message OpMetricsResult { message OpMetricsDbResult { // A bunch of OpMetricsResults. repeated OpMetricsResult metrics_db = 1; + // The total host infeed-enqueue duration in picoseconds. + optional uint64 total_host_infeed_enq_duration_ps = 2; + // The total of the difference between the start times of two + // consecutive infeed-enqueues (per host) in picoseconds. + optional uint64 total_host_infeed_enq_start_timestamp_ps_diff = 3; } // Result proto for StepInfo. diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index bf30d2ce091302eaf361a0018464d3b7de94ea6d..f3f3302ceb3d27dbb21bdce753aeb2d7fcd77448 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -13,6 +13,14 @@ service TPUProfiler { } } +message ProfileOptions { + // We don't collect the dataset ops by default for better trace-viewer + // scalability. The caller can mannually set this field to include the ops. + bool include_dataset_ops = 1; + + // next-field: 2 +} + message ProfileRequest { // In future, the caller will be able to customize when profiling starts and // stops. For now, it collects `duration_ms` milliseconds worth of data. @@ -25,10 +33,13 @@ message ProfileRequest { // required profiling tools name such as "input_pipeline_analyzer" etc repeated string tools = 3; + // Optional profiling options that control how a TF session will be profiled. + ProfileOptions opts = 4; + // In future, the caller will indicate which TF session is being profiled, and // only data relating to that program will be returned. For now, we assume // all activity during the profiling period is relevant. - // next-field: 4 + // next-field: 5 } message ProfileToolData { diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h index 0f645a549296b0f05acfb7ae564be1daf37925f8..dc6a934891138018d32d511750120453bdf290cf 100644 --- a/tensorflow/contrib/tpu/profiler/version.h +++ b/tensorflow/contrib/tpu/profiler/version.h @@ -16,6 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ #define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ -#define TPU_PROFILER_VERSION "1.4.3" +#define TPU_PROFILER_VERSION "1.5.0" #endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index a49a3dcf2999053d9b0d5ffcb6411e693702d785..97876216793e0e6b20b7c072cac4f575b8fd48be 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -47,7 +47,8 @@ if platform.system() != "Windows": # types are supported. _SUPPORTED_INFEED_DTYPES = set([ - dtypes.int32, dtypes.bfloat16, dtypes.float32 + dtypes.bool, dtypes.int32, dtypes.bfloat16, dtypes.float32, + dtypes.complex64 ]) def infeed_dequeue(dtype, shape, name=None): diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py index ee202610a8a8a1406363b3010771e7806d5d84bf..bdd9b88af55fa4fb483ddbdbe5c51d7076cce675 100644 --- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py +++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py @@ -87,6 +87,8 @@ class DeviceAssignment(object): core_assignment.shape)) self._core_assignment = core_assignment + self._task_and_cores_to_replicas = self._compute_task_and_cores_to_replicas( + self._core_assignment, self._topology_tasks) def _invert_topology(self, topology): """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps.""" @@ -100,6 +102,34 @@ class DeviceAssignment(object): devices[x, y, z] = device return tasks, devices + def _compute_task_and_cores_to_replicas(self, core_assignment, + topology_tasks): + """Computes a nested dict which maps task and logical core to replicas.""" + task_and_cores_to_replicas = {} + for replica in xrange(core_assignment.shape[0]): + for dx in xrange(core_assignment.shape[1]): + for dy in xrange(core_assignment.shape[2]): + for dz in xrange(core_assignment.shape[3]): + x, y, z = core_assignment[replica, dx, dy, dz, :] + task_id = topology_tasks[x, y, z] + if task_id not in task_and_cores_to_replicas: + task_and_cores_to_replicas[task_id] = {} + logical_core = (dx, dy, dz) + if logical_core not in task_and_cores_to_replicas[task_id]: + task_and_cores_to_replicas[task_id][logical_core] = set() + + task_and_cores_to_replicas[task_id][logical_core].add(replica) + + task_to_sorted_replica_id = {} + + for task, core_to_replicas in task_and_cores_to_replicas.items(): + core_to_sorted_replicas = {} + for core, replicas in core_to_replicas.items(): + core_to_sorted_replicas[core] = sorted(replicas) + + task_to_sorted_replica_id[task] = core_to_sorted_replicas + return task_to_sorted_replica_id + @property def topology(self): """A `Topology` that describes the TPU topology.""" @@ -119,6 +149,11 @@ class DeviceAssignment(object): """ return self._computation_shape + @property + def num_cores_per_replica(self): + """The number of cores per replica.""" + return np.prod(self.computation_shape) + @property def num_replicas(self): """The number of replicas of the computation.""" @@ -148,6 +183,26 @@ class DeviceAssignment(object): logical_offset = tuple([replica] + logical_core.tolist() + [slice(3)]) return tuple(self.core_assignment[logical_offset]) + def lookup_replicas(self, task_id, logical_core): + """Lookup replica ids by task number and logical core. + + Args: + task_id: TensorFlow task number. + logical_core: A tuple of three integers which represents a logical core. + Returns: + A sorted list of the replicas that are attached to that task and + loical_core. + Raises: + ValueError: If no replica exisis in the task which contains the logical + core. + """ + try: + return self._task_and_cores_to_replicas[task_id][logical_core] + except KeyError: + raise ValueError( + "Can not find any replica in task: {} contains logical_core: {} ". + format(task_id, logical_core)) + def tpu_ordinal(self, replica=0, logical_core=None): """Returns the ordinal of the TPU device assigned to a logical core.""" coordinates = self._coordinates(replica, logical_core) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 8fec379aad8a90d06cd05f4858d25656384a12b2..d5f54ff4fd278f0c84f79e0079bfb7a409dfba8d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -153,10 +153,11 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): raise NotImplementedError( "Non-resource Variables are not supported inside TPU computations " "(operator name: %s)" % op.name) - # pylint: enable=protected-access if _TPU_REPLICATE_ATTR in op.node_def.attr: raise ValueError("TPU computations cannot be nested") - op.node_def.attr[_TPU_REPLICATE_ATTR].s = compat.as_bytes(self._name) + op._set_attr(_TPU_REPLICATE_ATTR, + attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) + # pylint: enable=protected-access op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 0c2580211ab7674d841ca1953c9327df9488bb8e..644070218214643923b9ca3ee138615ec568e8b5 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -23,6 +23,8 @@ import collections import json import os +import numpy as np + from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.platform import tf_logging as logging @@ -31,29 +33,44 @@ from tensorflow.python.platform import tf_logging as logging _TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV _SERVICE_KEY = run_config_lib._SERVICE_KEY _TPU_WORKER_JOB_NAME = 'tpu_worker_job_name' +_NUM_CORES_PER_HOST = 8 # pylint: enable=protected-access +# TODO(b/72511246) Provide a simplified api to configure model parallelism. class TPUConfig( collections.namedtuple('TPUConfig', [ 'iterations_per_loop', 'num_shards', + 'computation_shape', 'per_host_input_for_training', 'tpu_job_name', 'initial_infeed_sleep_secs', ])): - """TPU related configuration required by `TPUEstimator`. + r"""TPU related configuration required by `TPUEstimator`. Args: - iterations_per_loop: This is the number of train steps runnining in TPU + iterations_per_loop: This is the number of train steps running in TPU system before returning to CPU host for each `Session.run`. This means global step is increased `iterations_per_loop` times in one `Session.run`. It is recommended to be set as number of global steps for next checkpoint. - num_shards: The number of TPU shards in the system. + num_shards: (Deprecated, ignored by TPUEstimator). + The number of model replicas in the system. For non-model-parallelism + case, this number equals the total number of TPU cores. For + model-parallelism, the total number of TPU cores equals + product(computation_shape) * num_shards. + computation_shape: Defaults to `None`, which disables model parallelism. A + list of size 3 which describes the shape of a model replica's block of + cores. This is required by model-parallelism which enables partitioning + the model to multiple cores. For example, [2, 2, 1] means the model is + partitioned across 4 cores which span two cores in both x and y + coordinates. Please refer to ${tf.contrib.tpu.TopologyProto} for the + geometry of a TPU mesh. per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host rather than Per-Core. With Per-Host input pipeline deployment, `input_fn` - is invoked once on each host. To be precise, with a global batch size + is invoked once on each host. With Per-Core input pipeline deployment, it + is invoked once for each core. To be precise, with a global batch size `train_batch_size` in `TPUEstimator` constructor, the batch size for each shard is `train_batch_size` // #hosts. With Per-Core input pipeline deployment, the shard batch size is `train_batch_size` // #cores. @@ -64,11 +81,15 @@ class TPUConfig( initial_infeed_sleep_secs: The number of seconds the infeed thread should wait before enqueueing the first batch. This helps avoid timeouts for models that require a long compilation time. + + Raises: + ValueError: If `computation_shape` or `computation_shape` are invalid. """ def __new__(cls, iterations_per_loop=2, - num_shards=2, + num_shards=None, + computation_shape=None, per_host_input_for_training=True, tpu_job_name=None, initial_infeed_sleep_secs=None): @@ -78,7 +99,22 @@ class TPUConfig( 'TPUConfig iterations_per_loop') # Check num_shards. - util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') + if num_shards is not None: + util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') + + # Check computation_shape + if computation_shape is not None and len(computation_shape) != 3: + raise ValueError( + 'computation_shape must be a list with length 3 or None; got {}'. + format(str(computation_shape))) + + if computation_shape is not None: + computation_shape_array = np.asarray(computation_shape, dtype=np.int32) + # This prevents any computation being replicated across multiple hosts, so + # that each host feeds the same number of computations. + if any(computation_shape_array < 1) or any(computation_shape_array > 2): + raise ValueError('computation_shape elements can only be 1 or 2; got ' + 'computation_shape={}'.format(computation_shape)) # Check initial_infeed_sleep_secs. if initial_infeed_sleep_secs: @@ -91,6 +127,7 @@ class TPUConfig( cls, iterations_per_loop=iterations_per_loop, num_shards=num_shards, + computation_shape=computation_shape, per_host_input_for_training=per_host_input_for_training, tpu_job_name=tpu_job_name, initial_infeed_sleep_secs=initial_infeed_sleep_secs) @@ -111,8 +148,7 @@ class RunConfig(run_config_lib.RunConfig): evaluation_master: a string. The address of the master to use for eval. Defaults to master if not set. master: a string. The address of the master to use for training. - tf_random_seed: an int. Sets the TensorFlow random seed. Defaults to None, - which initializes it randomly based on the environment. + **kwargs: keyword config parameters. """ super(RunConfig, self).__init__(**kwargs) self._tpu_config = tpu_config or TPUConfig() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py index 60884aa32f932413b49ea2193a145828489ea04c..37ef3dbe1e66efe18b13ab9153ee346c08b9774a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py @@ -43,6 +43,16 @@ class TPURunConfigTest(test.TestCase): tpu_config_lib.RunConfig( tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0)) + def test_fail_with_invalid_computation_shape(self): + with self.assertRaisesRegexp(ValueError, + 'computation_shape must be a list with length' + ' 3 or None'): + tpu_config_lib.TPUConfig(computation_shape=[2, 1]) + + with self.assertRaisesRegexp(ValueError, + 'computation_shape elements can only be'): + tpu_config_lib.TPUConfig(computation_shape=[1, 3, 1]) + class TPURunConfigMasterTest(test.TestCase): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c46ea741ea64ca37089431f8ed66cad7bc31fb --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -0,0 +1,510 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""TPU system metdata and associated tooling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from contextlib import contextmanager +import copy + +import numpy as np + +from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment +from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.platform import tf_logging as logging + + +_DEFAULT_JOB_NAME = 'tpu_worker' +_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' +_LOCAL_MASTERS = ('', 'local') + + +class _TPUContext(object): + """A context holds immutable states of TPU computation. + + This immutable object holds TPUEstimator config, train/eval batch size, and + `TPUEstimator.use_tpu`, which is expected to be passed around. It also + provides utility functions, basded on the current state, to determine other + information commonly required by TPU computation, such as TPU device names, + TPU hosts, shard batch size, etc. + + N.B. As `mode` is not immutable state in Estimator, but essential to + distinguish between TPU training and evaluation, a common usage for + _TPUContext with `mode` is as follows: + ``` + with _ctx.with_mode(mode) as ctx: + if ctx.is_running_on_cpu(): + ... + ``` + """ + + def __init__(self, config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu): + self._config = config + self._train_batch_size = train_batch_size + self._eval_batch_size = eval_batch_size + self._predict_batch_size = predict_batch_size + self._use_tpu = use_tpu + self._model_parallelism_enabled = ( + use_tpu and config.tpu_config.computation_shape) + self._mode = None + + self._lazy_tpu_system_metadata_dict = {} # key by master address + self._lazy_device_assignment_dict = {} # key by master address + self._lazy_validation_dict = {} # key by ModeKeys + + def _assert_mode(self): + if self._mode is None: + raise RuntimeError( + '`mode` needs to be set via contextmanager `with_mode`.') + return self._mode + + @contextmanager + def with_mode(self, mode): + # NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries, + # such as _lazy_tpu_system_metadata_dict between new copy and the original + # one. Note that all lazy states stored in properties _lazy_foo are sort of + # immutable as they should be same for the process lifetime. + new_ctx = copy.copy(self) + new_ctx._mode = mode # pylint: disable=protected-access + yield new_ctx + + @property + def mode(self): + return self._assert_mode() + + def _get_master_address(self): + mode = self._assert_mode() + config = self._config + master = ( + config.master + if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master) + return master + + def _get_tpu_system_metadata(self): + """Gets the (maybe cached) TPU system metadata.""" + master = self._get_master_address() + tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) + if tpu_system_metadata is not None: + return tpu_system_metadata + + # pylint: disable=protected-access + tpu_system_metadata = ( + tpu_system_metadata_lib._query_tpu_system_metadata( + master, + run_config=self._config, + query_topology=self.model_parallelism_enabled)) + + self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata + return tpu_system_metadata + + def _get_device_assignment(self): + """Gets the (maybe cached) TPU device assignment.""" + master = self._get_master_address() + device_assignment = self._lazy_device_assignment_dict.get(master) + if device_assignment is not None: + return device_assignment + + tpu_system_metadata = self._get_tpu_system_metadata() + + device_assignment = tpu_device_assignment.device_assignment( + tpu_system_metadata.topology, + computation_shape=self._config.tpu_config.computation_shape, + num_replicas=self.num_replicas) + + logging.info('computation_shape: %s', + str(self._config.tpu_config.computation_shape)) + logging.info('num_replicas: %d', self.num_replicas) + logging.info('device_assignment.topology.device_coordinates: %s', + str(device_assignment.topology.device_coordinates)) + logging.info('device_assignment.core_assignment: %s', + str(device_assignment.core_assignment)) + + self._lazy_device_assignment_dict[master] = device_assignment + return device_assignment + + @property + def model_parallelism_enabled(self): + return self._model_parallelism_enabled + + @property + def device_assignment(self): + return (self._get_device_assignment() + if self._model_parallelism_enabled else None) + + @property + def num_of_cores_per_host(self): + metadata = self._get_tpu_system_metadata() + return metadata.num_of_cores_per_host + + @property + def num_cores(self): + metadata = self._get_tpu_system_metadata() + return metadata.num_cores + + @property + def num_of_replicas_per_host(self): + if self.model_parallelism_enabled: + return self.num_replicas // self.num_hosts + else: + return self.num_of_cores_per_host + + @property + def num_replicas(self): + num_cores_in_system = self.num_cores + + if self.model_parallelism_enabled: + computation_shape_array = np.asarray( + self._config.tpu_config.computation_shape, dtype=np.int32) + num_cores_per_replica = np.prod(computation_shape_array) + if num_cores_per_replica > num_cores_in_system: + raise ValueError( + 'The num of cores required by the model parallelism, specified by ' + 'TPUConfig.computation_shape, is larger than the total num of ' + 'TPU cores in the system. computation_shape: {}, num cores ' + 'in the system: {}'.format( + self._config.tpu_config.computation_shape, + num_cores_in_system)) + + if num_cores_in_system % num_cores_per_replica != 0: + raise RuntimeError( + 'The num of cores in the system ({}) is not divisible by the num ' + 'of cores ({}) required by the model parallelism, specified by ' + 'TPUConfig.computation_shape. This should never happen!'.format( + num_cores_in_system, num_cores_per_replica)) + + return num_cores_in_system // num_cores_per_replica + else: + return num_cores_in_system + + @property + def num_hosts(self): + metadata = self._get_tpu_system_metadata() + return metadata.num_hosts + + @property + def config(self): + return self._config + + def is_input_sharded_per_core(self): + """Return true if input_fn is invoked per-core (other than per-host).""" + mode = self._assert_mode() + return (mode == model_fn_lib.ModeKeys.TRAIN and + not self._config.tpu_config.per_host_input_for_training) + + def is_running_on_cpu(self, is_export_mode=False): + """Determines whether the input_fn and model_fn should be invoked on CPU. + + This API also validates user provided configuration, such as batch size, + according the lazy initialized TPU system metadata. + + Args: + is_export_mode: Indicates whether the current mode is for exporting the + model, when mode == PREDICT. Only with this bool, we could + tell whether user is calling the Estimator.predict or + Estimator.export_savedmodel, which are running on TPU and CPU + respectively. Parent class Estimator does not distingush these two. + + Returns: + bool, whether current input_fn or model_fn should be running on CPU. + + Raises: + ValueError: any configuration is invalid. + """ + + is_running_on_cpu = self._is_running_on_cpu(is_export_mode) + if not is_running_on_cpu: + self._validate_tpu_configuration() + return is_running_on_cpu + + def _is_running_on_cpu(self, is_export_mode): + """Determines whether the input_fn and model_fn should be invoked on CPU.""" + mode = self._assert_mode() + + if not self._use_tpu: + return True + + if mode != model_fn_lib.ModeKeys.PREDICT: + return False + + # There are actually 2 use cases when running with mode.PREDICT: prediction + # and saving the model. We run actual predictions on the TPU, but + # model export is run on the CPU. + if is_export_mode: + return True + + return False + + @property + def global_batch_size(self): + mode = self._assert_mode() + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + elif mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + elif mode == model_fn_lib.ModeKeys.PREDICT: + return self._predict_batch_size + else: + return None + + @property + def batch_size_for_input_fn(self): + """Returns the shard batch size for `input_fn`.""" + global_batch_size = self.global_batch_size + + if self.is_running_on_cpu(): + return global_batch_size + + # On TPU + if self.is_input_sharded_per_core(): + # We prohibit per core input sharding for the model parallelism case, + # therefore it is safe to use num_cores here. + return global_batch_size // self.num_cores + else: + return global_batch_size // self.num_hosts + + @property + def batch_size_for_model_fn(self): + """Returns the shard batch size for `model_fn`.""" + global_batch_size = self.global_batch_size + + if self.is_running_on_cpu(): + return global_batch_size + + # On TPU. always sharded per shard. + return global_batch_size // self.num_replicas + + @property + def master_job(self): + """Returns the job name to use to place TPU computations on. + + Returns: + A string containing the job name, or None if no job should be specified. + + Raises: + ValueError: If the user needs to specify a tpu_job_name, because we are + unable to infer the job name automatically, or if the user-specified job + names are inappropriate. + """ + run_config = self._config + # If the user specifies the tpu_job_name, use that. + if run_config.tpu_config.tpu_job_name: + return run_config.tpu_config.tpu_job_name + + # The tpu job is determined by the run_config. Right now, this method is + # required as tpu_config is not part of the RunConfig. + mode = self._assert_mode() + master = ( + run_config.evaluation_master + if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) + if master in _LOCAL_MASTERS: + return None + + if (not run_config.session_config or + not run_config.session_config.cluster_def.job): + return _DEFAULT_JOB_NAME + cluster_def = run_config.session_config.cluster_def + job_names = set([job.name for job in cluster_def.job]) + if _DEFAULT_JOB_NAME in job_names: + # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. + raise ValueError('Currently, tpu_worker is not an allowed job name.') + if len(job_names) == 1: + return cluster_def.job[0].name + if len(job_names) == 2: + if _DEFAULT_COORDINATOR_JOB_NAME in job_names: + job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) + return job_names.pop() + # TODO(b/67716447): Include more sophisticated heuristics. + raise ValueError( + 'Could not infer TPU job name. Please specify a tpu_job_name as part ' + 'of your TPUConfig.') + + @property + def tpu_host_placement_function(self): + """Returns the TPU host place function.""" + master = self.master_job + + def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name + assert _sentinal is None + if core_id is not None and host_id is not None: + raise RuntimeError( + 'core_id and host_id can have only one non-None value.') + + if master is None: + return '/replica:0/task:0/device:CPU:0' + else: + if core_id is not None: + host_id = core_id / self.num_of_cores_per_host + return '/job:%s/task:%d/device:CPU:0' % (master, host_id) + + return _placement_function + + @property + def tpu_device_placement_function(self): + """Returns a TPU device placement Fn.""" + master = self.master_job + job_device = '' if master is None else ('/job:%s' % master) + + def _placement_function(i): + if self.model_parallelism_enabled: + return self.device_assignment.tpu_device(replica=i, job=master) + else: + num_of_cores_per_host = self.num_of_cores_per_host + host_id = i / num_of_cores_per_host + ordinal_id = i % num_of_cores_per_host + return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id) + + return _placement_function + + @property + def tpu_ordinal_function(self): + """Returns the TPU ordinal fn.""" + + def _tpu_ordinal_function(index): + """Return the TPU ordinal associated with a shard. + + Required because the enqueue ops are placed on CPU. + + Args: + index: the shard index + + Returns: + The ordinal of the TPU device the shard's infeed should be placed on. + """ + if self.model_parallelism_enabled: + return self.device_assignment.tpu_ordinal(replica=index) + else: + return index % self.num_of_cores_per_host + + return _tpu_ordinal_function + + def _validate_tpu_configuration(self): + """Validates the configuration based on the TPU system metadata.""" + mode = self._assert_mode() + if self._lazy_validation_dict.get(mode): + return + + # All following information is obtained from TPU system metadata. + num_cores = self.num_cores + num_replicas = self.num_replicas + num_hosts = self.num_hosts + + if not num_cores: + tpu_system_metadata = self._get_tpu_system_metadata() + raise RuntimeError( + 'Cannot find any TPU cores in the system. Please double check ' + 'Tensorflow master address and TPU worker(s). Available devices ' + 'are {}.'.format(tpu_system_metadata.devices)) + + if self._config.tpu_config.num_shards: + user_provided_num_replicas = self._config.tpu_config.num_shards + if user_provided_num_replicas != num_replicas: + message = ( + 'TPUConfig.num_shards is not set correctly. According to TPU ' + 'system metadata for Tensorflow master ({}): num_replicas should ' + 'be ({}), got ({}). For non-model-parallelism, num_replicas should ' + 'be the total num of TPU cores in the system. For ' + 'model-parallelism, the total number of TPU cores should be ' + 'product(computation_shape) * num_replicas. Please set it ' + 'accordingly or leave it as `None`'.format( + self._get_master_address(), num_replicas, + user_provided_num_replicas)) + + raise ValueError(message) + + if mode == model_fn_lib.ModeKeys.TRAIN: + if self._train_batch_size % num_replicas != 0: + raise ValueError( + 'train batch size {} must be divisible by number of replicas {}' + .format(self._train_batch_size, num_replicas)) + + elif mode == model_fn_lib.ModeKeys.EVAL: + if self._eval_batch_size is None: + raise ValueError( + 'eval_batch_size in TPUEstimator constructor cannot be `None`' + 'if .evaluate is running on TPU.') + if self._eval_batch_size % num_replicas != 0: + raise ValueError( + 'eval batch size {} must be divisible by number of replicas {}' + .format(self._eval_batch_size, num_replicas)) + if num_hosts > 1: + raise ValueError( + 'TPUEstimator.evaluate should be running on single TPU worker. ' + 'got {}.'.format(num_hosts)) + else: + assert mode == model_fn_lib.ModeKeys.PREDICT + if self._predict_batch_size is None: + raise ValueError( + 'predict_batch_size in TPUEstimator constructor should not be ' + '`None` if .predict is running on TPU.') + if self._predict_batch_size % num_replicas != 0: + raise ValueError( + 'predict batch size {} must be divisible by number of replicas {}' + .format(self._predict_batch_size, num_replicas)) + if num_hosts > 1: + raise ValueError( + 'TPUEstimator.predict should be running on single TPU worker. ' + 'got {}.'.format(num_hosts)) + + # Record the state "validated" into lazy dictionary. + self._lazy_validation_dict[mode] = True + + +class _OneCoreTPUContext(_TPUContext): + """Special _TPUContext for one core usage.""" + + def __init__(self, config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu): + + super(_OneCoreTPUContext, self).__init__( + config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu) + + def _get_tpu_system_metadata(self): + """Gets the (maybe cached) TPU system metadata.""" + master = self._get_master_address() + tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) + if tpu_system_metadata is not None: + return tpu_system_metadata + + tpu_system_metadata = ( + tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access + num_cores=1, + num_hosts=1, + num_of_cores_per_host=1, + topology=None, + devices=[])) + + self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata + return tpu_system_metadata + + +def _get_tpu_context(config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu): + """Returns an instance of `_TPUContext`.""" + + if (config.tpu_config.num_shards == 1 and + config.tpu_config.computation_shape is None): + logging.warning( + 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' + 'Please fix as soon as possible (leaving num_shards as None.') + return _OneCoreTPUContext(config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu) + + return _TPUContext(config, train_batch_size, eval_batch_size, + predict_batch_size, use_tpu) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 2ae3a26a853bf4941ac3855ec525293b5a508a2a..ff53fe4f5d0e219f56d77d3476640bb023c7535a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function import collections -from contextlib import contextmanager import copy +import signal import threading import time import traceback @@ -29,21 +29,23 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.summary import summary_ops as contrib_summary from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config +from tensorflow.contrib.tpu.python.tpu import tpu_context from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.contrib.tpu.python.tpu import util as util_lib - from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 - +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -59,6 +61,7 @@ from tensorflow.python.training import evaluation from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.training import training_util +from tensorflow.python.util import tf_inspect _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. @@ -66,9 +69,15 @@ _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' + _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] -# TODO(b/65703635): Flip the value and remove all dead code. + +# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is +# only used for per-core based deployments. For per-host based pipelines, if a +# user returns a Dataset instance it will be automatically wrapped in a +# tf.while_loop (This can be disabled by returning features and labels +# explicitly). _WRAP_INPUT_FN_INTO_WHILE_LOOP = False @@ -138,234 +147,6 @@ def _increase_eval_step_op(iterations_per_loop): use_locking=True) -_DEFAULT_JOB_NAME = 'tpu_worker' -_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' -_LOCAL_MASTERS = ('', 'local') - - -class _TPUContext(object): - """A context holds immutable states of TPU computation. - - This immutable object holds TPUEstimator config, train/eval batch size, and - `TPUEstimator.use_tpu`, which is expected to be passed around. It also - provides utility functions, basded on the current state, to determine other - information commonly required by TPU computation, such as TPU device names, - TPU hosts, shard batch size, etc. - - N.B. As `mode` is not immutable state in Estimator, but essential to - distinguish between TPU training and evaluation, a common usage for - _TPUContext with `mode` is as follows: - ``` - with _ctx.with_mode(mode) as ctx: - if ctx.is_running_on_cpu(): - ... - ``` - """ - - def __init__(self, config, train_batch_size, eval_batch_size, use_tpu): - self._config = config - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size - self._use_tpu = use_tpu - self._num_shards_or_none = self._config.tpu_config.num_shards - self._mode = None - - def _assert_mode(self): - if self._mode is None: - raise RuntimeError( - '`mode` needs to be set via contextmanager `with_mode`.') - return self._mode - - @property - def num_of_cores_per_host(self): - num_cores = self.num_cores - return min(num_cores, 8) - - @contextmanager - def with_mode(self, mode): - new_ctx = copy.copy(self) # Shallow copy is enough. - new_ctx._mode = mode # pylint: disable=protected-access - yield new_ctx - - @property - def mode(self): - return self._assert_mode() - - @property - def num_cores(self): - # TODO(xiejw): Adds lazy num_shards initialization. - return self._num_shards_or_none - - @property - def num_hosts(self): - return self.num_cores // self.num_of_cores_per_host - - @property - def config(self): - return self._config - - def is_input_sharded_per_core(self): - """Return true if input_fn is invoked per-core (other than per-host).""" - self._assert_mode() - return (self._mode == model_fn_lib.ModeKeys.TRAIN and - not self._config.tpu_config.per_host_input_for_training) - - def is_running_on_cpu(self): - """Determines whether the input_fn and model_fn should be invoked on CPU.""" - mode = self._assert_mode() - return ((not self._use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or - (mode == model_fn_lib.ModeKeys.EVAL and - self._eval_batch_size is None)) - - @property - def global_batch_size(self): - mode = self._assert_mode() - if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: - raise RuntimeError('Internal error, EVAL on TPU is not enabled, but ' - '`global_batch_size` is called.') - return (self._train_batch_size - if mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size) - - @property - def batch_size_for_input_fn(self): - """Returns the shard batch size for `input_fn`.""" - mode = self._assert_mode() - # Special case for eval. - if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: - return None - if self.is_running_on_cpu(): - if mode == model_fn_lib.ModeKeys.TRAIN: - return self._train_batch_size - if mode == model_fn_lib.ModeKeys.EVAL: - return self._eval_batch_size - return None - - global_batch_size = ( - self._train_batch_size - if mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size) - # On TPU - if self.is_input_sharded_per_core(): - return global_batch_size // self.num_cores - else: - return global_batch_size // self.num_hosts - - @property - def batch_size_for_model_fn(self): - """Returns the shard batch size for `model_fn`.""" - mode = self._assert_mode() - # Special case for eval. - if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: - return None - if self.is_running_on_cpu(): - if mode == model_fn_lib.ModeKeys.TRAIN: - return self._train_batch_size - if mode == model_fn_lib.ModeKeys.EVAL: - return self._eval_batch_size - return None - - # On TPU. always sharded per core. - if mode == model_fn_lib.ModeKeys.TRAIN: - return self._train_batch_size // self.num_cores - else: - return self._eval_batch_size // self.num_cores - - @property - def master_job(self): - """Returns the job name to use to place TPU computations on. - - Returns: - A string containing the job name, or None if no job should be specified. - - Raises: - ValueError: If the user needs to specify a tpu_job_name, because we are - unable to infer the job name automatically, or if the user-specified job - names are inappropriate. - """ - run_config = self._config - # If the user specifies the tpu_job_name, use that. - if run_config.tpu_config.tpu_job_name: - return run_config.tpu_config.tpu_job_name - - # The tpu job is determined by the run_config. Right now, this method is - # required as tpu_config is not part of the RunConfig. - mode = self._assert_mode() - master = ( - run_config.evaluation_master - if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) - if master in _LOCAL_MASTERS: - return None - - if (not run_config.session_config or - not run_config.session_config.cluster_def.job): - return _DEFAULT_JOB_NAME - cluster_def = run_config.session_config.cluster_def - job_names = set([job.name for job in cluster_def.job]) - if _DEFAULT_JOB_NAME in job_names: - # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. - raise ValueError('Currently, tpu_worker is not an allowed job name.') - if len(job_names) == 1: - return cluster_def.job[0].name - if len(job_names) == 2: - if _DEFAULT_COORDINATOR_JOB_NAME in job_names: - job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) - return job_names.pop() - # TODO(b/67716447): Include more sophisticated heuristics. - raise ValueError( - 'Could not infer TPU job name. Please specify a tpu_job_name as part ' - 'of your TPUConfig.') - - @property - def tpu_host_placement_function(self): - """Returns the TPU host place function.""" - master = self.master_job - - def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name - assert _sentinal is None - if core_id is not None and host_id is not None: - raise RuntimeError( - 'core_id and host_id can have only one non-None value.') - - if master is None: - return '/replica:0/task:0/device:CPU:0' - else: - # This assumes that if using more than 8 shards, - # the job configuration varies 'task'. - if core_id is not None: - host_id = core_id / 8 - return '/job:%s/task:%d/device:CPU:0' % (master, host_id) - - return _placement_function - - @property - def tpu_device_placement_function(self): - master = self.master_job - job_device = '' if master is None else ('/job:%s' % master) - - def _placement_function(i): - return '%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8) - - return _placement_function - - @property - def tpu_ordinal_function(self): - """Returns the TPU ordinal fn.""" - - def _tpu_ordinal_function(index): - """Return the TPU ordinal associated with a shard. - - Required because the enqueue ops are placed on CPU. - - Args: - index: the shard index - - Returns: - The ordinal of the TPU device the shard's infeed should be placed on. - """ - return index % 8 - - return _tpu_ordinal_function - - class _SIGNAL(object): """Signal used to control the thread of infeed/outfeed. @@ -384,7 +165,8 @@ class TPUEstimatorSpec( 'train_op', 'eval_metrics', 'export_outputs', - 'scaffold_fn' + 'scaffold_fn', + 'host_call' ])): """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. @@ -410,6 +192,15 @@ class TPUEstimatorSpec( `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This function should not capture any Tensors in `model_fn`. + + `host_call` is a tuple of a `function` and a list or dictionary of `tensors` + to pass to that function and returns a list of Tensors. `host_call` currently + works for train() and evaluate(). The Tensors returned by the function is + executed on the CPU on every step, so there is communication overhead when + sending tensors from TPU to CPU. To reduce the overhead, try reducing the + size of the tensors. The `tensors` are concatenated along their major (batch) + dimension, and so must be >= rank 1. The `host_call` is useful for writing + summaries with @{tf.contrib.summary.create_file_writer}. """ def __new__(cls, @@ -419,10 +210,15 @@ class TPUEstimatorSpec( train_op=None, eval_metrics=None, export_outputs=None, - scaffold_fn=None): + scaffold_fn=None, + host_call=None): """Creates a validated `TPUEstimatorSpec` instance.""" + host_calls = {} if eval_metrics is not None: - _EvalMetrics.validate(eval_metrics) + host_calls['eval_metrics'] = eval_metrics + if host_call is not None: + host_calls['host_call'] = host_call + _OutfeedHostCall.validate(host_calls) return super(TPUEstimatorSpec, cls).__new__( cls, mode=mode, @@ -431,12 +227,23 @@ class TPUEstimatorSpec( train_op=train_op, eval_metrics=eval_metrics, export_outputs=export_outputs, - scaffold_fn=scaffold_fn) + scaffold_fn=scaffold_fn, + host_call=host_call) def as_estimator_spec(self): """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" - eval_metric_ops = _EvalMetrics.to_metric_metric_ops_for_cpu( - self.eval_metrics) + host_calls = {} + if self.eval_metrics is not None: + host_calls['eval_metrics'] = self.eval_metrics + if self.host_call is not None: + host_calls['host_call'] = self.host_call + host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls) + eval_metric_ops = None + if self.eval_metrics is not None: + eval_metric_ops = host_call_ret['eval_metrics'] + hooks = None + if self.host_call is not None: + hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, @@ -445,7 +252,10 @@ class TPUEstimatorSpec( train_op=self.train_op, eval_metric_ops=eval_metric_ops, export_outputs=self.export_outputs, - scaffold=scaffold) + scaffold=scaffold, + training_hooks=hooks, + evaluation_hooks=hooks, + prediction_hooks=hooks) class _OpQueueContext(object): @@ -467,12 +277,12 @@ class _OpQueueContext(object): def read_iteration_counts(self): while True: - signal = self._queue.get(block=True) - logging.debug('%s read signal %s', self._name, signal) - if signal == _SIGNAL.STOP: - logging.info('%s received signal, stopping.', self._name) + iterations = self._queue.get(block=True) + logging.debug('%s read iterations %s', self._name, iterations) + if iterations == _SIGNAL.STOP: + logging.info('%s received shutdown signal, stopping.', self._name) return - yield signal + yield iterations def join(self): logging.info('Shutting down %s thread.' % self._name) @@ -480,6 +290,22 @@ class _OpQueueContext(object): self._thread.join() +class _OpSignalOnceQueueContext(_OpQueueContext): + """Manages work queue and thread for a infeed/outfeed thread. + + This subclass only signals once. + """ + + def __init__(self, name, target, args): + super(_OpSignalOnceQueueContext, self).__init__(name, target, args) + self._has_signaled = False + + def send_next_batch_signal(self, iterations): + if not self._has_signaled: + self._queue.put(iterations) + self._has_signaled = True + + class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): """A Session hook setting up the TPU initialization, infeed, and outfeed. @@ -489,12 +315,19 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): dequeue. """ - def __init__(self, ctx, enqueue_ops, dequeue_ops=None): + def __init__(self, + ctx, + enqueue_ops, + dequeue_ops, + run_infeed_loop_on_coordinator=True): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops + + self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) + self._session_cancel_timer = None self._feed_error = None @@ -503,8 +336,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_op = [tpu.initialize_system(job=self._master_job)] - self._finalize_op = [tpu.shutdown_system(job=self._master_job)] + self._init_ops = [tpu.initialize_system(job=self._master_job)] + self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] + + summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() + self._init_ops.extend(summary_writer_init_ops) + # Get all the writer resources from the initializer, so we know what to + # flush. + for op in summary_writer_init_ops: + self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) def _log_error(self, session, error): """Log an infeed or outfeed error. @@ -516,8 +356,9 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): emitting a stack trace for the infeed. Args: - session: `tf.Session`, session to be terminated - error: exception that triggered logging. + session: `tf.Session`, session to be terminated error: exception that + triggered logging. + error: the Exception to log. """ logging.warning( '\n\n' @@ -538,7 +379,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): # for TPU computation waits for the infeed enqueue forever. Close the # Session to cancel the main thread Session.run execution. # - # However, sleep for 2 minutes before explicit closing to give some time + # We sleep for a few seconds before closing to give some time # for the TPU compilation error, if any, propagating, from TPU to CPU # host. Compilation errors should be reported by the main thread so that # the program can be interrupted and users can take action. Due to a race @@ -551,7 +392,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): # If the main session is still running, the infeed/outfeed errors are # legitimate, and should be logged. - if not self._finished: + if not self._finished and self._feed_error: logging.error('Feed error: %s', self._feed_error) logging.error('Closing session. A RuntimeError should follow.') session.close() @@ -569,15 +410,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): logging.info('%s thread starting after sleep', self._name) try: - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - for _ in queue_ctx.read_iteration_counts(): - session.run(self._enqueue_ops) - else: + if self._run_infeed_loop_on_coordinator: for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) session.run(self._enqueue_ops) - logging.debug('Infeed thread finished, shutting down.') + else: + for _ in queue_ctx.read_iteration_counts(): + session.run(self._enqueue_ops) + logging.info('Infeed thread finished, shutting down.') except Exception as e: # pylint: disable=broad-except self._log_error(session, e) @@ -588,40 +429,42 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): for i in xrange(steps): logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) session.run(self._dequeue_ops) + logging.info('Outfeed thread finished, shutting down.') except Exception as e: # pylint: disable=broad-except self._log_error(session, e) + def _create_infeed_controller(self, name, target, args): + return _OpQueueContext(name=name, target=target, args=args) + def after_create_session(self, session, coord): logging.info('Init TPU system') - session.run( - self._init_op, - options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) + session.run(self._init_ops, + options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) logging.info('Start infeed thread controller') - self._infeed_controller = _OpQueueContext( + self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) - if self._dequeue_ops is not None: - logging.info('Start outfeed thread controller') - self._outfeed_controller = _OpQueueContext( - name='OutfeedController', target=self._run_outfeed, args=(session,)) + logging.info('Start outfeed thread controller') + self._outfeed_controller = _OpQueueContext( + name='OutfeedController', target=self._run_outfeed, args=(session,)) def before_run(self, run_context): - if self._feed_error: - logging.warning('Feed error occurred, terminating session.') - run_context.request_stop() - return + self._feed_error = None + + # Wait for the cancellation timer to complete before continuing. + if self._session_cancel_timer: + self._session_cancel_timer.join() + self._session_cancel_timer = None iterations = run_context.session.run(self._iterations_per_loop_var) logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) self._infeed_controller.send_next_batch_signal(iterations) - if self._dequeue_ops is not None: - # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop. - logging.info('Dequeue next (%d) batch(es) of data from outfeed.', - iterations) - self._outfeed_controller.send_next_batch_signal(iterations) + logging.info('Dequeue next (%d) batch(es) of data from outfeed.', + iterations) + self._outfeed_controller.send_next_batch_signal(iterations) def end(self, session): if self._session_cancel_timer: @@ -632,12 +475,21 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): logging.info('Stop infeed thread controller') self._infeed_controller.join() - if self._dequeue_ops is not None: - logging.info('Stop output thread controller') - self._outfeed_controller.join() + logging.info('Stop output thread controller') + self._outfeed_controller.join() logging.info('Shutdown TPU system.') - session.run(self._finalize_op) + session.run(self._finalize_ops) + + +class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): + + def __init__(self, ctx, enqueue_ops, dequeue_ops): + super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( + ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False) + + def _create_infeed_controller(self, name, target, args): + return _OpSignalOnceQueueContext(name=name, target=target, args=args) class _TPUStopAtStepHook(session_run_hook.SessionRunHook): @@ -727,6 +579,47 @@ class _SetEvalIterationsHook(session_run_hook.SessionRunHook): self._iterations_per_loop_var.load(self._num_steps, session=session) +class _StoppingPredictHook(session_run_hook.SessionRunHook): + """Hook that requests stop according to the stopping signal in prediction.""" + + def __init__(self, scalar_stopping_signal): + self._scalar_stopping_signal = scalar_stopping_signal + + def begin(self): + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() + + def after_create_session(self, session, coord): + # This is not necessary as we do not run infeed enqueue and outfeed dequeue + # in side threads for prediction model. But it makes the + # TPUInfeedOutfeedSessionHook prints nice message. + self._iterations_per_loop_var.load(1, session=session) + + def before_run(self, run_context): + return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) + + def after_run(self, run_context, run_values): + _ = run_context + scalar_stopping_signal = run_values.results + if _StopSignals.should_stop(scalar_stopping_signal): + # NOTE(xiejw): In prediction, stopping signals are inserted for each + # batch. And we append one more batch to signal the system it should stop. + # The data flow might look like + # + # batch 0: images, labels, stop = 0 (user provideded) + # batch 1: images, labels, stop = 0 (user provideded) + # ... + # batch 99: images, labels, stop = 0 (user provideded) + # batch 100: images, labels, stop = 1 (TPUEstimator appended) + # + # where the final batch (id = 100) is appended by TPUEstimator, so we + # should drop it before returning the predictions to user. + # To achieve that, we throw the OutOfRangeError in after_run. Once + # Monitored Session sees this error in SessionRunHook.after_run, the + # "current" prediciton, i.e., batch with id=100, will be discarded + # immediately + raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') + + def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn, inputs_structure_recorder): """Generates infeed enqueue ops for per-core input_fn on a single host.""" @@ -738,11 +631,14 @@ def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn, per_host_sharded_inputs = [] for core_ordinal in range(num_cores_per_host): with ops.name_scope('ordinal_%d' % (core_ordinal)): - inputs = input_fn() - if isinstance(inputs, tuple): - features, labels = inputs - else: - features, labels = inputs, None + inputs = _Inputs.from_input_fn(input_fn()) + if inputs.is_dataset: + raise TypeError( + '`input_fn` returning `Dataset` is not yet supported in ' + 'per-Core input pipeline deployment yet. Please set ' + 'TPUConfig.per_host_input_for_training to True or return ' + '`features` and `labels` from `input_fn`') + features, labels = inputs.features_and_labels() inputs_structure_recorder.validate_and_record_structure( features, labels) @@ -765,36 +661,76 @@ def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn, def generate_per_host_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, batch_axis, device): + ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id): """Generates infeed enqueue ops for per-host input_fn on a single host.""" captured_infeed_queue = _CapturedObject() + hooks = [] + + with ops.device(device): + inputs = _Inputs.from_input_fn(input_fn()) + + is_dataset = inputs.is_dataset + if ctx.mode == model_fn_lib.ModeKeys.PREDICT: + if not is_dataset: + raise TypeError( + 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' + '`features` and `labels`.') + inputs = _InputsWithStoppingSignals( + dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn) + + if is_dataset: + hooks.append(inputs.dataset_initializer_hook()) + + # TODO(ylc): Refactoring the code to merge the tpu ordinal logic here and the + # _TPUContext.tpu_ordinal_function. We should either introduce another + # abstraction or a different helper method. + def _tpu_ordinal_function_impl(shard_index_in_host): + # We put both enqueue/dequeue op at tpu.core(0) in each replica. + replica = ctx.device_assignment.lookup_replicas( + host_id, (0, 0, 0))[shard_index_in_host] + return ctx.device_assignment.tpu_ordinal(replica=replica) + + if ctx.model_parallelism_enabled: + tpu_ordinal_function = _tpu_ordinal_function_impl + else: + tpu_ordinal_function = None + def enqueue_ops_fn(): with ops.device(device): - num_cores_per_host = ctx.num_of_cores_per_host - inputs = input_fn() - if isinstance(inputs, tuple): - features, labels = inputs - else: - features, labels = inputs, None - inputs_structure_recorder.validate_and_record_structure(features, labels) + num_of_replicas_per_host = ctx.num_of_replicas_per_host + # Convert user input to features and labels. If the user returns a + # dataset, it is initialized and the features and labels extracted via + # `dataset.iterator.get_next()` + features, labels = inputs.features_and_labels() + signals = inputs.signals() + + inputs_structure_recorder.validate_and_record_structure( + features, labels, signals) unsharded_tensor_list = ( inputs_structure_recorder.flatten_features_and_labels( - features, labels)) + features, labels, signals)) infeed_queue = tpu_feed.InfeedQueue( tuple_types=[t.dtype for t in unsharded_tensor_list], tuple_shapes=[t.shape for t in unsharded_tensor_list], shard_dimensions=batch_axis) captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_number_of_shards(num_cores_per_host) - + infeed_queue.set_number_of_shards(num_of_replicas_per_host) per_host_enqueue_ops = ( infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_tensor_list, placement_function=lambda x: device)) - return per_host_enqueue_ops + unsharded_tensor_list, + placement_function=lambda x: device, + tpu_ordinal_function=tpu_ordinal_function)) + if signals is None: + return per_host_enqueue_ops + else: + return { + 'ops': per_host_enqueue_ops, + 'signals': signals, + } - return enqueue_ops_fn, captured_infeed_queue + return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset class _InputPipeline(object): @@ -834,6 +770,7 @@ class _InputPipeline(object): self._feature_names = [] self._label_names = [] self._has_labels = False + self._signals_helper = None # Internal state. self._initialized = False @@ -841,7 +778,7 @@ class _InputPipeline(object): def has_labels(self): return self._has_labels - def validate_and_record_structure(self, features, labels): + def validate_and_record_structure(self, features, labels, signals=None): """Validates and records the structure of features` and `labels`.""" def _extract_key_names(tensor_or_dict): @@ -854,6 +791,10 @@ class _InputPipeline(object): feature_names = _extract_key_names(features) label_names = _extract_key_names(labels) + if signals is not None and self._signals_helper is None: + # Record signals helper. + self._signals_helper = _SignalsHelper(signals) + if self._initialized: # Verify the structure is same. The following should never happen. assert feature_names == self._feature_names, 'feature keys mismatched' @@ -866,7 +807,7 @@ class _InputPipeline(object): self._label_names = label_names self._has_labels = has_labels - def flatten_features_and_labels(self, features, labels): + def flatten_features_and_labels(self, features, labels, signals=None): """Flattens the `features` and `labels` to a single tensor list.""" flattened_inputs = [] if self._feature_names: @@ -882,6 +823,9 @@ class _InputPipeline(object): flattened_inputs.extend([labels[name] for name in self._label_names]) else: flattened_inputs.append(labels) + + if signals is not None: + flattened_inputs.extend(_SignalsHelper.as_tensor_list(signals)) return flattened_inputs def unflatten_features_and_labels(self, flattened_inputs): @@ -907,7 +851,11 @@ class _InputPipeline(object): else: expected_num_labels = 0 - expected_num_tensors = expected_num_features + expected_num_labels + expected_num_signals = ( + self._signals_helper.num_signals if self._signals_helper else 0) + + expected_num_tensors = ( + expected_num_features + expected_num_labels + expected_num_signals) if expected_num_tensors != len(flattened_inputs): raise ValueError( @@ -924,13 +872,20 @@ class _InputPipeline(object): if expected_num_labels == 0: unflattened_label = None elif self._label_names: - unflattened_label = dict( - zip(self._label_names, flattened_inputs[expected_num_features:])) + label_list = flattened_inputs[ + expected_num_features:expected_num_features + expected_num_labels] + unflattened_label = dict(zip(self._label_names, label_list)) else: # Single tensor case. unflattened_label = flattened_inputs[expected_num_features] - return unflattened_features, unflattened_label + signals = None + if expected_num_signals != 0: + tensor_list_for_signals = flattened_inputs[ + expected_num_features + expected_num_labels:] + signals = self._signals_helper.unflatten(tensor_list_for_signals) + + return _Inputs(unflattened_features, unflattened_label, signals=signals) def __init__(self, input_fn, batch_axis, ctx): """Constructor. @@ -958,25 +913,34 @@ class _InputPipeline(object): # While tf.while_loop is called, the body function, which invokes # `enqueue_fn` passed in, is called to construct the graph. So, input_fn # structure is recorded. - enqueue_ops = self._invoke_input_fn_and_record_structure() + enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = ( + self._invoke_input_fn_and_record_structure()) self._validate_input_pipeline() def dequeue_fn(): """dequeue_fn is used by TPU to retrieve the tensors.""" - values = self._infeed_queue.generate_dequeue_op() + # In the model-parallel case, both the host-side and device-side + # computations must agree on the core on which infeed takes place. We + # choose to perform infeed on logical core 0 of each replica. + with ops.device(tpu.core(0)): + values = self._infeed_queue.generate_dequeue_op() # The unflatten process uses the structure information recorded above. return self._inputs_structure_recorder.unflatten_features_and_labels( values) - return (enqueue_ops, dequeue_fn) + return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) def _invoke_input_fn_and_record_structure(self): """Deploys the input pipeline and record input structure.""" enqueue_ops = [] infeed_queues = [] + all_hooks = [] num_hosts = self._ctx.num_hosts tpu_host_placement_fn = self._ctx.tpu_host_placement_function + + run_infeed_loop_on_coordinator = True + if self._sharded_per_core: # Per-Core input pipeline deployment. # Invoke input pipeline for each core and placed on the corresponding @@ -990,6 +954,7 @@ class _InputPipeline(object): self._ctx, self._input_fn, self._inputs_structure_recorder)) if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + run_infeed_loop_on_coordinator = False enqueue_ops.append( _wrap_computation_in_while_loop( device=host_device, op_fn=enqueue_ops_fn)) @@ -1003,15 +968,32 @@ class _InputPipeline(object): host_device = tpu_host_placement_fn(host_id=host_id) with ops.device(host_device): with ops.name_scope('input_pipeline_task%d' % (host_id)): - enqueue_ops_fn, captured_infeed_queue = ( + enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( generate_per_host_enqueue_ops_fn_for_host( self._ctx, self._input_fn, self._inputs_structure_recorder, - self._batch_axis, host_device)) - - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + self._batch_axis, host_device, host_id)) + all_hooks.extend(hooks) + + # NOTE(xiejw): We dispatch here based on the return type of the + # users `input_fn`. + # + # 1. If input_fn returns a Dataset instance, we initialize the + # iterator outside of tf.while_loop, and call the iterator.get_next + # inside tf.while_loop. This should be always safe. + # + # 2. If input_fn returns (features, labels), it is too late to wrap + # them inside tf.while_loop, as resource initialization cannot be + # handled in TF control flow properly. In this case, we will use + # python loop to enqueue the data into TPU system. This may be + # slow compared to the previous case. + if is_dataset: + run_infeed_loop_on_coordinator = False + wrap_fn = ( + _wrap_computation_in_while_loop + if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else + _wrap_computation_in_while_loop_with_stopping_signals) enqueue_ops.append( - _wrap_computation_in_while_loop( - device=host_device, op_fn=enqueue_ops_fn)) + wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) else: enqueue_ops.append(enqueue_ops_fn()) infeed_queues.append(captured_infeed_queue.get()) @@ -1019,7 +1001,7 @@ class _InputPipeline(object): # dequeue is dtypes and types. So, any one can be used. Here, grab the # first one. self._infeed_queue = infeed_queues[0] - return enqueue_ops + return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator def _validate_input_pipeline(self): # Perform some sanity checks to log user friendly information. We should @@ -1076,15 +1058,18 @@ class _ModelFnWrapper(object): infeed dequeue channel. Returns: - A Fn representing the train step for TPU. + A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn + representing the train step for TPU. """ + host_call = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() def train_step(loss): """Training step function for use inside a while loop.""" del loss # unused; required in function signature. - features, labels = dequeue_fn() + inputs = dequeue_fn() + features, labels = inputs.features_and_labels() estimator_spec = self._verify_estimator_spec( self._call_model_fn(features, labels)) @@ -1095,10 +1080,18 @@ class _ModelFnWrapper(object): else: captured_scaffold_fn.capture(None) + # We must run train_op to update the variables prior to running the + # outfeed. with ops.control_dependencies([train_op]): - return array_ops.identity(loss) + host_call_outfeed_ops = [] + if (isinstance(estimator_spec, TPUEstimatorSpec) and + estimator_spec.host_call is not None): + host_call.record({'host_call': estimator_spec.host_call}) + host_call_outfeed_ops = host_call.create_enqueue_op() + with ops.control_dependencies(host_call_outfeed_ops): + return array_ops.identity(loss) - return train_step, captured_scaffold_fn + return train_step, host_call, captured_scaffold_fn def convert_to_single_tpu_eval_step(self, dequeue_fn): """Converts user provided model_fn` as a single eval step on TPU. @@ -1123,15 +1116,16 @@ class _ModelFnWrapper(object): infeed dequeue channel. Returns: - A tuple of eval_fn and eval_metrics. The eval_fn representing the eval - step for TPU. and eval_metrics is an `_EvalMetrics` instance. + A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn + representing the eval step for TPU. """ - eval_metrics = _EvalMetrics(self._ctx) + host_calls = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() def eval_step(total_loss): """Evaluation step function for use inside a while loop.""" - features, labels = dequeue_fn() + inputs = dequeue_fn() + features, labels = inputs.features_and_labels() tpu_estimator_spec = self._call_model_fn(features, labels) if not isinstance(tpu_estimator_spec, TPUEstimatorSpec): @@ -1141,15 +1135,68 @@ class _ModelFnWrapper(object): loss = tpu_estimator_spec.loss captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) - eval_metrics.record(tpu_estimator_spec) - outfeed_ops = tpu_ops.outfeed_enqueue_tuple(eval_metrics.outfeed_tensors) - - with ops.control_dependencies([outfeed_ops]): + to_record = {} + to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics + if tpu_estimator_spec.host_call is not None: + # We assume that evaluate won't update global step, so we don't wrap + # this host_call. + to_record['host_call'] = tpu_estimator_spec.host_call + host_calls.record(to_record) + + with ops.control_dependencies(host_calls.create_enqueue_op()): return math_ops.add(total_loss, loss) - return eval_step, eval_metrics, captured_scaffold_fn + return eval_step, host_calls, captured_scaffold_fn + + def convert_to_single_tpu_predict_step(self, dequeue_fn): + """Converts user provided model_fn` as a single predict step on TPU. + + Args: + dequeue_fn: The function to retrieve inputs, features and labels, from TPU + infeed dequeue channel. + + Returns: + A tuple of predict_fn, host_calls, and captured scaffold_fn. The + predict_fn representing the predict step for TPU. + """ + host_calls = _OutfeedHostCall(self._ctx) + captured_scaffold_fn = _CapturedObject() + + def predict_step(unused_scalar_stopping_signal): + """Evaluation step function for use inside a while loop.""" + inputs = dequeue_fn() + features, labels = inputs.features_and_labels() + stopping_signals = inputs.signals() + + assert stopping_signals is not None, ( + 'Internal Error: `signals` is missing.') - def _call_model_fn(self, features, labels): + tpu_estimator_spec = self._call_model_fn( + features, labels, is_export_mode=False) + if not isinstance(tpu_estimator_spec, TPUEstimatorSpec): + raise RuntimeError( + 'estimator_spec used by TPU prediction must have type' + '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) + + captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) + to_record = {} + identity_fn = lambda **kwargs: kwargs + # TODO(xiejw): Adds validation for prediction dictionrary. + # TODO(xiejw): Adds support for single tensor as predictions. + if not isinstance(tpu_estimator_spec.predictions, dict): + raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') + to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] + to_record['signals'] = [identity_fn, stopping_signals] + if tpu_estimator_spec.host_call is not None: + to_record['host_call'] = tpu_estimator_spec.host_call + host_calls.record(to_record) + + with ops.control_dependencies(host_calls.create_enqueue_op()): + return _StopSignals.as_scalar_stopping_signal(stopping_signals) + + return predict_step, host_calls, captured_scaffold_fn + + def _call_model_fn(self, features, labels, is_export_mode=True): """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(self._model_fn) kwargs = {} @@ -1180,7 +1227,7 @@ class _ModelFnWrapper(object): params[_BATCH_SIZE_KEY] = batch_size_for_model_fn estimator_spec = self._model_fn(features=features, **kwargs) - if (self._ctx.is_running_on_cpu() and + if (self._ctx.is_running_on_cpu(is_export_mode) and isinstance(estimator_spec, TPUEstimatorSpec)): # The estimator_spec will be passed to `Estimator` directly, which expects # type `EstimatorSpec`. @@ -1207,158 +1254,212 @@ class _ModelFnWrapper(object): return estimator_spec -class _EvalMetrics(object): - """Class wraps TPUEstimator.eval_metrics.""" +class _OutfeedHostCall(object): + """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec.""" def __init__(self, ctx): self._ctx = ctx - self._metric_fn = None - self._is_dict = False - self._tensor_keys = [] - self._tensors = [] - self._tensor_dtypes = [] - self._tensor_shapes = [] - self._recorded = False + self._names = [] + # All of these are dictionaries of lists keyed on the name. + self._host_fns = {} + self._tensor_keys = collections.defaultdict(list) + self._tensors = collections.defaultdict(list) + self._tensor_dtypes = collections.defaultdict(list) + self._tensor_shapes = collections.defaultdict(list) @staticmethod - def validate(eval_metrics): - """Validates the `eval_metrics` in `TPUEstimatorSpec`.""" - - if not isinstance(eval_metrics, (tuple, list)): - raise ValueError('eval_metrics should be tuple or list') - if len(eval_metrics) != 2: - raise ValueError('eval_metrics should have two elements.') - if not callable(eval_metrics[0]): - raise TypeError('eval_metrics[0] should be callable.') - if not isinstance(eval_metrics[1], (tuple, list, dict)): - raise ValueError('eval_metrics[1] should be tuple or list, or dict.') - - if isinstance(eval_metrics[1], (tuple, list)): - fn_args = util.fn_args(eval_metrics[0]) - if len(eval_metrics[1]) != len(fn_args): - raise RuntimeError( - 'In TPUEstimatorSpec.eval_metrics, length of tensors does not ' - 'match method args of metric_fn.') + def validate(host_calls): + """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`.""" + + for name, host_call in host_calls.items(): + if not isinstance(host_call, (tuple, list)): + raise ValueError('{} should be tuple or list'.format(name)) + if len(host_call) != 2: + raise ValueError('{} should have two elements.'.format(name)) + if not callable(host_call[0]): + raise TypeError('{}[0] should be callable.'.format(name)) + if not isinstance(host_call[1], (tuple, list, dict)): + raise ValueError('{}[1] should be tuple or list, or dict.'.format(name)) + + if isinstance(host_call[1], (tuple, list)): + fullargspec = tf_inspect.getfullargspec(host_call[0]) + fn_args = util.fn_args(host_call[0]) + # wrapped_hostcall_with_global_step uses varargs, so we allow that. + if fullargspec.varargs is None and len(host_call[1]) != len(fn_args): + raise RuntimeError( + 'In TPUEstimatorSpec.{}, length of tensors {} does not match ' + 'method args of the function, which takes {}.'.format( + name, len(host_call[1]), len(fn_args))) @staticmethod - def to_metric_metric_ops_for_cpu(eval_metrics): - """Converts `TPUEstimatorSpec.eval_metrics` to `eval_metric_ops` for CPU.""" - if not eval_metrics: - return None - - _EvalMetrics.validate(eval_metrics) + def create_cpu_hostcall(host_calls): + """Runs on the host_call on CPU instead of TPU when use_tpu=False.""" + + _OutfeedHostCall.validate(host_calls) + ret = {} + for name, host_call in host_calls.items(): + host_fn, tensors = host_call + if isinstance(tensors, (tuple, list)): + ret[name] = host_fn(*tensors) + else: + # Must be dict. + try: + ret[name] = host_fn(**tensors) + except TypeError as e: + logging.warning( + 'Exception while calling %s: %s. It is likely the tensors ' + '(%s[1]) do not match the ' + 'function\'s arguments', name, e, name) + raise e + return ret + + def record(self, host_calls): + """Records the host_call structure.""" + + for name, host_call in host_calls.items(): + host_fn, tensor_list_or_dict = host_call + self._names.append(name) + self._host_fns[name] = host_fn + + if isinstance(tensor_list_or_dict, dict): + for (key, tensor) in six.iteritems(tensor_list_or_dict): + self._tensor_keys[name].append(key) + self._tensors[name].append(tensor) + self._tensor_dtypes[name].append(tensor.dtype) + self._tensor_shapes[name].append(tensor.shape) + else: + # List or tuple. + self._tensor_keys[name] = None + for tensor in tensor_list_or_dict: + self._tensors[name].append(tensor) + self._tensor_dtypes[name].append(tensor.dtype) + self._tensor_shapes[name].append(tensor.shape) - metric_fn, tensors = eval_metrics + def create_enqueue_op(self): + """Create the op to enqueue the recorded host_calls. - if isinstance(tensors, (tuple, list)): - return metric_fn(*tensors) - else: - # Must be dict. - try: - return metric_fn(**tensors) - except TypeError as e: - logging.warning( - 'Exception while calling metric_fn for evalution: %s. ' - 'It is likely the tensors (eval_metrics[1]) do not match the ' - 'metric_fn arguments', e) - raise e - - def record(self, spec): - """Records the eval_metrics structure in `spec`.""" - if self._recorded: - raise RuntimeError('Eval metrics have been recorded already.') - - self._metric_fn, tensor_list_or_dict = spec.eval_metrics - - if isinstance(tensor_list_or_dict, dict): - self._is_dict = True - for (key, tensor) in six.iteritems(tensor_list_or_dict): - self._tensor_keys.append(key) - self._tensors.append(tensor) - self._tensor_dtypes.append(tensor.dtype) - self._tensor_shapes.append(tensor.shape) - else: - # List or tuple. - self._is_dict = False - self._tensors = tensor_list_or_dict - for tensor in tensor_list_or_dict: - self._tensor_dtypes.append(tensor.dtype) - self._tensor_shapes.append(tensor.shape) - self._recorded = True + Returns: + A list of enqueue ops, which is empty if there are no host calls. + """ + if not self._names: + return [] - @property - def outfeed_tensors(self): - if not self._recorded: - raise RuntimeError('Eval metrics have not been recorded yet') - return self._tensors + tensors = [] + # TODO(jhseu): Consider deduping tensors. + for name in self._names: + tensors.extend(self._tensors[name]) - def to_metric_metric_ops_for_tpu(self, dummy_update_op): - """Creates the eval_metric_ops now based on the TPU outfeed. + with ops.device(tpu.core(0)): + return [tpu_ops.outfeed_enqueue_tuple(tensors)] - `eval_metric_ops` is defined in `EstimatorSpec`. From all shards, tensors - are dequeued from outfeed and then concatenated (along batch size dimension) - to form global-like tensors. All global-like tensors are passed to the - metric fn. + def create_tpu_hostcall(self): + """Sends the tensors through outfeed and runs the host_fn on CPU. - Args: - dummy_update_op: A dummy update op. + The tensors are concatenated along dimension 0 to form a global tensor + across all shards. The concatenated function is passed to the host_fn and + executed on the first host. Returns: - A tuple of (`eval_metric_ops` and `update_ops`), where `update_ops` should - be invoked in Outfeed dequeue thread, which drive the outfeed dequeue and - update the state of metrics. + A dictionary mapping name to the return type of the host_call by that + name. Raises: RuntimeError: If outfeed tensor is scalar. """ + if not self._names: + return [] - num_cores = self._ctx.num_cores - + ret = {} # For each i, dequeue_ops[i] is a list containing the tensors from all # shards. This list is concatenated later. dequeue_ops = [] - for i in xrange(len(self._tensors)): - dequeue_ops.append([]) - - # Outfeed ops execute on each JF node. + tensor_dtypes = [] + tensor_shapes = [] + for name in self._names: + for _ in self._tensors[name]: + dequeue_ops.append([]) + for dtype in self._tensor_dtypes[name]: + tensor_dtypes.append(dtype) + for shape in self._tensor_shapes[name]: + tensor_shapes.append(shape) + + # Outfeed ops execute on each replica's first logical core. Note: we must + # constraint it such that we have at most one outfeed dequeue and enqueue + # per replica. tpu_device_placement_fn = self._ctx.tpu_device_placement_function - for i in xrange(num_cores): + for i in xrange(self._ctx.num_replicas): with ops.device(tpu_device_placement_fn(i)): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( - dtypes=self._tensor_dtypes, shapes=self._tensor_shapes) + dtypes=tensor_dtypes, shapes=tensor_shapes) for j, item in enumerate(outfeed_tensors): dequeue_ops[j].append(item) - # It is assumed evaluation always happends on single host TPU system. So, + # Deconstruct dequeue ops. + dequeue_ops_by_name = {} + pos = 0 + for name in self._names: + dequeue_ops_by_name[name] = dequeue_ops[pos:pos+len(self._tensors[name])] + pos += len(self._tensors[name]) + + # It is assumed evaluation always happens on single host TPU system. So, # place all ops on tpu host if possible. + # + # TODO(jhseu): Evaluate whether this is right for summaries. with ops.device(self._ctx.tpu_host_placement_function(core_id=0)): - for i, item in enumerate(dequeue_ops): - if dequeue_ops[i][0].shape.ndims == 0: - raise RuntimeError( - 'All tensors outfed from TPU should preseve batch size ' - 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) - # TODO(xiejw): Allow users to specify the axis for batch size dimension. - dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) + for name in self._names: + dequeue_ops = dequeue_ops_by_name[name] + for i, item in enumerate(dequeue_ops): + if dequeue_ops[i][0].shape.ndims == 0: + raise RuntimeError( + 'All tensors outfed from TPU should preserve batch size ' + 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) + # TODO(xiejw): Allow users to specify the axis for batch size + # dimension. + dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) + + if self._tensor_keys[name] is not None: + # The user-provided eval_metrics[1] is a dict. + dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops)) + try: + ret[name] = self._host_fns[name](**dequeue_ops) + except TypeError as e: + logging.warning( + 'Exception while calling %s: %s. It is likely the tensors ' + '(%s[1]) do not match the ' + 'function\'s arguments', name, e, name) + raise e + else: + ret[name] = self._host_fns[name](*dequeue_ops) - if self._is_dict: - dequeue_ops = dict(zip(self._tensor_keys, dequeue_ops)) - try: - eval_metric_ops = self._metric_fn(**dequeue_ops) - except TypeError as e: - logging.warning( - 'Exception while calling metric_fn for evalution: %s. ' - 'It is likely the tensors (eval_metrics[1]) do not match the ' - 'metric_fn arguments', e) - raise e - else: - eval_metric_ops = self._metric_fn(*dequeue_ops) + return ret - eval_update_ops = [] - for k, v in eval_metric_ops.items(): - eval_metric_ops[k] = (v[0], dummy_update_op) - eval_update_ops.append(v[1]) - return eval_metric_ops, eval_update_ops +class _OutfeedHostCallHook(session_run_hook.SessionRunHook): + """Hook to run host calls when use_tpu=False.""" + + def __init__(self, tensors): + self._tensors = tensors + + def begin(self): + # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than + # create a separate hook to guarantee execution order, because summaries + # need to be initialized before the outfeed thread starts. + # TODO(jhseu): Make a wrapper hook instead? + self._init_ops = contrib_summary.summary_writer_initializer_op() + # Get all the writer resources from the initializer, so we know what to + # flush. + self._finalize_ops = [] + for op in self._init_ops: + self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) + + def after_create_session(self, session, coord): + session.run(self._init_ops) + + def before_run(self, run_context): + return basic_session_run_hooks.SessionRunArgs(self._tensors) + + def end(self, session): + session.run(self._finalize_ops) class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): @@ -1387,6 +1488,23 @@ class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): logging.info('examples/sec: %g', examples_per_sec) +class InstallSignalHandlerHook(session_run_hook.SessionRunHook): + """Change SIGINT (CTRL^C) handler to force quit the process. + + The default behavior often results in hanging processes. + The original handler is restored after training/evaluation. + """ + + def __init__(self): + self._signal_fn = signal.getsignal(signal.SIGINT) + + def before_run(self, run_context): + signal.signal(signal.SIGINT, signal.SIG_DFL) + + def end(self, session): + signal.signal(signal.SIGINT, self._signal_fn) + + class TPUEstimator(estimator_lib.Estimator): """Estimator with TPU support. @@ -1394,30 +1512,28 @@ class TPUEstimator(estimator_lib.Estimator): replicating inputs and models for each core, and returning to host periodically to run hooks. - If `use_tpu` is false, all training, evaluation, and predict are executed on - CPU. - - For training, TPUEstimator transforms a global batch size in params to a - per-shard batch size when calling the `input_fn` and `model_fn`. Users should - specify `train_batch_size` in constructor, and then get the batch size for - each shard in `input_fn` and `model_fn` by `params['batch_size']`. If - `TPUConfig.per_host_input_for_training` is `True`, `input_fn` is invoked per - host rather than per core. In this case, a global batch size is transformed a - per-host batch size in params for `input_fn`, but `model_fn` still gets - per-core batch size. - - For evaluation, if `eval_batch_size` is None, it is executed on CPU, even if - `use_tpu` is `True`. If `eval_batch_size` is not `None`, it is executed on - TPU, which is an experimental feature. In this case, `model_fn` should return - `TPUEstimatorSpec` instead of `EstimatorSpec`, which expects the - `eval_metrics` for TPU evaluation. - + TPUEstimator transforms a global batch size in params to a per-shard batch + size when calling the `input_fn` and `model_fn`. Users should specify + global batch size in constructor, and then get the batch size for each shard + in `input_fn` and `model_fn` by `params['batch_size']`. + For training, `model_fn` gets per-core batch size; `input_fn` may get + per-core or per-host batch size depending on + `per_host_input_for_training` in `TPUConfig`. + For evaluation, `model_fn` gets per-core batch size and `input_fn` get + per-host batch size. + + `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics` + for TPU evaluation. `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See `TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns a dict from metric string name to the result of calling a metric function, namely a `(metric_tensor, update_op)` tuple. + One can set `use_tpu` to `False` for testing. All training, evaluation, and + predict will be executed on CPU. `input_fn` and `model_fn` will receive + `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`. + Current limitations: 1. TPU evaluation only works on single host. @@ -1472,6 +1588,7 @@ class TPUEstimator(estimator_lib.Estimator): use_tpu=True, train_batch_size=None, eval_batch_size=None, + predict_batch_size=None, batch_axis=None): """Constructs an `TPUEstimator` instance. @@ -1490,18 +1607,17 @@ class TPUEstimator(estimator_lib.Estimator): basic python types. There are reserved keys for `TPUEstimator`, including 'batch_size'. use_tpu: A bool indicating whether TPU support is enabled. Currently, - - TPU training respects this bit. - - If true, see `eval_batch_size` for evaluate support. + - TPU training and evaluation respect this bit. - Predict still happens on CPU. train_batch_size: An int representing the global training batch size. TPUEstimator transforms this global batch size to a per-shard batch size, as params['batch_size'], when calling `input_fn` and `model_fn`. - Cannot be `None` if `use_tpu` is `True`. Must be divisible by - `config.tpu_config.num_shards`. - eval_batch_size: An int representing the global training batch size. - Currently, if `None`, evaluation is still executed on CPU (even when - `use_tpu` is True). In near future, `use_tpu` will be the only option to - switch between TPU/CPU evaluation. + Cannot be `None` if `use_tpu` is `True`. + Must be divisible by total number of replicas. + eval_batch_size: An int representing evaluation batch size. + Must be divisible by total number of replicas. + predict_batch_size: An int representing the prediction batch size. + Must be divisible by total number of replicas. batch_axis: A python tuple of int values describing how each tensor produced by the Estimator `input_fn` should be split across the TPU compute shards. For example, if your input_fn produced (images, labels) @@ -1525,30 +1641,24 @@ class TPUEstimator(estimator_lib.Estimator): _RESERVED_PARAMS_KEYS, params)) if use_tpu: + # Perform some very basic validations. More validations will be found in + # _TPUContext. if train_batch_size is None: raise ValueError('`train_batch_size` cannot be `None`') - if not isinstance(train_batch_size, int): - raise ValueError('`train_batch_size` must be an int') - if train_batch_size < 1: - raise ValueError('`train_batch_size` must be positive') - - # The specified batch size is the batch size for the entire computation. - # The input_fn and model_fn are called per-shard, so we want to calculate - # the per-shard batch size and pass that. - if train_batch_size % config.tpu_config.num_shards != 0: + util_lib.check_positive_integer(train_batch_size, 'train_batch_size') + + if (not config.tpu_config.per_host_input_for_training and + config.tpu_config.computation_shape): raise ValueError( - 'train batch size {} must be divisible by number of shards {}' - .format(train_batch_size, config.tpu_config.num_shards)) + 'Model parallelism only supports per host input for training. ' + 'Please adjust TPURunconfig.per_host_input_for_training.') if eval_batch_size is not None: - if config.tpu_config.num_shards > 8: - raise NotImplementedError( - 'TPU evaluation is only supported with one host.') + util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size') - if eval_batch_size % config.tpu_config.num_shards != 0: - raise ValueError( - 'eval batch size {} must be divisible by number of shards {}' - .format(eval_batch_size, config.tpu_config.num_shards)) + if predict_batch_size is not None: + util_lib.check_positive_integer(predict_batch_size, + 'predict_batch_size') # Verifies the model_fn signature according to Estimator framework. estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access @@ -1568,8 +1678,11 @@ class TPUEstimator(estimator_lib.Estimator): self._config.tpu_config.iterations_per_loop) # All properties passed to _TPUContext are immutable. - self._ctx = _TPUContext(self._config, train_batch_size, eval_batch_size, - use_tpu) + # pylint: disable=protected-access + self._ctx = tpu_context._get_tpu_context( + self._config, train_batch_size, + eval_batch_size, predict_batch_size, + use_tpu) def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -1657,7 +1770,9 @@ class TPUEstimator(estimator_lib.Estimator): if batch_size_for_input_fn is not None: kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn - if ctx.is_running_on_cpu(): + # For export_savedmodel, input_fn is never passed to Estimator. So, + # `is_export_mode` must be False. + if ctx.is_running_on_cpu(is_export_mode=False): with ops.device('/device:CPU:0'): return input_fn(**kwargs) @@ -1684,8 +1799,13 @@ class TPUEstimator(estimator_lib.Estimator): with self._ctx.with_mode(mode) as ctx: model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) - # TODO(jhseu): Move to PREDICT to TPU. - if ctx.is_running_on_cpu(): + # For export_savedmodel, input_fn is never passed to Estimator. So, + # if features is callable, it means it is the input_fn passed by + # TPUEstimator._call_input_fn. Then we can know if the mode == PREDICT, + # it implies, it is the .predict API, not export_savedmodel API. + is_export_mode = not callable(features) + + if ctx.is_running_on_cpu(is_export_mode=is_export_mode): logging.info('Running %s on CPU', mode) return model_fn_wrapper.call_without_tpu(features, labels) @@ -1695,22 +1815,31 @@ class TPUEstimator(estimator_lib.Estimator): input_fn = features input_holders = _InputPipeline(input_fn, batch_axis, ctx) - enqueue_ops, dequeue_fn = ( + enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) if mode == model_fn_lib.ModeKeys.TRAIN: - loss, scaffold = ( + loss, host_call, scaffold = ( _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) + host_ops = host_call.create_tpu_hostcall() + if host_ops is None: + host_ops = [] hooks = [ - TPUInfeedOutfeedSessionHook(ctx, enqueue_ops), + TPUInfeedOutfeedSessionHook( + ctx, + enqueue_ops, + host_ops, + run_infeed_loop_on_coordinator=( + run_infeed_loop_on_coordinator)), ExamplesPerSecondHook(ctx.global_batch_size), + InstallSignalHandlerHook(), training.LoggingTensorHook( { 'loss': array_ops.identity(loss), 'step': training.get_global_step() }, every_n_secs=30) - ] + ] + input_hooks summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) with ops.control_dependencies([loss]): update_ops = _sync_variables_ops() @@ -1725,40 +1854,114 @@ class TPUEstimator(estimator_lib.Estimator): train_op=control_flow_ops.group(*update_ops), scaffold=scaffold) - # Now eval. - total_loss, eval_metric_ops, scaffold = _eval_on_tpu_system( + if mode == model_fn_lib.ModeKeys.EVAL: + total_loss, host_calls, scaffold = _eval_on_tpu_system( + ctx, model_fn_wrapper, dequeue_fn) + iterations_per_loop_var = _create_or_get_iterations_per_loop() + mean_loss = math_ops.div(total_loss, + math_ops.cast( + iterations_per_loop_var, + dtype=total_loss.dtype)) + + # Creates a dummy metric update_op for all metrics. Estimator expects + # all metrics in eval_metric_ops have update_op and calls them one by + # one. The real metric update_ops are invoked in a separated thread. + # So, here give Estimator the dummy op for all metrics. + with ops.control_dependencies([mean_loss]): + # After TPU evaluation computation is done (the mean_loss tensor), + # reads all variables back from TPU and updates the eval step + # counter properly + internal_ops_to_run = _sync_variables_ops() + internal_ops_to_run.append( + _increase_eval_step_op(iterations_per_loop_var)) + with ops.control_dependencies(internal_ops_to_run): + dummy_update_op = control_flow_ops.no_op() + + host_call_ret = host_calls.create_tpu_hostcall() + eval_metric_ops = {} + eval_update_ops = [] + for k, v in host_call_ret['eval_metrics'].items(): + eval_metric_ops[k] = (v[0], dummy_update_op) + eval_update_ops.append(v[1]) + + if 'host_call' not in host_call_ret: + host_ops = [] + else: + host_ops = host_call_ret['host_call'] + hooks = [ + TPUInfeedOutfeedSessionHook( + ctx, + enqueue_ops, + eval_update_ops + host_ops, + run_infeed_loop_on_coordinator=( + run_infeed_loop_on_coordinator)), + ] + input_hooks + + return model_fn_lib.EstimatorSpec( + mode, + loss=mean_loss, + evaluation_hooks=hooks, + eval_metric_ops=eval_metric_ops, + scaffold=scaffold) + + # Predict + assert mode == model_fn_lib.ModeKeys.PREDICT + + dummy_predict_op, host_calls, scaffold = _predict_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) - iterations_per_loop_var = _create_or_get_iterations_per_loop() - mean_loss = math_ops.div(total_loss, - math_ops.cast( - iterations_per_loop_var, - dtype=total_loss.dtype)) - - # Creates a dummy metric update_op for all metrics. Estimator expects - # all metrics in eval_metric_ops have update_op and calls them one by - # one. The real metric update_ops are invoked in a separated thread. So, - # here give Estimator the dummy op for all metrics. - with ops.control_dependencies([mean_loss]): - # After TPU evaluation computation is done (the mean_loss tensor), - # reads all variables back from TPU and updates the eval step counter - # properly + with ops.control_dependencies([dummy_predict_op]): internal_ops_to_run = _sync_variables_ops() - internal_ops_to_run.append( - _increase_eval_step_op(iterations_per_loop_var)) with ops.control_dependencies(internal_ops_to_run): - dummy_update_op = control_flow_ops.no_op() + dummy_predict_op = control_flow_ops.no_op() + + # In train and evaluation, the main TPU program is passed to monitored + # training session to run. Infeed enqueue and outfeed dequeue are + # executed in side threads. This is not the configuration for + # prediction mode. + # + # For prediction, the Estimator executes the EstimatorSpec.predictions + # directly and yield the element (via generator) to call site. So, the + # outfeed based prediction must be passed to MonitoredSession directly. + # Other parts of the TPU execution are organized as follows. + # + # 1. All outfeed based Tensors must be grouped with predictions Tensors + # to form a single invocation. This avoid the issue we might trigger + # multiple outfeeds incorrectly. To achieve this, `host_call` is + # placed in control_dependencies of `stopping_signals`, and + # `stopping_signals` is passed into _StoppingPredictHook, which sets + # the `stopping_signals` as SessionRunArgs. MonitoredSession merges + # all SessionRunArgs with the fetch in session.run together. + # + # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue) + # are grouped together. They will be launched once and only once in + # side threads and they quit naturally according to the SAME stopping + # condition. + enqueue_ops.append(dummy_predict_op) + + host_call_ret = host_calls.create_tpu_hostcall() + if 'host_call' not in host_call_ret: + host_ops = [] + else: + host_ops = host_call_ret['host_call'] + + predictions = host_call_ret['predictions'] + stopping_signals = host_call_ret['signals'] + + with ops.control_dependencies(host_ops): + host_ops = [] # Empty, we do do not need it anymore. + scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal( + stopping_signals) - eval_metric_ops, eval_update_ops = ( - eval_metric_ops.to_metric_metric_ops_for_tpu(dummy_update_op)) hooks = [ - TPUInfeedOutfeedSessionHook(ctx, enqueue_ops, eval_update_ops), - ] + _StoppingPredictHook(scalar_stopping_signal), + TPUInfeedOutfeedSessionHookForPrediction(ctx, enqueue_ops, + host_ops), + ] + input_hooks return model_fn_lib.EstimatorSpec( mode, - loss=mean_loss, - evaluation_hooks=hooks, - eval_metric_ops=eval_metric_ops, + prediction_hooks=hooks, + predictions=predictions, scaffold=scaffold) return _model_fn @@ -1766,10 +1969,9 @@ class TPUEstimator(estimator_lib.Estimator): def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - num_cores = ctx.num_cores iterations_per_loop_var = _create_or_get_iterations_per_loop() - single_tpu_eval_step, eval_metric_ops, captured_scaffold_fn = ( + single_tpu_eval_step, host_calls, captured_scaffold_fn = ( model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)) def multi_tpu_eval_steps_on_single_shard(): @@ -1781,19 +1983,19 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard( multi_tpu_eval_steps_on_single_shard, inputs=[], - num_shards=num_cores, - outputs_from_all_shards=False) + num_shards=ctx.num_replicas, + outputs_from_all_shards=False, + device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) - return loss, eval_metric_ops, scaffold + return loss, host_calls, scaffold def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - num_cores = ctx.num_cores iterations_per_loop_var = _create_or_get_iterations_per_loop() - single_tpu_train_step, captured_scaffold_fn = ( + single_tpu_train_step, host_call, captured_scaffold_fn = ( model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) def multi_tpu_train_steps_on_single_shard(): @@ -1805,11 +2007,40 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard( multi_tpu_train_steps_on_single_shard, inputs=[], + num_shards=ctx.num_replicas, + outputs_from_all_shards=False, + device_assignment=ctx.device_assignment) + + scaffold = _get_scaffold(captured_scaffold_fn) + return loss, host_call, scaffold + + +def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): + """Executes `model_fn_wrapper` multiple times on all TPU shards.""" + num_cores = ctx.num_cores + + single_tpu_predict_step, host_calls, captured_scaffold_fn = ( + model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)) + + def multi_tpu_predict_steps_on_single_shard(): + + def cond(scalar_stopping_signal): + return math_ops.logical_not( + _StopSignals.should_stop(scalar_stopping_signal)) + + inputs = [_StopSignals.NON_STOPPING_SIGNAL] + outputs = training_loop.while_loop( + cond, single_tpu_predict_step, inputs=inputs, name=b'loop') + return outputs + + (dummy_predict_op,) = tpu.shard( + multi_tpu_predict_steps_on_single_shard, + inputs=[], num_shards=num_cores, outputs_from_all_shards=False) scaffold = _get_scaffold(captured_scaffold_fn) - return loss, scaffold + return dummy_predict_op, host_calls, scaffold def _wrap_computation_in_while_loop(device, op_fn): @@ -1830,6 +2061,29 @@ def _wrap_computation_in_while_loop(device, op_fn): parallel_iterations=1) +def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn): + """Wraps the ops generated by `op_fn` in tf.while_loop.""" + + def cond(scalar_stopping_signal): + return math_ops.logical_not( + _StopSignals.should_stop(scalar_stopping_signal)) + + def computation(unused_scalar_stopping_signal): + return_value = op_fn() + execute_ops = return_value['ops'] + signals = return_value['signals'] + with ops.control_dependencies(execute_ops): + return _StopSignals.as_scalar_stopping_signal(signals) + + # By setting parallel_iterations=1, the parallel execution in while_loop is + # basically turned off. + with ops.device(device): + return control_flow_ops.while_loop( + cond, + computation, [_StopSignals.NON_STOPPING_SIGNAL], + parallel_iterations=1) + + def _validate_tpu_training_graph(): """Validate graph before running distributed training. @@ -1920,3 +2174,194 @@ class _CapturingContext(control_flow_ops.ControlFlowContext): def __exit__(self, _, __, ___): # pylint: disable=invalid-name self._g._set_control_flow_context(self._old) # pylint: disable=protected-access + + +class _Inputs(object): + """A data structure representing the input_fn returned values. + + This also supports the returned value from input_fn as `Dataset`. + """ + + def __init__(self, features=None, labels=None, dataset=None, signals=None): + if dataset is not None and (features is not None or labels is not None or + signals is not None): + raise RuntimeError('Internal Error: Either (features and labels) or ' + 'dataset should be provided, not both. Please file ' + 'bug') + + self._features = features + self._labels = labels + self._signals = signals + + self._dataset = dataset + self._iterator = None + + @staticmethod + def from_input_fn(return_values): + """Returns an `_Inputs` instance according to `input_fn` return value.""" + if isinstance(return_values, dataset_ops.Dataset): + dataset = return_values + return _Inputs(dataset=dataset) + + features, labels = _Inputs._parse_inputs(return_values) + return _Inputs(features, labels) + + @staticmethod + def _parse_inputs(return_values): + if isinstance(return_values, tuple): + features, labels = return_values + else: + features, labels = return_values, None + return features, labels + + @property + def is_dataset(self): + """Returns True if the return value from input_fn is Dataset.""" + return self._dataset is not None + + def dataset_initializer_hook(self): + """Returns a `SessionRunHook` to initialize this dataset. + + This must be called before `features_and_labels`. + """ + iterator = self._dataset.make_initializable_iterator() + # pylint: disable=protected-access + hook = estimator_lib._DatasetInitializerHook(iterator) + self._iterator = iterator + return hook + + def features_and_labels(self): + """Gets `features` and `labels`.""" + if self.is_dataset: + return _Inputs._parse_inputs(self._iterator.get_next()) + + return (self._features, self._labels) + + def signals(self): + return self._signals + + @property + def dataset(self): + return self._dataset + + +# TODO(xiejw): Extend this to support final partial batch. +class _InputsWithStoppingSignals(_Inputs): + """Inputs with `_StopSignals` inserted into the dataset.""" + + def __init__(self, dataset, batch_size): + + assert dataset is not None + + user_provided_dataset = dataset.map( + _InputsWithStoppingSignals.insert_stopping_signal( + stop=False, batch_size=batch_size)) + final_batch_dataset = dataset.take(1).map( + _InputsWithStoppingSignals.insert_stopping_signal( + stop=True, batch_size=batch_size)) + dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2) + + super(_InputsWithStoppingSignals, self).__init__(dataset=dataset) + self._current_inputs = None + + def features_and_labels(self): + if self._current_inputs is not None: + raise RuntimeError( + 'Internal Error: The previous inputs have not been properly ' + 'consumed. First call features_and_labels, then call signals.') + + inputs_with_signals = self._iterator.get_next() + features = inputs_with_signals['features'] + labels = inputs_with_signals.get('labels') + + self._current_inputs = inputs_with_signals + return features, labels + + def signals(self): + """Returns the `Signals` from `_Inputs`.""" + if self._current_inputs is None: + raise RuntimeError( + 'Internal Error: The current inputs have not been properly ' + 'generated. First call features_and_labels, then call signals.') + signals = self._current_inputs['signals'] + self._current_inputs = None + return signals + + @staticmethod + def insert_stopping_signal(stop, batch_size): + """Inserts stopping_signal into dataset via _map_fn. + + Here we change the data structure in the dataset, such that the return value + is a dictionary now and `features`, `labels`, and `signals` are three + distinguished keys in that dict. This provides a better structure, which + eases the process to decompose the inputs (see `features_and_labels`). + + Args: + stop: bool, state of current stopping signals. + batch_size: int, batch size. + + Returns: + A map_fn passed to dataset.map API. + """ + + def _map_fn(*args): + features, labels = _Inputs._parse_inputs(args) + new_input_dict = {} + new_input_dict['features'] = features + if labels is not None: + new_input_dict['labels'] = labels + new_input_dict['signals'] = _StopSignals( + stop=stop, batch_size=batch_size).as_dict() + return new_input_dict + + return _map_fn + + +class _StopSignals(object): + """Signals class holding all logic to handle TPU stopping condition.""" + + NON_STOPPING_SIGNAL = 0.0 + STOPPING_SIGNAL = 1.0 + + def __init__(self, stop, batch_size): + self._stop = stop + self._batch_size = batch_size + + def as_dict(self): + shape = [self._batch_size, 1] + dtype = dtypes.float32 + + if self._stop: + stopping = array_ops.ones(shape=shape, dtype=dtype) + else: + stopping = array_ops.zeros(shape=shape, dtype=dtype) + + return {'stopping': stopping} + + @staticmethod + def as_scalar_stopping_signal(signals): + return array_ops.identity(signals['stopping'][0][0]) + + @staticmethod + def should_stop(scalar_stopping_signal): + return scalar_stopping_signal >= _StopSignals.STOPPING_SIGNAL + + +class _SignalsHelper(object): + """A general helper class to handle common signals manipulation.""" + + def __init__(self, signals): + self._signal_keys = [] + for key in sorted(signals.iterkeys()): + self._signal_keys.append(key) + + @property + def num_signals(self): + return len(self._signal_keys) + + def unflatten(self, tensor_list): + return dict(zip(self._signal_keys, tensor_list)) + + @staticmethod + def as_tensor_list(signals): + return [signals[key] for key in sorted(signals.iterkeys())] diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py index f8ba7d45e20b2f48e1409427665878df40a6db02..f5af03f33ca8f13af517007672e9ce0e12be6205 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py @@ -244,7 +244,8 @@ class ShardingPolicy(object): str(shapes), self.number_of_shards)) unsharded_shapes = [self._unshard_shape(s) for s in shapes] for i in xrange(self.number_of_shards - 1): - if unsharded_shapes[i] != unsharded_shapes[self.number_of_shards - 1]: + if not unsharded_shapes[i].is_compatible_with( + unsharded_shapes[self.number_of_shards - 1]): raise ValueError( "sharded shapes %s are not consistent shards of a full shape " "sharded %d ways along dimension %d" % ( diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..493d1848c072caa5254fc87c67badc2e99ec16ee --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -0,0 +1,155 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""TPU system metadata and associated tooling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re + +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import tf_logging as logging + +_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000 # 1 min +_RETRY_TIMES = 120 +_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins + +_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$') + +# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration, +# including num_cores and num_hosts. +_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ + 'num_cores', + 'num_hosts', + 'num_of_cores_per_host', + 'topology', + 'devices', +]) + + +def _query_tpu_system_metadata(master_address, run_config, + query_topology=False): + """Automatically detects the TPU system metadata in the system.""" + tpu_core_count = 0 + devices = [] + device_dict = collections.defaultdict(list) + + retry_count = 1 + while True: + logging.info('Querying Tensorflow master (%s) for TPU system metadata.', + master_address) + try: + with ops.Graph().as_default(): + with session_lib.Session( + master_address, + config=_get_session_config_with_timeout( + _PINGING_MASTER_TIMEOUT_IN_MS, run_config)) as sess: + devices = sess.list_devices() + for device in devices: + match = _TPU_DEVICE_REG.match(device.name) + if match: + host_id = match.group(1) + core_id = match.group(2) + device_dict[host_id].append(core_id) + tpu_core_count += 1 + break + except errors.DeadlineExceededError: + msg = ('Fail to connect Tensorflow master. It could be the TPU worker is ' + 'not ready (still under scheduling) or Tensorflow ' + 'master address is correct: got (%s).' % + (master_address)) + + # TODO(xiejw): For local or grpc master we might not need retry logic + # here. + if retry_count <= _RETRY_TIMES: + logging.warning('%s', msg) + logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) + retry_count += 1 + else: + raise ValueError(msg) + + num_of_cores_per_host = 0 + if tpu_core_count: + num_cores_per_host_set = set( + [len(core_ids) for core_ids in device_dict.values()]) + if len(num_cores_per_host_set) != 1: + raise RuntimeError( + 'TPU cores on each host is not same. This should not happen!. ' + 'devices: {}'.format(devices)) + num_of_cores_per_host = num_cores_per_host_set.pop() + + topology = None + if query_topology: + if not tpu_core_count: + raise RuntimeError( + 'Cannot find any TPU cores in the system (master address {}). ' + 'This usually means the master address is incorrect or the ' + 'TPU worker has some problems. Available devices: {}'.format( + master_address, devices)) + + topology = _obtain_topology(master_address, run_config) + + metadata = _TPUSystemMetadata( + num_cores=tpu_core_count, + num_hosts=len(device_dict), + num_of_cores_per_host=num_of_cores_per_host, + topology=topology, + devices=devices) + + if tpu_core_count: + logging.info('Found TPU system:') + logging.info('*** Num TPU Cores: %d', metadata.num_cores) + logging.info('*** Num TPU Workers: %d', metadata.num_hosts) + logging.info('*** Num TPU Cores Per Worker: %d', + metadata.num_of_cores_per_host) + logging.info('*** Available Devices: %s', metadata.devices) + else: + logging.info('Failed to find TPU: %s', metadata) + return metadata + + +def _obtain_topology(master_address, run_config): + try: + logging.info('Initializing TPU system (master: %s) to fetch topology ' + 'for model parallelism. This might take a while.', + master_address) + with ops.Graph().as_default(): + session_config = _get_session_config_with_timeout( + _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, run_config) + with session_lib.Session( + master_address, config=session_config) as sess: + topology = sess.run(tpu.initialize_system()) + return topology + except errors.DeadlineExceededError: + raise ValueError( + 'Fail to initialize TPU system with master (%s). ' + 'Please double check the TPU system is functional.' % ( + master_address)) + + +def _get_session_config_with_timeout(timeout_in_secs, run_config): + cluster_def = None + if run_config.session_config and run_config.session_config.cluster_def.job: + cluster_def = run_config.session_config.cluster_def + + config = config_pb2.ConfigProto( + operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) + return config diff --git a/tensorflow/contrib/tpu/tpu_estimator.md b/tensorflow/contrib/tpu/tpu_estimator.md index ca1255b16b1575d291df51dfde696b36c38359ae..4ef8f9eebdb165e5fe221be8670276bf943159b3 100644 --- a/tensorflow/contrib/tpu/tpu_estimator.md +++ b/tensorflow/contrib/tpu/tpu_estimator.md @@ -231,7 +231,7 @@ Refer to this link for all [Cloud TPU documentation](https://cloud.google.com/tp ### Profiling -You can profile the `worker` by using instructions as spcified in the [Cloud TPU Tools](https://cloud.google.com/tpu/docs/cloud-tpu-tools). +You can profile the `worker` by using instructions as specified in the [Cloud TPU Tools](https://cloud.google.com/tpu/docs/cloud-tpu-tools). ### Is `int64` supported? diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index cccaa2b833ee764921508a5b6d6affe0b8822ede..6db373d2d5e20ea7da449530b2730403c3bb64cc 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -26,6 +26,7 @@ py_library( "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", + "python/training/tensor_queue_dataset.py", "python/training/training.py", "python/training/tuner.py", ], @@ -285,6 +286,28 @@ py_test( ], ) +py_test( + name = "tensor_queue_dataset_test", + size = "large", + srcs = ["python/training/tensor_queue_dataset_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":training_py", + "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:random_seed", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py index 2a0ef0e6b3750b4f0464f1f4390819e1fc2c7872..dbdbb08a8252c799924812c83fff7f0631424761 100644 --- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py +++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py @@ -53,7 +53,7 @@ class BatchSequencesWithStatesTest(test.TestCase): sp_tensor1 = sparse_tensor.SparseTensor( array_ops.constant(ind1, dtypes.int64), array_ops.constant(val1, dtypes.int64), - array_ops.constant(shape1, dtypes.int64)) + array_ops.placeholder_with_default(shape1, shape=[2])) ind2 = np.array([ [0, 0, 1], [0, 1, 0], @@ -68,7 +68,7 @@ class BatchSequencesWithStatesTest(test.TestCase): sp_tensor2 = sparse_tensor.SparseTensor( array_ops.constant(ind2, dtypes.int64), array_ops.constant(val2, dtypes.int64), - array_ops.constant(shape2, dtypes.int64)) + array_ops.placeholder_with_default(shape2, shape=[3])) sp_tensor3 = sparse_tensor.SparseTensor( array_ops.constant([[1, 9], [2, 2], [2, 10]], dtypes.int64), array_ops.constant([7, 15, 2], dtypes.int64), @@ -320,6 +320,18 @@ class BatchSequencesWithStatesTest(test.TestCase): def testNotAMultiple(self): num_unroll = 3 # Not a divisor of value_length - # so padding would have been necessary. + + # Use placeholder_with_default in sequences to make sure we get runtime + # error instead of shape inference error + sequences = { + "seq1": array_ops.placeholder_with_default(self.sequences["seq1"], + shape=(None, 5)), + "seq2": array_ops.placeholder_with_default(self.sequences["seq2"], + shape=(None, 4, 2)), + "seq3": self.sequences["seq3"], + "seq4": self.sequences["seq4"], + } + with self.test_session() as sess: with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, ".*should be a multiple of: 3, but saw " @@ -330,7 +342,7 @@ class BatchSequencesWithStatesTest(test.TestCase): with coord.stop_on_exception(): next_batch = sqss.batch_sequences_with_states( input_key=self.key, - input_sequences=self.sequences, + input_sequences=sequences, input_context=self.context, input_length=3, initial_states=self.initial_states, @@ -493,6 +505,18 @@ class BatchSequencesWithStatesTest(test.TestCase): expected_seq4_batch2=expected_seq4_batch2) +class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest): + + def setUp(self): + self._prev_value = ops._USE_C_API + ops._USE_C_API = True + super(BatchSequencesWithStatesTestWithCApi, self).setUp() + + def tearDown(self): + super(BatchSequencesWithStatesTestWithCApi, self).tearDown() + ops._USE_C_API = self._prev_value + + class PaddingTest(test.TestCase): def testPaddingInvalidLengths(self): diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..409aba817c1ec37003eb98f000f6cf8918234c5d --- /dev/null +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py @@ -0,0 +1,200 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for Datasets and Iterators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util import nest as tf_nest + + +class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset): + """A `Dataset` that prepends a queue to another `Dataset`. + + A vector of handles to the queue is returned as the first component of + the associated iterator. This vector can be passed to + `enqueue_in_queue_dataset` to add new elements to the queue. + """ + + def __init__(self, input_dataset, batch_size, padded_shapes, padding_values): + """Initialize `PrependFromQueueAndPaddedBatchDataset`.""" + super(_PrependFromQueueAndPaddedBatchDataset, self).__init__() + if sparse.any_sparse(input_dataset.output_classes): + raise TypeError( + "Batching of padded sparse tensors is not currently supported") + self._input_dataset = input_dataset + self._batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + # pylint: disable=protected-access + if padded_shapes is None: + self._padded_shapes = nest.map_structure( + dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes) + else: + self._padded_shapes = nest.map_structure_up_to( + input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor, + padded_shapes) + padding_values = ( + padding_values if padding_values is not None else + dataset_ops._default_padding(input_dataset)) + self._padding_values = nest.map_structure_up_to( + input_dataset.output_shapes, dataset_ops._padding_value_to_tensor, + padding_values, input_dataset.output_types) + # pylint: enable=protected-access + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_dataset_ops.prepend_from_queue_and_padded_batch_dataset( + self._input_dataset._as_variant_tensor(), + batch_size=self._batch_size, + padded_shapes=[ + ops.convert_to_tensor(s, dtype=dtypes.int64) + for s in nest.flatten(self._padded_shapes) + ], + padding_values=nest.flatten(self._padding_values), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + # pylint: enable=protected-access + + @property + def output_classes(self): + return (ops.Tensor, self._input_dataset.output_classes) + + def _as_batch_shape(self, shape_like): + return tensor_shape.vector(None).concatenate( + tensor_util.constant_value_as_shape(shape_like)) + + @property + def output_shapes(self): + # First output is a variant representing the Queue + return (tensor_shape.vector(None), + nest.map_structure(self._as_batch_shape, self._padded_shapes)) + + @property + def output_types(self): + # First output is a variant representing the Queue + return (dtypes.variant, self._input_dataset.output_types) + + +def prepend_from_queue_and_padded_batch_dataset(batch_size, + padding_values=None, + padded_shapes=None): + """A transformation that prepends a queue to a `Dataset` and batches results. + + A vector of handles to the queue is returned as the first component of the + associated iterator. This vector can be passed to `enqueue_in_queue_dataset` + to add new elements to the queue. + + Below is an example of how this dataset might be used to split incoming + variable-length sequences into "head" and "rest" parts, where "rest" parts + are re-enqueued back into the dataset. A more realistic example would + perform some calculation on the "head" and modify some components of "rest" + with the result (before re-enqueueing). + + ```python + dataset = tf.data.Dataset.from_tensor_slices([2*x for x in range(10)]) + # Make a dataset of variable-length vectors and their lengths. + dataset = dataset.map(lambda count: (count, tf.ones((count,)))) + # Emit a queue we can prepend to, and counts/values as padded batch. + dataset = dataset.apply( + tf.contrib.training.prepend_from_queue_and_padded_batch_dataset( + batch_size=10)) + dataset = dataset.prefetch(1) + + iterator = dataset.make_one_shot_iterator() + queue, (count, padded_value) = iterator.get_next() + + # Split the padded_value into two pieces: head and rest + rest_indices = tf.squeeze(tf.where(count > 3), axis=1) + bound = tf.minimum(3, tf.reduce_max(count)) + value_head = padded_value[:, :bound] + count_rest = tf.gather(count - 3, rest_indices) + value_rest = tf.gather(padded_value[:, bound:], rest_indices) + queue_rest = tf.gather(queue, rest_indices) + enqueue_rest_op = tf.contrib.training.enqueue_in_queue_dataset( + queue_rest, (count_rest, value_rest)) + with tf.control_dependencies([enqueue_rest_op]): + calculation = fn(value_head) + + while True: # Will raise OutOfRange when finished with all pieces. + session.run(calculation) + ``` + + Args: + batch_size: `int64` scalar tensor. The batch size to use when performing + padded batching. + padding_values: (optional) Nested tuple of scalar tensors. If provided, + the structure and dtypes of padding_values should match that of + incoming dataset's `output_types`. + padded_shapes: (optional) Nested tuple of `int64` vector tensors. + If provided, the structure must match that of the incoming dataset's + `output_types`. If not provided, the incoming dataset's `output_shapes` + is used. Any unknown (`None` or `-1`) dimensions in the shapes are + treated as being unique per-batch: for each batch time, an unknown + dimension is replaced with the maximum given value of this dimension + across all tensors for the given component in the batch. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _PrependFromQueueAndPaddedBatchDataset( + dataset, + batch_size=batch_size, + padding_values=padding_values, + padded_shapes=padded_shapes) + + return _apply_fn + + +def enqueue_in_queue_dataset(queue, components): + """Enqueue components into queue from `PrependFromQueueAndPaddedBatchDataset`. + + The components' dtypes and shapes must be compatible with the `output_shapes` + attribute of the `dataset` created by + `prepend_from_queue_and_padded_batch_dataset`. This operation supports both + non-batched and batched modes. + + For more details, see the example in the docstring for + `prepend_from_queue_and_padded_batch_dataset`. + + Args: + queue: `variant` scalar or vector tensor. + The tensor emitted by the first component of the iterator associated with + `prepend_from_queue_and_padded_batch_dataset`. If this is a scalar, + then the `components` input tensors should not have a prepended batch + dimension. + components: Nested tuple of tensors, each with a leading batch dimension + if `queue` is a vector. The structure, dtypes, and shapes + (excluding batch dimension) must match the nested tuples + `dataset.output_types[1]` and `dataset.output_shapes[1]` (the non-queue + output types and shapes) of the `dataset` emitted by + the original `prepend_from_queue_and_padded_batch_dataset` call. + + Returns: + An `Operation` that enqueues `components` into the dataset(s) associated + with entries of `queue`. + """ + return gen_dataset_ops.enqueue_in_queue_dataset( + queue=queue, components=tf_nest.flatten(components)) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0338f409a203c232e63e99534a8f6d6a43fa661e --- /dev/null +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py @@ -0,0 +1,355 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TensorQueueDataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): + + def testNoEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + self.assertEqual((dtypes.variant, dtypes.int32), dataset.output_types) + self.assertAllEqual(([None],) * 2, + [x.as_list() for x in dataset.output_shapes]) + iterator = dataset.make_one_shot_iterator() + _, value = iterator.get_next() + self.assertEqual([0], self.evaluate(value)) + self.assertEqual([1], self.evaluate(value)) + self.assertEqual([2], self.evaluate(value)) + with self.assertRaisesOpError("End of sequence"): + self.evaluate(value) + + def testBatchedNoEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) + iterator = dataset.make_one_shot_iterator() + _, value = iterator.get_next() + self.assertAllEqual([0, 1], self.evaluate(value)) + self.assertAllEqual([2], self.evaluate(value)) + with self.assertRaisesOpError("End of sequence"): + self.evaluate(value) + + def testBatchedWithBiggerPaddingNoEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset( + batch_size=2, padded_shapes=[3])) + iterator = dataset.make_one_shot_iterator() + _, value = iterator.get_next() + self.assertAllEqual([[0, 0, 0], [1, 0, 0]], self.evaluate(value)) + self.assertAllEqual([[2, 0, 0]], self.evaluate(value)) + with self.assertRaisesOpError("End of sequence"): + self.evaluate(value) + + def testBatchedWithBiggerPaddingOneEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset( + batch_size=1, padded_shapes=[3])) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) + with self.test_session() as sess: + self.assertAllEqual([[0, 0, 0]], sess.run(value)) + value_1, _ = sess.run([value, enqueue_negative]) + self.assertAllEqual([[1, 0, 0]], value_1) + value_2, _ = sess.run([value, enqueue_negative]) + self.assertAllEqual([[-1, 0, 0]], value_2) + value_3 = sess.run(value) + self.assertAllEqual([[1, 0, 0]], value_3) + value_4, _ = sess.run([value, enqueue_negative]) + self.assertAllEqual([[2, 0, 0]], value_4) + value_5 = sess.run(value) + self.assertAllEqual([[-2, 0, 0]], value_5) + with self.assertRaisesOpError("End of sequence"): + sess.run(value) + + def testOneEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) + with self.test_session() as sess: + self.assertEqual([0], sess.run(value)) + value_1, _ = sess.run([value, enqueue_negative]) + self.assertEqual([1], value_1) + value_2, _ = sess.run([value, enqueue_negative]) + self.assertEqual([-1], value_2) + value_3 = sess.run(value) + self.assertEqual([1], value_3) + value_4, _ = sess.run([value, enqueue_negative]) + self.assertEqual([2], value_4) + value_5 = sess.run(value) + self.assertEqual([-2], value_5) + with self.assertRaisesOpError("End of sequence"): + sess.run(value) + + def testBatchedOneEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) + enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]], + array_ops.expand_dims( + value[0], axis=0)) + with self.test_session() as sess: + value_0, _ = sess.run([value, enqueue_negative]) + self.assertAllEqual([0, 1], value_0) + value_1, _ = sess.run([value, enqueue_zeroth]) + self.assertAllEqual([0, -1], value_1) + value_2, _ = sess.run([value, enqueue_negative]) + self.assertAllEqual([0, 2], value_2) + self.assertAllEqual([0, -2], sess.run(value)) + with self.assertRaisesOpError("End of sequence"): + sess.run(value) + + def testManyEnqueue(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + enqueue_many_more = [ + tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i) + for i in range(1000) + ] + with self.test_session() as sess: + value_0, _ = sess.run((value, enqueue_many_more)) + self.assertEqual([0], value_0) + rest = [] + for _ in range(1000): + rest.append(sess.run(value)) + self.assertEquals([[100 + i] for i in range(1000)], sorted(rest)) + # Going back to the original input. + value_1, _ = sess.run((value, enqueue_many_more)) + self.assertEqual(1, value_1) + rest = [] + for _ in range(1000): + rest.append(sess.run(value)) + self.assertEquals([[100 + i + 1] for i in range(1000)], sorted(rest)) + with self.assertRaisesOpError("End of sequence"): + sess.run(value) + + def testEnqueueWithPrefetch(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + # Prefetching will request additional values before they are + # available to the queue. + dataset = dataset.prefetch(buffer_size=3) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1) + with self.test_session() as sess: + i = 0 + while i < 4: + received, _ = sess.run((value, enqueue)) + if received.size > 0: + self.assertAllEqual([i], received) + i += 1 + received_last = False + while True: + try: + received = sess.run(value) + if received.size > 0: + self.assertAllEqual([4], received) + received_last = True + except errors.OutOfRangeError: + break + self.assertTrue(received_last) + + def testDatasetWithPaddedShapeSmallerThanInputFails(self): + dataset = dataset_ops.Dataset.from_tensor_slices([[0, 0, 0]]).repeat(None) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset( + batch_size=1, padded_shapes=[2])) + iterator = dataset.make_one_shot_iterator() + _, value = iterator.get_next() + with self.test_session() as sess: + with self.assertRaisesOpError( + r"Incompatible input shapes at component 0 between " + r"input dataset this dataset: \[3\] vs. \[2\]"): + sess.run(value) + + def testEnqueueWithIncompatibleInputsFailsWithInformativeError(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0]).repeat(None) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + iterator = dataset.make_one_shot_iterator() + queue_handle, value = iterator.get_next() + + enqueue_bad_structure = tqd.enqueue_in_queue_dataset( + queue_handle, (value, value)) + enqueue_bad_dtype = tqd.enqueue_in_queue_dataset(queue_handle, + np.array( + [1.0], + dtype=np.float32)) + enqueue_bad_shape_no_batch_dim = tqd.enqueue_in_queue_dataset( + queue_handle, ([1],)) + enqueue_bad_shape = tqd.enqueue_in_queue_dataset(queue_handle, + np.array( + [[1]], dtype=np.int32)) + + with self.test_session() as sess: + with self.assertRaisesOpError( + "mismatched number of tensors. Queue expects 1 tensors but " + "tried to insert 2"): + sess.run(enqueue_bad_structure) + with self.assertRaisesOpError(r"Expected component 0 to have batched " + r"shape \[1,...\], but saw shape: \[\]"): + sess.run(enqueue_bad_shape_no_batch_dim) + with self.assertRaisesOpError( + r"mismatched shapes at component 0. Attempted to insert tensor " + r"with shape \[1\] but queue expected shape: \[\]"): + sess.run(enqueue_bad_shape) + with self.assertRaisesOpError( + r"mismatched dtypes at component 0. Attempted to insert tensor " + r"of type float but queue expected type: int32"): + sess.run(enqueue_bad_dtype) + + def testEnqueueWithPaddedBatchFailsWithInformativeError(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) + with self.assertRaisesRegexp( + TypeError, r"Unable to create padding for field of type 'variant'"): + dataset.padded_batch(batch_size=10, padded_shapes=[1]) + + def testOneEnqueueWithPadding(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) + # Make a dataset of variable-length vectors and their lengths. + dataset = dataset.map( + lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) + # Emit a queue we can prepend to, and counts/values as padded + # batch. + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=3)) + + iterator = dataset.make_one_shot_iterator() + queue, (count, padded_value) = iterator.get_next() + + # Split the padded_value into two pieces: head and rest + rest_indices = array_ops.squeeze(array_ops.where(count > 2), axis=1) + bound = math_ops.minimum(2, math_ops.reduce_max(count)) + value_head = padded_value[:, :bound] + count_rest = array_ops.gather(count - 2, rest_indices) + value_rest = array_ops.gather(padded_value, rest_indices)[:, bound:] + queue_rest = array_ops.gather(queue, rest_indices) + enqueue_rest_op = tqd.enqueue_in_queue_dataset(queue_rest, + (count_rest, value_rest)) + with ops.control_dependencies([enqueue_rest_op]): + calc = array_ops.identity(value_head) + + with self.test_session() as sess: + self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc)) + self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc)) + self.assertAllEqual([[6, 6]], sess.run(calc)) + self.assertAllEqual([[6, 6]], sess.run(calc)) + # Get some final batches due to prefetching. + for _ in range(3): + try: + self.assertAllEqual( + np.empty(shape=(0, 0), dtype=np.int32), sess.run(calc)) + except errors.OutOfRangeError as e: + self.assertTrue(str(e).startswith("End of sequence")) + + def testNonstandardPadding(self): + dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) + # Make a dataset of variable-length vectors and their lengths. + dataset = dataset.map( + lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) + # Emit a queue we can prepend to, and counts/values as padded + # batch. + dataset = dataset.apply( + tqd.prepend_from_queue_and_padded_batch_dataset( + batch_size=3, padding_values=( + 0, + -1, + ))) + + iterator = dataset.make_one_shot_iterator() + _, (unused_count, padded_value) = iterator.get_next() + + with self.test_session() as sess: + self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]], + sess.run(padded_value)) + self.assertAllEqual([[6] * 6], sess.run(padded_value)) + with self.assertRaisesOpError("End of sequence"): + sess.run(padded_value) + + +# TODO(ebrevdo): Figure out how to use run_core_tests to test state +# saving of an iterator that's had some tensors enqueued into its queue. +class PrependFromQueueAndPaddedBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testPrependFromQueueAndPaddedBatch(self): + + def build_dataset(seq_lens): + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + lambda x: array_ops.fill([x], x)).apply( + tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=4)) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + def testPrependFromQueueAndPaddedBatchNonDefaultPadding(self): + + def build_dataset(seq_lens): + + def fill_tuple(x): + filled = array_ops.fill([x], x) + return (filled, string_ops.as_string(filled)) + + padded_shape = [-1] + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + fill_tuple).apply( + tqd.prepend_from_queue_and_padded_batch_dataset( + batch_size=4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, ""))) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc index 2992a61ea8186caada394208e9c27ddffe896dd1..9675428e56e93c9669753371dbca47d56325b0c4 100644 --- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc @@ -142,9 +142,9 @@ Status ConvertConstantsToImmutable(const string& in_graph_filename, const auto load_graph_status = ReadBinaryProto(default_env, in_graph_filename, &graph_def); if (!load_graph_status.ok()) { - return tensorflow::errors::NotFound("Failed to load graph at '", - in_graph_filename, "' : ", - load_graph_status.error_message()); + return tensorflow::errors::NotFound( + "Failed to load graph at '", in_graph_filename, + "' : ", load_graph_status.error_message()); } NodeConverter node_converter; diff --git a/tensorflow/contrib/util/inspect_checkpoint.cc b/tensorflow/contrib/util/inspect_checkpoint.cc index 39088aeaad68e26344b2e89ce10ae6da8026e481..9b578ceb07548b8d198f64bc859d31c92774a286 100644 --- a/tensorflow/contrib/util/inspect_checkpoint.cc +++ b/tensorflow/contrib/util/inspect_checkpoint.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/tensor_slice_reader.h" namespace tensorflow { diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md index 1b99f4ce4f645d0c59b2552cf26f47495cbbba73..58fed4e5cb4c24b0f21dfe9b99cf4c665d2591c7 100644 --- a/tensorflow/contrib/verbs/README.md +++ b/tensorflow/contrib/verbs/README.md @@ -25,9 +25,9 @@ The design is based on TensorFlow r1.0. An RDMA path is added between servers fo During the server setup, an RDMA manager is created to manage low-level RDMA components such as RDMA channel and RDMA adapter, an RDMA rendezvous manager is created to oversee send/recv operations between servers. Following the distributed TensorFlow design philosophy, the send operation is passive, i.e. merely placing a tensor in the local out-going table. It is the receive operation that actually initiates the tensor transfer. TensorFlow dynamically allocates memory for tensors that are to be sent or received. This causes difficulty for RDMA operations where pinned memory is required. Few remedies are possible: -1. The memory is pinned, transfered, then unpinned for each and every tensor to be transferred. This incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow. +1. The memory is pinned, transferred, then unpinned for each and every tensor to be transferred. This incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow. 2. Buffer is pre-allocated and pinned for each tensor. This incurs large memory overhead and extra copying from the tensor to its pinned buffer, but may still be faster than the former. -3. Following HKUST research on the use of GPU direct, and their [GDR implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gdr/README.md), there is a smart way to benefit from the TensorFlow allocation theme which is mostly pool based, i.e allocators pre-allocate a large memory block, and allocate the tensors from there. By attaching a custom Visitor to relevant alloactors, we can do a single registration of the entire memory block, which zeros the registration overhead. Once the block is registered, each new tensor allocated will be at a registred address, which will allow us to do direct RDMA writes to it. +3. Following HKUST research on the use of GPU direct, and their [GDR implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gdr/README.md), there is a smart way to benefit from the TensorFlow allocation theme which is mostly pool based, i.e allocators pre-allocate a large memory block, and allocate the tensors from there. By attaching a custom Visitor to relevant allocators, we can do a single registration of the entire memory block, which zeros the registration overhead. Once the block is registered, each new tensor allocated will be at a registered address, which will allow us to do direct RDMA writes to it. For best performance, we will adopt HKUST 0 copies approach in our solution. This means: @@ -77,7 +77,7 @@ When the receiver receives the **RDMA_MESSAGE_META_DATA_RESPONSE**, it will loca 1. Update the local meta-data cache. 2. Reallocate the result/proxy tensors. -3. Re-send the tensor request. For tracability, the new message has a different name: **RDMA_MESSAGE_TENSOR_RE_REQUEST**. +3. Re-send the tensor request. For traceability, the new message has a different name: **RDMA_MESSAGE_TENSOR_RE_REQUEST**. When the sender receives a **RDMA_MESSAGE_TENSOR_RE_REQUEST**, it will locate the relevant **RdmaTensorResponse** using the request index specified in the message, and invoke its **Resume()** method, which will RDMA write the contents of the tensor that was cloned earlier, to the new remote address specified in the re-request. @@ -93,7 +93,7 @@ When the receiver receives the RDMA write, it will locate the relevant **RdmaTen 1. When the sender receives a tensor request, the source tensor may or may not be ready yet. The situation is handled through a process of tag matching: * If the request arrives before the tensor is ready, then a callback is put in a local table, and will be invoked once the tensor arrives. - * If the tensor is ready before the request arives, than the tensor is put in a local table. When the request arrives, it will invoke the callback immediatly. + * If the tensor is ready before the request arives, than the tensor is put in a local table. When the request arrives, it will invoke the callback immediately. In code it is done by calling **RecvLocalAsync()**, which receives the tensor's key, step-id, and the callback. 2. When the callback is invoked, the relevant tensor is removed from the tag matching table. In the case where we need to send the tensor's meta-data, the **RdmaTensorResponse** will store a copy of the tensor until the re-request arrives. 3. The sending of protocol messages (**RDMA_MESSAGE_TENSOR_REQUEST**, **RDMA_MESSAGE_META_DATA_RESPONSE** and **RDMA_MESSAGE_TENSOR_RE_REQUEST**) is done by the class **RdmaMessageBuffer**. All messages are sent using RDMA writes from/to fixed messages buffers. This implies that we cannot send on a specific channel more than one message at a time. In order to synchronize the messages, the **RdmaMessageBuffer** holds the a local and remote buffer statuses which can be either busy or idle. When a write is issued, both statuses will be changed to busy. When the write-complete event is received, the local status is changed to idle. When the write is received on the remote side, the remote side will parse the message, and return an ACK back to the sending side on which the sending side will update the remote status to idle. When both the local and remote statuses are idle, the next message can be sent. @@ -115,7 +115,7 @@ When the receiver receives the RDMA write, it will locate the relevant **RdmaTen * Reallocate the result tensor (and proxy tensor if required). * Re-send the request to the remote side. * **RecvTensorContent()** - Receive tensor content from the remote side (RDMA write was completed). - * Decode proto if required and/or move to GPU if the content was not written to it directly (GPU direct is not avaliable). + * Decode proto if required and/or move to GPU if the content was not written to it directly (GPU direct is not available). * Invoke the done callback. * **class RdmaTensorResponse** - Holds and manages information for a single tensor response throughout the entire send cycle. API: * **Start()** - Start the response sequence. @@ -153,7 +153,7 @@ When the receiver receives the RDMA write, it will locate the relevant **RdmaTen * request_index - Request index. * is_dead/data_type/tensor_shape/tensor_bytes - The up-to-date meta-data. * checksum - In data validation mode, this will hold the checksum of the source tensor. -* **RDMA_MESSAGE_TENSOR_RE_REQUEST** - (receiver ==> sender) Tensor re-requset after meta-data update and reallocation of result/proxy tensors. +* **RDMA_MESSAGE_TENSOR_RE_REQUEST** - (receiver ==> sender) Tensor re-request after meta-data update and reallocation of result/proxy tensors. * type - The message type. * name (name_size) - Name of the requested tensor. * step_id - Step ID. diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index 47ed83f521c5e6165c906ea557e74faf27df2112..1a0b5028febb7b11f979abd179a3227a2615252d 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -49,8 +49,8 @@ VerbsServer::~VerbsServer() { Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, GrpcChannelCache** channel_cache) { string name_prefix = - strings::StrCat("/job:", server_def.job_name(), "/replica:0", "/task:", - server_def.task_index()); + strings::StrCat("/job:", server_def.job_name(), "/replica:0", + "/task:", server_def.task_index()); GrpcChannelSpec channel_spec; TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 94973a0e520e494ce2ccc947a803e10681ff5e21..30ac270109dbbb77cd6a400d1feaa1ac116456c1 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -193,6 +193,7 @@ CORE_PROTO_SRCS = [ "protobuf/rewriter_config.proto", "protobuf/tensor_bundle.proto", "protobuf/saver.proto", + "util/event.proto", "util/memmapped_file_system.proto", "util/saved_tensor_slice.proto", ] @@ -211,7 +212,6 @@ ADDITIONAL_CORE_PROTO_SRCS = [ "protobuf/named_tensor.proto", "protobuf/saved_model.proto", "protobuf/tensorflow_server.proto", - "util/event.proto", "util/test_log.proto", ] @@ -376,12 +376,22 @@ cc_library( hdrs = ["platform/abi.h"], ) +cc_library( + name = "session_message", + srcs = ["util/session_message.cc"], + hdrs = ["util/session_message.h"], + deps = [ + ":framework", + ":lib", + ":protos_all_cc", + ], +) + cc_library( name = "stacktrace_handler", srcs = ["platform/stacktrace_handler.cc"], hdrs = ["platform/stacktrace_handler.h"], deps = [ - ":abi", ":lib", ":lib_platform", ], @@ -434,6 +444,7 @@ tf_cuda_library( "framework/common_shape_fns.h", "framework/control_flow.h", # TODO(josh11b): Make internal? "framework/dataset.h", + "framework/dataset_stateful_op_whitelist.h", "framework/device_base.h", "framework/function.h", "framework/graph_def_util.h", @@ -454,6 +465,7 @@ tf_cuda_library( "framework/reader_interface.h", "framework/reader_op_kernel.h", "framework/register_types.h", + "framework/register_types_traits.h", "framework/resource_mgr.h", "framework/resource_op_kernel.h", "framework/selective_registration.h", @@ -611,6 +623,7 @@ tf_gen_op_libs( "list_ops", "lookup_ops", "logging_ops", + "manip_ops", "math_ops", "nn_ops", "no_op", @@ -693,6 +706,7 @@ cc_library( ":list_ops_op_lib", ":logging_ops_op_lib", ":lookup_ops_op_lib", + ":manip_ops_op_lib", ":math_ops_op_lib", ":nn_ops_op_lib", ":no_op_op_lib", @@ -784,6 +798,7 @@ tf_cuda_library( "graph/graph.h", "graph/graph_constructor.h", "graph/graph_def_builder.h", + "graph/graph_def_builder_util.h", "graph/node_builder.h", "graph/validate.h", "graph/while_context.h", @@ -823,6 +838,7 @@ cc_library( "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/kernels:histogram_op", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", @@ -830,6 +846,7 @@ cc_library( "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", + "//tensorflow/core/kernels:manip", "//tensorflow/core/kernels:math", "//tensorflow/core/kernels:multinomial_op", "//tensorflow/core/kernels:nn", @@ -1152,6 +1169,7 @@ cc_library( deps = [ ":protos_all_cc_impl", "//third_party/eigen3", + "@nsync//:nsync_cpp", "@protobuf_archive//:protobuf", ], alwayslink = 1, @@ -1317,6 +1335,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "framework/kernel_def_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "framework/kernel_def.proto", + visibility = ["//visibility:public"], +) + tf_pyclif_proto_library( name = "framework/node_def_pyclif", proto_lib = ":protos_all_cc", @@ -1352,6 +1377,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "protobuf/device_properties_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "protobuf/device_properties.proto", + visibility = ["//visibility:public"], +) + # ----------------------------------------------------------------------------- # Internal targets @@ -1711,6 +1743,9 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ "platform/variant_coding.h", "graph/edgeset.h", "graph/graph.h", + "graph/graph_def_builder.h", + "graph/node_builder.h", + "graph/tensor_id.h", ] + glob( [ "example/**/*.h", @@ -1728,6 +1763,7 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ "framework/reader_base.*", "util/memmapped_file_system.*", "util/memmapped_file_system_writer.*", + "util/session_message.*", "util/version_info.cc", ], ) + select({ @@ -1797,6 +1833,9 @@ tf_cuda_library( ] + [ "graph/edgeset.cc", "graph/graph.cc", + "graph/graph_def_builder.cc", + "graph/node_builder.cc", + "graph/tensor_id.cc", "graph/while_context.h", "graph/while_context.cc", ], @@ -1811,6 +1850,7 @@ tf_cuda_library( "framework/resource_handle.cc", "util/memmapped_file_system.*", "util/memmapped_file_system_writer.*", + "util/session_message.cc", "util/version_info.cc", ], ) + select({ @@ -1896,6 +1936,13 @@ cc_library( ], ) +tf_cuda_library( + name = "cuda_device_functions", + hdrs = ["util/cuda_device_functions.h"], + visibility = ["//visibility:public"], + deps = [":framework_lite"], +) + # TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"? cc_library( name = "protos_cc", @@ -1918,6 +1965,7 @@ GRAPH_HDRS = [ "graph/graph.h", "graph/graph_constructor.h", # NOTE(mrry): Don't include the .cc since it depends on common_runtime. "graph/graph_def_builder.h", + "graph/graph_def_builder_util.h", "graph/graph_partition.h", "graph/mkl_layout_pass.h", "graph/mkl_tfconversion_pass.h", @@ -1938,12 +1986,9 @@ tf_cuda_library( "graph/colors.cc", "graph/control_flow.cc", "graph/costmodel.cc", - "graph/graph_def_builder.cc", "graph/graph_partition.cc", - "graph/node_builder.cc", "graph/optimizer_cse.cc", "graph/subgraph.cc", - "graph/tensor_id.cc", "graph/validate.cc", ], hdrs = GRAPH_HDRS, @@ -1972,6 +2017,7 @@ tf_cuda_library( "common_runtime/shape_refiner.h", "framework/versions.h", "graph/graph_constructor.cc", # Depends on common_runtime. + "graph/graph_def_builder_util.cc", # Depends on common_runtime. "public/session.h", "public/session_options.h", "public/version.h", @@ -2216,12 +2262,25 @@ tf_cuda_library( ] + tf_additional_device_tracer_deps(), ) +cc_library( + name = "gpu_id", + srcs = ["common_runtime/gpu/gpu_id_manager.cc"], + hdrs = [ + "common_runtime/gpu/gpu_id.h", + "common_runtime/gpu/gpu_id_manager.h", + ], + deps = [ + ":lib", + ], +) + GPU_RUNTIME_HEADERS = [ "common_runtime/gpu/gpu_bfc_allocator.h", "common_runtime/gpu/gpu_cudamalloc_allocator.h", "common_runtime/gpu/gpu_debug_allocator.h", "common_runtime/gpu/gpu_device.h", "common_runtime/gpu/gpu_id.h", + "common_runtime/gpu/gpu_id_manager.h", "common_runtime/gpu/gpu_id_utils.h", "common_runtime/gpu/gpu_init.h", "common_runtime/gpu/gpu_managed_allocator.h", @@ -2240,7 +2299,6 @@ tf_cuda_library( "common_runtime/gpu/gpu_debug_allocator.cc", "common_runtime/gpu/gpu_device.cc", "common_runtime/gpu/gpu_device_factory.cc", - "common_runtime/gpu/gpu_id_utils.cc", "common_runtime/gpu/gpu_managed_allocator.cc", "common_runtime/gpu/gpu_stream_util.cc", "common_runtime/gpu/gpu_util.cc", @@ -2255,6 +2313,7 @@ tf_cuda_library( ":core_cpu_lib", ":framework", ":framework_internal", + ":gpu_id", ":gpu_init_impl", ":gpu_lib", ":graph", @@ -2406,7 +2465,6 @@ cc_library( deps = [ ":lib", ":lib_internal", - ":stacktrace_handler", ":test", # buildcleaner: keep "//tensorflow/core/platform/default/build_config:test_main", ], @@ -2845,6 +2903,7 @@ tf_cc_tests_gpu( linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":gpu_headers_lib", + ":gpu_id", ":gpu_runtime", ":test", ], @@ -2856,7 +2915,7 @@ tf_cc_tests_gpu( srcs = glob(["user_ops/**/*_test.cc"]) + [ "common_runtime/gpu/gpu_bfc_allocator_test.cc", "common_runtime/gpu/gpu_device_test.cc", - "common_runtime/gpu/gpu_id_utils_test.cc", + "common_runtime/gpu/gpu_id_manager_test.cc", "common_runtime/gpu/gpu_event_mgr_test.cc", "common_runtime/gpu/pool_allocator_test.cc", ], @@ -2868,6 +2927,7 @@ tf_cc_tests_gpu( ":direct_session", ":framework", ":framework_internal", + ":gpu_id", ":gpu_runtime", ":lib", ":lib_internal", @@ -3263,6 +3323,7 @@ tf_cc_test_gpu( ":direct_session", ":framework", ":framework_internal", + ":gpu_id", ":gpu_runtime", ":lib", ":lib_internal", diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 81187ff6b772633105e0962d9da8f87d6cfd9558..58dbac4e8edac7079d315fbfcdafbd136793df0b 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -96,6 +96,7 @@ tf_cc_test( srcs = ["api_test.cc"], data = [ ":base_api_def", + ":python_api_def", ], deps = [ ":excluded_ops_lib", diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 112c55ccc3ba1262b48c1b6c0890b3ae22744383..477a0b670e49f8aa4ee8c250d4957886eb865ed5 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -41,8 +41,9 @@ namespace tensorflow { namespace { constexpr char kDefaultApiDefDir[] = "tensorflow/core/api_def/base_api"; +constexpr char kPythonApiDefDir[] = + "tensorflow/core/api_def/python_api"; constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt"; -} // namespace // Reads golden ApiDef files and returns a map from file name to ApiDef file // contents. @@ -66,9 +67,93 @@ void GetGoldenApiDefs(Env* env, const string& api_files_dir, } } -class ApiTest : public ::testing::Test { +void TestAllApiDefsHaveCorrespondingOp( + const OpList& ops, const std::unordered_map& api_defs_map) { + std::unordered_set op_names; + for (const auto& op : ops.op()) { + op_names.insert(op.name()); + } + for (const auto& name_and_api_def : api_defs_map) { + ASSERT_TRUE(op_names.find(name_and_api_def.first) != op_names.end()) + << name_and_api_def.first << " op has ApiDef but missing from ops. " + << "Does api_def_" << name_and_api_def.first << " need to be deleted?"; + } +} + +void TestAllApiDefInputArgsAreValid( + const OpList& ops, const std::unordered_map& api_defs_map) { + for (const auto& op : ops.op()) { + const auto api_def_iter = api_defs_map.find(op.name()); + if (api_def_iter == api_defs_map.end()) { + continue; + } + const auto& api_def = api_def_iter->second; + for (const auto& api_def_arg : api_def.in_arg()) { + bool found_arg = false; + for (const auto& op_arg : op.input_arg()) { + if (api_def_arg.name() == op_arg.name()) { + found_arg = true; + break; + } + } + ASSERT_TRUE(found_arg) + << "Input argument " << api_def_arg.name() + << " (overwritten in api_def_" << op.name() + << ".pbtxt) is not defined in OpDef for " << op.name(); + } + } +} + +void TestAllApiDefOutputArgsAreValid( + const OpList& ops, const std::unordered_map& api_defs_map) { + for (const auto& op : ops.op()) { + const auto api_def_iter = api_defs_map.find(op.name()); + if (api_def_iter == api_defs_map.end()) { + continue; + } + const auto& api_def = api_def_iter->second; + for (const auto& api_def_arg : api_def.out_arg()) { + bool found_arg = false; + for (const auto& op_arg : op.output_arg()) { + if (api_def_arg.name() == op_arg.name()) { + found_arg = true; + break; + } + } + ASSERT_TRUE(found_arg) + << "Output argument " << api_def_arg.name() + << " (overwritten in api_def_" << op.name() + << ".pbtxt) is not defined in OpDef for " << op.name(); + } + } +} + +void TestAllApiDefAttributeNamesAreValid( + const OpList& ops, const std::unordered_map& api_defs_map) { + for (const auto& op : ops.op()) { + const auto api_def_iter = api_defs_map.find(op.name()); + if (api_def_iter == api_defs_map.end()) { + continue; + } + const auto& api_def = api_def_iter->second; + for (const auto& api_def_attr : api_def.attr()) { + bool found_attr = false; + for (const auto& op_attr : op.attr()) { + if (api_def_attr.name() == op_attr.name()) { + found_attr = true; + } + } + ASSERT_TRUE(found_attr) + << "Attribute " << api_def_attr.name() << " (overwritten in api_def_" + << op.name() << ".pbtxt) is not defined in OpDef for " << op.name(); + } + } +} +} // namespace + +class BaseApiTest : public ::testing::Test { protected: - ApiTest() { + BaseApiTest() { OpRegistry::Global()->Export(false, &ops_); const std::vector multi_line_fields = {"description"}; @@ -80,7 +165,7 @@ class ApiTest : public ::testing::Test { }; // Check that all ops have an ApiDef. -TEST_F(ApiTest, AllOpsAreInApiDef) { +TEST_F(BaseApiTest, AllOpsAreInApiDef) { auto* excluded_ops = GetExcludedOps(); for (const auto& op : ops_.op()) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { @@ -94,16 +179,8 @@ TEST_F(ApiTest, AllOpsAreInApiDef) { } // Check that ApiDefs have a corresponding op. -TEST_F(ApiTest, AllApiDefsHaveCorrespondingOp) { - std::unordered_set op_names; - for (const auto& op : ops_.op()) { - op_names.insert(op.name()); - } - for (const auto& name_and_api_def : api_defs_map_) { - ASSERT_TRUE(op_names.find(name_and_api_def.first) != op_names.end()) - << name_and_api_def.first << " op has ApiDef but missing from ops. " - << "Does api_def_" << name_and_api_def.first << " need to be deleted?"; - } +TEST_F(BaseApiTest, AllApiDefsHaveCorrespondingOp) { + TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_); } string GetOpDefHasDocStringError(const string& op_name) { @@ -117,7 +194,7 @@ string GetOpDefHasDocStringError(const string& op_name) { // Check that OpDef's do not have descriptions and summaries. // Descriptions and summaries must be in corresponding ApiDefs. -TEST_F(ApiTest, OpDefsShouldNotHaveDocs) { +TEST_F(BaseApiTest, OpDefsShouldNotHaveDocs) { auto* excluded_ops = GetExcludedOps(); for (const auto& op : ops_.op()) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { @@ -143,62 +220,56 @@ TEST_F(ApiTest, OpDefsShouldNotHaveDocs) { // Checks that input arg names in an ApiDef match input // arg names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefInputArgsAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_arg : api_def.in_arg()) { - bool found_arg = false; - for (const auto& op_arg : op.input_arg()) { - if (api_def_arg.name() == op_arg.name()) { - found_arg = true; - break; - } - } - ASSERT_TRUE(found_arg) - << "Input argument " << api_def_arg.name() - << " (overwritten in api_def_" << op.name() - << ".pbtxt) is not defined in OpDef for " << op.name(); - } - } +TEST_F(BaseApiTest, AllApiDefInputArgsAreValid) { + TestAllApiDefInputArgsAreValid(ops_, api_defs_map_); } // Checks that output arg names in an ApiDef match output // arg names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefOutputArgsAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_arg : api_def.out_arg()) { - bool found_arg = false; - for (const auto& op_arg : op.output_arg()) { - if (api_def_arg.name() == op_arg.name()) { - found_arg = true; - break; - } - } - ASSERT_TRUE(found_arg) - << "Output argument " << api_def_arg.name() - << " (overwritten in api_def_" << op.name() - << ".pbtxt) is not defined in OpDef for " << op.name(); - } - } +TEST_F(BaseApiTest, AllApiDefOutputArgsAreValid) { + TestAllApiDefOutputArgsAreValid(ops_, api_defs_map_); } // Checks that attribute names in an ApiDef match attribute // names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefAttributeNamesAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_attr : api_def.attr()) { - bool found_attr = false; - for (const auto& op_attr : op.attr()) { - if (api_def_attr.name() == op_attr.name()) { - found_attr = true; - } - } - ASSERT_TRUE(found_attr) - << "Attribute " << api_def_attr.name() << " (overwritten in api_def_" - << op.name() << ".pbtxt) is not defined in OpDef for " << op.name(); - } +TEST_F(BaseApiTest, AllApiDefAttributeNamesAreValid) { + TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); +} + +class PythonApiTest : public ::testing::Test { + protected: + PythonApiTest() { + OpRegistry::Global()->Export(false, &ops_); + const std::vector multi_line_fields = {"description"}; + + Env* env = Env::Default(); + GetGoldenApiDefs(env, kPythonApiDefDir, &api_defs_map_); } + OpList ops_; + std::unordered_map api_defs_map_; +}; + +// Check that ApiDefs have a corresponding op. +TEST_F(PythonApiTest, AllApiDefsHaveCorrespondingOp) { + TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_); } + +// Checks that input arg names in an ApiDef match input +// arg names in corresponding OpDef. +TEST_F(PythonApiTest, AllApiDefInputArgsAreValid) { + TestAllApiDefInputArgsAreValid(ops_, api_defs_map_); +} + +// Checks that output arg names in an ApiDef match output +// arg names in corresponding OpDef. +TEST_F(PythonApiTest, AllApiDefOutputArgsAreValid) { + TestAllApiDefOutputArgsAreValid(ops_, api_defs_map_); +} + +// Checks that attribute names in an ApiDef match attribute +// names in corresponding OpDef. +TEST_F(PythonApiTest, AllApiDefAttributeNamesAreValid) { + TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); +} + } // namespace tensorflow diff --git a/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt index 5d21d7bab699ff481c65ed44eb9bf66ec14ea387..ac05b54eea95f70e4a6db843aab13adf7b94602c 100644 --- a/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_AssignAddVariableOp.pbtxt @@ -20,10 +20,7 @@ END } summary: "Adds a value to the current value of a variable." description: < [3, 4, 0, 1, 2] + +# shifting along multiple dimensions +# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] + +# shifting along the same axis multiple times +# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] +``` +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt index db890cb2f51256fd9dabaa8aa590ccde37eec343..5e2912fcdd7324f219b430860784903f85f31dca 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt @@ -16,7 +16,7 @@ END } summary: "Computes the maximum along segments of a tensor." description: <::min()`. +If the maximum is empty for a given segment ID `i`, it outputs the smallest +possible value for the specific numeric type, +`output[i] = numeric_limits::lowest()`.
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..55ea69b5dd5f7fda5c877ca5771ec2cbb86e3a9a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt @@ -0,0 +1,33 @@ +op { + graph_op_name: "UnsortedSegmentMin" + in_arg { + name: "segment_ids" + description: <::max()`. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..577ff53d60c5a174b4ba43a667885a6983b2dfb9 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt @@ -0,0 +1,32 @@ +op { + graph_op_name: "UnsortedSegmentProd" + in_arg { + name: "segment_ids" + description: <requested_size; } -size_t BFCAllocator::AllocatedSize(void* ptr) { +size_t BFCAllocator::AllocatedSize(const void* ptr) { mutex_lock l(lock_); BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr); CHECK(h != kInvalidChunkHandle) @@ -539,7 +539,7 @@ size_t BFCAllocator::AllocatedSize(void* ptr) { return c->size; } -int64 BFCAllocator::AllocationId(void* ptr) { +int64 BFCAllocator::AllocationId(const void* ptr) { mutex_lock l(lock_); BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr); CHECK(h != kInvalidChunkHandle) diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h index 3dd011a58e4724a8db34703ec68055c3a3a26fa3..b8e773503c7a2f8024e8a6f58247ad343a762f71 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.h +++ b/tensorflow/core/common_runtime/bfc_allocator.h @@ -62,11 +62,11 @@ class BFCAllocator : public VisitableAllocator { bool TracksAllocationSizes() override; - size_t RequestedSize(void* ptr) override; + size_t RequestedSize(const void* ptr) override; - size_t AllocatedSize(void* ptr) override; + size_t AllocatedSize(const void* ptr) override; - int64 AllocationId(void* ptr) override; + int64 AllocationId(const void* ptr) override; void GetStats(AllocatorStats* stats) override; @@ -127,10 +127,10 @@ class BFCAllocator : public VisitableAllocator { string DebugString(BFCAllocator* a, bool recurse) NO_THREAD_SAFETY_ANALYSIS { string dbg; - strings::StrAppend(&dbg, " Size: ", strings::HumanReadableNumBytes(size), - " | Requested Size: ", - strings::HumanReadableNumBytes(requested_size), - " | in_use: ", in_use()); + strings::StrAppend( + &dbg, " Size: ", strings::HumanReadableNumBytes(size), + " | Requested Size: ", strings::HumanReadableNumBytes(requested_size), + " | in_use: ", in_use()); if (recurse && prev != BFCAllocator::kInvalidChunkHandle) { Chunk* p = a->ChunkFromHandle(prev); strings::StrAppend(&dbg, ", prev: ", p->DebugString(a, false)); diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 0398c2a60d1fe4dfeed91e242272f13dd45389b2..b5a51d2526d95313d4564337ae0420472bc0b3da 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -328,7 +328,8 @@ void FindConstantFoldableNodes( ConsiderConstantFoldableNode( n, opts, nodes, constant_control_deps, shape_replacement_map, &internal_node_inserted); - }); + }, + NodeComparatorName()); // If we have inserted just leaf level nodes, then there is nothing to fold. if (!internal_node_inserted) { nodes->clear(); @@ -339,8 +340,8 @@ void FindConstantFoldableNodes( typedef std::pair NodeAndOutput; int64 UniqueConstantId() { - static std::atomic_int_fast64_t id; - return id.fetch_add(1); + static std::atomic_int_fast64_t unique_constant_id; + return unique_constant_id.fetch_add(1); } // Adds n to constant_graph which is being built up for subsequent evaluation of @@ -386,14 +387,12 @@ void AddShapeNodeToConstantGraph( const std::unordered_map>& shape_replacement_map, std::unordered_map>* node_map, - Graph* constant_graph) { + const ConstantFoldNameGenerator& generate_new_name, Graph* constant_graph) { std::vector& added = (*node_map)[n]; const string& node_name = n->name(); for (const Tensor& t : shape_replacement_map.at(n)) { auto builder = - NodeDefBuilder(strings::StrCat(constant_graph->NewName(node_name), - "__cf__", UniqueConstantId()), - "Const") + NodeDefBuilder(generate_new_name(constant_graph, node_name), "Const") .Attr("dtype", t.dtype()) .Attr("value", t); NodeDef def; @@ -414,7 +413,8 @@ Graph* GetConstantGraph( const Graph* orig_graph, const std::vector& nodes, const std::unordered_map>& shape_replacement_map, - std::map* tensors_to_fetch) { + std::map* tensors_to_fetch, + const ConstantFoldNameGenerator& generate_new_name) { Graph* constant_graph = new Graph(orig_graph->op_registry()); std::unordered_map> node_map; node_map[orig_graph->source_node()] = {constant_graph->source_node()}; @@ -424,7 +424,7 @@ Graph* GetConstantGraph( AddNodeToConstantGraph(n, &node_map, constant_graph); } else { AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map, - constant_graph); + generate_new_name, constant_graph); } } @@ -458,10 +458,11 @@ Graph* GetConstantGraph( // replacement was successful, false otherwise. // 'control_deps' is the set of nodes that should be control predecessors of the // new constant node. -bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device, - NodeAndOutput tensor, const Tensor& constant, - const gtl::FlatSet& control_deps, - int64 max_constant_size_in_bytes) { +bool ReplaceTensorWithConstant( + Graph* graph, Device* partition_device, NodeAndOutput tensor, + const Tensor& constant, const gtl::FlatSet& control_deps, + int64 max_constant_size_in_bytes, + const ConstantFoldNameGenerator& generate_new_name) { // Be conservative when replacing a tensor with a constant, when not // running on CPU. // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY @@ -509,9 +510,7 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device, } const string& node_name = n->name(); Node* constant_node; - auto builder = NodeDefBuilder(strings::StrCat(graph->NewName(node_name), - "__cf__", UniqueConstantId()), - "Const") + auto builder = NodeDefBuilder(generate_new_name(graph, node_name), "Const") .Attr("dtype", constant.dtype()) .Attr("value", constant); if (partition_device) { @@ -555,6 +554,13 @@ Status ConstantFold(const ConstantFoldingOptions& opts, FunctionLibraryRuntime* function_library, Env* env, Device* partition_device, Graph* graph, bool* was_mutated) { DumpGraph("Before", graph); + ConstantFoldNameGenerator generate_new_name = opts.generate_new_name; + if (generate_new_name == nullptr) { + generate_new_name = [](Graph* graph, string old_name) { + return strings::StrCat(graph->NewName(old_name), "__cf__", + UniqueConstantId()); + }; + } std::vector constant_foldable_nodes; std::unordered_map> constant_control_deps; @@ -571,7 +577,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts, std::map tensors_to_fetch; std::unique_ptr constant_graph( GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map, - &tensors_to_fetch)); + &tensors_to_fetch, generate_new_name)); DumpGraph("Constant graph", constant_graph.get()); if (tensors_to_fetch.empty()) { @@ -585,7 +591,16 @@ Status ConstantFold(const ConstantFoldingOptions& opts, std::vector tensors_to_fetch_names; std::vector tensors_to_replace; - for (auto n : tensors_to_fetch) { + // Sorting the nodes based on the name gives us a stable ordering between runs + // for the same graph. + std::vector> tensors_to_fetch_sorted( + tensors_to_fetch.begin(), tensors_to_fetch.end()); + std::sort(tensors_to_fetch_sorted.begin(), tensors_to_fetch_sorted.end(), + [](const std::pair& n1, + const std::pair& n2) { + return n1.first.first->name() < n2.first.first->name(); + }); + for (auto n : tensors_to_fetch_sorted) { tensors_to_fetch_names.push_back( strings::StrCat(n.first.first->name(), ":", n.first.second)); tensors_to_replace.push_back({n.second, n.first.second}); @@ -617,7 +632,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts, constant_control_deps[tensors_to_replace[c].first]; if (ReplaceTensorWithConstant( graph, partition_device, tensors_to_replace[c], outputs[c], - control_deps, opts.max_constant_size_in_bytes)) { + control_deps, opts.max_constant_size_in_bytes, generate_new_name)) { ++num_nodes_replaced; } } diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h index e4d724c58a25347db3e40a0d024acf1ac97ea575..b1e1fb831963bccb81731752ec76b9d5be123d9f 100644 --- a/tensorflow/core/common_runtime/constant_folding.h +++ b/tensorflow/core/common_runtime/constant_folding.h @@ -24,6 +24,11 @@ limitations under the License. namespace tensorflow { +// This generator type is used to generate a name for the newly folded node +// based on the node's old name. +using ConstantFoldNameGenerator = + std::function; + // Options specific to constant folding optimizations. struct ConstantFoldingOptions { // If "consider" is not a nullptr, then only constant fold a node "n" if @@ -37,6 +42,11 @@ struct ConstantFoldingOptions { // The maximum size of each constant created during constant folding // optimization. int64 max_constant_size_in_bytes = 10 * 1024 * 1024; + + // A generator for the name suffix of constant folded nodes. A + // default id generator that monotonically increases is used if nullptr is + // passed. + ConstantFoldNameGenerator generate_new_name = nullptr; }; // Perform constant folding optimization on "graph". diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 923a4d924936386ce0e06c6355c2a4d0af5cc4a4..6ac9319ad1e2c4953c2d82257dac6a3aeeffcd5c 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -121,6 +121,58 @@ TEST_F(ConstantFoldingTest, Basic) { {2, 2}); } +// Tests that different node creation ordering creates same graph after constant +// folding. +TEST_F(ConstantFoldingTest, DeterministicFolding) { + auto build_graph_and_constant_folding = [](Graph& g, bool swap) -> Status { + Scope s = Scope::NewRootScope(); + auto a = ops::Const(s, {1.0}, {}); + auto b = ops::Const(s, {2.0}, {}); + + if (swap) { + auto add1 = ops::Add(s.WithOpName("add1"), a, b); + auto add2 = ops::Add(s.WithOpName("add2"), a, b); + auto s1 = + ops::_Send(s.WithOpName("s1"), add1, "add1", "sender", 0, "receiver"); + auto s2 = + ops::_Send(s.WithOpName("s2"), add2, "add2", "sender", 0, "receiver"); + } else { + // Swap the order of node creation. + auto add2 = ops::Add(s.WithOpName("add2"), a, b); + auto add1 = ops::Add(s.WithOpName("add1"), a, b); + auto s1 = + ops::_Send(s.WithOpName("s1"), add1, "add1", "sender", 0, "receiver"); + auto s2 = + ops::_Send(s.WithOpName("s2"), add2, "add2", "sender", 0, "receiver"); + } + + TF_CHECK_OK(s.ToGraph(&g)); + bool was_mutated; + int64 unique_id = 0; + auto generate_new_name = [&unique_id](Graph* graph, string old_name) { + return strings::StrCat(graph->NewName(old_name), "__cf__", unique_id++); + }; + ConstantFoldingOptions opt{}; + opt.generate_new_name = generate_new_name; + TF_CHECK_OK( + ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + return Status::OK(); + }; + + Graph g1(OpRegistry::Global()); + TF_ASSERT_OK(build_graph_and_constant_folding(g1, false)); + Graph g2(OpRegistry::Global()); + TF_ASSERT_OK(build_graph_and_constant_folding(g2, true)); + EXPECT_EQ(g1.num_nodes(), g2.num_nodes()); + auto index = NodeNameIndex(g2); + + // All the nodes in g1 are expected to be present in g2. + for (int64 i = 0; i < g1.num_nodes(); ++i) { + Node* n1 = g1.FindNodeId(i); + EXPECT_GT(index.count(n1->name()), 0); + } +} + TEST_F(ConstantFoldingTest, ConsiderFunction) { Scope s = Scope::NewRootScope(); BuildSimpleGraph(&s); diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc index 0507076c8c3734083ac0ef7ffea0edebf180ad1a..fd9c4222a7afd4914415c9c62e1ced118ea75d1f 100644 --- a/tensorflow/core/common_runtime/device_set_test.cc +++ b/tensorflow/core/common_runtime/device_set_test.cc @@ -88,7 +88,9 @@ TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) { // D3 is prioritized below D1. AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0"); EXPECT_EQ((std::vector{ - DeviceType("d2"), DeviceType("d1"), DeviceType("d3"), + DeviceType("d2"), + DeviceType("d1"), + DeviceType("d3"), }), types()); } diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 20c59ad42b35865f6fdad60e8bc8ac5ffebc4415..ecbffcbf6c4030bde82f2abe0e7779bf9c5a9870 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -61,7 +61,6 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/env_var.h" - namespace tensorflow { namespace { @@ -472,9 +471,9 @@ Status DirectSession::Run(const RunOptions& run_options, Executor::Args args; args.step_id = step_id_counter_.fetch_add(1); - TF_RETURN_IF_ERROR( - GetOrCreateExecutors(input_tensor_names, output_names, target_nodes, - &executors_and_keys, &run_state_args)); + TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names, + target_nodes, &executors_and_keys, + &run_state_args)); const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); std::unique_ptr debugger_state; @@ -1251,7 +1250,7 @@ Status DirectSession::GetOrCreateExecutors( item->device = device; Executor* executor; TF_RETURN_IF_ERROR( - NewLocalExecutor(params, partition_graph.release(), &executor)); + NewLocalExecutor(params, std::move(partition_graph), &executor)); item->executor.reset(executor); } diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 99b33e2ef0d532aca08dfb538857d347d22a7351..b75a4f76d94f704cf38a6c4657b6089a863c085f 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -436,10 +436,7 @@ TEST(DirectSessionTest, FetchMultipleTimes) { } } -REGISTER_OP("Darth") - .Input("x: float") - .Output("y: float") - .Doc(R"doc( +REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc( Darth promises one return value. x: float @@ -972,39 +969,38 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib, std::atomic num_done(0); // Runs session to compute :0 using inter_op thread pool . - auto add_session_run_call = [use_global_pools, &def, &options, &sessions, - &sessions_mu, - &num_done](thread::ThreadPool* tp, Node* node, - int inter_op_pool) { - auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu, - inter_op_pool, node, &num_done]() { - RunOptions run_options; - run_options.set_inter_op_thread_pool(inter_op_pool); - std::vector outputs; - - Session* session; - if (use_global_pools) { - std::unique_ptr s(NewSession(options)); - TF_ASSERT_OK(s->Create(def)); - session = s.get(); - - mutex_lock l(sessions_mu); - sessions.emplace_back(std::move(s)); - } else { - session = sessions[0].get(); - } + auto add_session_run_call = + [use_global_pools, &def, &options, &sessions, &sessions_mu, &num_done]( + thread::ThreadPool* tp, Node* node, int inter_op_pool) { + auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu, + inter_op_pool, node, &num_done]() { + RunOptions run_options; + run_options.set_inter_op_thread_pool(inter_op_pool); + std::vector outputs; + + Session* session; + if (use_global_pools) { + std::unique_ptr s(NewSession(options)); + TF_ASSERT_OK(s->Create(def)); + session = s.get(); + + mutex_lock l(sessions_mu); + sessions.emplace_back(std::move(s)); + } else { + session = sessions[0].get(); + } - Status s = session->Run(run_options, {} /* inputs */, - {node->name() + ":0"} /* output_names */, {}, - &outputs, nullptr /* run_metadata */); - TF_CHECK_OK(s); - ASSERT_EQ(1, outputs.size()); - auto flat = outputs[0].flat(); - EXPECT_FLOAT_EQ(1.2, flat(0)); - num_done.fetch_add(1); - }; - tp->Schedule(fn); - }; + Status s = session->Run(run_options, {} /* inputs */, + {node->name() + ":0"} /* output_names */, {}, + &outputs, nullptr /* run_metadata */); + TF_CHECK_OK(s); + ASSERT_EQ(1, outputs.size()); + auto flat = outputs[0].flat(); + EXPECT_FLOAT_EQ(1.2, flat(0)); + num_done.fetch_add(1); + }; + tp->Schedule(fn); + }; // For blocking states: // - Starts at 0, BlockingOp::Compute will move to 1. diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index df9cf0c91f1b7e5521061b6915fc1b7ed609e003..31fb128f937ae46eefb309fc9bab8167e54846a7 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -161,14 +161,14 @@ static void TestHWAccelerator(bool enableHWTrace) { x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); #ifdef TENSORFLOW_USE_SYCL x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // y = A * x Node* y = test::graph::Matmul(&graph, a, x, false, false); y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); #ifdef TENSORFLOW_USE_SYCL -y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); -#endif // TENSORFLOW_USE_SYCL + y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); +#endif // TENSORFLOW_USE_SYCL Node* y_neg = test::graph::Unary(&graph, "Neg", y); y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); @@ -181,7 +181,7 @@ y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); (*options.config.mutable_device_count())["GPU"] = 1; #ifdef TENSORFLOW_USE_SYCL (*options.config.mutable_device_count())["SYCL"] = 1; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL options.config.set_allow_soft_placement(true); options.config.mutable_graph_options()->set_build_cost_model(1); std::unique_ptr session(NewSession(options)); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 9d03caff1e1e89c4c667f94853352580545e70e5..b06b75d6585f01640374eb7ab9842bf441cf9411 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -172,7 +172,7 @@ void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) { stats->AddAllocation(allocator_pair.first, allocator_pair.second); } auto* ms = stats->stats()->mutable_memory_stats(); - ms->set_temp_memory_size(ctx->temp_memory_size()); + ms->set_temp_memory_size(ctx->temp_memory_allocated()); for (const auto& alloc_id : ctx->persistent_alloc_ids()) { ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); } @@ -332,8 +332,8 @@ class GraphView { class ExecutorImpl : public Executor { public: - ExecutorImpl(const LocalExecutorParams& p, const Graph* g) - : params_(p), graph_(g), gview_() { + ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr g) + : params_(p), graph_(std::move(g)), gview_() { CHECK(p.create_kernel != nullptr); CHECK(p.delete_kernel != nullptr); } @@ -348,7 +348,6 @@ class ExecutorImpl : public Executor { for (auto fiter : frame_info_) { delete fiter.second; } - delete graph_; } Status Initialize(); @@ -412,7 +411,7 @@ class ExecutorImpl : public Executor { // Owned. LocalExecutorParams params_; - const Graph* graph_; + std::unique_ptr graph_; GraphView gview_; // A cached value of params_ @@ -605,11 +604,11 @@ void GetMaxPendingCounts(const Node* n, size_t* max_pending, } Status ExecutorImpl::Initialize() { - gview_.Initialize(graph_); + gview_.Initialize(graph_.get()); // Build the information about frames in this subgraph. ControlFlowInfo cf_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_, &cf_info)); + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info)); // Cache this value so we make this virtual function call once, rather // that O(# steps * # nodes per step) times. @@ -676,9 +675,9 @@ Status ExecutorImpl::Initialize() { // Initialize PendingCounts only after item->pending_id is initialized for // all nodes. - InitializePending(graph_, cf_info); + InitializePending(graph_.get(), cf_info); - return gview_.SetAllocAttrs(graph_, params_.device); + return gview_.SetAllocAttrs(graph_.get(), params_.device); } Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) { @@ -1415,7 +1414,7 @@ void ExecutorImpl::InitializePending(const Graph* graph, } void ExecutorState::RunAsync(Executor::DoneCallback done) { - const Graph* graph = impl_->graph_; + const Graph* graph = impl_->graph_.get(); TaggedNodeSeq ready; // Ask the device to fill in the device context map. @@ -1609,7 +1608,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { auto done = [this, state]() { Device* device = impl_->params_.device; NodeExecStatsWrapper* stats = state->stats; // Shorthand - Entry* first_input = state->first_input; // Shorthand + Entry* first_input = state->first_input; // Shorthand nodestats::SetOpEnd(stats); EntryVector outputs; @@ -1776,6 +1775,19 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, entry->ref_mu = nullptr; inp->tensor = entry->val.get(); + // The dtype of entry->ref could have been changed by another operation + // that ran after the operation that "produced" it executed, so + // re-validate that the type of the dereferenced tensor matches the + // expected input type. + if (item.input_type(i) != inp->tensor->dtype()) { + return AttachDef( + errors::InvalidArgument( + i, "-th input expects type ", + DataTypeString(item.input_type(i)), + " but automatically dereferenced input tensor has type ", + DataTypeString(inp->tensor->dtype())), + item.kernel->def()); + } } } } @@ -2593,9 +2605,10 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { } // end namespace -Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph, +Status NewLocalExecutor(const LocalExecutorParams& params, + std::unique_ptr graph, Executor** executor) { - ExecutorImpl* impl = new ExecutorImpl(params, graph); + ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph)); const Status s = impl->Initialize(); if (s.ok()) { *executor = impl; diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 3fd932da5b6c44833ba940351dad6cf373ffa05c..adf80a2417e2a86e874dd1d1068a1bbb611ff882 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -122,9 +122,8 @@ class Executor { // Creates an Executor that computes the given "graph". // -// If successful, returns the constructed executor in "*executor". The -// caller keeps the ownership of "device". The returned executor takes -// the ownership of "graph". Otherwise, returns an error status. +// If successful, returns the constructed executor in "*executor". Otherwise, +// returns an error status. // // "params" provides a set of context for the executor. We expect that // different context would provide different implementations. @@ -143,7 +142,8 @@ struct LocalExecutorParams { Executor::Args::NodeOutputsCallback node_outputs_cb; }; ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, - const Graph* graph, Executor** executor); + std::unique_ptr graph, + Executor** executor); // A class to help run multiple executors in parallel and wait until // all of them are complete. diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index e9c4328f29e2c941afd8e14142beb0db224110d8..b941819838a7b155d8c8f54985bd6ae8bc15ce9d 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -97,12 +97,11 @@ static Node* AddNoOp(Graph* g) { static Node* AddIdentity(Graph* g, Endpoint input) { DCHECK_LT(0, input.dtype()); - DCHECK_LT(input.dtype(), DT_FLOAT_REF); NodeDef ndef; ndef.set_name(g->NewName(kNodeLabel)); ndef.set_op("Identity"); ndef.add_input(input.name()); - AddNodeAttr("T", input.dtype(), &ndef); + AddNodeAttr("T", BaseType(input.dtype()), &ndef); Status s; Node* ret = g->AddNode(ndef, &s); TF_CHECK_OK(s); @@ -183,6 +182,10 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { string DebugString(Handle h) override; + Status Clone(std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) override; + private: typedef FunctionLibraryRuntimeImpl ME; @@ -205,7 +208,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { // The instantiated and transformed function is encoded as a Graph // object, and an executor is created for the graph. struct Item : public core::RefCounted { - const Graph* graph = nullptr; // Owned by exec. + const Graph* graph = nullptr; // Owned by exec. const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned. FunctionBody* func_graph = nullptr; Executor* exec = nullptr; @@ -628,7 +631,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { }; Graph* graph = g.get(); Executor* exec; - TF_RETURN_IF_ERROR(NewLocalExecutor(params, g.release(), &exec)); + TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &exec)); { // Guard item since it is already inserted in items_. @@ -895,6 +898,21 @@ string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { } } +Status FunctionLibraryRuntimeImpl::Clone( + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) { + TF_RETURN_IF_ERROR( + parent_->Clone(env_, graph_def_version_, optimizer_.options(), + custom_kernel_creator_, out_lib_def, out_pflr)); + *out_flr = (*out_pflr)->GetFLR(device_->name()); + if (out_flr != nullptr) { + return Status::OK(); + } else { + return errors::Internal("Cloning FunctionLibraryRuntime failed."); + } +} + namespace { struct CustomCreatorSingleton { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index cad3b3801e74a00a9f6fb6b236842f5caeaf72bc..63ad0d231c28a5af144b61e967a73e8ecfe6049a 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -71,11 +71,11 @@ class FunctionTest : public ::testing::Test { arg_types_ = result.arg_types; ret_types_ = result.ret_types; - Graph* g = new Graph(OpRegistry::Global()); + std::unique_ptr g(new Graph(OpRegistry::Global())); GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; - TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g)); + TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get())); const int version = g->versions().producer(); LocalExecutorParams params; @@ -89,7 +89,7 @@ class FunctionTest : public ::testing::Test { DeleteNonCachedKernel(kernel); }; Executor* exec; - TF_CHECK_OK(NewLocalExecutor(params, g, &exec)); + TF_CHECK_OK(NewLocalExecutor(params, std::move(g), &exec)); exec_.reset(exec); } @@ -787,7 +787,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto x4_x2_scale = ops::Const( - s.WithOpName("x4/x2/scale/_15__cf__9") + s.WithOpName("x4/x2/scale/_12__cf__6") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale); @@ -993,13 +993,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1); auto scale = ops::Const( - s.WithOpName("scale/_5__cf__10") + s.WithOpName("scale/_6__cf__11") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale); auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x); auto const0 = ops::Const( - s.WithOpName("Func/_1/sy/_6__cf__11") + s.WithOpName("Func/_1/sy/_5__cf__10") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 0, {0}); auto func1_rx = ops::internal::BroadcastGradientArgs( diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc index 9e4b617d2bd5b070f5b8bdeedabb15b94d212743..67caeb3495c6b0600f12c9b20ef73ee90f8b3e0d 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -154,8 +154,9 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) { a.DeallocateRaw(t3); a.DeallocateRaw(t4); } - CheckStats(&a, 4097, 0, 1024 * sizeof(float) + 1048576 * sizeof(int64) + - 2048 * sizeof(double) + 10485760 * sizeof(float), + CheckStats(&a, 4097, 0, + 1024 * sizeof(float) + 1048576 * sizeof(int64) + + 2048 * sizeof(double) + 10485760 * sizeof(float), 10485760 * sizeof(float)); // At the end, we should have coalesced all memory into one region diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc index cd29a5c50b6d4d6e9b36ad627fe72d855bde1372..63ed0b8be16ecb187113311db5283c8d4f3b1a5e 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc @@ -121,18 +121,20 @@ void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) { bool GPUDebugAllocator::TracksAllocationSizes() { return true; } -size_t GPUDebugAllocator::RequestedSize(void* ptr) { - auto req_size = - base_allocator_->RequestedSize(static_cast(ptr) - MASK_BYTES); +size_t GPUDebugAllocator::RequestedSize(const void* ptr) { + auto req_size = base_allocator_->RequestedSize(static_cast(ptr) - + MASK_BYTES); return req_size - 2 * MASK_BYTES; } -size_t GPUDebugAllocator::AllocatedSize(void* ptr) { - return base_allocator_->AllocatedSize(static_cast(ptr) - MASK_BYTES); +size_t GPUDebugAllocator::AllocatedSize(const void* ptr) { + return base_allocator_->AllocatedSize(static_cast(ptr) - + MASK_BYTES); } -int64 GPUDebugAllocator::AllocationId(void* ptr) { - return base_allocator_->AllocationId(static_cast(ptr) - MASK_BYTES); +int64 GPUDebugAllocator::AllocationId(const void* ptr) { + return base_allocator_->AllocationId(static_cast(ptr) - + MASK_BYTES); } void GPUDebugAllocator::GetStats(AllocatorStats* stats) { @@ -201,11 +203,11 @@ void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) { return base_allocator_->AddFreeVisitor(visitor); } -size_t GPUNanResetAllocator::RequestedSize(void* ptr) { +size_t GPUNanResetAllocator::RequestedSize(const void* ptr) { return base_allocator_->RequestedSize(ptr); } -size_t GPUNanResetAllocator::AllocatedSize(void* ptr) { +size_t GPUNanResetAllocator::AllocatedSize(const void* ptr) { return base_allocator_->AllocatedSize(ptr); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h index 139fa2847e5e4e9b114e5289572da68419d002c7..adce3a84368ced958002443721016778cb6df028 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h @@ -42,9 +42,9 @@ class GPUDebugAllocator : public VisitableAllocator { void AddAllocVisitor(Visitor visitor) override; void AddFreeVisitor(Visitor visitor) override; bool TracksAllocationSizes() override; - size_t RequestedSize(void* ptr) override; - size_t AllocatedSize(void* ptr) override; - int64 AllocationId(void* ptr) override; + size_t RequestedSize(const void* ptr) override; + size_t AllocatedSize(const void* ptr) override; + int64 AllocationId(const void* ptr) override; void GetStats(AllocatorStats* stats) override; void ClearStats() override; @@ -73,8 +73,8 @@ class GPUNanResetAllocator : public VisitableAllocator { void DeallocateRaw(void* ptr) override; void AddAllocVisitor(Visitor visitor) override; void AddFreeVisitor(Visitor visitor) override; - size_t RequestedSize(void* ptr) override; - size_t AllocatedSize(void* ptr) override; + size_t RequestedSize(const void* ptr) override; + size_t AllocatedSize(const void* ptr) override; void GetStats(AllocatorStats* stats) override; void ClearStats() override; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 933d700f6042bf51f11f773d731cece6ef5af436..15ff15fd5ab28605c4ab0904e62305edc3815adb 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h" #include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h" @@ -65,6 +66,10 @@ limitations under the License. #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/stream_executor_util.h" +#if !defined(PLATFORM_GOOGLE) +#include "cuda/cuda_config.h" +#endif + namespace tensorflow { // Eigen Ops directly allocate memory only for temporary buffers used @@ -99,7 +104,7 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface { reinterpret_cast(scratch + Eigen::kCudaScratchSize); stream_ = cuda_stream; allocator_ = alloc; - const int cuda_gpu_id = GpuIdUtil::TfToCudaGpuId(tf_gpu_id).value(); + const int cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id).value(); device_prop_ = &Eigen::m_deviceProperties[cuda_gpu_id]; } @@ -311,7 +316,7 @@ Status BaseGPUDevice::Init(const SessionOptions& options) { gpu_device_info_->stream = streams_[0]->compute; gpu_device_info_->default_context = device_contexts_[0]; gpu_device_info_->event_mgr = em_.get(); - gpu_device_info_->gpu_id = GpuIdUtil::TfToCudaGpuId(tf_gpu_id_).value(); + gpu_device_info_->gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id_).value(); set_tensorflow_gpu_device_info(gpu_device_info_); // Whether and how the GPU device uses its own threadpool. @@ -762,9 +767,11 @@ int64 MinSystemMemory(int64 available_memory) { // is necessary. min_system_memory *= 2; #endif + #if defined(ANDROID_TEGRA) - // 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM and Video RAM - min_system_memory = 1<<30; + // 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM + // and Video RAM + min_system_memory = 1 << 30; #endif return min_system_memory; } @@ -833,6 +840,9 @@ void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context, } } +const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000; +const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1; + Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { @@ -892,6 +902,35 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, } } + std::vector interconnect_maps; + TF_RETURN_IF_ERROR( + GetInterconnectMaps(visible_gpu_order, gpu_manager, &interconnect_maps)); + + // Print each interconnect map to the log. + for (const InterconnectMap& im : interconnect_maps) { + LOG(INFO) << "Device interconnect " << im.name << " with strength " + << im.strength << " edge matrix:"; + string line_buf = " "; + for (int i = 0; i < visible_gpu_order.size(); ++i) { + strings::StrAppend(&line_buf, visible_gpu_order[i].value(), " "); + } + LOG(INFO) << line_buf; + for (int i = 0; i < visible_gpu_order.size(); ++i) { + line_buf = strings::StrCat(visible_gpu_order[i].value(), ": "); + CudaGpuId cuda_id_i = visible_gpu_order[i]; + for (int j = 0; j < visible_gpu_order.size(); ++j) { + CudaGpuId cuda_id_j = visible_gpu_order[j]; + if (im.directed_links.find({cuda_id_i, cuda_id_j}) != + im.directed_links.end()) { + line_buf.append("Y "); + } else { + line_buf.append("N "); + } + } + LOG(INFO) << line_buf; + } + } + const auto& virtual_devices = gpu_options.experimental().virtual_devices(); if (!virtual_devices.empty()) { TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings( @@ -902,9 +941,9 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, valid_cuda_gpu_ids == visible_gpu_order); } int next_tf_gpu_id = 0; + std::vector memory_limit_bytes; for (int i = 0; i < num_gpus_to_use; ++i) { const CudaGpuId cuda_gpu_id = valid_cuda_gpu_ids[i]; - std::vector memory_limit_bytes; if (virtual_devices.empty() || virtual_devices.Get(i).memory_limit_mb_size() == 0) { int64 single_virtual_device_memory_limit = 0; @@ -918,14 +957,31 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, return static_cast(mb) * (1ll << 20); }); } - for (int64 bytes : memory_limit_bytes) { + while (next_tf_gpu_id < memory_limit_bytes.size()) { TfGpuId tf_gpu_id(next_tf_gpu_id); ++next_tf_gpu_id; - GpuIdUtil::InsertTfCudaGpuIdPair(tf_gpu_id, cuda_gpu_id); - TF_RETURN_IF_ERROR( - CreateGPUDevice(options, name_prefix, tf_gpu_id, bytes, devices)); + GpuIdManager::InsertTfCudaGpuIdPair(tf_gpu_id, cuda_gpu_id); } } + const int num_tf_gpus = next_tf_gpu_id; + + LocalityMap device_localities; + TF_RETURN_IF_ERROR( + GetDeviceLocalities(num_tf_gpus, interconnect_maps, &device_localities)); + + // Build the GPUDevices + CHECK_EQ(next_tf_gpu_id, memory_limit_bytes.size()); + for (int di = 0; di < num_tf_gpus; ++di) { + TfGpuId tf_gpu_id(di); + int64 bytes = memory_limit_bytes[di]; + auto it = device_localities.find(tf_gpu_id); + if (it == device_localities.end()) { + return errors::Internal("Failed to find DeviceLocality for GPU device ", + tf_gpu_id.value()); + } + TF_RETURN_IF_ERROR(CreateGPUDevice(options, name_prefix, tf_gpu_id, bytes, + it->second, devices)); + } return Status::OK(); } @@ -949,41 +1005,19 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options, const string& name_prefix, TfGpuId tf_gpu_id, int64 memory_limit, + const DeviceLocality& dev_locality, std::vector* devices) { CHECK_GE(tf_gpu_id.value(), 0); const string device_name = strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value()); - - // Look up the device, to see its attributes. GpuIdUtil::CheckValidTfGpuId(tf_gpu_id); - gpu::StreamExecutor* se = - GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie(); - const gpu::DeviceDescription& desc = se->GetDeviceDescription(); - int numa_node = desc.numa_node(); - if (numa_node < 0) { - // For some reason the StreamExecutor couldn't get the NUMA - // affinity of the GPU. If this is not a multi-socket mobo with - // GPUs local to different buses, it doesn't matter. If it is, we - // may run into trouble later with data transfer operations. The - // trouble may manifest as slower than expected performance, or - // outright failures. - LOG(INFO) << "Could not identify NUMA node of " << device_name - << ", defaulting to 0. Your kernel may not have been built " - << "with NUMA support."; - numa_node = 0; - } + CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id); + int numa_node = dev_locality.numa_node(); Bytes allocated_bytes = static_cast(memory_limit); - // Get GPU bus_id from its reported NUMA affinity. Because GPUs are - // virtualized in some environments, we can't just use the GPU id. - // NUMA locales are indexed from 0, buses are indexed from 1. - DeviceLocality dev_locality; - dev_locality.set_bus_id(numa_node + 1); - const CudaGpuId cuda_gpu_id = GpuIdUtil::TfToCudaGpuId(tf_gpu_id); - VLOG(1) << "GPUDevice id " << cuda_gpu_id << " on bus " - << dev_locality.bus_id() << " numa: " << numa_node - << " pci: " << desc.pci_bus_id(); - + gpu::StreamExecutor* se = + GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + const gpu::DeviceDescription& desc = se->GetDeviceDescription(); LOG(INFO) << "Creating TensorFlow device (" << device_name << " with " << (memory_limit >> 20) << " MB memory) -> physical GPU (" << GetShortDeviceDescription(cuda_gpu_id, desc) << ")"; @@ -1000,6 +1034,116 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options, return Status::OK(); } +namespace { +std::unique_ptr, bool>> +GetPeerAccessMap(gpu::Platform* platform, + const std::vector& visible_gpu_order) { + std::unique_ptr, bool>> map( + new std::map, bool>); + for (CudaGpuId cuda_gpu_i : visible_gpu_order) { + for (CudaGpuId cuda_gpu_j : visible_gpu_order) { + gpu::StreamExecutor* from = + GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie(); + gpu::StreamExecutor* to = + GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie(); + (*map)[{cuda_gpu_i, cuda_gpu_j}] = from->CanEnablePeerAccessTo(to); + } + } + + return map; +} + +} // namespace + +Status BaseGPUDeviceFactory::GetInterconnectMaps( + const std::vector& visible_gpu_order, gpu::Platform* gpu_manager, + std::vector* maps) { + // The default interconnect map is obtained from the StreamExecutor. + auto access_map = GetPeerAccessMap(gpu_manager, visible_gpu_order); + maps->resize(1); + InterconnectMap& imap = maps->at(0); + imap.name = "StreamExecutor"; + imap.strength = InterconnectMap::kStreamExecutorStrength; + for (CudaGpuId cuda_id_i : visible_gpu_order) { + for (CudaGpuId cuda_id_j : visible_gpu_order) { + if (cuda_id_i == cuda_id_j) continue; + if ((*access_map)[{cuda_id_i, cuda_id_j}]) { + imap.directed_links.insert({cuda_id_i, cuda_id_j}); + } + } + } + return Status::OK(); +} + +Status BaseGPUDeviceFactory::GetDeviceLocalities( + int num_tf_gpus, const std::vector& interconnects, + LocalityMap* localities) { + std::vector all_tf_gpu_ids; + for (int i = 0; i < num_tf_gpus; ++i) { + all_tf_gpu_ids.push_back(TfGpuId(i)); + } + for (TfGpuId tf_gpu_id : all_tf_gpu_ids) { + CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id); + // Get GPU bus_id from its reported NUMA affinity. Because GPUs are + // virtualized in some environments, we can't just use the GPU id. + // NUMA locales are indexed from 0, buses are indexed from 1. + gpu::StreamExecutor* se = + GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + const gpu::DeviceDescription& desc = se->GetDeviceDescription(); + int numa_node = desc.numa_node(); + if (numa_node < 0) { + // For some reason the StreamExecutor couldn't get the NUMA + // affinity of the GPU. If this is not a multi-socket mobo with + // GPUs local to different buses, it doesn't matter. If it is, we + // may run into trouble later with data transfer operations. The + // trouble may manifest as slower than expected performance, or + // outright failures. + LOG(INFO) << "Could not identify NUMA node of CUDA gpu id " << cuda_gpu_id + << ", defaulting to 0. Your kernel may not have been built " + << "with NUMA support."; + numa_node = 0; + } + DeviceLocality dev_locality; + dev_locality.set_numa_node(numa_node); + dev_locality.set_bus_id(numa_node + 1); + + // Set LocalLinks from InterconnectMaps. + LocalLinks* links = dev_locality.mutable_links(); + for (const InterconnectMap& imap : interconnects) { + for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) { + CudaGpuId cuda_gpu_dst = GpuIdManager::TfToCudaGpuId(tf_gpu_dst); + if (imap.directed_links.find({cuda_gpu_id, cuda_gpu_dst}) != + imap.directed_links.end()) { + InterconnectLink* ilink = links->add_link(); + ilink->set_device_id(tf_gpu_dst.value()); + ilink->set_type(imap.name); + ilink->set_strength(imap.strength); + } + } + } + + // If this is one of multiple virtual GPUs on the same physical GPU + // add high strength links to the others. + for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) { + if (tf_gpu_id == tf_gpu_dst) continue; + CudaGpuId cuda_gpu_dst = GpuIdManager::TfToCudaGpuId(tf_gpu_dst); + if (cuda_gpu_id == cuda_gpu_dst) { + InterconnectLink* ilink = links->add_link(); + ilink->set_device_id(tf_gpu_dst.value()); + ilink->set_type("SAME_DEVICE"); + ilink->set_strength(InterconnectMap::kSameDeviceStrength); + } + } + + (*localities)[tf_gpu_id] = dev_locality; + VLOG(1) << "GPUDevice CudaGpuId " << cuda_gpu_id << " TfGpuId " << tf_gpu_id + << " on bus " << dev_locality.bus_id() << " numa: " << numa_node + << " pci: " << desc.pci_bus_id() + << " DeviceLocality: " << dev_locality.DebugString(); + } + return Status::OK(); +} + static int GetDefaultMinGPUMultiprocessorCount( gpu::Platform* gpu_manager, const std::vector& visible_gpu_order) { @@ -1105,38 +1249,19 @@ std::vector GetSupportedCudaComputeCapabilities() { return cuda_caps; } -std::unique_ptr, bool>> GetPeerAccessMap( - gpu::Platform* platform, const std::vector& visible_gpu_order) { - std::unique_ptr, bool>> map( - new std::map, bool>); - for (int i = 0; i < visible_gpu_order.size(); ++i) { - const CudaGpuId i_gpu_id = visible_gpu_order[i]; - for (int j = 0; j < visible_gpu_order.size(); ++j) { - const CudaGpuId j_gpu_id = visible_gpu_order[j]; - gpu::StreamExecutor* from = - GpuIdUtil::ExecutorForCudaGpuId(platform, i_gpu_id).ValueOrDie(); - gpu::StreamExecutor* to = - GpuIdUtil::ExecutorForCudaGpuId(platform, j_gpu_id).ValueOrDie(); - (*map)[{i, j}] = from->CanEnablePeerAccessTo(to); - } - } - - return map; -} - Status EnablePeerAccess(gpu::Platform* platform, const std::vector& visible_gpu_order) { int possible_peer_count = 0; int enabled_peer_count = 0; for (int i = 0; i < visible_gpu_order.size(); ++i) { - const CudaGpuId i_gpu_id = visible_gpu_order[i]; + const CudaGpuId cuda_gpu_i = visible_gpu_order[i]; for (int j = 0; j < visible_gpu_order.size(); ++j) { - const CudaGpuId j_gpu_id = visible_gpu_order[j]; + const CudaGpuId cuda_gpu_j = visible_gpu_order[j]; // We have already validated that ExecutorForDevice() calls return OK. gpu::StreamExecutor* from = - GpuIdUtil::ExecutorForCudaGpuId(platform, i_gpu_id).ValueOrDie(); + GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie(); gpu::StreamExecutor* to = - GpuIdUtil::ExecutorForCudaGpuId(platform, j_gpu_id).ValueOrDie(); + GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie(); if (from->CanEnablePeerAccessTo(to)) { ++possible_peer_count; @@ -1144,7 +1269,7 @@ Status EnablePeerAccess(gpu::Platform* platform, if (!status.ok()) { LOG(WARNING) << "Unable to enable peer access between device ordinals " - << i_gpu_id << " and " << j_gpu_id << ", status: " << status; + << cuda_gpu_i << " and " << cuda_gpu_j << ", status: " << status; } else { ++enabled_peer_count; } @@ -1215,27 +1340,6 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds( if (new_gpu_found && visible_gpu_order.size() > 1) { // Enable peer access TF_RETURN_IF_ERROR(EnablePeerAccess(gpu_manager, visible_gpu_order)); - - // Print out a matrix showing which devices can DMA to one - // another. - LOG(INFO) << "Device peer to peer matrix"; - auto access_map = GetPeerAccessMap(gpu_manager, visible_gpu_order); - string line_buf = "DMA: "; - for (int i = 0; i < visible_gpu_order.size(); ++i) { - strings::StrAppend(&line_buf, visible_gpu_order[i].value(), " "); - } - LOG(INFO) << line_buf; - for (int i = 0; i < visible_gpu_order.size(); ++i) { - line_buf = strings::StrCat(visible_gpu_order[i].value(), ": "); - for (int j = 0; j < visible_gpu_order.size(); ++j) { - if ((*access_map)[{i, j}]) { - line_buf.append("Y "); - } else { - line_buf.append("N "); - } - } - LOG(INFO) << line_buf; - } } auto cuda_supported_capabilities = GetSupportedCudaComputeCapabilities(); diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h index 41e60b4884673673f2e791cbbafa4ef0091bdf8f..c88daa8ff87589a3fc48f4c7693d073d6adf9a5a 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.h +++ b/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -88,7 +89,7 @@ class BaseGPUDevice : public LocalDevice { // Returns the CUDA GPU id of this device within the native driver system; // e.g., for CUDA this is the ordinal of the GPU within the system. - int gpu_id() const { return GpuIdUtil::TfToCudaGpuId(tf_gpu_id_).value(); } + int gpu_id() const { return GpuIdManager::TfToCudaGpuId(tf_gpu_id_).value(); } // The executor that provides control for the device; e.g., for CUDA this // corresponds to the cuda context. @@ -140,17 +141,50 @@ class BaseGPUDeviceFactory : public DeviceFactory { Status CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) override; + struct InterconnectMap { + // Name of interconnect technology, if known. + string name; + // If possible, strength should approximate Gb/sec bandwidth rate. + // Where architecture-specific subclassing is not done that won't + // always be possible. The minimum expectation is that + // faster links should have a higher value than slower links. + int32 strength; + static const int kSameDeviceStrength; + static const int kStreamExecutorStrength; + std::set> directed_links; + }; + + protected: + // Populates *maps with interconnect maps for all local direct access + // pathways between GPUs. + virtual Status GetInterconnectMaps( + const std::vector& visible_gpu_order, + gpu::Platform* gpu_manager, std::vector* maps); + + struct TfGpuIdHash { + std::size_t operator()(const TfGpuId& id) const noexcept { + return std::hash{}(id.value()); + } + }; + typedef std::unordered_map LocalityMap; + // Populates *localities with the DeviceLocality descriptor for + // every TfGpuId. + virtual Status GetDeviceLocalities( + int num_tf_gpus, const std::vector& interconnects, + LocalityMap* localities); + private: // Creates a BaseGPUDevice associated with 'tf_gpu_id', allocates (strictly) // 'memory_limit' bytes of GPU memory to it, and adds it to the 'devices' // vector. Status CreateGPUDevice(const SessionOptions& options, const string& name_prefix, TfGpuId tf_gpu_id, - int64 memory_limit, std::vector* devices); + int64 memory_limit, const DeviceLocality& dev_locality, + std::vector* devices); virtual BaseGPUDevice* CreateGPUDevice(const SessionOptions& options, const string& name, Bytes memory_limit, - const DeviceLocality& locality, + const DeviceLocality& dev_locality, TfGpuId tf_gpu_id, const string& physical_device_desc, Allocator* gpu_allocator, diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc index ff46be9c015ac3d0ad59e302f53d52c4bd3e25ea..b56823204afe8ee52e0ea376b1a79d91d6932fa0 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc @@ -180,6 +180,18 @@ TEST(GPUDeviceTest, MultipleVirtualDevices) { EXPECT_EQ(2, devices.size()); EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit()); EXPECT_EQ(456 << 20, devices[1]->attributes().memory_limit()); + ASSERT_EQ(1, devices[0]->attributes().locality().links().link_size()); + ASSERT_EQ(1, devices[1]->attributes().locality().links().link_size()); + EXPECT_EQ(1, devices[0]->attributes().locality().links().link(0).device_id()); + EXPECT_EQ("SAME_DEVICE", + devices[0]->attributes().locality().links().link(0).type()); + EXPECT_EQ(BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength, + devices[0]->attributes().locality().links().link(0).strength()); + EXPECT_EQ(0, devices[1]->attributes().locality().links().link(0).device_id()); + EXPECT_EQ("SAME_DEVICE", + devices[1]->attributes().locality().links().link(0).type()); + EXPECT_EQ(BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength, + devices[1]->attributes().locality().links().link(0).strength()); for (auto d : devices) delete d; } diff --git a/tensorflow/core/common_runtime/gpu/gpu_id.h b/tensorflow/core/common_runtime/gpu/gpu_id.h index 4e9c4abce1264d0533c10c1d4dcfcc3f1455e727..2a6caea2967dcd0a1d3d6550aa428a882408ea17 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id.h @@ -40,7 +40,7 @@ namespace tensorflow { // a BaseGPUDevice. Note that the configuration allows us to create multiple // BaseGPUDevice per GPU hardware in order to use multi CUDA streams on the // hardware, so the mapping between TF GPU id and CUDA GPU id is not a 1:1 -// mappping, see the example below. +// mapping, see the example below. // // For example, assuming that in the machine we have GPU device with index 0, 1, // 2 and 3 (physical GPU id). Setting "CUDA_VISIBLE_DEVICES=1,2,3" will create diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc similarity index 79% rename from tensorflow/core/common_runtime/gpu/gpu_id_utils.cc rename to tensorflow/core/common_runtime/gpu/gpu_id_manager.cc index 92cd19453f14c886c0d105a5c1809b7fdbcafc9b..207afdca75642b14c1617c8abae4fd5e9916f020 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include @@ -24,10 +24,10 @@ limitations under the License. namespace tensorflow { namespace { // Manages the map between TfGpuId and CUDA GPU id. -class GpuIdManager { +class TfToCudaGpuIdMap { public: - static GpuIdManager* singleton() { - static auto* manager = new GpuIdManager; + static TfToCudaGpuIdMap* singleton() { + static auto* manager = new TfToCudaGpuIdMap; return manager; } @@ -62,13 +62,13 @@ class GpuIdManager { }; } // namespace -void GpuIdUtil::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, - CudaGpuId cuda_gpu_id) { - GpuIdManager::singleton()->InsertOrDie(tf_gpu_id, cuda_gpu_id); +void GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, + CudaGpuId cuda_gpu_id) { + TfToCudaGpuIdMap::singleton()->InsertOrDie(tf_gpu_id, cuda_gpu_id); } -CudaGpuId GpuIdUtil::TfToCudaGpuId(TfGpuId tf_gpu_id) { - return CudaGpuId(GpuIdManager::singleton()->FindOrDie(tf_gpu_id)); +CudaGpuId GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id) { + return CudaGpuId(TfToCudaGpuIdMap::singleton()->FindOrDie(tf_gpu_id)); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..33925d8c36f44a9d2c7abc8f2801f3f203bcb982 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h @@ -0,0 +1,33 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_MANAGER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_MANAGER_H_ + +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" + +namespace tensorflow { + +// Class that manages the translation between Tensorflow GPU ids and CUDA GPU +// ids. +class GpuIdManager { + public: + static void InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id); + static CudaGpuId TfToCudaGpuId(TfGpuId tf_gpu_id); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_MANAGER_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils_test.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc similarity index 67% rename from tensorflow/core/common_runtime/gpu/gpu_id_utils_test.cc rename to tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc index bebe00a4317becdba1fc6146b4eb188b93933fff..bdbd8d065b398159305504202ed342c08cc3ee7d 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_utils_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/platform/test.h" @@ -21,33 +21,33 @@ limitations under the License. namespace tensorflow { namespace test { -TEST(GpuIdTest, Basics) { +TEST(GpuIdManagerTest, Basics) { TfGpuId key_0(0); CudaGpuId value_0(0); - GpuIdUtil::InsertTfCudaGpuIdPair(key_0, value_0); - EXPECT_EQ(value_0, GpuIdUtil::TfToCudaGpuId(key_0)); + GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0); + EXPECT_EQ(value_0, GpuIdManager::TfToCudaGpuId(key_0)); // Multiple calls to map the same value is ok. - GpuIdUtil::InsertTfCudaGpuIdPair(key_0, value_0); - EXPECT_EQ(value_0, GpuIdUtil::TfToCudaGpuId(key_0)); + GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0); + EXPECT_EQ(value_0, GpuIdManager::TfToCudaGpuId(key_0)); // Map a different TfGpuId to a different value. TfGpuId key_1(3); CudaGpuId value_1(2); - GpuIdUtil::InsertTfCudaGpuIdPair(key_1, value_1); - EXPECT_EQ(value_1, GpuIdUtil::TfToCudaGpuId(key_1)); + GpuIdManager::InsertTfCudaGpuIdPair(key_1, value_1); + EXPECT_EQ(value_1, GpuIdManager::TfToCudaGpuId(key_1)); // Mapping a different TfGpuId to the same value is ok. TfGpuId key_2(10); - GpuIdUtil::InsertTfCudaGpuIdPair(key_2, value_1); - EXPECT_EQ(value_1, GpuIdUtil::TfToCudaGpuId(key_2)); + GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_1); + EXPECT_EQ(value_1, GpuIdManager::TfToCudaGpuId(key_2)); // Mapping the same TfGpuId to a different value will crash the program. - ASSERT_DEATH(GpuIdUtil::InsertTfCudaGpuIdPair(key_2, value_0), + ASSERT_DEATH(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_0), "Mapping the same TfGpuId to a different CUDA GPU id"); // Getting an nonexistent mapping will crash the program. - ASSERT_DEATH(GpuIdUtil::TfToCudaGpuId(TfGpuId(100)), + ASSERT_DEATH(GpuIdManager::TfToCudaGpuId(TfGpuId(100)), "Could not find the mapping for TfGpuId"); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h index 6d196b16eddfb4b77db97cd098538a7e1f7cae5b..2e90687fe8854460dc2ec683d8587ab2ceadf42e 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_ #include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/lib/gtl/int_type.h" #include "tensorflow/core/platform/stream_executor.h" @@ -27,9 +28,6 @@ namespace gpu = ::perftools::gputools; // Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids. class GpuIdUtil { public: - static void InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id); - static CudaGpuId TfToCudaGpuId(TfGpuId tf_gpu_id); - // Convenient methods for getting the associated executor given a TfGpuId or // CudaGpuId. static gpu::port::StatusOr ExecutorForCudaGpuId( @@ -42,12 +40,12 @@ class GpuIdUtil { } static gpu::port::StatusOr ExecutorForTfGpuId( TfGpuId tf_gpu_id) { - return ExecutorForCudaGpuId(GpuIdUtil::TfToCudaGpuId(tf_gpu_id)); + return ExecutorForCudaGpuId(GpuIdManager::TfToCudaGpuId(tf_gpu_id)); } // Verify that the cuda_gpu_id associated with a TfGpuId is legitimate. static void CheckValidTfGpuId(TfGpuId tf_gpu_id) { - const CudaGpuId cuda_gpu_id = GpuIdUtil::TfToCudaGpuId(tf_gpu_id); + const CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id); const int visible_device_count = GPUMachineManager()->VisibleDeviceCount(); CHECK_LT(cuda_gpu_id.value(), visible_device_count) << "cuda_gpu_id is outside discovered device range." diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc index 7763a4f2e6f50292e78b4d16d8d4a3ee84d4163b..2500425359c424fa479af6dd34d6a0312c404577 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc @@ -108,7 +108,8 @@ TEST_F(GpuStreamUtilTest, StreamOverrides) { ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0, "/device:GPU:0"); Output n = ops::MatMul(root, {}, {}); - ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0, "/cpu:0"); + ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0, + "/cpu:0"); Graph g(OpRegistry::Global()); TF_ASSERT_OK(root.ToGraph(&g)); diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc index 995fd1253fb9f352742410199174b8567e92351b..61013bd1acd254b6e927a8d41accaeda424d6ebc 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.cc +++ b/tensorflow/core/common_runtime/gpu/process_state.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h" #include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h" #include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/common_runtime/gpu/pool_allocator.h" @@ -88,8 +89,8 @@ ProcessState::~ProcessState() { } string ProcessState::MemDesc::DebugString() { - return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, ", dma: ", - gpu_registered, ", nic: ", nic_registered); + return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, + ", dma: ", gpu_registered, ", nic: ", nic_registered); } ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { @@ -124,7 +125,7 @@ Allocator* ProcessState::GetGPUAllocator(const GPUOptions& options, return nullptr; } - const CudaGpuId cuda_gpu_id = GpuIdUtil::TfToCudaGpuId(tf_gpu_id); + const CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id); gpu_allocator = new GPUBFCAllocator(cuda_gpu_id, total_bytes, options, strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc")); @@ -230,8 +231,24 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) { // TODO(tucker): actually maintain separate CPUAllocators for // different numa_nodes. For now, just one. numa_node = 0; - mutex_lock lock(mu_); + { + // Here we optimize the most common use case where cuda_host_allocators_ + // and cuda_al_ have already been populated and since we're only reading + // these vectors, we can get by with a shared lock. In the slower case, + // we take a unique lock and populate these vectors. + tf_shared_lock lock(mu_); + + if (FLAGS_brain_gpu_record_mem_types && + static_cast(cuda_al_.size()) > 0) { + return cuda_al_[0]; + } + if (static_cast(cuda_host_allocators_.size()) > numa_node) { + return cuda_host_allocators_[0]; + } + } + + mutex_lock lock(mu_); // Find the first valid StreamExecutor to request CUDA host memory // through, since any will work. // diff --git a/tensorflow/core/common_runtime/gpu/process_state.h b/tensorflow/core/common_runtime/gpu/process_state.h index abe458f685b5425d3dc4c469a33251c2b531fb80..f6e234967306476542cec3038ea2e271cca2dc8c 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.h +++ b/tensorflow/core/common_runtime/gpu/process_state.h @@ -155,8 +155,8 @@ class RecordingAllocator : public Allocator { a_->DeallocateRaw(p); } bool TracksAllocationSizes() override { return a_->TracksAllocationSizes(); } - size_t RequestedSize(void* p) override { return a_->RequestedSize(p); } - size_t AllocatedSize(void* p) override { return a_->AllocatedSize(p); } + size_t RequestedSize(const void* p) override { return a_->RequestedSize(p); } + size_t AllocatedSize(const void* p) override { return a_->AllocatedSize(p); } void GetStats(AllocatorStats* stats) override { a_->GetStats(stats); } void ClearStats() override { a_->ClearStats(); } ProcessState::MDMap* mm_; // not owned diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 3b309e915cdd2c6d5eead9ed0312f3873bcf7335..33a5d60eb7ec4de829d3c0784f909ef42cf994d1 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -340,8 +340,11 @@ Status GraphExecutionState::OptimizeGraph( std::unordered_map device_map; Device* cpu_device = nullptr; for (const auto& device : device_set_->devices()) { - device_map[device->name()] = - grappler::GetDeviceInfo(device->parsed_name()); + DeviceProperties props = grappler::GetDeviceInfo(device->parsed_name()); + if (props.type() == "UNKNOWN") { + continue; + } + device_map[device->name()] = props; if (device->parsed_name().id == 0 && StringPiece(device->parsed_name().type) == "CPU" && device->GetAllocator(AllocatorAttributes()) != nullptr) { diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h index db2686ce2c45aa4c9997a624bb12720d63710b65..2312e1a89fd1fd5734fab4316c25ca2e39f16ae5 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.h +++ b/tensorflow/core/common_runtime/graph_execution_state.h @@ -139,9 +139,7 @@ class GraphExecutionState { // The graph returned by BuildGraph may contain only the pruned // graph, whereas some clients may want access to the full graph. - const Graph* full_graph() { - return graph_; - } + const Graph* full_graph() { return graph_; } // Returns the node with the given name, or null if it does not exist. const Node* get_node_by_name(const string& name) const { diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 8477cea126f1808d9472bd4f4127fd43e172848e..80246281cde373863e4da1bb8d86bee39bfb9dfd 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -52,6 +52,8 @@ class GraphOptimizer { shape_map, const std::function& cse_consider_fn = nullptr); + const OptimizerOptions& options() { return opts_; } + private: OptimizerOptions opts_; diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index a21304f7ef843706d564bd3f3a511324fd3189d6..f1082a60030fb3c289de35b4cab397c527f8afca 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -156,21 +156,21 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, // should not be running expensive operators. auto runner = [](Executor::Args::Closure c) { c(); }; - // Take ownership and pass to NewLocalExecutor - Graph* g = graph_to_run.release(); - LocalExecutorParams params; // The ownership of the output tensors are bound to this device's lifetime. params.device = cpu_device_.get(); params.function_library = function_library; - params.create_kernel = [this, g](const NodeDef& ndef, OpKernel** kernel) { - return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef, - g->versions().producer(), kernel); + const int producer = graph_to_run->versions().producer(); + params.create_kernel = [this, producer](const NodeDef& ndef, + OpKernel** kernel) { + return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef, producer, + kernel); }; params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; Executor* executor; - TF_RETURN_IF_ERROR(NewLocalExecutor(params, g, &executor)); + TF_RETURN_IF_ERROR( + NewLocalExecutor(params, std::move(graph_to_run), &executor)); std::unique_ptr executor_unref(executor); Executor::Args args; diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index 420dfe338efb473e36eb02a757fa957d15ba64df..64d884947568381eb2e5f60ab181b3c8c709d53b 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -39,6 +39,7 @@ limitations under the License. namespace tensorflow { namespace test { +// TODO(hongm): Convert `g` and `init` to using std::unique_ptr. Benchmark::Benchmark(const string& device, Graph* g, const SessionOptions* options, Graph* init, Rendezvous* rendez) { @@ -85,7 +86,8 @@ Benchmark::Benchmark(const string& device, Graph* g, if (init) { Executor* init_exec; - TF_CHECK_OK(NewLocalExecutor(params, init, &init_exec)); + TF_CHECK_OK( + NewLocalExecutor(params, std::unique_ptr(init), &init_exec)); Executor::Args args; args.rendezvous = rendez_; args.runner = runner; @@ -93,7 +95,7 @@ Benchmark::Benchmark(const string& device, Graph* g, delete init_exec; } - TF_CHECK_OK(NewLocalExecutor(params, g, &exec_)); + TF_CHECK_OK(NewLocalExecutor(params, std::unique_ptr(g), &exec_)); } Benchmark::~Benchmark() { diff --git a/tensorflow/core/common_runtime/memory_types.cc b/tensorflow/core/common_runtime/memory_types.cc index 76b926ba40053288360f0e4e6fe2a37bd44ff0b4..090a16ebeb10007261666aeb6491a1785dd2e5c4 100644 --- a/tensorflow/core/common_runtime/memory_types.cc +++ b/tensorflow/core/common_runtime/memory_types.cc @@ -47,7 +47,7 @@ struct EndpointEq { static Status ProcessMemoryTypes( const DeviceType& device_type, const Graph* g, const std::function& fn) { - if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL ) { + if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL) { // On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always // compatible. return Status::OK(); diff --git a/tensorflow/core/common_runtime/memory_types_test.cc b/tensorflow/core/common_runtime/memory_types_test.cc index 2a834ddca4236c626c6252f63c97118e8e1f0bd0..a093585571994e8b161b46a7fc397cdc3cd4254c 100644 --- a/tensorflow/core/common_runtime/memory_types_test.cc +++ b/tensorflow/core/common_runtime/memory_types_test.cc @@ -36,7 +36,7 @@ TEST(MemoryTypeChecker, Int32OK) { #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g)); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL delete g; } @@ -64,7 +64,7 @@ TEST(MemoryTypeChecker, Int32NotOk) { // But we can insert _HostSend/_HostRecv to ensure the invariant. TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_SYCL, "/device:SYCL:0", g)); TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g)); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL delete g; } @@ -91,7 +91,7 @@ TEST(MemoryTypeChecker, MemoryTypeForOutput) { TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type)); // int Switch's output on GPU has HOST_MEMORY constraint. EXPECT_EQ(memory_type, HOST_MEMORY); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL delete g; } diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index c7a2b616c7b4ebeecf7e8f00b45a30263b36dd40..71e0de9724680cfcc012ae04782b90b867e0095b 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -21,11 +21,10 @@ limitations under the License. #ifdef INTEL_MKL -#include #include #include #include "tensorflow/core/common_runtime/bfc_allocator.h" -#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/common_runtime/visitable_allocator.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mem.h" @@ -46,7 +45,7 @@ class MklSubAllocator : public SubAllocator { /// CPU allocator for MKL that wraps BFC allocator and intercepts /// and redirects memory allocation calls from MKL. -class MklCPUAllocator : public Allocator { +class MklCPUAllocator : public VisitableAllocator { public: // Constructor and other standard functions @@ -119,6 +118,14 @@ class MklCPUAllocator : public Allocator { void ClearStats() override { allocator_->ClearStats(); } + void AddAllocVisitor(Visitor visitor) override { + allocator_->AddAllocVisitor(visitor); + } + + void AddFreeVisitor(Visitor visitor) override { + allocator_->AddFreeVisitor(visitor); + } + private: // Hooks provided by this allocator for memory allocation routines from MKL @@ -153,7 +160,7 @@ class MklCPUAllocator : public Allocator { /// The alignment that we need for the allocations static const size_t kAlignment = 64; - Allocator* allocator_; // owned by this class + VisitableAllocator* allocator_; // owned by this class }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h index c5b76592e1b4b86863009ef42b7bb7106377d054..75dce7c7feb2269fc994cbb8c5efd4b3799e75dd 100644 --- a/tensorflow/core/common_runtime/placer.h +++ b/tensorflow/core/common_runtime/placer.h @@ -88,9 +88,9 @@ class Placer { void AssignAndLog(int assigned_device, Node* node) const; void LogDeviceAssignment(const Node* node) const; - Graph* const graph_; // Not owned. - const DeviceSet* const devices_; // Not owned. - const SessionOptions* options_; // Not owned. + Graph* const graph_; // Not owned. + const DeviceSet* const devices_; // Not owned. + const SessionOptions* options_; // Not owned. const bool log_device_placement_; TF_DISALLOW_COPY_AND_ASSIGN(Placer); diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 5d87b1e279ab0390a642df8f285fd451803ba29a..098024d2195aad8ef651120181ab271be168f92a 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -193,7 +194,7 @@ class PlacerTest : public ::testing::Test { // Builds the given graph, and (if successful) indexes the node // names for use in placement, and later lookup. Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) { - TF_RETURN_IF_ERROR(builder.ToGraph(out_graph)); + TF_RETURN_IF_ERROR(GraphDefBuilderToGraph(builder, out_graph)); nodes_by_name_.clear(); for (Node* node : out_graph->nodes()) { nodes_by_name_[node->name()] = node->id(); @@ -619,9 +620,9 @@ TEST_F(PlacerTest, TestReferenceConnectionIgnoreInfeasible) { Node* input = ops::SourceOp( "TestDevice", b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0")); - Node* var = ops::SourceOp("TestVariable", - b.opts().WithName("var_0").WithDevice( - "/job:a/task:0/device:fakegpu:0")); + Node* var = + ops::SourceOp("TestVariable", b.opts().WithName("var_0").WithDevice( + "/job:a/task:0/device:fakegpu:0")); // This op is specified on CPU, but in practice will be ignored, // because the reference edges forces it on GPU. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 12947e284a36fef171caf6af0c46d59ca89efb61..e205e34aa0f6afb1363d65bd23403d4b50f056eb 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -70,23 +70,6 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( } } -ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Env* env, int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options) - : ProcessFunctionLibraryRuntime(device_mgr, env, graph_def_version, lib_def, - optimizer_options, - nullptr /* cluster_flr */) {} - -ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Env* env, int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator) - : ProcessFunctionLibraryRuntime( - device_mgr, env, graph_def_version, lib_def, optimizer_options, - std::move(custom_kernel_creator), nullptr /* cluster_flr */) {} - /* static */ Status ProcessFunctionLibraryRuntime::SendTensors( const string& source_device, const string& target_device, @@ -158,7 +141,7 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( } FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( - const string& device_name) { + const string& device_name) const { Device* device = nullptr; if (device_name != kDefaultFLRDevice) { if (!device_mgr_->LookupDevice(device_name, &device).ok()) { @@ -263,7 +246,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseHandle( string target_device; { mutex_lock l(mu_); - CHECK_EQ(1, function_data_.count(handle)); + CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle; target_device = function_data_[handle].target_device; } flr = GetFLR(target_device); @@ -350,4 +333,16 @@ void ProcessFunctionLibraryRuntime::Run( done(errors::Internal("Could not find device")); } +Status ProcessFunctionLibraryRuntime::Clone( + Env* env, int graph_def_version, const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator, + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr) { + out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_)); + out_pflr->reset(new ProcessFunctionLibraryRuntime( + device_mgr_, env, graph_def_version, out_lib_def->get(), + optimizer_options, std::move(custom_kernel_creator), parent_)); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index a1adc4b6b35950339b727774c45014ef71839554..0473e16d242814930a9de17c88d4851d0d73edbe 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -29,12 +29,13 @@ class ProcessFunctionLibraryRuntime { // Creates FunctionLibraryRuntime objects for each device in the provided // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent // (if provided) outlive this object. - ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, - int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - DistributedFunctionLibraryRuntime* parent); + ProcessFunctionLibraryRuntime( + const DeviceMgr* device_mgr, Env* env, int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options, + DistributedFunctionLibraryRuntime* parent = nullptr); + // With `custom_kernel_creator`. ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, @@ -42,17 +43,6 @@ class ProcessFunctionLibraryRuntime { CustomKernelCreator custom_kernel_creator, DistributedFunctionLibraryRuntime* parent); - ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, - int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options); - - ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, - int graph_def_version, - const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator); - // Sends `tensors_to_send` from `source_device` to `target_device` using // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the // Rendezvous. `device_context` should be the DeviceContext of the device @@ -85,7 +75,7 @@ class ProcessFunctionLibraryRuntime { static const char kDefaultFLRDevice[]; // Returns the FunctionLibraryRuntime for the corresponding device_name. - FunctionLibraryRuntime* GetFLR(const string& device_name); + FunctionLibraryRuntime* GetFLR(const string& device_name) const; // Returns the device incarnation for the given device_name. Status GetDeviceIncarnation(const string& device_name, int64* incarnation); @@ -145,6 +135,12 @@ class ProcessFunctionLibraryRuntime { // Removes handle from the state owned by this object. Status RemoveHandle(FunctionLibraryRuntime::Handle handle); + Status Clone(Env* env, int graph_def_version, + const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator, + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr); + friend class FunctionLibraryRuntimeImpl; mutable mutex mu_; diff --git a/tensorflow/core/common_runtime/session_factory.cc b/tensorflow/core/common_runtime/session_factory.cc index 0234d4c37250d8ed3c645759dd17f94093e57df0..4dbe113e44ee0b7a6eba44ace3c1ff8daa17059f 100644 --- a/tensorflow/core/common_runtime/session_factory.cc +++ b/tensorflow/core/common_runtime/session_factory.cc @@ -60,8 +60,8 @@ const string RegisteredFactoriesErrorMessageLocked() { str_util::Join(factory_types, ", "), "}."); } string SessionOptionsToString(const SessionOptions& options) { - return strings::StrCat("target: \"", options.target, "\" config: ", - ProtoShortDebugString(options.config)); + return strings::StrCat("target: \"", options.target, + "\" config: ", ProtoShortDebugString(options.config)); } } // namespace diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index d7e01144c9ef3aa09ddd212947eafe48ccff555b..cb900db10af98496cfdfafa5a38296bfdc4e996b 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -226,22 +226,23 @@ void StepStatsCollector::BuildCostModel( if (node) { for (int i = 0; i < stats.output_size(); ++i) { const auto& output = stats.output(i); - cm->RecordMaxMemorySize(node, i, Bytes(output.tensor_description() - .allocation_description() - .allocated_bytes()), + cm->RecordMaxMemorySize(node, i, + Bytes(output.tensor_description() + .allocation_description() + .allocated_bytes()), stats.output(i).tensor_description().shape(), node->output_types()[i]); - cm->RecordAllocationId(node, i, output.tensor_description() - .allocation_description() - .allocation_id()); + cm->RecordAllocationId(node, i, + output.tensor_description() + .allocation_description() + .allocation_id()); } cm->RecordMemoryStats(node, stats.memory_stats()); // Use hardware stats to record the execution time if they're available, // otherwise use the regular (less accurate) stats string node_name = dev_stats.regular_stats->node_stats(i).node_name(); - if (dev_stats.hardware_stats && - name_to_hw_node_stats.find(node_name) != - name_to_hw_node_stats.end()) { + if (dev_stats.hardware_stats && name_to_hw_node_stats.find(node_name) != + name_to_hw_node_stats.end()) { const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name]; cm->RecordMaxExecutionTime( node, Microseconds(hw_stats.op_end_rel_micros())); diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.cc b/tensorflow/core/common_runtime/sycl/sycl_allocator.cc index 9094824ee734a9398db5aca2a507af4acd07c26b..02bd8b8f3bc692728ce73176f6268d95f860dc9b 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_allocator.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.cc @@ -80,7 +80,7 @@ void SYCLAllocator::ClearStats() override { size_t SYCLAllocator::RequestedSize(void* ptr) { mutex_lock lock(mu_); - if(!sycl_device_) { + if (!sycl_device_) { return 0; } const auto& buffer = sycl_device_->get_sycl_buffer(ptr); diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.h b/tensorflow/core/common_runtime/sycl/sycl_allocator.h index cca9f92c62e2a4f4d57c8a6111b53dccee505f93..550f1933322420fc97da2bb588c719c73ea5ae4d 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_allocator.h +++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.h @@ -20,10 +20,10 @@ limitations under the License. #ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ #define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -56,14 +56,13 @@ class SYCLAllocator : public Allocator { // Clear the SYCL device used by the Allocator void ClearSYCLDevice() { mutex_lock lock(mu_); - if(sycl_device_) { + if (sycl_device_) { delete sycl_device_; sycl_device_ = nullptr; } } private: - mutable mutex mu_; Eigen::SyclDevice* sycl_device_ GUARDED_BY(mu_); // owned AllocatorStats stats_ GUARDED_BY(mu_); diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.h b/tensorflow/core/common_runtime/sycl/sycl_device.h index cc272d156ef67a4f4f93f35603ffe301d154932a..7c09e0b8f194c7dc8a594aa487ec62e00d5b5e39 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.h +++ b/tensorflow/core/common_runtime/sycl/sycl_device.h @@ -187,9 +187,9 @@ class GSYCLInterface { type = "Unknown"; } - return strings::StrCat("id: ", device_id, ", type: ", type, ", name: ", - name.c_str(), ", vendor: ", vendor.c_str(), - ", profile: ", profile.c_str()); + return strings::StrCat( + "id: ", device_id, ", type: ", type, ", name: ", name.c_str(), + ", vendor: ", vendor.c_str(), ", profile: ", profile.c_str()); } }; diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc index 19c14770dcad7a3ca045ccb4ff68189c943d8cff..14f7727659d91db2373a1ac8ad0e46258cc32fbe 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc @@ -26,7 +26,6 @@ class SYCLDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions &options, const string &name_prefix, std::vector *devices) override { - auto syclInterface = GSYCLInterface::instance(); size_t n = 1; @@ -37,13 +36,11 @@ class SYCLDeviceFactory : public DeviceFactory { for (int i = 0; i < n; i++) { string name = strings::StrCat(name_prefix, "/device:SYCL:", i); - devices->push_back( - new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality() - , syclInterface->GetShortDeviceDescription(i) - , syclInterface->GetSYCLAllocator(i) - , syclInterface->GetCPUAllocator(i) - , syclInterface->GetSYCLContext(i)) - ); + devices->push_back(new SYCLDevice( + options, name, Bytes(256 << 20), DeviceLocality(), + syclInterface->GetShortDeviceDescription(i), + syclInterface->GetSYCLAllocator(i), syclInterface->GetCPUAllocator(i), + syclInterface->GetSYCLContext(i))); } return Status::OK(); @@ -51,6 +48,6 @@ class SYCLDeviceFactory : public DeviceFactory { }; REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200); -} +} // namespace tensorflow #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/common_runtime/sycl/sycl_util.h b/tensorflow/core/common_runtime/sycl/sycl_util.h index 83016b706a57033bfdaec932f763bc118434db90..3124ed23c92eb542e90e6c077fc703fb84b38a18 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_util.h +++ b/tensorflow/core/common_runtime/sycl/sycl_util.h @@ -20,8 +20,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ -#include "tensorflow/core/common_runtime/device.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/common_runtime/device.h" // For DMA helper #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index a32badef6dfdb8b62662da880c99842b1cafd13c..40cb8353cdccb4307f09b537ff7016e3dca5a8da 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -196,7 +196,10 @@ tf_cc_test( srcs = ["debug_gateway_test.cc"], args = ["--heap_check=local"], linkstatic = tf_kernel_tests_linkstatic(), - tags = ["no_gpu"], + tags = [ + "no_cuda_on_cpu_tap", + "no_gpu", + ], deps = [ ":debug", ":debug_gateway_internal", diff --git a/tensorflow/core/debug/debug_gateway.cc b/tensorflow/core/debug/debug_gateway.cc index 616ced3d0f3d9cfed683120e792b40eb9010fe06..2e1aabd1cc8066df6a5f7e6dd0aa27c6a16ef614 100644 --- a/tensorflow/core/debug/debug_gateway.cc +++ b/tensorflow/core/debug/debug_gateway.cc @@ -24,31 +24,31 @@ limitations under the License. namespace tensorflow { DebugGateway::DebugGateway(DirectSession* session) : session_(session) { - session_->node_outputs_callback_ = [this]( - const string& node_name, const int output_slot, const Tensor* tensor, - const bool is_ref, OpKernelContext* ctx) { - if (comp_cb_ != nullptr && output_slot <= 0) { - // The node completion callback is invoked once for a node regardless - // of whether the node has zero, one or more outputs. - // The output_slot can be negative (-1, or kControlSlot) if - // node_outputs_callback_ is invoked for a node with no output. If that - // is the case, notify the callback that the node in question has no - // output. - comp_cb_(node_name, output_slot == 0); - } - - // Copy tensor values (e.g., from GPU to host) only if the - // value callback is not nullptr. - if (val_cb_ != nullptr && output_slot >= 0) { - CopyTensor( - node_name, output_slot, tensor, ctx, - [this, node_name, output_slot, is_ref](const Tensor* copied_tensor) { - val_cb_(node_name, output_slot, *copied_tensor, is_ref); - }); - } - - return Status::OK(); - }; + session_->node_outputs_callback_ = + [this](const string& node_name, const int output_slot, + const Tensor* tensor, const bool is_ref, OpKernelContext* ctx) { + if (comp_cb_ != nullptr && output_slot <= 0) { + // The node completion callback is invoked once for a node regardless + // of whether the node has zero, one or more outputs. + // The output_slot can be negative (-1, or kControlSlot) if + // node_outputs_callback_ is invoked for a node with no output. If + // that is the case, notify the callback that the node in question has + // no output. + comp_cb_(node_name, output_slot == 0); + } + + // Copy tensor values (e.g., from GPU to host) only if the + // value callback is not nullptr. + if (val_cb_ != nullptr && output_slot >= 0) { + CopyTensor(node_name, output_slot, tensor, ctx, + [this, node_name, output_slot, + is_ref](const Tensor* copied_tensor) { + val_cb_(node_name, output_slot, *copied_tensor, is_ref); + }); + } + + return Status::OK(); + }; } DebugGateway::~DebugGateway() { @@ -86,7 +86,8 @@ void DebugGateway::CopyTensor(const string& node_name, const int output_slot, // Determine if the tensor is on device (GPU) or host (CPU). // The second part of the check is necessary because even an OpKernel on // may have output tensors allocated on CPU. - if ((device->name().find("GPU:") != string::npos || device->name().find("SYCL:") != string::npos) && + if ((device->name().find("GPU:") != string::npos || + device->name().find("SYCL:") != string::npos) && !ctx->output_alloc_attr(output_slot).on_host()) { // GPU tensors: Copy it to host (CPU). DeviceContext* device_ctxt = ctx->op_device_context(); diff --git a/tensorflow/core/debug/debug_gateway_test.cc b/tensorflow/core/debug/debug_gateway_test.cc index 57583349069a0b4deb137cb09564cdbb3909a4b0..b1bbd3f6980b16c13a1e5c9cd3a0f6c4bb8c1217 100644 --- a/tensorflow/core/debug/debug_gateway_test.cc +++ b/tensorflow/core/debug/debug_gateway_test.cc @@ -390,9 +390,9 @@ TEST_F(SessionDebugMinusAXTest, debug_gateway.SetNodeValueCallback( [this, &mu, &val_callback_count, &a_debug_identity_node_name, &x_debug_identity_node_name, &y_debug_identity_node_name, - &debug_identity_tensor_vals, &callbacks_done, &kConcurrentRuns]( - const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { + &debug_identity_tensor_vals, &callbacks_done, + &kConcurrentRuns](const string& node_name, const int output_slot, + const Tensor& tensor_value, const bool is_ref) { mutex_lock l(mu); if (node_name == a_debug_identity_node_name && output_slot == 0) { @@ -560,21 +560,21 @@ TEST_F(SessionDebugOutputSlotWithoutOutgoingEdgeTest, Notification callbacks_done; std::vector debug_identity_tensor_vals; - debug_gateway.SetNodeValueCallback([this, &mu, &callbacks_done, - &debug_identity_node_name, - &debug_identity_tensor_vals]( - const string& node_name, const int output_slot, - const Tensor& tensor_value, const bool is_ref) { - mutex_lock l(mu); + debug_gateway.SetNodeValueCallback( + [this, &mu, &callbacks_done, &debug_identity_node_name, + &debug_identity_tensor_vals]( + const string& node_name, const int output_slot, + const Tensor& tensor_value, const bool is_ref) { + mutex_lock l(mu); - if (node_name == debug_identity_node_name && output_slot == 0) { - debug_identity_tensor_vals.push_back(tensor_value); + if (node_name == debug_identity_node_name && output_slot == 0) { + debug_identity_tensor_vals.push_back(tensor_value); - if (!callbacks_done.HasBeenNotified()) { - callbacks_done.Notify(); - } - } - }); + if (!callbacks_done.HasBeenNotified()) { + callbacks_done.Notify(); + } + } + }); // Add DebugIdentity watch on c:0, which does not have an outgoing edge. RunOptions run_opts; diff --git a/tensorflow/core/debug/debug_grpc_testlib.cc b/tensorflow/core/debug/debug_grpc_testlib.cc index a312f789d8444360a0892faa4b3a0f9a0bdf7a32..f70931e926507c72287588da278a3b8d6bb19122 100644 --- a/tensorflow/core/debug/debug_grpc_testlib.cc +++ b/tensorflow/core/debug/debug_grpc_testlib.cc @@ -30,7 +30,7 @@ namespace test { ::grpc::Status TestEventListenerImpl::SendEvents( ::grpc::ServerContext* context, - ::grpc::ServerReaderWriter< ::tensorflow::EventReply, ::tensorflow::Event>* + ::grpc::ServerReaderWriter<::tensorflow::EventReply, ::tensorflow::Event>* stream) { Event event; diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index f81445c20bd2ba56a6d7d3bb4ddefc71f5199784..baa8c08fdf1508cd599d4c9523b06954280a609d 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -574,8 +574,6 @@ Status DebugIO::CloseDebugURL(const string& debug_url) { } } -static Status CloseDebugURL(const string& debug_url) { return Status::OK(); } - Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key, const Tensor& tensor, const uint64 wall_time_us, diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc index 2f83c2415b831cc1a2b90d4e6a2046218e6fe5f6..0807a85b8b39cf8bf479227bd6b6bd581e2ba9b0 100644 --- a/tensorflow/core/debug/debug_io_utils_test.cc +++ b/tensorflow/core/debug/debug_io_utils_test.cc @@ -57,7 +57,8 @@ class DebugIOUtilsTest : public ::testing::Test { TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) { DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/device:GPU:2", "hidden_1/MatMul", 0, "DebugIdentity"); - EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2", debug_node_key.device_name); + EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2", + debug_node_key.device_name); EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name); EXPECT_EQ(0, debug_node_key.output_slot); EXPECT_EQ("DebugIdentity", debug_node_key.debug_op); diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index f4ee841032bf2b78b70fd446a6e4679bd9c943f1..9e152aa0823b67fceb7f103cc6e090f00870f88a 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -145,6 +145,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:worker_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/executor_test.cc b/tensorflow/core/distributed_runtime/executor_test.cc index 5b115f9a4d4ea3e9b99228918e16fc354d5a99fe..e34224205bac48a2dba1bf8cb07f9c623cd38281 100644 --- a/tensorflow/core/distributed_runtime/executor_test.cc +++ b/tensorflow/core/distributed_runtime/executor_test.cc @@ -57,7 +57,7 @@ class ExecutorTest : public ::testing::Test { } // Resets executor_ with a new executor based on a graph 'gdef'. - void Create(const Graph* graph) { + void Create(std::unique_ptr graph) { const int version = graph->versions().producer(); LocalExecutorParams params; params.device = device_; @@ -69,7 +69,7 @@ class ExecutorTest : public ::testing::Test { DeleteNonCachedKernel(kernel); }; delete exec_; - TF_CHECK_OK(NewLocalExecutor(params, graph, &exec_)); + TF_CHECK_OK(NewLocalExecutor(params, std::move(graph), &exec_)); runner_ = [this](std::function fn) { thread_pool_->Schedule(fn); }; rendez_ = NewLocalRendezvous(); } @@ -144,12 +144,12 @@ Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, TEST_F(ExecutorTest, SimpleAdd) { // c = a + b - Graph* g = new Graph(OpRegistry::Global()); - auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB); - auto tmp = test::graph::Add(g, in0, in1); - test::graph::Send(g, tmp, "c", BOB, 1, ALICE); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB); + auto tmp = test::graph::Add(g.get(), in0, in1); + test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); + Create(std::move(g)); Rendezvous::Args args; TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); // in0 = 1.0 @@ -172,15 +172,15 @@ TEST_F(ExecutorTest, SelfAdd) { // // b <- v10 // All nodes are executed by one thread. - Graph* g = new Graph(OpRegistry::Global()); - auto v = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto v = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); const int N = 10; for (int i = 1; i <= N; ++i) { - v = test::graph::Add(g, v, v); + v = test::graph::Add(g.get(), v, v); } // out <- v10 - test::graph::Send(g, v, "b", BOB, 1, ALICE); - Create(g); + test::graph::Send(g.get(), v, "b", BOB, 1, ALICE); + Create(std::move(g)); Rendezvous::Args args; // a = 1.0 TF_ASSERT_OK( @@ -229,9 +229,9 @@ void BuildTree(int N, Graph* g) { } TEST_F(ExecutorTest, RandomTree) { - Graph* g = new Graph(OpRegistry::Global()); - BuildTree(4096, g); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + BuildTree(4096, g.get()); + Create(std::move(g)); Rendezvous::Args args; TF_ASSERT_OK( rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); @@ -262,9 +262,9 @@ void BuildConcurrentAddAssign(Graph* g) { #ifndef THREAD_SANITIZER TEST_F(ExecutorTest, ConcurrentAddAssign) { - Graph* g = new Graph(OpRegistry::Global()); - BuildConcurrentAddAssign(g); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + BuildConcurrentAddAssign(g.get()); + Create(std::move(g)); for (int iters = 0; iters < 16; ++iters) { Rendezvous* rendez = NewLocalRendezvous(); TF_ASSERT_OK(Run(rendez)); @@ -281,12 +281,12 @@ TEST_F(ExecutorTest, ConcurrentAddAssign) { #endif TEST_F(ExecutorTest, SimpleSwitchLive) { - Graph* g = new Graph(OpRegistry::Global()); - auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Constant(g, VB(false)); - auto tmp = test::graph::Switch(g, in0, in1); - test::graph::Send(g, tmp, "c", BOB, 1, ALICE); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Constant(g.get(), VB(false)); + auto tmp = test::graph::Switch(g.get(), in0, in1); + test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); + Create(std::move(g)); Rendezvous::Args args; TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); // in0 = 1.0 @@ -300,12 +300,12 @@ TEST_F(ExecutorTest, SimpleSwitchLive) { } TEST_F(ExecutorTest, SimpleSwitchDead) { - Graph* g = new Graph(OpRegistry::Global()); - auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Constant(g, VB(true)); - auto tmp = test::graph::Switch(g, in0, in1); - test::graph::Send(g, tmp, "c", BOB, 1, ALICE); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Constant(g.get(), VB(true)); + auto tmp = test::graph::Switch(g.get(), in0, in1); + test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); + Create(std::move(g)); Rendezvous::Args args; TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); // in0 = 1.0 @@ -319,16 +319,16 @@ TEST_F(ExecutorTest, SimpleSwitchDead) { TEST_F(ExecutorTest, Abort) { // e = a + b + c + d - Graph* g = new Graph(OpRegistry::Global()); - auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB); - auto in2 = test::graph::Recv(g, "c", "float", ALICE, 1, BOB); - auto in3 = test::graph::Recv(g, "d", "float", ALICE, 1, BOB); - auto add0 = test::graph::Add(g, in0, in1); - auto add1 = test::graph::Add(g, in2, in3); - auto add2 = test::graph::Add(g, add0, add1); - test::graph::Send(g, add2, "e", BOB, 1, ALICE); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB); + auto in2 = test::graph::Recv(g.get(), "c", "float", ALICE, 1, BOB); + auto in3 = test::graph::Recv(g.get(), "d", "float", ALICE, 1, BOB); + auto add0 = test::graph::Add(g.get(), in0, in1); + auto add1 = test::graph::Add(g.get(), in2, in3); + auto add2 = test::graph::Add(g.get(), add0, add1); + test::graph::Send(g.get(), add2, "e", BOB, 1, ALICE); + Create(std::move(g)); // Needs 4 inputs (recv). One of them is aborted. rendez_->Ref(); @@ -371,17 +371,17 @@ TEST_F(ExecutorTest, Abort) { } TEST_F(ExecutorTest, RecvInvalidDtype) { - Graph* g = new Graph(OpRegistry::Global()); + std::unique_ptr g(new Graph(OpRegistry::Global())); // An input vector of type float of size 1. - auto one = test::graph::Recv(g, "one", "float", ALICE, 1, BOB); + auto one = test::graph::Recv(g.get(), "one", "float", ALICE, 1, BOB); // A floating point variable vector of size 1. - auto var = test::graph::Var(g, DT_FLOAT, TensorShape({1})); + auto var = test::graph::Var(g.get(), DT_FLOAT, TensorShape({1})); // Initialize the variable with input. - auto init = test::graph::Assign(g, var, one); + auto init = test::graph::Assign(g.get(), var, one); // Output - auto* two = test::graph::Send(g, var, "two", BOB, 1, ALICE); + auto* two = test::graph::Send(g.get(), var, "two", BOB, 1, ALICE); g->AddControlEdge(init, two); // Ensures run after init. - Create(g); + Create(std::move(g)); Rendezvous* rendez = NewLocalRendezvous(); // Send a double instead of float. TF_ASSERT_OK(rendez->Send(Key(ALICE, 1, BOB, "one"), Rendezvous::Args(), @@ -396,11 +396,11 @@ TEST_F(ExecutorTest, RecvInvalidDtype) { } TEST_F(ExecutorTest, RecvInvalidRefDtype) { - Graph* g = new Graph(OpRegistry::Global()); + std::unique_ptr g(new Graph(OpRegistry::Global())); // A var that always produces as invalid dtype. - auto var = test::graph::InvalidRefType(g, DT_FLOAT, DT_DOUBLE); - test::graph::Send(g, var, "out", BOB, 1, ALICE); - Create(g); + auto var = test::graph::InvalidRefType(g.get(), DT_FLOAT, DT_DOUBLE); + test::graph::Send(g.get(), var, "out", BOB, 1, ALICE); + Create(std::move(g)); Rendezvous* rendez = NewLocalRendezvous(); EXPECT_TRUE(errors::IsInternal(Run(rendez))); Tensor output; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 0120f612ac8bee32999304b1a6f63fff3802606a..7878ebb5f06db0f64e9216250da2a79352274ab3 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -271,7 +271,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, skip_cost_models_ = false; } TF_RETURN_IF_ERROR( - NewLocalExecutor(params, subgraph.release(), &unit->root)); + NewLocalExecutor(params, std::move(subgraph), &unit->root)); } return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index d0ca2a625778ff73c6d40492cc5d02ec81ef3cc6..cc35264b8fe0b6decc325dab793c6a5fe6ad097f 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -140,7 +140,7 @@ class GraphMgr { GraphMgr* graph_mgr; }; - const WorkerEnv* worker_env_; // Not owned. + const WorkerEnv* worker_env_; // Not owned. DeviceMgr* device_mgr_; CostModelManager cost_model_manager_; diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index d1dc622ce79df1a98c3712e447a66bad3baecba1..1a488303ac73b8628b9d3fe4050ad9144724348e 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -528,8 +528,8 @@ void Master::ListDevices(const ListDevicesRequest* req, auto session = FindMasterSession(req->session_handle()); if (session == nullptr) { done(errors::InvalidArgument( - "Session ", req->session_handle(), - " is not found. Possibly, this master has restarted.")); + "Session ", req->session_handle(), + " is not found. Possibly, this master has restarted.")); return; } core::ScopedUnref ref(session); diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index dcc25e4426df386da2543f76239a1468af4bc3d2..878a1398c9d382a4b2018712ca9f9e48c11a9345 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -1448,6 +1448,7 @@ Status MasterSession::DoPartialRun(CallOptions* opts, const auto count = run_state->count; pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE; + pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE; pss.report_tensor_allocations_upon_oom = req.options().report_tensor_allocations_upon_oom(); @@ -1610,6 +1611,7 @@ Status MasterSession::DoRunWithLocalExecution( TRACEPRINTF("stepid %llu", step_id); pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE; + pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE; pss.report_tensor_allocations_upon_oom = req.options().report_tensor_allocations_upon_oom(); // Build the cost model every 'build_cost_model_every' steps after skipping an diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc index 121c58762f10a87fea059ce43b190f70e49e1f64..f2c1f3489c388d6a5fff729b1c8f98136532105c 100644 --- a/tensorflow/core/distributed_runtime/master_test.cc +++ b/tensorflow/core/distributed_runtime/master_test.cc @@ -61,7 +61,7 @@ class MasterTest : public ::testing::Test { // rpc calls. Status CreateSession(const GraphDef& def, string* handle, - int64* initial_version) { + int64* initial_version) { ::grpc::ClientContext ctx; CreateSessionRequest req; *(req.mutable_graph_def()) = def; @@ -77,7 +77,7 @@ class MasterTest : public ::testing::Test { } Status ExtendSession(const string& handle, const GraphDef& def, - int64 current_version, int64* new_version) { + int64 current_version, int64* new_version) { ::grpc::ClientContext ctx; ExtendSessionRequest req; req.set_session_handle(handle); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc index 7efc0ba6d8510fb0d462df13f7b3ebf68e939313..613188244fcb196a2bca7307d536a652a0f7f551 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc @@ -157,7 +157,7 @@ class MultiGrpcChannelCache : public CachingGrpcChannelCache { } } - void ListWorkers(std::vector* workers) const override { + void ListWorkers(std::vector* workers) override { for (GrpcChannelCache* cache : caches_) { cache->ListWorkers(workers); } @@ -216,7 +216,7 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache { } ~SparseGrpcChannelCache() override {} - void ListWorkers(std::vector* workers) const override { + void ListWorkers(std::vector* workers) override { workers->reserve(workers->size() + host_ports_.size()); for (const auto& id_host_port : host_ports_) { workers->emplace_back(MakeAddress(job_id_, id_host_port.first)); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h index de9840fca8c312bffefea501522210cafc2af82e..48b9d958aa921b0e758fc17a0f4da7c3a13e6c16 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h @@ -65,7 +65,7 @@ class GrpcChannelCache { // was created to handle. Worker names are in the format // /job:/task: // e.g. /job:mnist/task:2 - virtual void ListWorkers(std::vector* workers) const = 0; + virtual void ListWorkers(std::vector* workers) = 0; // If found, returns a gRPC channel that is connected to the remote // worker named by 'target'. 'target' is of the following diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index ac279937730466514451d7e81257d2110e128eff..b4d18d8607eaddd75f4e395e71fbd75554645a61 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -185,23 +185,22 @@ class GrpcMasterService : public AsyncServiceInterface { MutableRunStepResponseWrapper* wrapped_response = new NonOwnedProtoRunStepResponse(&call->response); call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - master_impl_->RunStep(call_opts, wrapped_request, wrapped_response, - [call, call_opts, wrapped_request, wrapped_response, - trace](const Status& status) { - call->ClearCancelCallback(); - delete call_opts; - delete wrapped_request; - delete trace; - if (call->request.store_errors_in_response_body() && - !status.ok()) { - call->response.set_status_code(status.code()); - call->response.set_status_error_message( - status.error_message()); - call->SendResponse(ToGrpcStatus(Status::OK())); - } else { - call->SendResponse(ToGrpcStatus(status)); - } - }); + master_impl_->RunStep( + call_opts, wrapped_request, wrapped_response, + [call, call_opts, wrapped_request, wrapped_response, + trace](const Status& status) { + call->ClearCancelCallback(); + delete call_opts; + delete wrapped_request; + delete trace; + if (call->request.store_errors_in_response_body() && !status.ok()) { + call->response.set_status_code(status.code()); + call->response.set_status_error_message(status.error_message()); + call->SendResponse(ToGrpcStatus(Status::OK())); + } else { + call->SendResponse(ToGrpcStatus(status)); + } + }); ENQUEUE_REQUEST(RunStep, true); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h index 4e203e260a1a370cc2bc7e40c3ce9e84da4d3ad4..6ae94b74417c3fb6c4da1589bb9f532cb6d79930 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h @@ -89,9 +89,9 @@ class MasterService final { ::grpc::Status ExtendSession(::grpc::ClientContext* context, const ExtendSessionRequest& request, ExtendSessionResponse* response) override; - ::grpc::Status PartialRunSetup( - ::grpc::ClientContext* context, const PartialRunSetupRequest& request, - PartialRunSetupResponse* response) override; + ::grpc::Status PartialRunSetup(::grpc::ClientContext* context, + const PartialRunSetupRequest& request, + PartialRunSetupResponse* response) override; ::grpc::Status RunStep(::grpc::ClientContext* context, const RunStepRequest& request, RunStepResponse* response) override; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index 70418f63686843414dca6c5ae4907ee263dc2904..1088e9be66ceb7fbddfaed0691423745f362343f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -69,8 +69,7 @@ class GrpcRemoteMaster : public MasterInterface { ::grpc::ClientContext ctx; auto trace = TraceRpc("RunStep/Client", &ctx); return Call(&ctx, call_options, &request->ToProto(), - get_proto_from_wrapper(response), - &MasterServiceStub::RunStep); + get_proto_from_wrapper(response), &MasterServiceStub::RunStep); } Status CloseSession(CallOptions* call_options, @@ -114,8 +113,9 @@ class GrpcRemoteMaster : public MasterInterface { template Status Call(::grpc::ClientContext* ctx, CallOptions* call_options, const Request* request, Response* response, - ::grpc::Status (MasterServiceStub::*pfunc)( - ::grpc::ClientContext*, const Request&, Response*)) { + ::grpc::Status (MasterServiceStub::*pfunc)(::grpc::ClientContext*, + const Request&, + Response*)) { ctx->set_fail_fast(false); SetDeadline(ctx, call_options->GetTimeout()); return FromGrpcStatus((stub_.get()->*pfunc)(ctx, *request, response)); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h b/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h index dd114d39c62f6b69a3fb9ea4401459f963137a1f..730124c25e9a3e8d102a9dd39e4c4a17f2ce39d1 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h @@ -66,7 +66,7 @@ class GrpcBufferWriter final } // It's dangerous to keep an inlined grpc_slice as the backup slice, since // on a following Next() call, a reference will be returned to this slice - // via GRPC_SLICE_START_PTR, which will not be an adddress held by + // via GRPC_SLICE_START_PTR, which will not be an address held by // slice_buffer_. have_backup_ = backup_slice_.refcount != NULL; byte_count_ -= count; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc index 373eecffcab1dded60de7ffea96ba58208bb692c..5597ee7a76a55f125dd0db82eceb58f5e922ab13 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc @@ -21,11 +21,8 @@ namespace tensorflow { namespace test { // ErrorOp::Compute returns an error. -REGISTER_OP("Error") - .Input("in: T") - .Output("out: T") - .Attr("T: type") - .Attr("message: string"); +REGISTER_OP("Error").Input("in: T").Output("out: T").Attr("T: type").Attr( + "message: string"); class ErrorOp : public OpKernel { public: explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -66,11 +63,8 @@ REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU), // DelayOp::AsyncCompute sleeps for "micros"-econd and then returns // its input. -REGISTER_OP("Delay") - .Input("in: T") - .Output("out: T") - .Attr("T: type") - .Attr("micros: int"); +REGISTER_OP("Delay").Input("in: T").Output("out: T").Attr("T: type").Attr( + "micros: int"); class DelayOp : public AsyncOpKernel { public: explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index bb14e0197b7b0ea44c4a75528f4919045574f4c5..2ed07e3669a3badd82b8ef27f45bac2b712c8978 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -34,7 +34,7 @@ namespace { class GrpcWorkerCache : public WorkerCachePartial { public: // TODO(ncteisen): consider adding a config var or flag for this - static constexpr const size_t kGrpcWorkerCacheThreadCount = 8; + static constexpr const size_t kGrpcWorkerCacheThreadCount = 2; explicit GrpcWorkerCache(GrpcChannelCache* channel_cache, WorkerInterface* local_worker, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 95811476f789be0225231f86aa0242db71b81199..1beb198732ad40ed9e21f66c665ff82a231eebb6 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -52,7 +52,7 @@ namespace { class GrpcWorkerService : public AsyncServiceInterface { // TODO(ncteisen): consider adding a config var or flag for this - static constexpr const size_t kGrpcWorkerServiceThreadCount = 8; + static constexpr const size_t kGrpcWorkerServiceThreadCount = 2; public: GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder) @@ -444,6 +444,24 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, }); } +void GrpcWorker::LoggingAsync(const LoggingRequest* request, + LoggingResponse* response, StatusCallback done) { + auto env = this->env(); + if (env) { + auto session_mgr = (SessionMgr*)env->session_mgr; + if (session_mgr) { + session_mgr->SetLogging(request->rpc_logging()); + for (const auto& step_id : request->fetch_step_id()) { + session_mgr->RetrieveLogs(step_id, response); + } + if (request->clear()) { + session_mgr->ClearLogs(); + } + } + } + done(Status::OK()); +} + WorkerEnv* GrpcWorker::env() { return env_; } std::unique_ptr NewGrpcWorker(WorkerEnv* env) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h index 78a21fd9f6ecb6deac171bb5c4a16fa074988fa2..fbddbda9e6f9e5561d4db0e035a48ed8db0d8559 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -40,6 +40,9 @@ class GrpcWorker : public Worker { ::grpc::ByteBuffer* response, StatusCallback done); + virtual void LoggingAsync(const LoggingRequest* request, + LoggingResponse* response, StatusCallback done); + WorkerEnv* env(); private: diff --git a/tensorflow/core/distributed_runtime/rpcbench_test.cc b/tensorflow/core/distributed_runtime/rpcbench_test.cc index b2668fae25a8a6bc60b37ddfaa83b8b523c3a6f5..d3af7417e61105c788b8029c84c222e49a0d2830 100644 --- a/tensorflow/core/distributed_runtime/rpcbench_test.cc +++ b/tensorflow/core/distributed_runtime/rpcbench_test.cc @@ -184,8 +184,8 @@ static void BM_Helper(int iters, int width, int num_stages, int tensor_size, testing::SetLabel( strings::StrCat(def.node_size(), " nodes; ", - use_multiple_devices ? "Multi device" : "Single device", - "; tensor bytes/send: ", tensor_size * sizeof(float))); + use_multiple_devices ? "Multi device" : "Single device", + "; tensor bytes/send: ", tensor_size * sizeof(float))); std::vector outputs; diff --git a/tensorflow/core/distributed_runtime/scheduler.cc b/tensorflow/core/distributed_runtime/scheduler.cc index 4766f4c33b654481f7d99ab82939e33e77564771..9dae5b3b926fab14c2b36955436d3956baa29fdd 100644 --- a/tensorflow/core/distributed_runtime/scheduler.cc +++ b/tensorflow/core/distributed_runtime/scheduler.cc @@ -17,9 +17,9 @@ limitations under the License. #include -#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/util/util.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/scheduler.h b/tensorflow/core/distributed_runtime/scheduler.h index eabcaccdd1e6c1a732f8871bc9da6265bd9a8dd8..ef87b9834dba50cf628a8c29c70b0266661d6227 100644 --- a/tensorflow/core/distributed_runtime/scheduler.h +++ b/tensorflow/core/distributed_runtime/scheduler.h @@ -16,15 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ -#include #include +#include #include #include #include -#include "tensorflow/core/graph/costmodel.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/graph/costmodel.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index 8db49e7f151517a51de1f64242031a8bd9bd96e6..51b9547f53ba687c863b0fd11647e7bb82d80e03 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -43,8 +43,8 @@ SessionMgr::SessionMgr( worker_cache_factory_(std::move(worker_cache_factory)) {} string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { - return strings::StrCat("/job:", server_def.job_name(), - "/replica:0/task:", server_def.task_index()); + return strings::StrCat("/job:", server_def.job_name(), "/replica:0/task:", + server_def.task_index()); } Status SessionMgr::CreateSession(const string& session, @@ -64,8 +64,13 @@ Status SessionMgr::CreateSession(const string& session, TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache)); } + if (worker_cache != nullptr & default_worker_cache_.get() != nullptr) { + worker_cache->SetLogging(this->is_logging_active_); + } + CHECK(!worker_env_->local_devices.empty()) << "The WorkerEnv must have at least one device in `local_devices`."; + std::vector renamed_devices; for (Device* d : worker_env_->local_devices) { renamed_devices.push_back(RenamedDevice::NewRenamedDevice( @@ -113,4 +118,77 @@ std::shared_ptr SessionMgr::LegacySession() { return legacy_session_; } +void SessionMgr::SetLogging(bool active) { + mutex_lock l(mu_); + this->is_logging_active_ = active; + // Legacy Session + if (legacy_session_) { + auto* worker_cache = legacy_session_->worker_cache.get(); + if (worker_cache) { + worker_cache->SetLogging(active); + } + } + + for (const auto& session_kv : sessions_) { + auto session = session_kv.second.get(); + if (session) { + auto* worker_cache = session->worker_cache.get(); + if (worker_cache) { + worker_cache->SetLogging(active); + } + } + } +} + +void SessionMgr::RetrieveLogs(tensorflow::int64 step_id, + LoggingResponse* response) { + mutex_lock l(mu_); + // Legacy Session + if (legacy_session_) { + auto* worker_cache = legacy_session_->worker_cache.get(); + if (worker_cache) { + auto step_stats = StepStats(); + if (worker_cache->RetrieveLogs(step_id, &step_stats)) { + auto* labeled_step_stats = response->add_step(); + labeled_step_stats->set_step_id(step_id); + labeled_step_stats->mutable_step_stats()->Swap(&step_stats); + } + } + } + for (const auto& session_kv : sessions_) { + auto session = session_kv.second.get(); + if (session) { + auto* worker_cache = session->worker_cache.get(); + if (worker_cache) { + auto step_stats = StepStats(); + if (worker_cache->RetrieveLogs(step_id, &step_stats)) { + auto* labeled_step_stats = response->add_step(); + labeled_step_stats->set_step_id(step_id); + labeled_step_stats->mutable_step_stats()->Swap(&step_stats); + } + } + } + } +} + +void SessionMgr::ClearLogs() { + mutex_lock l(mu_); + // Legacy Session + if (legacy_session_) { + auto* worker_cache = legacy_session_->worker_cache.get(); + if (worker_cache) { + worker_cache->ClearLogs(); + } + } + + for (const auto& session_kv : sessions_) { + auto session = session_kv.second.get(); + if (session) { + auto* worker_cache = session->worker_cache.get(); + if (worker_cache) { + worker_cache->ClearLogs(); + } + } + } +} } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index 3ce260d12e92e3458fe12f3f5b5723f9c39b5f4b..4c9702d522cede454d5efd15669eaec2b0c1c1b1 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" +#include "tensorflow/core/protobuf/worker.pb.h" namespace tensorflow { @@ -56,6 +57,12 @@ class SessionMgr { static string WorkerNameFromServerDef(const ServerDef& server_def); + void SetLogging(bool active); + + void RetrieveLogs(tensorflow::int64 step_id, LoggingResponse* response); + + void ClearLogs(); + private: const WorkerEnv* const worker_env_; // Not owned. @@ -75,6 +82,8 @@ class SessionMgr { std::unique_ptr default_worker_cache_; std::shared_ptr legacy_session_; + bool is_logging_active_ = false; + const WorkerCacheFactory worker_cache_factory_; std::shared_ptr WorkerSessionForSessionUnlocked( diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc index fe2d1a12934dde814344b70f52fbc972f74347e0..34a4013547b5feef12b49198bff4e733f1b9e932 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/tensor_coding.cc @@ -81,7 +81,7 @@ void TensorResponse::InitPartial(const RecvTensorResponse& response) { Status TensorResponse::ParseFrom(Source* source) { if (!on_host_) { protobuf::io::CodedInputStream input(source->contents()); - input.SetTotalBytesLimit(INT_MAX, INT_MAX); // Unlimited + input.SetTotalBytesLimit(INT_MAX); // Unlimited // Pre-parse into local storage, then delegate to device. if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) { @@ -217,7 +217,7 @@ bool TensorResponse::ParseTensorSubmessage( bool TensorResponse::ParseFast(Source* source) { protobuf::io::CodedInputStream input(source->contents()); - input.SetTotalBytesLimit(INT_MAX, INT_MAX); // Unlimited + input.SetTotalBytesLimit(INT_MAX); // Unlimited while (true) { auto p = input.ReadTagWithCutoff(127); int tag = GetTagFieldNumber(p.first); diff --git a/tensorflow/core/distributed_runtime/worker_cache_logger.cc b/tensorflow/core/distributed_runtime/worker_cache_logger.cc index 702af78c88014d54fe2f72a8266e5e7e43b3cfb9..95ca3c3b4d11fac0d103eb52f19d5b0b2f4ad3ea 100644 --- a/tensorflow/core/distributed_runtime/worker_cache_logger.cc +++ b/tensorflow/core/distributed_runtime/worker_cache_logger.cc @@ -97,9 +97,8 @@ void WorkerCacheLogger::RecordDataTransfer(int64 step_id, int64 start_usecs, const string& tensor_name, const string& src_device, const string& dst_device, - int64 bytes, - const string& details, - const string& transfer_method_name){ + int64 bytes, const string& details, + const string& transfer_method_name) { NodeExecStats* ns = new NodeExecStats; ns->set_node_name(transfer_method_name); if (details.empty()) { diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h index 4e9352ee32227376957157c7ada63390689ac39a..d977935b8a392adf1f78c38955f77f6f364502c9 100644 --- a/tensorflow/core/example/feature_util.h +++ b/tensorflow/core/example/feature_util.h @@ -56,9 +56,9 @@ limitations under the License. // // To add values to feature_lists: // AppendFeatureValues({4.0}, -// GetFeatureList("movie_ratings", &se)->Add()); +// GetFeatureList("images", &se)->Add()); // AppendFeatureValues({5.0, 3.0}, -// GetFeatureList("movie_ratings", &se)->Add()); +// GetFeatureList("images", &se)->Add()); // This will create a feature list keyed as "images" with two features: // feature_lists { // feature_list { diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 2bd19663fc6f45aeae857a5bdd69bf41d5a94bd4..94bf34afa49f586e1bb61c1654865a5abc9abe19 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -113,7 +113,7 @@ class CPUAllocator : public Allocator { stats_.max_alloc_size = 0; } - size_t AllocatedSizeSlow(void* ptr) override { + size_t AllocatedSizeSlow(const void* ptr) override { return port::MallocExtension_GetAllocatedSize(ptr); } diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 5a95d3a15d1699e518e16cd300bccfb7a40ab50f..3ce1b612464291eceb6e08d9b0f2deca70cda27a 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -156,7 +156,7 @@ class Allocator { // // REQUIRES: 'ptr!=nullptr' and points to a buffer previously // allocated by this allocator. - virtual size_t RequestedSize(void* ptr) { + virtual size_t RequestedSize(const void* ptr) { CHECK(false) << "allocator doesn't track sizes"; return size_t(0); } @@ -169,7 +169,7 @@ class Allocator { // // REQUIRES: 'ptr!=nullptr' and points to a buffer previously // allocated by this allocator. - virtual size_t AllocatedSize(void* ptr) { return RequestedSize(ptr); } + virtual size_t AllocatedSize(const void* ptr) { return RequestedSize(ptr); } // Returns either 0 or an identifier assigned to the buffer at 'ptr' // when the buffer was returned by AllocateRaw. If non-zero, the @@ -180,7 +180,7 @@ class Allocator { // // REQUIRES: 'ptr!=nullptr' and points to a buffer previously // allocated by this allocator. - virtual int64 AllocationId(void* ptr) { return 0; } + virtual int64 AllocationId(const void* ptr) { return 0; } // Returns the allocated size of the buffer at 'ptr' if known, // otherwise returns 0. This method can be called when @@ -188,7 +188,7 @@ class Allocator { // // REQUIRES: 'ptr!=nullptr' and points to a buffer previously // allocated by this allocator. - virtual size_t AllocatedSizeSlow(void* ptr) { + virtual size_t AllocatedSizeSlow(const void* ptr) { if (TracksAllocationSizes()) { return AllocatedSize(ptr); } @@ -312,17 +312,19 @@ class AllocatorWrapper : public Allocator { return wrapped_->TracksAllocationSizes(); } - size_t RequestedSize(void* ptr) override { + size_t RequestedSize(const void* ptr) override { return wrapped_->RequestedSize(ptr); } - size_t AllocatedSize(void* ptr) override { + size_t AllocatedSize(const void* ptr) override { return wrapped_->AllocatedSize(ptr); } - int64 AllocationId(void* ptr) override { return wrapped_->AllocationId(ptr); } + int64 AllocationId(const void* ptr) override { + return wrapped_->AllocationId(ptr); + } - size_t AllocatedSizeSlow(void* ptr) override { + size_t AllocatedSizeSlow(const void* ptr) override { return wrapped_->AllocatedSizeSlow(ptr); } diff --git a/tensorflow/core/framework/bfloat16.cc b/tensorflow/core/framework/bfloat16.cc index 0efe43fde2dadd42aa03d3bf2968d2cbfb113e8d..6025be517048d33b20f7af15ef7ad1339adebdf9 100644 --- a/tensorflow/core/framework/bfloat16.cc +++ b/tensorflow/core/framework/bfloat16.cc @@ -21,13 +21,13 @@ void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) { const uint16_t* p = reinterpret_cast(src); uint16_t* q = reinterpret_cast(dst); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - for (; size != 0; p += 2, q++, size--) { - *q = p[0]; - } + for (; size != 0; p += 2, q++, size--) { + *q = p[0]; + } #else - for (; size != 0; p += 2, q++, size--) { - *q = p[1]; - } + for (; size != 0; p += 2, q++, size--) { + *q = p[1]; + } #endif } @@ -35,15 +35,15 @@ void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) { const uint16_t* p = reinterpret_cast(src); uint16_t* q = reinterpret_cast(dst); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - for (; size != 0; p++, q += 2, size--) { - q[0] = *p; - q[1] = 0; - } + for (; size != 0; p++, q += 2, size--) { + q[0] = *p; + q[1] = 0; + } #else - for (; size != 0; p++, q += 2, size--) { - q[0] = 0; - q[1] = *p; - } + for (; size != 0; p++, q += 2, size--) { + q[0] = 0; + q[1] = *p; + } #endif } diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 7ab8e3ec188a223e35b47b6f9517abd9327b23f8..623248b6ce6adff8ed323acad7dae300742f8eba 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -49,7 +49,11 @@ Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size, break; } if (*output_size < 0) { - return errors::InvalidArgument("computed output size would be negative"); + return errors::InvalidArgument( + "Computed output size would be negative: ", *output_size, + " [input_size: ", input_size, + ", effective_filter_size: ", effective_filter_size, + ", stride: ", stride, "]"); } return Status::OK(); } @@ -1356,10 +1360,11 @@ Status ScatterNdUpdateShape(InferenceContext* c) { Status s = c->Merge(prefix_indices, prefix_updates, &unused); if (!s.ok()) { return errors::InvalidArgument( - "The outer ", num_outer_dims, " dimensions of indices.shape=", - c->DebugString(indices_shape), " must match the outer ", - num_outer_dims, " dimensions of updates.shape=", - c->DebugString(updates_shape), ": ", s.error_message()); + "The outer ", num_outer_dims, + " dimensions of indices.shape=", c->DebugString(indices_shape), + " must match the outer ", num_outer_dims, + " dimensions of updates.shape=", c->DebugString(updates_shape), + ": ", s.error_message()); } ShapeHandle input_suffix; diff --git a/tensorflow/core/kernels/data/dataset.cc b/tensorflow/core/framework/dataset.cc similarity index 99% rename from tensorflow/core/kernels/data/dataset.cc rename to tensorflow/core/framework/dataset.cc index 2ea6875567604e4e5bf7c990ad6a42ed8c5dafaa..4145ef7bc9d22632db3d0a71f8901a671dd95ee5 100644 --- a/tensorflow/core/kernels/data/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/framework/dataset.h" + #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/node_builder.h" diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 2c2c7e7c585c9364e1d08280d5fe76f1bf1eff23..6ab23d92a421df8b5fb9bcf637ad805d67577aa1 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,64 +12,605 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ -#ifndef TENSORFLOW_FRAMEWORK_DATASET_H_ -#define TENSORFLOW_FRAMEWORK_DATASET_H_ +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/tracing.h" + +// Polymorphic datasets should support all primitive TensorFlow +// types. Use this macro to expand `m(T)` once for each primitive type +// `T`, e.g. to build a `switch` statement. +#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) namespace tensorflow { -namespace dataset { -// Registry for stateful ops that need to be used in dataset functions. -// See below macro for usage details. -class WhitelistedStatefulOpRegistry { + +// Interface for reading values from a key-value store. +// Used for restoring iterator state. +class IteratorStateReader { + public: + virtual Status ReadScalar(StringPiece key, int64* val) = 0; + virtual Status ReadScalar(StringPiece key, string* val) = 0; + virtual Status ReadTensor(StringPiece key, Tensor* val) = 0; + virtual bool Contains(StringPiece key) = 0; + + virtual ~IteratorStateReader() {} +}; + +// Interface for writing values to a key-value store. +// Used for saving iterator state. +class IteratorStateWriter { + public: + virtual Status WriteScalar(StringPiece key, const int64 val) = 0; + virtual Status WriteScalar(StringPiece key, const string& val) = 0; + virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; + + virtual ~IteratorStateWriter() {} +}; + +// Forward declarations to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; +class GraphDatasetBase; +class Node; + +// Wrapper around GraphDefBuilder. Used to serialize Dataset graph. +class GraphDefBuilderWrapper { + public: + explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} + + // Adds a Const node with scalar value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + template + Status AddScalar(const T& val, Node** output) { + Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); + val_t.scalar()() = val; + AddTensorInternal(val_t, output); + if (*output == nullptr) { + return errors::Internal("AddScalar: Failed to build Const op."); + } + return Status::OK(); + } + + // Adds a Const node with vector value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? + template + Status AddVector(const std::vector& val, Node** output) { + Tensor val_t = Tensor(DataTypeToEnum::v(), + TensorShape({static_cast(val.size())})); + for (int i = 0; i < val.size(); i++) { + val_t.flat()(i) = val[i]; + } + AddTensorInternal(val_t, output); + if (*output == nullptr) { + return errors::Internal("AddVector: Failed to build Const op."); + } + return Status::OK(); + } + + // Adds a Const node with Tensor value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + Status AddTensor(const Tensor& val, Node** output) { + AddTensorInternal(val, output); + if (*output == nullptr) { + return errors::Internal("AddTensor: Failed to build Const op."); + } + return Status::OK(); + } + + Status AddDataset(const GraphDatasetBase* dataset, + const std::vector& inputs, Node** output) { + return AddDataset(dataset, inputs, {}, output); + } + + // Adds a node corresponding to the `DatasetType` to the Graph. + // Return value of `DatasetType::op_name()` is used as the op type for the + // node. + // Values for the output_types and output_shapes node attributes are also + // written if those attributes are defined in the OpDef. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + Status AddDataset(const GraphDatasetBase* dataset, + const std::vector& inputs, + const std::vector>& attrs, + Node** output) { + std::vector> enumerated_inputs(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + enumerated_inputs[i] = std::make_pair(i, inputs[i]); + } + return AddDataset(dataset, enumerated_inputs, {}, attrs, output); + } + + Status AddDataset( + const GraphDatasetBase* dataset, + const std::vector>& inputs, + const std::vector>>& list_inputs, + const std::vector>& attrs, + Node** output); + + // Adds a user-defined function with name `function_name` to the graph and + // recursively adds all functions it references. If a function with a matching + // name has already been added, returns with OK status. If a user-defined with + // name `function_name` is not found in the FunctionLibraryDefinition, returns + // an InvalidArgumentError. If the function with name `function_name` or any + // of its dependent functions are stateful, returns an InvalidArgument error. + Status AddFunction(OpKernelContext* ctx, const string& function_name); + + template + void BuildAttrValue(const T& value, AttrValue* attr) { + SetAttrValue(value, attr); + } + + private: + void AddTensorInternal(const Tensor& val, Node** output); + + Status EnsureFunctionIsStateless(OpKernelContext* ctx, + const string& function_name) const { + const FunctionLibraryDefinition* lib_def = + ctx->function_library()->GetFunctionLibraryDefinition(); + const FunctionDef* function_def = lib_def->Find(function_name); + if (!function_def) { + return errors::InvalidArgument("Unable to find FunctionDef for ", + function_name, " in registry."); + } + for (const NodeDef& node_def : function_def->node_def()) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def)); + // TODO(b/65524810): Hack to allow functions to capture Dataset op + // nodes needed for FlatMap. Currently, source datasets nodes have been + // marked stateful to avoid constant folding since we do not have a + // good way of serializing them. + if (IsOpWhitelisted(op_def)) { + continue; + } + if (op_def->is_stateful()) { + return errors::InvalidArgument( + "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ", + "in function ", function_name, " is stateful. ", + "Saving stateful functions is not supported yet."); + } + } + return Status::OK(); + } + + // Returns whether an op has been whitelisted for use inside map_fns. + // Uses a heuristic to whitelist source dataset ops which have been + // marked stateful due to b/65524810. + // Also looks up the `op_def->name` in the global + // `WhitelistedStatefulOpRegistry`. + bool IsOpWhitelisted(const OpDef* op_def) const { + return (StringPiece(op_def->name()).ends_with("Dataset") && + op_def->output_arg_size() == 1 && + op_def->output_arg(0).type() == DT_VARIANT) || + dataset::WhitelistedStatefulOpRegistry::Global()->Contains( + op_def->name()); + } + + bool HasAttr(const string& op_type_name, const string& attr_name) const; + + bool HasAttr(const OpDef* op_def, const string& attr_name) const { + for (auto attr : op_def->attr()) { + if (attr.name() == attr_name) { + return true; + } + } + return false; + } + + Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) { + if (attr_value.has_func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); + } else if (attr_value.has_list()) { + for (const NameAttrList& name_attr_list : attr_value.list().func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); + } + } + return Status::OK(); + } + + GraphDefBuilder* b_; +}; + +class StatsAggregator; + +// A cut-down version of OpKernelContext for running computations in +// iterators. Note that we cannot simply use OpKernelContext here +// because we might run computation in an iterator whose lifetime is +// not nested within the lifetime of a single OpKernelContext +// (e.g. asynchronous prefetching). +// +// TODO(mrry): We will probably need to support more of +// OpKernelContext here. For example, should allocation be handled by +// the IteratorContext? +// TODO(mrry): We're making some daring assumptions about the lifetime +// of the runner passed in here. A runner will be deleted when the original +// step ends, but all existing runners only close over session-lifetime (or +// longer-lived) state, so we can make a copy of the function. There's nothing +// in the definition of the API from which we took the runner to guarantee that +// what we are doing is safe. We should formalize the properties here. +class IteratorContext { + public: + struct Params { + // Interface to operating system functionality. + Env* env; + + // Function call support. + std::function)> runner = nullptr; + + // A function that returns the current `StatsAggregator` instance to be + // used when recording statistics about the iterator. + // + // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator` + // is a property of the `IteratorResource` (which this class does not know + // about), and (ii) it can change after the `IteratorContext` has been + // created. Better suggestions are welcome! + std::function()> stats_aggregator_getter = + nullptr; + + // The FunctionLibraryRuntime object to be used to make function calls. + FunctionLibraryRuntime* lib = nullptr; + std::shared_ptr function_library = nullptr; + + // The Allocator to be used to allocate the output of an iterator. + std::function allocator_getter = nullptr; + }; + + explicit IteratorContext(Params params) : params_(std::move(params)) {} + + Env* env() const { return params_.env; } + + std::function)>* runner() { + return ¶ms_.runner; + } + + std::shared_ptr stats_aggregator() { + if (params_.stats_aggregator_getter) { + return params_.stats_aggregator_getter(); + } else { + return nullptr; + } + } + + std::shared_ptr function_library() { + return params_.function_library; + } + + FunctionLibraryRuntime* lib() { return params_.lib; } + + void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; } + + Allocator* allocator(AllocatorAttributes attrs) { + return params_.allocator_getter(attrs); + } + + private: + Params params_; +}; + +// Represents the current position in a range of outputs, where the +// range of outputs is typically represented by an `DatasetBase`, +// defined below. +class IteratorBase { + public: + virtual ~IteratorBase() {} + + // Gets the next output from the range that this iterator is traversing. + // + // If at least one output remains in this iterator's range, that + // output will be stored in `*out_tensors` and `false` will be + // stored in `*end_of_sequence`. + // + // If no more outputs remain in this iterator's range, `true` will + // be stored in `*end_of_sequence`, and the content of + // `*out_tensors` will be undefined. + // + // This method is thread-safe. + // + // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and + // potentially remove this method. + virtual Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) = 0; + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this + // iterator. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this iterator. + virtual const std::vector& output_shapes() const = 0; + + // Saves the state of this iterator. + virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { + return SaveInternal(writer); + } + + // Restores the state of this iterator. + virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) { + return RestoreInternal(ctx, reader); + } + + protected: + // This is needed so that sub-classes of IteratorBase can call + // `SaveInternal` on their parent iterators, e.g., in + // `RepeatDataasetOp::Dataset`. + Status SaveParent(IteratorStateWriter* writer, + const std::unique_ptr& parent) { + return parent->SaveInternal(writer); + } + + // This is needed so that sub-classes of IteratorBase can call + // `RestoreInternal` on their parent iterators, e.g., in + // `RepeatDataasetOp::Dataset`. + Status RestoreParent(IteratorContext* ctx, IteratorStateReader* reader, + const std::unique_ptr& parent) { + return parent->RestoreInternal(ctx, reader); + } + + // Saves the state of this iterator recursively. + virtual Status SaveInternal(IteratorStateWriter* writer) { + return errors::Unimplemented("SaveInternal"); + } + + // Restores the state of this iterator recursively. + virtual Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) { + return errors::Unimplemented("RestoreInternal"); + } +}; + +// Represents a (potentially infinite) range of outputs, where each +// output is a tuple of tensors. +class DatasetBase : public core::RefCounted { public: - Status Add(StringPiece op_name) { - op_names_.insert(op_name); + // Returns a new iterator for iterating over the range of elements in + // this dataset. + // + // This method may be called multiple times on the same instance, + // and the resulting iterators will have distinct state. Each + // iterator will traverse all elements in this dataset from the + // start. + // + // Ownership of the created iterator will be transferred to the caller. + // + // The prefix identifies the sequence of iterators leading up to the newly + // created iterator. + virtual std::unique_ptr MakeIterator( + const string& prefix) const = 0; + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this + // dataset. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this dataset. + virtual const std::vector& output_shapes() const = 0; + + // A human-readable debug string for this dataset. + virtual string DebugString() = 0; + + // Serializes the dataset and writes it to the `writer`. + virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { + return errors::Unimplemented("DatasetBase::Save"); + } + + protected: + // TODO(srbs): Ideally all graph related logic should reside in + // GraphDatasetBase. However, that would require Datasets defined in all ops + // to derive from GraphDatasetBase. Once that is done we can move + // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase. + class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { + public: + DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} + Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset, + Node** output) { + return dataset->AsGraphDefInternal(ctx, this, output); + } + }; + + virtual Status AsGraphDefInternal(OpKernelContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const { + return AsGraphDefInternal(b, node); + } + + virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** node) const { + return errors::Unimplemented("AsGraphDefInternal"); + } +}; + +// Base-class for datasets that are built by ops. +class GraphDatasetBase : public DatasetBase { + public: + GraphDatasetBase(OpKernelContext* ctx) + : op_name_(ctx->op_kernel().type_string()) {} + + const string op_name() const { return op_name_; } + + Status Save(OpKernelContext* ctx, + IteratorStateWriter* writer) const override { + string serialized_graph_def; + string output_node; + TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); return Status::OK(); } - bool Contains(StringPiece op_name) { - return op_names_.find(op_name) != op_names_.end(); + // Key for storing the Dataset graph in the serialized format. + static const char kDatasetGraphKey[]; + + // Key for storing the output node of the Dataset graph in the serialized + // format. + static const char kDatasetGraphOutputNodeKey[]; + + private: + Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, + string* output_node) const; + + const string op_name_; +}; + +// Represents an iterator that is associated with a particular parent dataset. +template +class DatasetIterator : public IteratorBase { + public: + struct Params { + // Owns one reference on the shared dataset resource. + const DatasetType* dataset; + + // Identifies the sequence of iterators leading up to this iterator. + const string prefix; + }; + + explicit DatasetIterator(const Params& params) : params_(params) { + params_.dataset->Ref(); + } + + ~DatasetIterator() override { params_.dataset->Unref(); } + + // The dataset from which this iterator was created. + const DatasetType* dataset() const { return params_.dataset; } + + // The sequence of iterators leading up to this iterator. + const string prefix() const { return params_.prefix; } + + const DataTypeVector& output_dtypes() const override { + return params_.dataset->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return params_.dataset->output_shapes(); } - static WhitelistedStatefulOpRegistry* Global() { - static WhitelistedStatefulOpRegistry* reg = - new WhitelistedStatefulOpRegistry; - return reg; + Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) final { + port::Tracing::TraceMe activity(params_.prefix); + Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); + if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { + s = errors::Internal( + "Iterator \"", params_.prefix, + "\" returned OutOfRange without setting `*end_of_sequence`. This " + "indicates that an error may have occurred. Original message: ", + s.error_message()); + LOG(ERROR) << s; + } + return s; + } + + Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final { + TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer)); + return IteratorBase::Save(ctx, writer); + } + + protected: + // Internal implementation of GetNext that is wrapped in tracing logic. + virtual Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) = 0; + + string full_name(const string& name) const { + return strings::StrCat(prefix(), ":", name); } private: - WhitelistedStatefulOpRegistry() {} - WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy); - WhitelistedStatefulOpRegistry operator=( - WhitelistedStatefulOpRegistry const& copy); - std::set op_names_; + Params params_; }; -} // namespace dataset - -// Use this macro to whitelist an op that is marked stateful but needs to be -// used inside a map_fn in an input pipeline. This is only needed if you wish -// to be able to checkpoint the state of the input pipeline. We currently -// do not allow stateful ops to be defined inside of map_fns since it is not -// possible to save their state. -// Note that the state of the whitelisted ops inside functions will not be -// saved during checkpointing, hence this should only be used if the op is -// marked stateful for reasons like to avoid constant folding during graph -// optimiztion but is not stateful. -// If possible, try to remove the stateful flag on the op first. -// Example usage: +// Encapsulates the work required to plug a DatasetBase into the core TensorFlow +// graph execution engine. +class DatasetOpKernel : public OpKernel { + public: + DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) final; + + protected: + // Subclasses should implement this method. It will be called during Compute + // execution. + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; + + template + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); + } +}; + +// Encapsulates the work required to plug unary Datasets into the core +// TensorFlow graph execution engine. +class UnaryDatasetOpKernel : public DatasetOpKernel { + public: + UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) = 0; +}; + +// Encapsulates the work required to plug binary Datasets into the core +// TensorFlow graph execution engine. +class BinaryDatasetOpKernel : public DatasetOpKernel { + public: + BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase* another_input, + DatasetBase** output) = 0; +}; + +// Validates and extracts a `DatasetBase` object from `tensor`. // -// WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LegacyStatefulReader"); +// `tensor` must have been written by a call to SetVariantTensorToDataset(). +// +// The retrieved pointer is a borrowed reference to the dataset, which is owned +// by the tensor. The consumer must either acquire its own reference to the +// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not +// destroyed or mutated while the retrieved pointer is in use. +Status GetDatasetFromVariantTensor(const Tensor& tensor, + DatasetBase** out_dataset); + +// Stores a `DatasetBase` object in `tensor`. // -#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS(name) \ - WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name) -#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \ - WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) -#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ - static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ - ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \ - name) +// The ownership of `dataset` is transferred to `tensor`. +Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); } // namespace tensorflow -#endif // TENSORFLOW_FRAMEWORK_DATASET_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h new file mode 100644 index 0000000000000000000000000000000000000000..3b48999edb37da4fdf232f2cbcd61df7affb40f2 --- /dev/null +++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace dataset { +// Registry for stateful ops that need to be used in dataset functions. +// See below macro for usage details. +class WhitelistedStatefulOpRegistry { + public: + Status Add(StringPiece op_name) { + op_names_.insert(op_name); + return Status::OK(); + } + + bool Contains(StringPiece op_name) { + return op_names_.find(op_name) != op_names_.end(); + } + + static WhitelistedStatefulOpRegistry* Global() { + static WhitelistedStatefulOpRegistry* reg = + new WhitelistedStatefulOpRegistry; + return reg; + } + + private: + WhitelistedStatefulOpRegistry() {} + WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy); + WhitelistedStatefulOpRegistry operator=( + WhitelistedStatefulOpRegistry const& copy); + std::set op_names_; +}; + +} // namespace dataset + +// Use this macro to whitelist an op that is marked stateful but needs to be +// used inside a map_fn in an input pipeline. This is only needed if you wish +// to be able to checkpoint the state of the input pipeline. We currently +// do not allow stateful ops to be defined inside of map_fns since it is not +// possible to save their state. +// Note that the state of the whitelisted ops inside functions will not be +// saved during checkpointing, hence this should only be used if the op is +// marked stateful for reasons like to avoid constant folding during graph +// optimiztion but is not stateful. +// If possible, try to remove the stateful flag on the op first. +// Example usage: +// +// WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LegacyStatefulReader"); +// +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS(name) \ + WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name) +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \ + WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ + static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \ + name) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ diff --git a/tensorflow/core/framework/device_attributes.proto b/tensorflow/core/framework/device_attributes.proto index 9983bcb6bec63602c2e624a183a111622f7f2ace..0b3c0d5bdf9f3db858d631dcf67d1120022520f2 100644 --- a/tensorflow/core/framework/device_attributes.proto +++ b/tensorflow/core/framework/device_attributes.proto @@ -6,10 +6,26 @@ option java_outer_classname = "DeviceAttributesProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; +message InterconnectLink { + int32 device_id = 1; + string type = 2; + int32 strength = 3; +}; + +message LocalLinks { + repeated InterconnectLink link = 1; +}; + message DeviceLocality { // Optional bus locality of device. Default value of 0 means // no specific locality. Specific localities are indexed from 1. int32 bus_id = 1; + + // Optional NUMA locality of device. + int32 numa_node = 2; + + // Optional local interconnect links to other devices. + LocalLinks links = 3; }; message DeviceAttributes { diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc index ad301a8aa4ba4be5b7031d00984d8e6febf1583e..70d1e20a17c6cbf75a15d32a97216f6a1354ccf4 100644 --- a/tensorflow/core/framework/fake_input.cc +++ b/tensorflow/core/framework/fake_input.cc @@ -104,8 +104,8 @@ Status FakeInputImpl::AddInputToBuilder() { Status status = GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts); if (!status.ok()) { return errors::InvalidArgument( - "Could not infer list of types for input '", arg_->name(), "': ", - status.error_message()); + "Could not infer list of types for input '", arg_->name(), + "': ", status.error_message()); } SourceList(dts); return Status::OK(); @@ -131,8 +131,8 @@ Status FakeInputImpl::GetN(int* n) const { Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n); if (!status.ok()) { return errors::InvalidArgument("Could not infer length of input '", - arg_->name(), "': ", - status.error_message()); + arg_->name(), + "': ", status.error_message()); } } return Status::OK(); @@ -153,8 +153,8 @@ Status FakeInputImpl::GetDataType(DataType* dt) const { *dt = attr->default_value().type(); } else { return errors::InvalidArgument("Could not infer type for input '", - arg_->name(), "': ", - status.error_message()); + arg_->name(), + "': ", status.error_message()); } } } else { diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 0224f252270cdfb856957be33b3dd857ecb07ec9..eae8e6c3c10c4b49081aed0e253d9a6f382f562b 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1064,26 +1064,36 @@ Status FunctionLibraryDefinition::AddLibrary( return Status::OK(); } -void FunctionLibraryDefinition::RemoveFunction(const string& func) { +Status FunctionLibraryDefinition::RemoveFunction(const string& func) { const auto& i = function_defs_.find(func); - DCHECK(i != function_defs_.end()); + if (i == function_defs_.end()) { + return errors::InvalidArgument("Tried to remove non-existent function ", + func); + } function_defs_.erase(i); + return Status::OK(); } -void FunctionLibraryDefinition::RemoveGradient(const string& func) { +Status FunctionLibraryDefinition::RemoveGradient(const string& func) { const auto& i = func_grad_.find(func); - DCHECK(i != func_grad_.end()); + if (i == func_grad_.end()) { + return errors::InvalidArgument("Tried to remove non-existent gradient ", + func); + } func_grad_.erase(i); + return Status::OK(); } void FunctionLibraryDefinition::Remove( const std::vector& funcs, const std::vector& funcs_with_grads) { for (const string& f : funcs) { - RemoveFunction(f); + Status s = RemoveFunction(f); + DCHECK(s.ok()); } for (const string& f : funcs_with_grads) { - RemoveGradient(f); + Status s = RemoveGradient(f); + DCHECK(s.ok()); } } @@ -1264,8 +1274,8 @@ FunctionDef FunctionDefHelper::Define(const string& name, } for (const string& a : src.arg) { const auto iter = ret_index.find(a); - CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '" - << src.ret[0] << "' of " << name; + CHECK(iter != ret_index.end()) + << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name; n->add_input(iter->second); } for (const string& d : src.dep) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 3bb5638cdf232c144157b587a7431f435e2fa6ea..e27001133bbb5056abf1a3e1f5b9d69c8e01bc56 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -35,6 +35,7 @@ namespace tensorflow { class CancellationManager; class GraphDef; class OpKernel; +class ProcessFunctionLibraryRuntime; class ResourceMgr; class Rendezvous; class ScopedStepContainer; @@ -312,6 +313,14 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // This operation is atomic. Status AddGradientDef(const GradientDef& grad); + // Remove function `func` from the library. Returns non-OK Status unless + // `func` is in the library. + Status RemoveFunction(const string& func); + + // Remove gradient of function `func` from the library. Returns non-OK Status + // unless `func` has a gradient. + Status RemoveGradient(const string& func); + // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. @@ -384,13 +393,6 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // attr from. const FunctionDef* GetAttrImpl(const NodeDef& ndef) const; - // Remove function `func` from the library. `func` must be in the library. - void RemoveFunction(const string& func); - - // Remove gradient of function `func` from the library. `func` must have - // a gradient. - void RemoveGradient(const string& func); - // Remove all functions in `funcs` and all gradients of // functions in `funcs_with_grads` from this library. void Remove(const std::vector& funcs, @@ -534,6 +536,10 @@ class FunctionLibraryRuntime { virtual int graph_def_version() = 0; typedef uint64 LocalHandle; + + virtual Status Clone(std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) = 0; }; // Returns a canonicalized string for the instantiation of the @@ -656,7 +662,7 @@ bool RegisterOp(const string& op, Creator func); // Returns OK the gradient creator for the "op" is found (may be // nullptr if REGISTER_OP_NO_GRADIENT is used. Status GetOpGradientCreator(const string& op, Creator* creator); -}; +}; // namespace gradient // Declare explicit instantiations of GetAttr #define GET_ATTR(T) \ diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index bd018b7243897a5b45aa35d7fb94ca1ee1b12e75..1f670535d575e9bbc4196fb1f1e1c381d33ae204 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -35,8 +35,8 @@ namespace tensorflow { string SummarizeGraphDef(const GraphDef& graph_def) { string ret; - strings::StrAppend(&ret, "versions = ", - ProtoShortDebugString(graph_def.versions()), ";\n"); + strings::StrAppend( + &ret, "versions = ", ProtoShortDebugString(graph_def.versions()), ";\n"); for (const NodeDef& node : graph_def.node()) { strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n"); } @@ -90,9 +90,9 @@ static Status RemoveNewDefaultAttrsFromNodeDef( FindAttr(attr.first, *producer_op_def); if (producer_attr_def == nullptr) { return errors::InvalidArgument( - "Attr '", attr.first, "' missing in producer's OpDef: ", - SummarizeOpDef(*producer_op_def), " but found in node: ", - SummarizeNodeDef(*node_def)); + "Attr '", attr.first, + "' missing in producer's OpDef: ", SummarizeOpDef(*producer_op_def), + " but found in node: ", SummarizeNodeDef(*node_def)); } // ...and it has the same value as the default in producer, if (producer_attr_def->has_default_value() && diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h index 99a5d0a054e9fe2c5dd729e165276369ebea7a71..4c38fbbe591a5d07ba4cbbea00dcbfb41ca2f403 100644 --- a/tensorflow/core/framework/numeric_types.h +++ b/tensorflow/core/framework/numeric_types.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ #include - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" // Disable clang-format to prevent 'FixedPoint' header from being included // before 'Tensor' header on which it depends. @@ -43,12 +42,47 @@ typedef Eigen::QUInt16 quint16; } // namespace tensorflow + + + +static inline tensorflow::bfloat16 FloatToBFloat16(float float_val) { +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return *reinterpret_cast( + reinterpret_cast(&float_val)); +#else + return *reinterpret_cast( + &(reinterpret_cast(&float_val)[1])); +#endif +} + namespace Eigen { -// TOOD(xpan): We probably need to overwrite more methods to have correct eigen -// behavior. E.g. loest(), is_integer, etc. See NumTraits.h in eigen. +// TODO(xpan): We probably need to overwrite more methods to have correct eigen +// behavior. E.g. epsilon(), dummy_precision, etc. See NumTraits.h in eigen. template <> struct NumTraits - : GenericNumTraits {}; + : GenericNumTraits { + enum { + IsInteger = 0, + IsSigned = 1, + RequireInitialization = 0 + }; + static EIGEN_STRONG_INLINE tensorflow::bfloat16 highest() { + return FloatToBFloat16(NumTraits::highest()); + } + + static EIGEN_STRONG_INLINE tensorflow::bfloat16 lowest() { + return FloatToBFloat16(NumTraits::lowest()); + } + + static EIGEN_STRONG_INLINE tensorflow::bfloat16 infinity() { + return FloatToBFloat16(NumTraits::infinity()); + } + + static EIGEN_STRONG_INLINE tensorflow::bfloat16 quiet_NaN() { + return FloatToBFloat16(NumTraits::quiet_NaN()); + } +}; + using ::tensorflow::operator==; using ::tensorflow::operator!=; diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index a4e8add6c49b823948eb5978f99239bb4d9b52ef..2d035ab90d0f4493f6b6f572d0dd8550f5098e7e 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -170,20 +170,20 @@ const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) { return nullptr; } -#define VALIDATE(EXPR, ...) \ - do { \ - if (!(EXPR)) { \ - return errors::InvalidArgument(__VA_ARGS__, "; in OpDef: ", \ - ProtoShortDebugString(op_def)); \ - } \ +#define VALIDATE(EXPR, ...) \ + do { \ + if (!(EXPR)) { \ + return errors::InvalidArgument( \ + __VA_ARGS__, "; in OpDef: ", ProtoShortDebugString(op_def)); \ + } \ } while (false) static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, bool output, std::set* names) { const string suffix = strings::StrCat( output ? " for output '" : " for input '", arg.name(), "'"); - VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), "Duplicate name: ", - arg.name()); + VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), + "Duplicate name: ", arg.name()); VALIDATE(HasAttrStyleType(arg), "Missing type", suffix); if (!arg.number_attr().empty()) { @@ -250,8 +250,8 @@ Status ValidateOpDef(const OpDef& op_def) { std::set names; // for detecting duplicate names for (const auto& attr : op_def.attr()) { // Validate name - VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), "Duplicate name: ", - attr.name()); + VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), + "Duplicate name: ", attr.name()); DataType dt; VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ", attr.name(), " that matches a data type"); @@ -680,8 +680,8 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, if (!penultimate_attr.has_default_value() || !new_attr->has_default_value()) { return errors::InvalidArgument("Missing default for attr '", - penultimate_attr.name(), "' in op: ", - SummarizeOpDef(new_op)); + penultimate_attr.name(), + "' in op: ", SummarizeOpDef(new_op)); } // Actually test that the attr's default value hasn't changed. diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h index d1613ee89b29ef0bcdd97b1bc3c34edbcb65f5d8..0ba1325a03b148e0a1c8fe94723e2dc5503773d1 100644 --- a/tensorflow/core/framework/op_def_util.h +++ b/tensorflow/core/framework/op_def_util.h @@ -82,7 +82,7 @@ bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2); uint64 AttrDefHash(const OpDef::AttrDef& a); // Returns true if all AttrDefs in `a1` equal corresponding AttrDefs in -// `a2`. Corrspondence is established by name. +// `a2`. Correspondence is established by name. bool RepeatedAttrDefEqual(const protobuf::RepeatedPtrField& a1, const protobuf::RepeatedPtrField& a2); diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc index 28809c11c58704479c9c45b1de96dffef3d575bd..2b9812d4fcbc145540155959b19dd37cf902c1a2 100644 --- a/tensorflow/core/framework/op_def_util_test.cc +++ b/tensorflow/core/framework/op_def_util_test.cc @@ -200,10 +200,11 @@ TEST_F(ValidateOpDefTest, BadAttrDefault) { "default_value { list { s: ['foo'] } } }"), "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op " "'BadAttrDef'"); - ExpectFailure(TestBuilder(OpDefBuilder("GoodAttrDef") - .Attr("a: list(type) >=2 = [DT_STRING]")), - "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op " - "'GoodAttrDef'"); + ExpectFailure( + TestBuilder( + OpDefBuilder("GoodAttrDef").Attr("a: list(type) >=2 = [DT_STRING]")), + "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op " + "'GoodAttrDef'"); } TEST_F(ValidateOpDefTest, NoRefTypes) { @@ -213,9 +214,10 @@ TEST_F(ValidateOpDefTest, NoRefTypes) { ExpectFailure( TestBuilder(OpDefBuilder("BadAttrDef").Attr("T: type = DT_INT32_REF")), "AttrValue must not have reference type value of int32_ref"); - ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef") - .Attr("T: list(type) = [DT_STRING_REF]")), - "AttrValue must not have reference type value of string_ref"); + ExpectFailure( + TestBuilder( + OpDefBuilder("BadAttrDef").Attr("T: list(type) = [DT_STRING_REF]")), + "AttrValue must not have reference type value of string_ref"); } TEST_F(ValidateOpDefTest, BadAttrMin) { @@ -245,9 +247,10 @@ TEST_F(ValidateOpDefTest, BadAttrAllowed) { TF_EXPECT_OK(TestBuilder( OpDefBuilder("GoodAttrtude").Attr("x: numbertype = DT_INT32"))); // Not in list of allowed types. - ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") - .Attr("x: numbertype = DT_STRING")), - "attr 'x' of string is not in the list of allowed values"); + ExpectFailure( + TestBuilder( + OpDefBuilder("BadAttrtude").Attr("x: numbertype = DT_STRING")), + "attr 'x' of string is not in the list of allowed values"); ExpectFailure( TestBuilder(OpDefBuilder("BadAttrtude") .Attr("x: list(realnumbertype) = [DT_COMPLEX64]")), @@ -260,9 +263,10 @@ TEST_F(ValidateOpDefTest, BadAttrAllowed) { TF_EXPECT_OK(TestBuilder( OpDefBuilder("GoodAttrtude").Attr("x: {'foo', 'bar'} = 'bar'"))); // Not in list of allowed strings. - ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") - .Attr("x: {'foo', 'bar'} = 'baz'")), - "attr 'x' of \"baz\" is not in the list of allowed values"); + ExpectFailure( + TestBuilder( + OpDefBuilder("BadAttrtude").Attr("x: {'foo', 'bar'} = 'baz'")), + "attr 'x' of \"baz\" is not in the list of allowed values"); ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") .Attr("x: list({'foo', 'bar'}) = ['baz']")), "attr 'x' of \"baz\" is not in the list of allowed values"); diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index e78b6ab5d977c6ea2f0dec66988432a617154916..5f2eb9d99ab11f9862bd277d93af61c05e2517f4 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -266,35 +266,6 @@ static void StringReplace(const string& from, const string& to, string* s) { *s = str_util::Join(split, to.c_str()); } -static void RenameInDocs(const string& from, const string& to, OpDef* op_def) { - const string from_quoted = strings::StrCat("`", from, "`"); - const string to_quoted = strings::StrCat("`", to, "`"); - for (int i = 0; i < op_def->input_arg_size(); ++i) { - if (!op_def->input_arg(i).description().empty()) { - StringReplace(from_quoted, to_quoted, - op_def->mutable_input_arg(i)->mutable_description()); - } - } - for (int i = 0; i < op_def->output_arg_size(); ++i) { - if (!op_def->output_arg(i).description().empty()) { - StringReplace(from_quoted, to_quoted, - op_def->mutable_output_arg(i)->mutable_description()); - } - } - for (int i = 0; i < op_def->attr_size(); ++i) { - if (!op_def->attr(i).description().empty()) { - StringReplace(from_quoted, to_quoted, - op_def->mutable_attr(i)->mutable_description()); - } - } - if (!op_def->summary().empty()) { - StringReplace(from_quoted, to_quoted, op_def->mutable_summary()); - } - if (!op_def->description().empty()) { - StringReplace(from_quoted, to_quoted, op_def->mutable_description()); - } -} - static void RenameInDocs(const string& from, const string& to, ApiDef* api_def) { const string from_quoted = strings::StrCat("`", from, "`"); @@ -325,7 +296,6 @@ static void RenameInDocs(const string& from, const string& to, } } - namespace { // Initializes given ApiDef with data in OpDef. diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h index 94fe194a1a5032b472259d26145ba7cd4460191c..ff38e4b22141a7f1b7212a516ec5adbd5c7aad79 100644 --- a/tensorflow/core/framework/op_gen_lib.h +++ b/tensorflow/core/framework/op_gen_lib.h @@ -47,7 +47,6 @@ string PBTxtToMultiline(StringPiece pbtxt, const std::vector& multi_line_fields); string PBTxtFromMultiline(StringPiece multiline_pbtxt); - // Takes a list of files with ApiDefs text protos, and allows you to // look up the specific ApiDef for any given op. class ApiDefMap { diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index aee3a0afbca23a180d5415fef2b1b405f23b3f53..8654437059ca449432e6381b9eb3c4ba15e56f48 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -79,8 +79,14 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, // OpKernel ------------------------------------------------------------------ +// TODO(mrry): Convert to std::make_unique when available. OpKernel::OpKernel(OpKernelConstruction* context) - : def_(new NodeDef(context->def())), + : OpKernel(context, + std::unique_ptr(new NodeDef(context->def()))) {} + +OpKernel::OpKernel(OpKernelConstruction* context, + std::unique_ptr node_def) + : def_(std::move(node_def)), input_types_(context->input_types().begin(), context->input_types().end()), input_memory_types_(context->input_memory_types().begin(), @@ -101,7 +107,8 @@ OpKernel::OpKernel(OpKernelConstruction* context) // Kernels executing on GPU/SYCL tie very few resources on the CPU where the // scheduler runs: we consider them as inexpensive. - expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && context->device_type() != DeviceType(DEVICE_SYCL); + expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && + context->device_type() != DeviceType(DEVICE_SYCL); } OpKernel::~OpKernel() {} @@ -253,7 +260,7 @@ OpKernelContext::OpKernelContext(Params* params) OpKernelContext::OpKernelContext(Params* params, int num_outputs) : params_(params), outputs_(num_outputs), - temp_memory_size_(0), + temp_memory_allocated_(0), persistent_memory_allocated_(0) { Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); params_->ensure_eigen_gpu_device(); @@ -467,7 +474,7 @@ std::unique_ptr OpKernelContext::forward_input( return nullptr; } // Check that input and output memory types match, i.e. - // that they either both live in host or both live in device memmory. + // that they either both live in host or both live in device memory. if (input_memory_type(input_index) != output_memory_type) { return nullptr; } @@ -662,12 +669,11 @@ Status OpKernelContext::allocate_temp( const AllocationAttributes& allocation_attr) { Status s = allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr); - if (track_allocations() && out_temp->TotalBytes() > 0) { + if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) { Allocator* a = get_allocator(allocator_attr); if (a->TracksAllocationSizes()) { - int64 alloc_size = - a->AllocatedSize(const_cast(out_temp->tensor_data().data())); - record_temp_memory_size(alloc_size); + int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data()); + record_temp_memory_allocation(alloc_size, *out_temp); } } return s; @@ -685,6 +691,15 @@ Status OpKernelContext::allocate_persistent(DataType type, if (out_tensor) { *out_tensor = out_persistent->AccessTensor(this); } + if (track_allocations()) { + Tensor* t = out_persistent->AccessTensor(this); + Allocator* a = get_allocator(attr); + if (a->TracksAllocationSizes()) { + int64 alloc_size = a->AllocatedSize(t->tensor_data().data()); + int64 alloc_id = a->AllocationId(t->tensor_data().data()); + record_persistent_memory_allocation(alloc_size, alloc_id); + } + } } return s; } @@ -709,6 +724,22 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) { DCHECK_EQ(mutable_output(index), nullptr); record_tensor_reference(tensor); outputs_[index] = TensorValue(new Tensor(tensor)); + if (track_allocations() && tensor.TotalBytes() > 0) { + mutex_lock l(stats_mu_); + if (!temp_tensor_buffer_and_size_) { + return; + } + auto it = std::find_if(temp_tensor_buffer_and_size_->begin(), + temp_tensor_buffer_and_size_->end(), + [&tensor](const std::pair& e) { + return e.first == static_cast( + tensor.tensor_data().data()); + }); + if (it != temp_tensor_buffer_and_size_->end()) { + temp_memory_allocated_ -= it->second; + temp_tensor_buffer_and_size_->erase(it); + } + } } void OpKernelContext::set_output_ref(int index, mutex* mu, @@ -786,19 +817,60 @@ Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, outputs); } -bool OpKernelContext::allocate_on_host(AllocatorAttributes alloc_attr) const { - return alloc_attr.on_host() || device()->attributes().device_type() == "CPU"; +void OpKernelContext::record_temp_memory_allocation(int64 size, + const Tensor& t) { + mutex_lock l(stats_mu_); + temp_memory_allocated_ += size; + if (!temp_tensor_buffer_and_size_) { + temp_tensor_buffer_and_size_.reset( + new gtl::InlinedVector, 2>()); + } + temp_tensor_buffer_and_size_->emplace_back( + static_cast(t.tensor_data().data()), size); +} + +int64 OpKernelContext::temp_memory_allocated() const { + mutex_lock l(stats_mu_); + return temp_memory_allocated_; } void OpKernelContext::record_persistent_memory_allocation(int64 size, int64 alloc_id) { + mutex_lock l(stats_mu_); persistent_memory_allocated_ += size; - persistent_alloc_ids_.push_back(alloc_id); + if (alloc_id >= 0) { + if (!persistent_alloc_ids_) { + persistent_alloc_ids_.reset(new gtl::InlinedVector()); + } + persistent_alloc_ids_->push_back(alloc_id); + } +} + +int64 OpKernelContext::persistent_memory_allocated() const { + mutex_lock l(stats_mu_); + return persistent_memory_allocated_; } std::vector OpKernelContext::persistent_alloc_ids() const { - return std::vector(persistent_alloc_ids_.begin(), - persistent_alloc_ids_.end()); + mutex_lock l(stats_mu_); + if (persistent_alloc_ids_) { + return std::vector(persistent_alloc_ids_->begin(), + persistent_alloc_ids_->end()); + } else { + return std::vector(); + } +} + +void OpKernelContext::clear_recorded_memory() { + mutex_lock l(stats_mu_); + temp_memory_allocated_ = 0; + persistent_memory_allocated_ = 0; + if (temp_tensor_buffer_and_size_) { + temp_tensor_buffer_and_size_->clear(); + } + if (persistent_alloc_ids_) { + persistent_alloc_ids_->clear(); + } } // OpKernel registration ------------------------------------------------------ @@ -943,13 +1015,6 @@ Status FindKernelRegistration(const DeviceType& device_type, return Status::OK(); } -Status FindKernelRegistration(const DeviceType& device_type, const Node& node, - const KernelRegistration** reg, - bool* was_attr_mismatch) { - return FindKernelRegistration(device_type, node.def(), reg, - was_attr_mismatch); -} - } // namespace // TODO(irving): Change const NodeDef& to const Node& diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index b72f1405cffd83439dd837fa7f8e641ecf44e2ae..5ccd45efc980393aa02582595dde873be7426e26 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -75,6 +75,14 @@ class OpKernel { // OpKernel won't be instantiated by the scheduler, so you may perform // expensive initialization in the descendant's constructor. explicit OpKernel(OpKernelConstruction* context); + + // Specialized constructor that enables the descendant to provide a different + // `NodeDef` value. For example, this constructor can be used to provide a + // stripped-down `NodeDef` that does not contain the full set of attrs (such + // as tensor values) if the descendant stores them in a different form. + explicit OpKernel(OpKernelConstruction* context, + std::unique_ptr node_def); + virtual ~OpKernel(); // An OpKernel's computation can be either synchronous or @@ -901,9 +909,13 @@ class OpKernelContext { } AllocatorAttributes input_alloc_attr(int index) const { - DCHECK_GE(index, 0); - DCHECK_LT(index, params_->input_alloc_attrs->size()); - return (*params_->input_alloc_attrs)[index]; + if (params_->input_alloc_attrs == nullptr) { + return AllocatorAttributes(); + } else { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_->input_alloc_attrs->size()); + return (*params_->input_alloc_attrs)[index]; + } } AllocatorAttributes output_alloc_attr(int index) const { @@ -1034,24 +1046,27 @@ class OpKernelContext { TensorValue release_output(int index); bool track_allocations() const { return params_->track_allocations; } - bool allocate_on_host(AllocatorAttributes alloc_attr) const; - // Records temporary memory sizes. - void record_temp_memory_size(int64 size) { temp_memory_size_ += size; } + // Records temp memory allocation. Tensor object is recorded to identify the + // case where temp memory is used as output memory. + void record_temp_memory_allocation(int64 size, const Tensor& t) + LOCKS_EXCLUDED(stats_mu_); // Returns recorded size of temporary memory; - int64 temp_memory_size() const { return temp_memory_size_; } + int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_); // Records persistent memory allocation, size can be negative indicating // deallocation. - void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1); + void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1) + LOCKS_EXCLUDED(stats_mu_); // Returns recorded size and ids of persistent memory. - int64 persistent_memory_allocated() const { - return persistent_memory_allocated_; - } + int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_); + + std::vector persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_); - std::vector persistent_alloc_ids() const; + // Resets counters for temp and persistent memory and recorded ids. + void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_); bool input_is_ref(int index) const; @@ -1096,9 +1111,15 @@ class OpKernelContext { bool is_output_dead_ = false; - int64 temp_memory_size_; - gtl::InlinedVector persistent_alloc_ids_; - int64 persistent_memory_allocated_; + // The following data members are only used when allocation tracking is + // enabled. + mutable mutex stats_mu_; + int64 temp_memory_allocated_ GUARDED_BY(stats_mu_); + int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_); + std::unique_ptr, 2>> + temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_); + std::unique_ptr> persistent_alloc_ids_ + GUARDED_BY(stats_mu_); TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext); }; diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 94a9d1335a7c46372e05633431427d44fc46e027..b53b877f28d2c80e969fb418aa316ad96c6e2eaa 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -510,10 +510,9 @@ TEST_F(OpKernelBuilderTest, BuilderBoth) { } REGISTER_OP("BuildTypeAttr").Attr("T: type"); -REGISTER_KERNEL_BUILDER(Name("BuildTypeAttr") - .Device(DEVICE_CPU) - .TypeConstraint("T"), - DummyKernel); +REGISTER_KERNEL_BUILDER( + Name("BuildTypeAttr").Device(DEVICE_CPU).TypeConstraint("T"), + DummyKernel); TEST_F(OpKernelBuilderTest, BuilderTypeAttr) { ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"}); @@ -525,10 +524,9 @@ TEST_F(OpKernelBuilderTest, BuilderTypeAttr) { } REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)"); -REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr") - .Device(DEVICE_CPU) - .TypeConstraint("T"), - DummyKernel); +REGISTER_KERNEL_BUILDER( + Name("BuildTypeListAttr").Device(DEVICE_CPU).TypeConstraint("T"), + DummyKernel); TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) { ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"}); @@ -574,14 +572,12 @@ TEST_F(OpKernelBuilderTest, DuplicateKernel) { } REGISTER_OP("DuplicateKernelForT").Attr("T: type"); -REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT") - .Device(DEVICE_CPU) - .TypeConstraint("T"), - DummyKernel); -REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT") - .Device(DEVICE_CPU) - .TypeConstraint("T"), - DummyKernel); +REGISTER_KERNEL_BUILDER( + Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint("T"), + DummyKernel); +REGISTER_KERNEL_BUILDER( + Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint("T"), + DummyKernel); TEST_F(OpKernelBuilderTest, DuplicateKernelForT) { const NodeDef ndef = diff --git a/tensorflow/core/framework/reader_base.cc b/tensorflow/core/framework/reader_base.cc index b8c771a0a1955b29f78478f60972b22d804351b2..f84ef0f953cf23e3fb2af210706586f95cfbb8ad 100644 --- a/tensorflow/core/framework/reader_base.cc +++ b/tensorflow/core/framework/reader_base.cc @@ -178,9 +178,9 @@ void ReaderBase::Read(QueueInterface* queue, string* key, string* value, " must set *at_end=true, *produced=true, or return an error."); } if (!status.ok() && produced) { - status = errors::Internal("ReadLocked() for ", name(), - " set *produced=true *and* returned an error: ", - status.ToString()); + status = errors::Internal( + "ReadLocked() for ", name(), + " set *produced=true *and* returned an error: ", status.ToString()); } if (status.ok() && at_end) { status = OnWorkFinishedLocked(); diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h index e062adffe821464cd349227cde17b9d4db54c44e..e90596980f840588768c7883031f1ad179628833 100644 --- a/tensorflow/core/framework/register_types.h +++ b/tensorflow/core/framework/register_types.h @@ -179,7 +179,7 @@ limitations under the License. // Call "m" on all types. #define TF_CALL_ALL_TYPES(m) \ - TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) + TF_CALL_POD_TYPES(m) TF_CALL_string(m) TF_CALL_resource(m) TF_CALL_variant(m) // Call "m" on POD and string types. #define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m) @@ -211,14 +211,12 @@ limitations under the License. #define TF_CALL_SYCL_double(m) #else // TENSORFLOW_SYCL_NO_DOUBLE #define TF_CALL_SYCL_double(m) TF_CALL_double(m) -#endif // TENSORFLOW_SYCL_NO_DOUBLE +#endif // TENSORFLOW_SYCL_NO_DOUBLE #ifdef __ANDROID_TYPES_SLIM__ -#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) +#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) #else // __ANDROID_TYPES_SLIM__ -#define TF_CALL_SYCL_NUMBER_TYPES(m) \ - TF_CALL_float(m) \ - TF_CALL_SYCL_double(m) -#endif // __ANDROID_TYPES_SLIM__ +#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) TF_CALL_SYCL_double(m) +#endif // __ANDROID_TYPES_SLIM__ #endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ diff --git a/tensorflow/core/framework/register_types_traits.h b/tensorflow/core/framework/register_types_traits.h index c1fe5517c6986838a07f67c0f2fa5474f89ffa33..ab35c2f0951d21e63fe06e378461c019e45495f1 100644 --- a/tensorflow/core/framework/register_types_traits.h +++ b/tensorflow/core/framework/register_types_traits.h @@ -23,7 +23,7 @@ typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/platform/types.h" @@ -79,7 +79,7 @@ template <> struct proxy_type_pod { typedef float type; }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL /// If POD we use proxy_type_pod, otherwise this maps to identiy. template @@ -99,7 +99,7 @@ struct proxy_type { #ifdef TENSORFLOW_USE_SYCL #define TF_CALL_SYCL_PROXY_TYPES(m) \ TF_CALL_double(m) TF_CALL_float(m) TF_CALL_int32(m) -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc index 32b8ad784d5228a40a073d166f33972def380280..de148f0bd3474421c1361cf7ae4aa681107aa883 100644 --- a/tensorflow/core/framework/rendezvous_test.cc +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -69,9 +69,7 @@ class LocalRendezvousTest : public ::testing::Test { rendez_ = NewLocalRendezvous(); } - ~LocalRendezvousTest() override { - rendez_->Unref(); - } + ~LocalRendezvousTest() override { rendez_->Unref(); } void SchedClosure(std::function fn) { threads_.Schedule(std::move(fn)); @@ -99,8 +97,8 @@ string V(const Tensor& tensor) { Rendezvous::ParsedKey MakeKey(const string& name) { string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890, - "/job:mnist/replica:1/task:2/device:GPU:0", name, - FrameAndIter(0, 0)); + "/job:mnist/replica:1/task:2/device:GPU:0", + name, FrameAndIter(0, 0)); Rendezvous::ParsedKey k; TF_EXPECT_OK(Rendezvous::ParseKey(s, &k)); return k; diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index d552ec1693f89a6695609681f2e8bffa9d78f93c..e3cc848a169bd848b8f3617d552938ba1ced3663 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -32,7 +32,7 @@ class ShapeRefinerTest; namespace grappler { class GraphProperties; class SymbolicShapeManager; -} +} // namespace grappler namespace shape_inference { diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index a9b63ca60e4574bb0d59c4b939ac157e62f317e8..f48a7b9c47df3cfa93434ccf585dda8c5a29a2ba 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -760,7 +760,10 @@ TEST_F(ShapeInferenceTest, MergePrefix) { NodeDef def; InferenceContext c(kVersion, &def, MakeOpDef(4, 2), { - Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}), + Unknown(), + S({-1, 2}), + S({1, -1, 3}), + S({2, 4}), }, {}, {}, {}); diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 77a3edcc10e9c5ceb8bf26570c3e271f9e853444..5d32b71628263fe89d6f54fd07b2fe18bbb55e53 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -886,8 +886,9 @@ bool Tensor::CanUseDMA() const { namespace { // Print from left dim to right dim recursively. template -void PrintOneDim(int dim_index, gtl::InlinedVector shape, int64 limit, - int shape_size, T* data, int64* data_index, string* result) { +void PrintOneDim(int dim_index, const gtl::InlinedVector& shape, + int64 limit, int shape_size, const T* data, int64* data_index, + string* result) { if (*data_index >= limit) return; int64 element_count = shape[dim_index]; // We have reached the right-most dimension of the tensor. @@ -1024,9 +1025,8 @@ StringPiece Tensor::tensor_data() const { } bool Tensor::SharesBufferWith(const Tensor& b) const { - CHECK_NE(nullptr, buf_); - CHECK_NE(nullptr, b.buf_); - return buf_->root_buffer() == b.buf_->root_buffer(); + return buf_ != nullptr && b.buf_ != nullptr && + buf_->root_buffer() == b.buf_->root_buffer(); } string Tensor::DebugString() const { diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 94c39c53a6fbb6a30e054346a2ec608a6970c373..62c42ba652356a5128d4a337e34a3b449781b445 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -660,8 +660,7 @@ void Tensor::FillDimsAndValidateCompatibleShape( template typename TTypes::Tensor Tensor::shaped( gtl::ArraySlice new_sizes) { - CheckType(DataTypeToEnum::v()); - CHECK(IsAligned()); + CheckTypeAndIsAligned(DataTypeToEnum::v()); Eigen::array dims; FillDimsAndValidateCompatibleShape(new_sizes, &dims); return typename TTypes::Tensor(base(), dims); diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index adb41b81c6ec019ce51a3871ca329c82f8a1f4b7..fe2ba375aa0c5c50009b3155338cd8860070d47a 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -191,9 +191,6 @@ class TensorShapeBase : public TensorShapeRep { /// Appends all the dimensions from `shape`. void AppendShape(const TensorShapeBase& shape); - // Maximum number of dimensions in a tensor. - static constexpr int MaxDimensions() { return 254; } - /// \brief Insert a dimension somewhere in the `TensorShape`. /// REQUIRES: `0 <= d <= dims()` /// REQUIRES: `size >= 0` diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc index d8a9c0bac5b950157044dae07771b6733481ac9e..d7517bb311d517351f4dd2a59438780482485dff 100644 --- a/tensorflow/core/framework/tensor_shape_test.cc +++ b/tensorflow/core/framework/tensor_shape_test.cc @@ -582,7 +582,8 @@ TEST(TensorShapeTest, Large) { TEST(TensorShapeTest, Overflow) { int64 one = 1; std::vector> overflows = { - {1 << 30, 1 << 30, 1 << 30}, {1 << 5, (one << 60) + 1}, + {1 << 30, 1 << 30, 1 << 30}, + {1 << 5, (one << 60) + 1}, }; for (const auto& overflow : overflows) { TensorShapeProto proto; diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 81644388abcf9c14bc5812069f25906a7f72b4cc..b613effd18bbbaf107a56b518859024db1c9bbb2 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -1085,6 +1085,21 @@ class DummyCPUAllocator : public Allocator { void DeallocateRaw(void* ptr) override {} }; +TEST(Tensor, SharesBufferWith) { + Tensor a_empty; + Tensor b_empty; + Tensor a(DT_FLOAT, TensorShape({1})); + Tensor b(DT_FLOAT, TensorShape({1})); + Tensor copy(a); + EXPECT_FALSE(a_empty.SharesBufferWith(a_empty)); + EXPECT_FALSE(a_empty.SharesBufferWith(b_empty)); + EXPECT_FALSE(a_empty.SharesBufferWith(a)); + EXPECT_FALSE(a_empty.SharesBufferWith(copy)); + EXPECT_TRUE(a.SharesBufferWith(a)); + EXPECT_FALSE(a.SharesBufferWith(b)); + EXPECT_TRUE(a.SharesBufferWith(copy)); +} + TEST(Tensor, FailureToAllocate) { TensorShape shape({1}); DummyCPUAllocator allocator; diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc index a8d141230093152397c792588a716c00556df77d..8f480d65f25012b858d7d375196b2693d3a533b9 100644 --- a/tensorflow/core/framework/tensor_testutil.cc +++ b/tensorflow/core/framework/tensor_testutil.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "tensorflow/core/framework/tensor_testutil.h" +#include namespace tensorflow { namespace test { diff --git a/tensorflow/core/framework/tensor_types.h b/tensorflow/core/framework/tensor_types.h index 921f88dc0ba09e7904333613b728021751d5425c..a5c1a56bfc06a9785f08c468f78bda5111e15409 100644 --- a/tensorflow/core/framework/tensor_types.h +++ b/tensorflow/core/framework/tensor_types.h @@ -25,7 +25,8 @@ template struct TTypes { // Rank- tensor of scalar type T. typedef Eigen::TensorMap, - Eigen::Aligned> Tensor; + Eigen::Aligned> + Tensor; typedef Eigen::TensorMap< Eigen::Tensor, Eigen::Aligned> ConstTensor; @@ -33,35 +34,42 @@ struct TTypes { // Unaligned Rank- tensor of scalar type T. typedef Eigen::TensorMap > UnalignedTensor; - typedef Eigen::TensorMap > UnalignedConstTensor; + typedef Eigen::TensorMap< + Eigen::Tensor > + UnalignedConstTensor; typedef Eigen::TensorMap, - Eigen::Aligned> Tensor32Bit; + Eigen::Aligned> + Tensor32Bit; // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. typedef Eigen::TensorMap< Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, - Eigen::Aligned> Scalar; + Eigen::Aligned> + Scalar; typedef Eigen::TensorMap, Eigen::RowMajor, IndexType>, - Eigen::Aligned> ConstScalar; + Eigen::Aligned> + ConstScalar; // Unaligned Scalar tensor of scalar type T. - typedef Eigen::TensorMap, Eigen::RowMajor, IndexType> > UnalignedScalar; + typedef Eigen::TensorMap< + Eigen::TensorFixedSize, Eigen::RowMajor, IndexType> > + UnalignedScalar; typedef Eigen::TensorMap, Eigen::RowMajor, IndexType> > UnalignedConstScalar; // Rank-1 tensor (vector) of scalar type T. typedef Eigen::TensorMap, - Eigen::Aligned> Flat; + Eigen::Aligned> + Flat; typedef Eigen::TensorMap< Eigen::Tensor, Eigen::Aligned> ConstFlat; typedef Eigen::TensorMap, - Eigen::Aligned> Vec; + Eigen::Aligned> + Vec; typedef Eigen::TensorMap< Eigen::Tensor, Eigen::Aligned> ConstVec; @@ -69,16 +77,19 @@ struct TTypes { // Unaligned Rank-1 tensor (vector) of scalar type T. typedef Eigen::TensorMap > UnalignedFlat; - typedef Eigen::TensorMap > UnalignedConstFlat; + typedef Eigen::TensorMap< + Eigen::Tensor > + UnalignedConstFlat; typedef Eigen::TensorMap > UnalignedVec; typedef Eigen::TensorMap< - Eigen::Tensor > UnalignedConstVec; + Eigen::Tensor > + UnalignedConstVec; // Rank-2 tensor (matrix) of scalar type T. typedef Eigen::TensorMap, - Eigen::Aligned> Matrix; + Eigen::Aligned> + Matrix; typedef Eigen::TensorMap< Eigen::Tensor, Eigen::Aligned> ConstMatrix; @@ -86,8 +97,9 @@ struct TTypes { // Unaligned Rank-2 tensor (matrix) of scalar type T. typedef Eigen::TensorMap > UnalignedMatrix; - typedef Eigen::TensorMap > UnalignedConstMatrix; + typedef Eigen::TensorMap< + Eigen::Tensor > + UnalignedConstMatrix; }; typedef typename TTypes::Tensor32Bit::Index Index32; diff --git a/tensorflow/core/framework/tracking_allocator.cc b/tensorflow/core/framework/tracking_allocator.cc index 65c98ad1ee7bf68249389890c05cd968ddbf068c..2df402573a58ad3728e03a22d391b32766c49b00 100644 --- a/tensorflow/core/framework/tracking_allocator.cc +++ b/tensorflow/core/framework/tracking_allocator.cc @@ -113,7 +113,7 @@ bool TrackingAllocator::TracksAllocationSizes() { return track_sizes_locally_ || allocator_->TracksAllocationSizes(); } -size_t TrackingAllocator::RequestedSize(void* ptr) { +size_t TrackingAllocator::RequestedSize(const void* ptr) { if (track_sizes_locally_) { mutex_lock lock(mu_); auto it = in_use_.find(ptr); @@ -126,7 +126,7 @@ size_t TrackingAllocator::RequestedSize(void* ptr) { } } -size_t TrackingAllocator::AllocatedSize(void* ptr) { +size_t TrackingAllocator::AllocatedSize(const void* ptr) { if (track_sizes_locally_) { mutex_lock lock(mu_); auto it = in_use_.find(ptr); @@ -139,7 +139,7 @@ size_t TrackingAllocator::AllocatedSize(void* ptr) { } } -int64 TrackingAllocator::AllocationId(void* ptr) { +int64 TrackingAllocator::AllocationId(const void* ptr) { if (track_sizes_locally_) { mutex_lock lock(mu_); auto it = in_use_.find(ptr); diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h index 4825ed414f0dc64d98ab848e18f8aceb88629f40..f6c3c0b71b951c3b89b0d444c0d2d588a395dadd 100644 --- a/tensorflow/core/framework/tracking_allocator.h +++ b/tensorflow/core/framework/tracking_allocator.h @@ -64,9 +64,9 @@ class TrackingAllocator : public Allocator { const AllocationAttributes& allocation_attr) override; void DeallocateRaw(void* ptr) override; bool TracksAllocationSizes() override; - size_t RequestedSize(void* ptr) override; - size_t AllocatedSize(void* ptr) override; - int64 AllocationId(void* ptr) override; + size_t RequestedSize(const void* ptr) override; + size_t AllocatedSize(const void* ptr) override; + int64 AllocationId(const void* ptr) override; void GetStats(AllocatorStats* stats) override; void ClearStats() override; @@ -125,7 +125,7 @@ class TrackingAllocator : public Allocator { size_t allocated_size; int64 allocation_id; }; - std::unordered_map in_use_ GUARDED_BY(mu_); + std::unordered_map in_use_ GUARDED_BY(mu_); int64 next_allocation_id_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/framework/tracking_allocator_test.cc b/tensorflow/core/framework/tracking_allocator_test.cc index 4e32a907f20f34183abbbc57b93c38197710fa51..2cdc7edd2d1e9f2634a96e85879dc45a53f633cc 100644 --- a/tensorflow/core/framework/tracking_allocator_test.cc +++ b/tensorflow/core/framework/tracking_allocator_test.cc @@ -39,7 +39,7 @@ class TestableSizeTrackingAllocator : public Allocator { port::Free(ptr); } bool TracksAllocationSizes() override { return true; } - size_t RequestedSize(void* ptr) override { + size_t RequestedSize(const void* ptr) override { const auto& iter = size_map_.find(ptr); EXPECT_NE(size_map_.end(), iter); return iter->second; @@ -47,7 +47,7 @@ class TestableSizeTrackingAllocator : public Allocator { void GetStats(AllocatorStats* stats) override { stats->Clear(); } private: - std::unordered_map size_map_; + std::unordered_map size_map_; }; class NoMemoryAllocator : public Allocator { diff --git a/tensorflow/core/framework/types_test.cc b/tensorflow/core/framework/types_test.cc index 5ddc9865633623561760bbcb06d1edf4eecec7a6..60f2b4135a68c4eed618e3efb07758fbab85fa07 100644 --- a/tensorflow/core/framework/types_test.cc +++ b/tensorflow/core/framework/types_test.cc @@ -70,8 +70,8 @@ TEST(TypesTest, kDataTypeRefOffset) { << "Extra reference enum " << enum_descriptor->FindValueByNumber(e_ref)->name() << " without corresponding base enum with value " << e; - ASSERT_LT(DataType_MAX, e_ref) << "Gap in reference types, missing value for " - << e_ref; + ASSERT_LT(DataType_MAX, e_ref) + << "Gap in reference types, missing value for " << e_ref; // Make sure there are no enums defined after the last regular type before // the first reference type. diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index 0e2a410429d199998722e68280a8438465988ddd..c9e8dd2217e0dc0225fa38d0739d1551e0ba2433 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -177,10 +177,10 @@ class UnaryVariantOpRegistry { Op op_type_; StringPiece device_, typename_; }; - //friend declaration for operator== + // friend declaration for operator== // needed for clang template - friend bool operator==(const FuncTuple &l, const FuncTuple &r); + friend bool operator==(const FuncTuple& l, const FuncTuple& r); struct TupleHash { template std::size_t operator()( @@ -208,7 +208,8 @@ class UnaryVariantOpRegistry { binary_op_fns; // Find or insert a string into a persistent string storage - // container; return the StringPiece pointing to the permanent string location. + // container; return the StringPiece pointing to the permanent string + // location. static StringPiece GetPersistentStringPiece(const string& str) { const auto string_storage = PersistentStringStorage(); auto found = string_storage->find(str); diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc index 0cdcdb66856f0135720277bb7fab23dd24d3dde9..99ced0c0f5daa7c722aa4060e9a954855411010b 100644 --- a/tensorflow/core/graph/algorithm_test.cc +++ b/tensorflow/core/graph/algorithm_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" @@ -81,7 +82,7 @@ TEST(AlgorithmTest, ReversePostOrder) { BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3")); Graph g(OpRegistry::Global()); - TF_ASSERT_OK(b.ToGraph(&g)); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); std::vector order; // Test reverse post order: @@ -139,7 +140,7 @@ TEST(AlgorithmTest, ReversePostOrderStable) { BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t3")); Graph g(OpRegistry::Global()); - TF_ASSERT_OK(b.ToGraph(&g)); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); std::vector order; // Test reverse post order generates expected ordering. diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc index b1e6cf64e837a04b0121a5e2c5c5a905cf1821f6..4f3a6ec38cb88213c7127df41823bc16e9834d09 100644 --- a/tensorflow/core/graph/costmodel.cc +++ b/tensorflow/core/graph/costmodel.cc @@ -57,10 +57,10 @@ void CostModel::MergeFromLocal(const Graph& g, const CostModel& cm) { const int local_id = cm.Id(n); const int global_id = Id(n); if (local_id < 0 || global_id < 0) continue; - Ensure(global_id); + int num_slots = cm.slot_bytes_[local_id].size(); + Ensure(global_id, num_slots); count_[global_id] += cm.count_[local_id]; time_[global_id] += cm.time_[local_id]; - int num_slots = cm.slot_bytes_[local_id].size(); if (num_slots > 0) { if (slot_bytes_[global_id].empty()) { slot_bytes_[global_id].resize(num_slots); @@ -78,11 +78,11 @@ void CostModel::MergeFromGlobal(const CostModel& cm) { CHECK(is_global_); CHECK_EQ(true, cm.is_global()); const int num_nodes = cm.count_.size(); - Ensure(num_nodes); - for (int i = 0; i < num_nodes; ++i) { + for (int i = num_nodes - 1; i >= 0; --i) { count_[i] += cm.count_[i]; time_[i] += cm.time_[i]; int num_slots = cm.slot_bytes_[i].size(); + Ensure(i, num_slots); if (num_slots > 0) { if (slot_bytes_[i].empty()) { slot_bytes_[i].resize(num_slots); @@ -106,7 +106,7 @@ void CostModel::MergeFromStats(const NodeNameToCostIdMap& map, // copy/send/recv nodes, feed/fetch, etc. if (iter == map.end()) continue; int32 global_id = iter->second; - Ensure(global_id); + Ensure(global_id, ns.output_size()); int64 elapsed_micros = ns.op_end_rel_micros() - ns.op_start_rel_micros(); count_[global_id]++; time_[global_id] += elapsed_micros; @@ -122,7 +122,7 @@ void CostModel::MergeFromStats(const NodeNameToCostIdMap& map, } } -void CostModel::Ensure(int id) { +void CostModel::Ensure(int id, int num_outputs) { if (slot_bytes_.size() <= static_cast(id)) { slot_bytes_.resize(id + 1); count_.resize(id + 1); @@ -131,25 +131,37 @@ void CostModel::Ensure(int id) { max_exec_time_.resize(id + 1); output_port_alloc_ids_.resize(id + 1); } + if (num_outputs > 0) { + auto perslot = &slot_bytes_[id]; + auto output_port_alloc_ids = &output_port_alloc_ids_[id]; + auto max_mem_usage = &max_mem_usage_[id]; + + CHECK_LE(perslot->size(), num_outputs); + DCHECK_EQ(output_port_alloc_ids->size(), perslot->size()); + DCHECK_EQ(max_mem_usage->output_port_mem.size(), perslot->size()); + DCHECK_EQ(max_mem_usage->output_port_shape.size(), perslot->size()); + DCHECK_EQ(max_mem_usage->output_port_type.size(), perslot->size()); + + perslot->resize(num_outputs, Bytes(-1)); + output_port_alloc_ids->resize(num_outputs, -1); + max_mem_usage->output_port_mem.resize(num_outputs, Bytes(-1)); + max_mem_usage->output_port_shape.resize(num_outputs, unknown_shape_); + max_mem_usage->output_port_type.resize(num_outputs, DT_INVALID); + } } void CostModel::SetNumOutputs(const Node* node, int num_outputs) { const int id = Id(node); if (id < 0) return; - Ensure(id); + // Do not resize the number of slots before checking its existing number of + // slots. + Ensure(id, 0); auto perslot = &slot_bytes_[id]; - auto max_mem_usage = &max_mem_usage_[id]; - auto output_port_alloc_ids = &output_port_alloc_ids_[id]; if (!perslot->empty()) { - CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node=" - << node->name(); - } else { - perslot->resize(num_outputs, Bytes(-1)); - output_port_alloc_ids->resize(num_outputs, -1); - max_mem_usage->output_port_mem.resize(num_outputs, Bytes(-1)); - max_mem_usage->output_port_shape.resize(num_outputs, unknown_shape_); - max_mem_usage->output_port_type.resize(num_outputs, DT_INVALID); + CHECK_EQ(num_outputs, perslot->size()) + << "Cannot resize slot_bytes, node=" << node->name(); } + Ensure(id, num_outputs); } void CostModel::RecordCount(const Node* node, int count) { @@ -198,7 +210,7 @@ void CostModel::RecordTime(const Node* node, Microseconds time) { const int id = Id(node); if (id < 0) return; DCHECK(node->IsOp()) << node->DebugString(); - Ensure(id); + Ensure(id, node->num_outputs()); time_[id] += time; } @@ -240,7 +252,13 @@ void CostModel::RecordMaxMemorySize(const Node* node, int output_slot, const DataType& dtype) { const int id = Id(node); if (id < 0) return; - Ensure(id); + if (output_slot >= node->num_outputs()) { + LOG(ERROR) << "Unexpected output slot for node " << node->DebugString() + << ". Got " << output_slot << " but its num_outputs is " + << node->num_outputs(); + return; + } + Ensure(id, node->num_outputs()); auto& current_max = max_mem_usage_[id].output_port_mem[output_slot]; // If the memory allocator doesn't track memory usage, let's infer a lower // bound from the tensor shape and its data type. @@ -316,7 +334,7 @@ void CostModel::RecordMemoryStats(const Node* node, void CostModel::RecordMaxExecutionTime(const Node* node, Microseconds time) { const int id = Id(node); if (id < 0) return; - Ensure(id); + Ensure(id, node->num_outputs()); max_exec_time_[id] = std::max(max_exec_time_[id], time); } @@ -332,7 +350,7 @@ void CostModel::RecordAllocationId(const Node* node, int output_slot, int64 alloc_id) { const int id = Id(node); if (id < 0) return; - Ensure(id); + Ensure(id, node->num_outputs()); output_port_alloc_ids_[id][output_slot] = alloc_id; } diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h index 081eb2ff4c226c4dd5079f16cc6c2a102d0d2d63..9b703e46938b3355ed769045cdb3f298b48bb922 100644 --- a/tensorflow/core/graph/costmodel.h +++ b/tensorflow/core/graph/costmodel.h @@ -183,8 +183,8 @@ class CostModel { const bool is_global_; - // Resizes vectors so that they are large enough for "id". - void Ensure(int id); + // Resizes vectors so that they are large enough for "id" and id's outputs. + void Ensure(int id, int num_outputs); // Nodes and Edges whose count is < this value // get type/byte estimates of 0. @@ -198,7 +198,7 @@ class CostModel { // Cumulative execution time. std::vector time_; // Cumulative Bytes output on each channel. - std::vector > slot_bytes_; + std::vector> slot_bytes_; // Maximum execution time std::vector max_exec_time_; @@ -217,7 +217,7 @@ class CostModel { }; std::vector max_mem_usage_; - std::vector > output_port_alloc_ids_; + std::vector> output_port_alloc_ids_; std::set persistent_alloc_ids_; std::map> persistent_alloc_ids_by_devices_; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index fd1b5d33b93d0e2685cd7a909bbcc9909d7d3f87..9b56216f1f97a9598dd7ae8b70786e32bb7e0f4b 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -522,6 +522,12 @@ void Graph::ToGraphDef(GraphDef* graph_def) const { ToGraphDefSubRange(graph_def, 0); } +GraphDef Graph::ToGraphDefDebug() const { + GraphDef ret; + ToGraphDef(&ret); + return ret; +} + void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const { graph_def->Clear(); *graph_def->mutable_versions() = versions(); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index b620127d9072a845721f97112f4bad107412b06f..9d96cd4654bbf1fd65c5135d6a8bdc271c6e9443 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -62,8 +62,8 @@ class Node; class VersionDef; class WhileContext; -class NeighborIter; // Declared below -class NodeIter; // Declared below +class NeighborIter; // Declared below +class NodeIter; // Declared below class NodeProperties; // Defined in .cc class Node { @@ -494,6 +494,13 @@ class Graph { // Serialize to a GraphDef. void ToGraphDef(GraphDef* graph_def) const; + // This version can be called from debugger to inspect the graph content. + // Use the previous version outside debug context for efficiency reasons. + // + // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is + // not defined in some TensorFlow builds. + GraphDef ToGraphDefDebug() const; + // Generate new node name with the specified prefix that is unique // across this graph. string NewName(StringPiece prefix); diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 2a52c7516e539d78d4534239868c5fae7f804e17..0629ff32d00cf7fad00c39f07810aa4a9d57f14f 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -374,15 +374,8 @@ Status GraphConstructor::EnsureNoNameCollisions() { return errors::InvalidArgument("Imported node name prefix '", prefix_, "' would lead to invalid node names"); } - if (NameExistsInGraph(prefix_no_slash)) { - if (opts_.uniquify_prefix) { - prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/"); - } else { - return errors::InvalidArgument("Import node name prefix '", - prefix_no_slash, - "' conflicts with " - "name already used in the graph"); - } + if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) { + prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/"); } } return Status::OK(); @@ -990,7 +983,10 @@ Status GraphConstructor::Convert() { if (opts_.importing) { if (!prefix_.empty()) { AddPrefixToNodeDef(input_already_exists, &imported_node_def); - } else if (opts_.uniquify_names) { + } + // Note: no need to uniquify names if the prefix already guarantees + // uniqueness + if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) { UniquifyNames(input_already_exists, &imported_node_def); } TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def)); diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 01bb1ac748fd512dcd1d715d949de8eb6e77142d..963c1dc024b4265e14314c610399fc92331f053c 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -160,9 +160,7 @@ class GraphConstructorTest : public ::testing::Test { } string GraphDebugString() const { - GraphDef def; - graph_.ToGraphDef(&def); - return def.DebugString(); + return graph_.ToGraphDefDebug().DebugString(); } Graph graph_; @@ -1836,7 +1834,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) { EXPECT_EQ(results.return_nodes[1]->name(), "B_2"); EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_2:0"); - // Import with an already-used prefix + // Import with an already-used prefix and uniquify_prefix = true opts.prefix = "A"; opts.uniquify_prefix = true; results = ImportGraphDefResults(); @@ -1848,9 +1846,27 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) { EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_3/A"); // Create B_3 node to keep the A/B numbering in sync - opts = ImportGraphDefOptions(); ExpectOK("node { name: 'B_3' op: 'TestInput' }"); + // Import with an already-used prefix and uniquify_prefix = false + opts.uniquify_prefix = false; + results = ImportGraphDefResults(); + ExpectOK(graph_def_str, opts, &refiner, &results); + + ASSERT_EQ(results.return_nodes.size(), 2); + EXPECT_EQ(results.return_nodes[0]->name(), "A/A"); + EXPECT_EQ(results.return_nodes[1]->name(), "A/B"); + EXPECT_EQ(results.return_nodes[1]->def().input(0), "A/A"); + + // Repeat the same import + results = ImportGraphDefResults(); + ExpectOK(graph_def_str, opts, &refiner, &results); + + ASSERT_EQ(results.return_nodes.size(), 2); + EXPECT_EQ(results.return_nodes[0]->name(), "A/A_1"); + EXPECT_EQ(results.return_nodes[1]->name(), "A/B_1"); + EXPECT_EQ(results.return_nodes[1]->def().input(0), "A/A_1:0"); + // Import with existing de-duped node names opts = ImportGraphDefOptions(); opts.uniquify_names = true; diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc index 33d2021f3819e7781a0a488a04e7459eaf14a0d7..7a58347bd1ba44d822f5c52d2686e4b3c6e43d9b 100644 --- a/tensorflow/core/graph/graph_def_builder.cc +++ b/tensorflow/core/graph/graph_def_builder.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" @@ -72,16 +71,6 @@ Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const { return status_; } -Status GraphDefBuilder::ToGraph(Graph* graph) const { - if (status_.ok()) { - GraphDef graph_def; - graph_.ToGraphDef(&graph_def); - GraphConstructorOptions opts; - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, graph)); - } - return status_; -} - string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const { if (name_.empty()) return graph_->NewName(op); return name_; diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h index a2c0c4d553e7229ae7e0f116691d8f717fe77f87..776a74c6d8821e53a26d73399105f55189f227df 100644 --- a/tensorflow/core/graph/graph_def_builder.h +++ b/tensorflow/core/graph/graph_def_builder.h @@ -161,14 +161,6 @@ class GraphDefBuilder { // successful, and if so fill *graph_def. Status ToGraphDef(GraphDef* graph_def) const; - // Like ToGraphDef(), but converts to a Graph (using the default - // GraphConstructorOptions). - // TODO(josh11b): Make this faster; right now it converts - // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds - // edges from the source and to the sink node, resolves back edges - // by name), and makes sure the resulting graph is valid. - Status ToGraph(Graph* graph) const; - // Adds the function and gradient definitions in `fdef_lib` to this graph's op // registry. Ignores duplicate functions, and returns a bad status if an // imported function differs from an existing function or op with the same diff --git a/tensorflow/core/graph/graph_def_builder_test.cc b/tensorflow/core/graph/graph_def_builder_test.cc index e85de71ef79988199cd194274f2ef9986e86d350..be3c2be8007a4539e111a6d2375cef87dc5dff8e 100644 --- a/tensorflow/core/graph/graph_def_builder_test.cc +++ b/tensorflow/core/graph/graph_def_builder_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -26,7 +27,6 @@ namespace tensorflow { namespace { TEST(GraphDefBuilderTest, Version) { - // Verify that our assertions will be nontrivial ASSERT_LT(0, TF_GRAPH_DEF_VERSION); @@ -35,7 +35,7 @@ TEST(GraphDefBuilderTest, Version) { // Check version when we convert to a Graph Graph graph(OpRegistry::Global()); - TF_EXPECT_OK(builder.ToGraph(&graph)); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, &graph)); ASSERT_EQ(graph.versions().producer(), TF_GRAPH_DEF_VERSION); ASSERT_EQ(graph.versions().min_consumer(), TF_GRAPH_DEF_VERSION_MIN_CONSUMER); diff --git a/tensorflow/core/graph/graph_def_builder_util.cc b/tensorflow/core/graph/graph_def_builder_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..102c72185f7c1d1a0f0370fcbad08b0fc473c237 --- /dev/null +++ b/tensorflow/core/graph/graph_def_builder_util.cc @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/graph/graph_def_builder_util.h" + +#include "tensorflow/core/graph/graph_constructor.h" + +namespace tensorflow { + +Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph) { + GraphDef graph_def; + TF_RETURN_IF_ERROR(builder.ToGraphDef(&graph_def)); + GraphConstructorOptions opts; + return ConvertGraphDefToGraph(opts, graph_def, graph); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_def_builder_util.h b/tensorflow/core/graph/graph_def_builder_util.h new file mode 100644 index 0000000000000000000000000000000000000000..4a157e5b71da48178139ff71da4d707901b955fe --- /dev/null +++ b/tensorflow/core/graph/graph_def_builder_util.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_ +#define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_ + +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class Graph; + +// Converts the `GraphDef` being built by `builder` to a `Graph` and +// stores it in `*graph`. +// TODO(josh11b): Make this faster; right now it converts +// Graph->GraphDef->Graph. This cleans up the graph (e.g. adds +// edges from the source and to the sink node, resolves back edges +// by name), and makes sure the resulting graph is valid. +Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_ diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index 3df981437afed760744ef870fd542d7abdd6e25d..1b99d54e8e33fd5155913a78ee833343bf92b905 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -21,102 +21,101 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { - // Since our ops are going to produce and also consume N addition tensors - // (Mkl) for N Tensorflow tensors, we can have following different - // orderings among these 2N tensors. - // - // E.g., for Tensorflow tensors A, B, and C, our ops will produce and - // consume A_m, B_m, and C_m additionally. - // - // INTERLEAVED: in this case 2N tensors are interleaved. So for above - // example, the ordering looks like: A, A_m, B, B_m, C, C_m. - // - // CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed - // by N Mkl tensors. So for above example, the ordering looks - // like: A, B, C, A_m, B_m, C_m - // - // Following APIs map index of original Tensorflow tensors to their - // appropriate position based on selected ordering. For contiguous ordering, - // we need to know the total number of tensors (parameter total). - // - typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; - // NOTE: Currently, we use contiguous ordering. If you change this, then you - // would need to change Mkl op definitions in nn_ops.cc. - static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; +// Since our ops are going to produce and also consume N addition tensors +// (Mkl) for N Tensorflow tensors, we can have following different +// orderings among these 2N tensors. +// +// E.g., for Tensorflow tensors A, B, and C, our ops will produce and +// consume A_m, B_m, and C_m additionally. +// +// INTERLEAVED: in this case 2N tensors are interleaved. So for above +// example, the ordering looks like: A, A_m, B, B_m, C, C_m. +// +// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed +// by N Mkl tensors. So for above example, the ordering looks +// like: A, B, C, A_m, B_m, C_m +// +// Following APIs map index of original Tensorflow tensors to their +// appropriate position based on selected ordering. For contiguous ordering, +// we need to know the total number of tensors (parameter total). +// +typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; +// NOTE: Currently, we use contiguous ordering. If you change this, then you +// would need to change Mkl op definitions in nn_ops.cc. +static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; - // Get index of MetaData tensor from index 'n' of Data tensor. - inline int DataIndexToMetaDataIndex(int n, int total_tensors) { - if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { - // For interleaved ordering, Mkl tensor follows immediately after - // Tensorflow tensor. - return n + 1; - } else { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away. - return n + total_tensors / 2; - } +// Get index of MetaData tensor from index 'n' of Data tensor. +inline int DataIndexToMetaDataIndex(int n, int total_tensors) { + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + // For interleaved ordering, Mkl tensor follows immediately after + // Tensorflow tensor. + return n + 1; + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away. + return n + total_tensors / 2; } +} - int inline GetTensorDataIndex(int n, int total_tensors) { - if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { - return 2 * n; // index corresponding to nth input/output tensor - } else { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - return n; - } - } +int inline GetTensorDataIndex(int n, int total_tensors) { + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + return 2 * n; // index corresponding to nth input/output tensor + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + return n; + } +} - int inline GetTensorMetaDataIndex(int n, int total_tensors) { - // Get index for TensorData first and then use mapping function - // to get TensorMetaData index from TensorData index. - int tidx = GetTensorDataIndex(n, total_tensors); - return DataIndexToMetaDataIndex(tidx, total_tensors); - } +int inline GetTensorMetaDataIndex(int n, int total_tensors) { + // Get index for TensorData first and then use mapping function + // to get TensorMetaData index from TensorData index. + int tidx = GetTensorDataIndex(n, total_tensors); + return DataIndexToMetaDataIndex(tidx, total_tensors); +} namespace mkl_op_registry { - static const char* kMklOpLabel = "MklOp"; - static const char* kMklOpLabelPattern = "label='MklOp'"; - // Prefix that we add to Tensorflow op name to construct Mkl op name. - static const char* const kMklOpPrefix = "_Mkl"; +static const char* kMklOpLabel = "MklOp"; +static const char* kMklOpLabelPattern = "label='MklOp'"; +// Prefix that we add to Tensorflow op name to construct Mkl op name. +static const char* const kMklOpPrefix = "_Mkl"; - // Get the name of Mkl op from original TensorFlow op - // We prefix 'Mkl' to the original op to get Mkl op. - inline string GetMklOpName(const string& name) { - return string(kMklOpPrefix) + name; - } +// Get the name of Mkl op from original TensorFlow op +// We prefix 'Mkl' to the original op to get Mkl op. +inline string GetMklOpName(const string& name) { + return string(kMklOpPrefix) + name; +} - // Check whether opname with type T is registered as MKL-compliant. - // - // @input: name of the op - // @input: T datatype to be used for checking op - // @return: true if opname is registered as Mkl op; false otherwise - static inline bool IsMklOp(const std::string& op_name, DataType T) { - string kernel = KernelsRegisteredForOp(op_name); - bool result = - kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); - return result; - } +// Check whether opname with type T is registered as MKL-compliant. +// +// @input: name of the op +// @input: T datatype to be used for checking op +// @return: true if opname is registered as Mkl op; false otherwise +static inline bool IsMklOp(const std::string& op_name, DataType T) { + string kernel = KernelsRegisteredForOp(op_name); + bool result = + kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); + return result; +} - // Check whether opname with type T is registered as MKL-compliant and - // is element-wise. - // - // @input: name of the op - // @input: T datatype to be used for checking op - // @return: true if opname is registered as element-wise Mkl op; - // false otherwise - static inline bool IsMklElementWiseOp(const std::string& op_name, - DataType T) { - if (!IsMklOp(op_name, T)) { - return false; - } - bool result = (0 == op_name.compare(GetMklOpName("Add")) || - 0 == op_name.compare(GetMklOpName("Sub")) || - 0 == op_name.compare(GetMklOpName("Mul")) || - 0 == op_name.compare(GetMklOpName("Maximum")) || - 0 == op_name.compare(GetMklOpName("SquaredDifference"))); - - return result; +// Check whether opname with type T is registered as MKL-compliant and +// is element-wise. +// +// @input: name of the op +// @input: T datatype to be used for checking op +// @return: true if opname is registered as element-wise Mkl op; +// false otherwise +static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) { + if (!IsMklOp(op_name, T)) { + return false; } + bool result = (0 == op_name.compare(GetMklOpName("Add")) || + 0 == op_name.compare(GetMklOpName("Sub")) || + 0 == op_name.compare(GetMklOpName("Mul")) || + 0 == op_name.compare(GetMklOpName("Maximum")) || + 0 == op_name.compare(GetMklOpName("SquaredDifference"))); + + return result; +} } // namespace mkl_op_registry } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 55bc401b9d61d43e1908faf0ac7e24639ec04c44..7d3be152991351533a6185ea088503032f720b47 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -37,12 +37,12 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/tensor_format.h" -#include "tensorflow/core/graph/mkl_layout_pass.h" #include "tensorflow/core/graph/mkl_graph_util.h" +#include "tensorflow/core/graph/mkl_layout_pass.h" namespace tensorflow { -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML // This pass implements rewriting of graph to support following scenarios: // (A) Merging nodes in the graph @@ -281,7 +281,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_with_bias_backprop_bias = - "_MklConv2DWithBiasBackpropBias"; + "_MklConv2DWithBiasBackpropBias"; csinfo_.relu = "Relu"; csinfo_.relu_grad = "ReluGrad"; csinfo_.reshape = "Reshape"; @@ -297,10 +297,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // End - element-wise ops. See note above. // NOTE: names are alphabetically sorted. - rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN, - AddNRewrite, nullptr}); - rinfo_.push_back({csinfo_.add, - mkl_op_registry::GetMklOpName(csinfo_.add), + rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), + CopyAttrsAddN, AddNRewrite, nullptr}); + rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), @@ -337,14 +336,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.fused_batch_norm, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.fused_batch_norm_grad, - mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), - CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); + rinfo_.push_back( + {csinfo_.fused_batch_norm_grad, + mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), + CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsIdentity, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.lrn, - mkl_op_registry::GetMklOpName(csinfo_.lrn), + rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), CopyAttrsLRN, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), @@ -358,11 +357,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum), CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.mul, - mkl_op_registry::GetMklOpName(csinfo_.mul), + rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.relu, - mkl_op_registry::GetMklOpName(csinfo_.relu), + rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.relu_grad, mkl_op_registry::GetMklOpName(csinfo_.relu_grad), @@ -373,8 +370,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.squared_difference, mkl_op_registry::GetMklOpName(csinfo_.squared_difference), CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.sub, - mkl_op_registry::GetMklOpName(csinfo_.sub), + rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), CopyAttrsDataType, AlwaysRewrite, nullptr}); // Add info about which ops to add workspace edge to and the slots. @@ -388,9 +384,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, IsBiasAddGradInMatMulContext}; - biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad, - csinfo_.mkl_conv2d_with_bias, - IsBiasAddGradInConv2DWithBiasContext}; + biasaddgrad_conv2dwithbias_context_ = { + csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, + IsBiasAddGradInConv2DWithBiasContext}; cinfo_.push_back(&biasaddgrad_matmul_context_); cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); @@ -410,9 +406,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Structure to specify the context information used in a node rewrite rule typedef struct { - string node; // Name of the node to be rewritten - string fwd; // Name of the node in the forward pass that this node - // corresponds to + string node; // Name of the node to be rewritten + string fwd; // Name of the node in the forward pass that this node + // corresponds to std::function context_match_fn; } ContextInfo; @@ -615,14 +611,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { std::vector ksize, strides; CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); - CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), - true); + CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); CHECK_EQ(FormatFromString(data_format_str, &data_format), true); // Condition that specifies non-batch-wise and non-depth-wise pooling. - if (GetTensorDim(ksize, data_format, 'N') == 1 && + if (GetTensorDim(ksize, data_format, 'N') == 1 && GetTensorDim(strides, data_format, 'N') == 1 && - GetTensorDim(ksize, data_format, 'C') == 1 && + GetTensorDim(ksize, data_format, 'C') == 1 && GetTensorDim(strides, data_format, 'C') == 1) { return true; } @@ -785,8 +780,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { for (const Edge* fe : first_inp_of_filter->out_edges()) { if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && fe->dst_input() == 0) { - VLOG(1) << "MklLayoutRewritePass: found " - << fe->dst()->DebugString() + VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString() << " as the forward node for matching context, backward" << " node is: " << n->DebugString(); *fwd_node = fe->dst(); @@ -803,13 +797,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // // @return - true (if BiasAddGrad is associated with MatMul); // false otherwise. - static bool IsBiasAddGradInMatMulContext(const Node* n, - const Node** fwd_node, + static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node, void* ci) { return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci)); } - // Rewrite rule that uses context-information for matching, // used in scenario 2. // @@ -880,10 +872,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @output output_nodes - the list of new nodes creating Mkl tensors // // @return None - void GetNodesProducingMklTensorList(std::unique_ptr* g, - Node* orig_node, const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes); + void GetNodesProducingMklTensorList( + std::unique_ptr* g, Node* orig_node, + const gtl::InlinedVector, 4>& inputs, + int* input_idx, int list_length, + std::vector* output_nodes); // Get a node that will feed an Mkl tensor to the new // node that we are constructing. The output node could be (1) 'n' @@ -900,7 +893,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // will feed the tensor // @return None void GetNodeProducingMklTensor(std::unique_ptr* g, Node* orig_node, - Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); + Node* n, int n_output_slot, Node** mkl_node, + int* mkl_node_output_slot); // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are @@ -970,9 +964,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; MklLayoutRewritePass::ContextInfo - MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; + MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; MklLayoutRewritePass::ContextInfo - MklLayoutRewritePass::biasaddgrad_matmul_context_; + MklLayoutRewritePass::biasaddgrad_matmul_context_; std::vector MklLayoutRewritePass::cinfo_; // We register Mkl rewrite pass for phase 1 in post partitioning group. @@ -1041,13 +1035,13 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, TensorShape dummy_shape({8}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // the same device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // the same device as the + // device of the original + // node. + .Finalize(&**g, out)); // If number of inputs to the original node is > 0, then we add // control dependency between 1st input (index 0) of the original node and @@ -1060,8 +1054,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, // the same frame. if (orig_node->num_inputs() > 0) { Node* orig_input0 = nullptr; - TF_CHECK_OK(orig_node->input_node(0, - const_cast(&orig_input0))); + TF_CHECK_OK( + orig_node->input_node(0, const_cast(&orig_input0))); CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); } @@ -1069,11 +1063,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, } void MklLayoutRewritePass::GetNodesProducingMklTensorList( - std::unique_ptr* g, - Node* orig_node, - const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes) { + std::unique_ptr* g, Node* orig_node, + const gtl::InlinedVector, 4>& inputs, int* input_idx, + int list_length, std::vector* output_nodes) { CHECK_LT(*input_idx, inputs.size()); CHECK_GT(list_length, 0); CHECK_NOTNULL(output_nodes); @@ -1090,8 +1082,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( int mkl_node_output_slot = 0; GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, &mkl_node_output_slot); - output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, - mkl_node_output_slot)); + output_nodes->push_back( + NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); (*input_idx)++; list_length--; } @@ -1101,9 +1093,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( // node that we are constructing. An input node could be (1) 'n' // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor // if 'n' is not an Mkl layer. -void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr* g, - Node* orig_node, Node* n, - int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { +void MklLayoutRewritePass::GetNodeProducingMklTensor( + std::unique_ptr* g, Node* orig_node, Node* n, int n_output_slot, + Node** mkl_node, int* mkl_node_output_slot) { CHECK_NOTNULL(n); CHECK_NOTNULL(mkl_node); CHECK_NOTNULL(mkl_node_output_slot); @@ -1234,8 +1226,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs( if (ArgIsList(arg)) { std::vector new_node_inputs; int N = GetTensorListLength(arg, old_node); - GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, - N, &new_node_inputs); + GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, + &new_node_inputs); nb->Input(new_node_inputs); nn_slot_idx++; } else { @@ -1336,13 +1328,13 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( TensorShape dummy_shape({1}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // same the device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // same the device as the + // device of the original + // node. + .Finalize(&**g, out)); // If number of inputs to the original node is > 0, then we add // control dependency between 1st input (index 0) of the original node and @@ -1355,8 +1347,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( // the same frame. if (orig_node->num_inputs() > 0) { Node* orig_input0 = nullptr; - TF_CHECK_OK(orig_node->input_node(0, - const_cast(&orig_input0))); + TF_CHECK_OK( + orig_node->input_node(0, const_cast(&orig_input0))); CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); } @@ -1374,7 +1366,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); for (auto ws : wsinfo_) { if (orig_node->type_string() == ws.fwd_op && - mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { + mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { // If this op is a fwd op, then we need to check if there is an // edge from this node's fwd_slot to bwdop's bwd_slot. If there is // an edge, then we just add an attribute on this node for setting @@ -1400,8 +1393,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( nb->Attr("workspace_enabled", false); } } else if (orig_node->type_string() == ws.bwd_op && - mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), - T)) { + mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(orig_node->type_string()), + T)) { // If this op is a bwd op, then we need to add workspace edge and // it's Mkl tensor edge between its corresponding fwd op and this // op. Corresponding fwd op is specified in 'fwd_op' field of @@ -1416,7 +1410,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( if (e->src_output() == ws.fwd_slot && // We would have rewritten the forward op, so we need to use // GetMklOpName call to get its Mkl name. - e->src()->type_string() == mkl_op_registry::GetMklOpName(ws.fwd_op) && + e->src()->type_string() == + mkl_op_registry::GetMklOpName(ws.fwd_op) && e->dst_input() == ws.bwd_slot) { nb->Attr("workspace_enabled", true); CHECK_NOTNULL(ws_tensors); @@ -1593,7 +1588,7 @@ void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, } void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, - NodeBuilder* nb) { + NodeBuilder* nb) { DataType T; DataType Tshape; @@ -1869,8 +1864,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr* g, Node* succ, if (e->IsControlEdge()) { CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); } else { - CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(), - e->dst_input())); + CHECK_NOTNULL( + (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input())); } } @@ -1941,9 +1936,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, // and leave BiasAddGrad as it is. But we check for this condition // when we check for node rewrite rule. So we should not even come // here for MatMul. So we will fail now. - return Status( - error::Code::INVALID_ARGUMENT, - "No rewrite is required for BiasAddGrad for MatMul context."); + return Status( + error::Code::INVALID_ARGUMENT, + "No rewrite is required for BiasAddGrad for MatMul context."); } } @@ -2012,9 +2007,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, if (e->IsControlEdge()) { CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); } else { - CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), - e->src()->num_outputs()), - e->dst(), e->dst_input())); + CHECK_NOTNULL((*g)->AddEdge( + new_node, + GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), + e->dst(), e->dst_input())); } } @@ -2070,7 +2066,8 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // BiasAddGrad is not an Mkl layer, so we make an exception for it. if (n->type_string() != csinfo_.bias_add_grad) { - if (!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) { + if (!mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(n->type_string()), T)) { return nullptr; } } @@ -2186,8 +2183,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr* g) { return MklLayoutRewritePass().RunPass(g); } -Status MklLayoutRewritePass::Run( - const GraphOptimizationPassOptions& options) { +Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { if (options.graph == nullptr && options.partition_graphs == nullptr) { return Status::OK(); } @@ -2215,7 +2211,7 @@ Status MklLayoutRewritePass::Run( return Status::OK(); } -#else // INTEL_MKL_DNN +#else // INTEL_MKL_ML // This pass implements rewriting of graph to support following scenarios: // (A) Merging nodes in the graph @@ -2421,7 +2417,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; csinfo_.conv2d_grad_filter_with_bias = - "__MklDummyConv2DBackpropFilterWithBias"; + "__MklDummyConv2DBackpropFilterWithBias"; csinfo_.fused_batch_norm = "FusedBatchNorm"; csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; csinfo_.identity = "Identity"; @@ -2435,11 +2431,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_grad_filter_with_bias = - "_MklConv2DBackpropFilterWithBias"; + "_MklConv2DBackpropFilterWithBias"; csinfo_.relu = "Relu"; csinfo_.relu_grad = "ReluGrad"; - csinfo_.tanh = "Tanh"; - csinfo_.tanh_grad = "TanhGrad"; + csinfo_.tanh = "Tanh"; + csinfo_.tanh_grad = "TanhGrad"; csinfo_.reshape = "Reshape"; csinfo_.softmax = "Softmax"; csinfo_.split = "Split"; @@ -2456,9 +2452,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // NOTE: names are alphabetically sorted. rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN, AddNRewrite}); - /* rinfo_.push_back({csinfo_.add, - mkl_op_registry::GetMklOpName(csinfo_.add), - CopyAttrsDataType, AlwaysRewrite}); */ + rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), + CopyAttrsDataType, AlwaysRewrite}); rinfo_.push_back({csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), CopyAttrsPooling, AlwaysRewrite}); @@ -2474,29 +2469,28 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), CopyAttrsConv2D, AlwaysRewrite}); - rinfo_.push_back({csinfo_.conv2d_with_bias, - csinfo_.mkl_conv2d_with_bias, + rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias, CopyAttrsConv2D, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_filter, mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), CopyAttrsConv2D, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, - csinfo_.mkl_conv2d_grad_filter_with_bias, - CopyAttrsConv2D, AlwaysRewrite}); + csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D, + AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_input, mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), CopyAttrsConv2D, AlwaysRewrite}); rinfo_.push_back({csinfo_.fused_batch_norm, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), CopyAttrsFusedBatchNorm, AlwaysRewrite}); - rinfo_.push_back({csinfo_.fused_batch_norm_grad, - mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), - CopyAttrsFusedBatchNorm, AlwaysRewrite}); + rinfo_.push_back( + {csinfo_.fused_batch_norm_grad, + mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), + CopyAttrsFusedBatchNorm, AlwaysRewrite}); rinfo_.push_back({csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsDataType, AlwaysRewrite}); - rinfo_.push_back({csinfo_.lrn, - mkl_op_registry::GetMklOpName(csinfo_.lrn), + rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), CopyAttrsLRN, AlwaysRewrite}); rinfo_.push_back({csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), @@ -2507,16 +2501,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.max_pool_grad, mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), CopyAttrsPooling, AlwaysRewrite}); - /* + rinfo_.push_back({csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum), CopyAttrsDataType, AlwaysRewrite}); rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), CopyAttrsDataType, AlwaysRewrite}); - */ - rinfo_.push_back({csinfo_.relu, - mkl_op_registry::GetMklOpName(csinfo_.relu), + rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), CopyAttrsDataType, AlwaysRewrite}); rinfo_.push_back({csinfo_.relu_grad, mkl_op_registry::GetMklOpName(csinfo_.relu_grad), @@ -2535,14 +2527,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax), CopyAttrsDataType, AlwaysRewrite}); - /* + rinfo_.push_back({csinfo_.squared_difference, mkl_op_registry::GetMklOpName(csinfo_.squared_difference), CopyAttrsDataType, AlwaysRewrite}); rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), CopyAttrsDataType, AlwaysRewrite}); - */ // Add info about which ops to add workspace edge to and the slots. wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); @@ -2550,8 +2541,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Add a rule for merging nodes minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, - csinfo_.conv2d_with_bias, - GetConv2DOrBiasAdd}); + csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd}); minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad, csinfo_.conv2d_grad_filter_with_bias, @@ -2846,9 +2836,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Default rewrite rule to be used in scenario 1 for rewrite. // @return - true (since we want to always rewrite) - static bool AlwaysRewrite(const Node* n) { - return true; - } + static bool AlwaysRewrite(const Node* n) { return true; } // Check if we are performing pooling on depth or batch. If it is, then we // do not rewrite MaxPool node to Mkl version. @@ -2862,14 +2850,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { std::vector ksize, strides; CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); - CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), - true); + CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); CHECK_EQ(FormatFromString(data_format_str, &data_format), true); // Condition that specifies non-batch-wise and non-depth-wise pooling. - if (GetTensorDim(ksize, data_format, 'N') == 1 && + if (GetTensorDim(ksize, data_format, 'N') == 1 && GetTensorDim(strides, data_format, 'N') == 1 && - GetTensorDim(ksize, data_format, 'C') == 1 && + GetTensorDim(ksize, data_format, 'C') == 1 && GetTensorDim(strides, data_format, 'C') == 1) { return true; } @@ -2941,10 +2928,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @output output_nodes - the list of new nodes creating Mkl tensors // // @return None - void GetNodesProducingMklTensorList(std::unique_ptr* g, - Node* orig_node, const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes); + void GetNodesProducingMklTensorList( + std::unique_ptr* g, Node* orig_node, + const gtl::InlinedVector, 4>& inputs, + int* input_idx, int list_length, + std::vector* output_nodes); // Get a node that will feed an Mkl tensor to the new // node that we are constructing. The output node could be (1) 'n' @@ -2961,7 +2949,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // will feed the tensor // @return None void GetNodeProducingMklTensor(std::unique_ptr* g, Node* orig_node, - Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); + Node* n, int n_output_slot, Node** mkl_node, + int* mkl_node_output_slot); // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are @@ -3096,13 +3085,13 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, TensorShape dummy_shape({8}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // the same device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // the same device as the + // device of the original + // node. + .Finalize(&**g, out)); // If number of inputs to the original node is > 0, then we add // control dependency between 1st input (index 0) of the original node and @@ -3115,8 +3104,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, // the same frame. if (orig_node->num_inputs() > 0) { Node* orig_input0 = nullptr; - TF_CHECK_OK(orig_node->input_node(0, - const_cast(&orig_input0))); + TF_CHECK_OK( + orig_node->input_node(0, const_cast(&orig_input0))); // Allow duplicate while adding control edge as it would fail (return // NULL) if we try to add duplicate edge. CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); @@ -3126,11 +3115,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr* g, } void MklLayoutRewritePass::GetNodesProducingMklTensorList( - std::unique_ptr* g, - Node* orig_node, - const gtl::InlinedVector, 4>& inputs, - int* input_idx, int list_length, - std::vector* output_nodes) { + std::unique_ptr* g, Node* orig_node, + const gtl::InlinedVector, 4>& inputs, int* input_idx, + int list_length, std::vector* output_nodes) { CHECK_LT(*input_idx, inputs.size()); CHECK_GT(list_length, 0); CHECK_NOTNULL(output_nodes); @@ -3147,8 +3134,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( int mkl_node_output_slot = 0; GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, &mkl_node_output_slot); - output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, - mkl_node_output_slot)); + output_nodes->push_back( + NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); (*input_idx)++; list_length--; } @@ -3158,9 +3145,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( // node that we are constructing. An input node could be (1) 'n' // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor // if 'n' is not an Mkl layer. -void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr* g, - Node* orig_node, Node* n, - int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { +void MklLayoutRewritePass::GetNodeProducingMklTensor( + std::unique_ptr* g, Node* orig_node, Node* n, int n_output_slot, + Node** mkl_node, int* mkl_node_output_slot) { CHECK_NOTNULL(n); CHECK_NOTNULL(mkl_node); CHECK_NOTNULL(mkl_node_output_slot); @@ -3292,8 +3279,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs( if (ArgIsList(arg)) { std::vector new_node_inputs; int N = GetTensorListLength(arg, old_node); - GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, - N, &new_node_inputs); + GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, + &new_node_inputs); nb->Input(new_node_inputs); nn_slot_idx++; } else { @@ -3394,13 +3381,13 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( TensorShape dummy_shape({1}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // same the device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // same the device as the + // device of the original + // node. + .Finalize(&**g, out)); // If number of inputs to the original node is > 0, then we add // control dependency between 1st input (index 0) of the original node and @@ -3413,8 +3400,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( // the same frame. if (orig_node->num_inputs() > 0) { Node* orig_input0 = nullptr; - TF_CHECK_OK(orig_node->input_node(0, - const_cast(&orig_input0))); + TF_CHECK_OK( + orig_node->input_node(0, const_cast(&orig_input0))); // Allow duplicate while adding control edge as it would fail (return // NULL) if we try to add duplicate edge. CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); @@ -3434,8 +3421,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); for (auto ws : wsinfo_) { if (orig_node->type_string() == ws.fwd_op && - mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( - orig_node->type_string()), T)) { + mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { // If this op is a fwd op, then we need to check if there is an // edge from this node's fwd_slot to bwdop's bwd_slot. If there is // an edge, then we just add an attribute on this node for setting @@ -3461,8 +3448,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( nb->Attr("workspace_enabled", false); } } else if (orig_node->type_string() == ws.bwd_op && - mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( - orig_node->type_string()), T)) { + mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(orig_node->type_string()), + T)) { // If this op is a bwd op, then we need to add workspace edge and // it's Mkl tensor edge between its corresponding fwd op and this // op. Corresponding fwd op is specified in 'fwd_op' field of @@ -3477,8 +3465,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( if (e->src_output() == ws.fwd_slot && // We would have rewritten the forward op, so we need to use // GetMklOpName call to get its Mkl name. - e->src()->type_string() == mkl_op_registry::GetMklOpName( - ws.fwd_op) && + e->src()->type_string() == + mkl_op_registry::GetMklOpName(ws.fwd_op) && e->dst_input() == ws.bwd_slot) { nb->Attr("workspace_enabled", true); CHECK_NOTNULL(ws_tensors); @@ -3645,7 +3633,7 @@ void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, } void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, - NodeBuilder* nb) { + NodeBuilder* nb) { DataType T; DataType Tshape; @@ -3776,8 +3764,9 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr* g, Node* m, Node* n) { CHECK_EQ(((m->type_string() == csinfo_.bias_add && n->type_string() == csinfo_.conv2d)) || - ((n->type_string() == csinfo_.bias_add && - m->type_string() == csinfo_.conv2d)), true); + ((n->type_string() == csinfo_.bias_add && + m->type_string() == csinfo_.conv2d)), + true); // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd, // BiasAdd is successor node, and Conv2D predecessor node. @@ -3796,8 +3785,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr* g, TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); - TF_CHECK_OK( - GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); + TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); // We check to ensure that data formats of both succ and pred are same. // We expect them to be same, so we can enforce this as assert. // But assert can be too strict, so we enforce this as a check. @@ -3900,8 +3888,8 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr* g, // BiasAdd has only 1 output (at slot 0) and merged node also has only 1 // output (at slot 0). const int kConv2DWithBiasOutputSlot = 0; - CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, - e->dst(), e->dst_input())); + CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(), + e->dst_input())); } } @@ -3924,8 +3912,9 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( std::unique_ptr* g, Node* m, Node* n) { CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad && n->type_string() == csinfo_.conv2d_grad_filter)) || - ((n->type_string() == csinfo_.bias_add_grad && - m->type_string() == csinfo_.conv2d_grad_filter)), true); + ((n->type_string() == csinfo_.bias_add_grad && + m->type_string() == csinfo_.conv2d_grad_filter)), + true); // If 'm' is BiasAddGrad, then 'n' is BackpropFilter. Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n; @@ -4132,9 +4121,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, // NULL) if we try to add duplicate edge. CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); } else { - CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), - e->src()->num_outputs()), - e->dst(), e->dst_input())); + CHECK_NOTNULL((*g)->AddEdge( + new_node, + GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), + e->dst(), e->dst_input())); } } @@ -4166,9 +4156,9 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // names. if (n->type_string() != csinfo_.conv2d_with_bias && n->type_string() != csinfo_.conv2d_grad_filter_with_bias && - !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( - n->type_string()), T)) { - return nullptr; + !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), + T)) { + return nullptr; } // For elementwise node, we reuse the Eigen implementation and pass the MKL @@ -4184,29 +4174,30 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // eigen code to reduce cross-library dependency. VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string(); if (mkl_op_registry::IsMklElementWiseOp( - mkl_op_registry::GetMklOpName(n->type_string()), T) || + mkl_op_registry::GetMklOpName(n->type_string()), T) || n->type_string().find("Identity") != string::npos) { VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string(); bool incoming_mkl_edge = false; int num_parent = 0; for (auto parent : n->in_edges()) { if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) { - VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is MKL op: " - << parent->src()->type_string(); + VLOG(1) << "ELEMENTWISE: parent " << num_parent++ + << " is MKL op: " << parent->src()->type_string(); incoming_mkl_edge = true; break; } else { - VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is NON-MKL op: " - << parent->src()->type_string(); + VLOG(1) << "ELEMENTWISE: parent " << num_parent++ + << " is NON-MKL op: " << parent->src()->type_string(); } } if (incoming_mkl_edge == false) { - VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which has no MKL " + VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which " + "has no MKL " "parents."; return nullptr; } else { - VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() << - " which has MKL parents"; + VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() + << " which has MKL parents"; } } @@ -4214,8 +4205,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // for this op, then we rewrite it to Mkl op. // Find matching RewriteInfo and then check that rewrite rule applies. for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { - if (n->type_string().compare(ri->name) == 0 && - ri->rewrite_rule(n)) { + if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { return &*ri; } } @@ -4297,8 +4287,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr* g) { return MklLayoutRewritePass().RunPass(g); } -Status MklLayoutRewritePass::Run( - const GraphOptimizationPassOptions& options) { +Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { if (options.graph == nullptr && options.partition_graphs == nullptr) { return Status::OK(); } @@ -4325,7 +4314,7 @@ Status MklLayoutRewritePass::Run( return Status::OK(); } -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML } // namespace tensorflow #endif diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 75f7ca2d4d7ce7c86858a40fe34fed6aa707c9e5..5e2a465e22c7cbe45cbea40ea7a11491e2b2ad24 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -38,7 +38,7 @@ limitations under the License. namespace tensorflow { -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML namespace { @@ -125,8 +125,10 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); -REGISTER_OP("_MklInput2").Output("o: uint8") - .Output("o1: uint8").SetIsStateful(); +REGISTER_OP("_MklInput2") + .Output("o: uint8") + .Output("o1: uint8") + .SetIsStateful(); ///////////////////////////////////////////////////////////////////// // Unit tests related to node merge optiimization @@ -498,7 +500,6 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) { "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); } - // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) { InitGraph( @@ -874,11 +875,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { " input: ['A', 'B:0', 'B:1']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" - "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); + EXPECT_EQ( + DoMklLayoutOptimizationPass(), + "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" + "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" + "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); } // Concat with 2 Mkl layers feeding it @@ -1273,7 +1275,8 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { "node { name: 'H' op: 'Input'}" "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['H', 'G'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), + EXPECT_EQ( + DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" @@ -1640,7 +1643,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B']}" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['B', 'C'] }", kGPUDevice); + " input: ['B', 'C'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); } @@ -1666,7 +1670,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { "node { name: 'F' op: 'BiasAddGrad'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }", kGPUDevice); + " input: ['E'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" @@ -1687,7 +1692,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'D'] }", kGPUDevice); + " input: ['A', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" "A->D;A->E;B->D:1;C->D:2;D->E:1"); @@ -1700,7 +1706,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); } @@ -1713,7 +1720,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'C'] }", kGPUDevice); + " input: ['A', 'C'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); } @@ -1729,7 +1737,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } @@ -1745,7 +1754,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); } @@ -1766,7 +1776,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { " attr { key: 'N' value { i: 2 } }" " input: ['A', 'B:0', 'B:1']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }", kGPUDevice); + " input: ['C', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" "B->D:1;B:1->D:2;C->E;D->E:1"); @@ -1788,7 +1799,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { " attr { key: 'N' value { i: 2 } }" " input: ['B:0', 'B:1', 'A']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }", kGPUDevice); + " input: ['C', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); @@ -1808,7 +1820,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { " attr { key: 'is_training' value { b: true } }" " input: ['A', 'B', 'C', 'D', 'E'] }" "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'F'] }", kGPUDevice); + " input: ['A', 'F'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);E(Input);" "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" @@ -1837,7 +1850,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { "node { name: 'Y' op: 'Input'}" "node { name: 'Z' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['E', 'Y']}", kGPUDevice); + " input: ['E', 'Y']}", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" @@ -1885,7 +1899,7 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); } // namespace -#else // INTEL_MKL_DNN +#else // INTEL_MKL_ML namespace { @@ -1972,8 +1986,10 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); -REGISTER_OP("_MklInput2").Output("o: uint8") - .Output("o1: uint8").SetIsStateful(); +REGISTER_OP("_MklInput2") + .Output("o: uint8") + .Output("o1: uint8") + .SetIsStateful(); ///////////////////////////////////////////////////////////////////// // Unit tests related to node merge optiimization @@ -2492,11 +2508,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { " input: ['A', 'B:0', 'B:1']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" - "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); + EXPECT_EQ( + DoMklLayoutOptimizationPass(), + "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" + "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" + "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" + "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); } // Concat with 2 Mkl layers feeding it @@ -2891,7 +2908,8 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { "node { name: 'H' op: 'Input'}" "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['H', 'G'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), + EXPECT_EQ( + DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" @@ -3258,7 +3276,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B']}" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['B', 'C'] }", kGPUDevice); + " input: ['B', 'C'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); } @@ -3284,7 +3303,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { "node { name: 'F' op: 'BiasAddGrad'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }", kGPUDevice); + " input: ['E'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" @@ -3305,7 +3325,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'D'] }", kGPUDevice); + " input: ['A', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" "A->D;A->E;B->D:1;C->D:2;D->E:1"); @@ -3318,7 +3339,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); } @@ -3331,7 +3353,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }" "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'C'] }", kGPUDevice); + " input: ['A', 'C'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); } @@ -3347,7 +3370,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } @@ -3363,7 +3387,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", kGPUDevice); + " input: ['A', 'B'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); } @@ -3384,7 +3409,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { " attr { key: 'N' value { i: 2 } }" " input: ['A', 'B:0', 'B:1']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }", kGPUDevice); + " input: ['C', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" "B->D:1;B:1->D:2;C->E;D->E:1"); @@ -3406,7 +3432,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { " attr { key: 'N' value { i: 2 } }" " input: ['B:0', 'B:1', 'A']}" "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }", kGPUDevice); + " input: ['C', 'D'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); @@ -3426,7 +3453,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { " attr { key: 'is_training' value { b: true } }" " input: ['A', 'B', 'C', 'D', 'E'] }" "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'F'] }", kGPUDevice); + " input: ['A', 'F'] }", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);E(Input);" "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" @@ -3455,7 +3483,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { "node { name: 'Y' op: 'Input'}" "node { name: 'Z' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['E', 'Y']}", kGPUDevice); + " input: ['E', 'Y']}", + kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" @@ -3503,7 +3532,7 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); } // namespace -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML } // namespace tensorflow diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index 599bb88f015bfc035b7666747571a652a954139d..e9ced4d2b6b2e7bffa0fbe61f546bef0aa9db974 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -33,8 +33,8 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/graph/mkl_tfconversion_pass.h" #include "tensorflow/core/graph/mkl_graph_util.h" +#include "tensorflow/core/graph/mkl_tfconversion_pass.h" namespace tensorflow { @@ -152,12 +152,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge( string data_format; TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype)); - bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) == - Status::OK(); + bool dst_dtype_found = + GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK(); // We compare source and destination datatypes only when both are found. if (dst_dtype_found && (src_datatype != dst_datatype)) { - string err_msg = "T attribute of " + src->name() + " and " + - dst->name() + " do not match. Will not insert" + + string err_msg = "T attribute of " + src->name() + " and " + dst->name() + + " do not match. Will not insert" + " MklToTf node in such case."; return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str()); } @@ -222,7 +222,7 @@ Status MklToTfConversionPass::InsertInputConversionNode( BaseType(n->input_type(0))); // Check ordering of edges - for (uint i = 0; i < 4; i++) { + for (uint32 i = 0; i < 4; i++) { CHECK_EQ((edges[i]->dst_input() == i), true); } @@ -325,12 +325,12 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr* g) { // may not be Mkl node. DataType src_datatype; DataType dst_datatype; - bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) == - Status::OK() && - IsMklSupportedOp(src->type_string(), src_datatype)); - bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) == - Status::OK() && - IsMklSupportedOp(dst->type_string(), dst_datatype)); + bool src_is_mkl_op = + (GetNodeAttr(src->def(), "T", &src_datatype) == Status::OK() && + IsMklSupportedOp(src->type_string(), src_datatype)); + bool dst_is_mkl_op = + (GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK() && + IsMklSupportedOp(dst->type_string(), dst_datatype)); // Check if src with is Mkl-compliant, while dst is not Mkl-compliant. if (src_is_mkl_op && !dst_is_mkl_op) { diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index fde1ea17437e86d01054a1b153055170bda51e8b..7219d9812f3e4a01cffa4b6b17d38781f7d5e2b0 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -361,7 +362,7 @@ static void BM_SubgraphHelper(int iters, int num_nodes, last_node = ops::SourceOp("In", b.opts().WithName(name)); } } - TF_CHECK_OK(b.ToGraph(&g)); + TF_CHECK_OK(GraphDefBuilderToGraph(b, &g)); } std::vector fed; diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index 172471e34bc5ce344a4a8db2d404b77b7406c99f..0d88d1ff723b94783693559926c51c6726a2341b 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -40,7 +40,7 @@ REGISTER_KERNEL_BUILDER( #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER( Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Register the HostConst Op // Returns a constant tensor on the host. Useful for writing C++ tests @@ -273,6 +273,16 @@ Node* Reverse(Graph* g, Node* tensor, Node* axis) { return Binary(g, "ReverseV2", tensor, axis); } +Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry()) + .Input(input) + .Input(shift) + .Input(axis) + .Finalize(g, &ret)); + return ret; +} + Node* Error(Graph* g, Node* input, const string& errmsg) { Node* ret; TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error") diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h index 06597778bb204c83dae7699e1ffe0e2b196ac160..eb9038d619ed273bbfd2596bce964fda005b4ec1 100644 --- a/tensorflow/core/graph/testlib.h +++ b/tensorflow/core/graph/testlib.h @@ -117,6 +117,10 @@ Node* RandomGamma(Graph* g, Node* shape, Node* alpha); // Output dtype determined by lam. Node* RandomPoisson(Graph* g, Node* shape, Node* lam); +// Rolls tensor by an offset of along the corresponding +// dimensions. +Node* Roll(Graph* g, Node* input, Node* shift, Node* axis); + // Generates random parameters from the truncated standard normal distribution // of the nput shape Node* TruncatedNormal(Graph* g, Node* input, DataType dtype); diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index 5b8ce373bcf87a10875e764ba5cdbec96d58c080..b8f8e13c9a6830658e2b53388e1f91fbc8a22eab 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -114,7 +114,10 @@ tf_cc_test( name = "single_machine_test", srcs = ["single_machine_test.cc"], args = ["--heap_check=local"], # The GPU tracer leaks memory - tags = ["no_gpu"], + tags = [ + "no_cuda_on_cpu_tap", + "no_gpu", + ], deps = [ ":single_machine", "//tensorflow/cc:cc_ops", diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index 01a618ed7775eee64ce40e283394c09622353157..39bfca244ed2d40544dd2a17a019dadbe50f6d29 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -23,8 +23,7 @@ Cluster::Cluster(int timeout_s) : timeout_s_(timeout_s) { DisableDetailedStats(false); } -Cluster::~Cluster() { -} +Cluster::~Cluster() {} void Cluster::AllowSoftPlacement(bool soft_placement_state) { options_.config.set_allow_soft_placement(soft_placement_state); diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index 2712c5b67910c2d10a13237673cc671222955fbb..cc7f418d49816d64ffc51704d2f127a441815d7b 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" namespace tensorflow { @@ -36,10 +37,7 @@ namespace grappler { static std::atomic already_provisioned(false); SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus) - : Cluster(timeout_s), - num_gpus_(num_gpus), - expected_init_time_s_(0), - closing_(false) { + : Cluster(timeout_s), expected_init_time_s_(0), closing_(false) { VLOG(1) << "Number of CPU cores: " << num_cpu_cores << " Number of GPUs: " << num_gpus; thread_pool_.reset(new thread::ThreadPool( @@ -89,7 +87,9 @@ Status SingleMachine::Provision() { attr = GetLocalCPUInfo(); } else if (dev.device_type() == "GPU") { attr = GetLocalGPUInfo(gpu_id++); - } else { + } else if (dev.device_type().find("XLA") == string::npos) { + // Filter out the fake XLA devices to avoid double counting the actual + // hardware resources that are available. attr.set_type(dev.device_type()); } // Overwrite the memory size since users might have requested to use only a diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h index a254f72f0c7719e49d4f52e8cc42181a09071801..90d6a04cab650178db0dc14ac94564690b0d7bbb 100644 --- a/tensorflow/core/grappler/clusters/single_machine.h +++ b/tensorflow/core/grappler/clusters/single_machine.h @@ -64,7 +64,6 @@ class SingleMachine : public Cluster { Status ClearAllocatorStats() const; - const int num_gpus_; std::unique_ptr session_; std::vector queue_runner_defs_; string last_graph_id_; diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc index 592e4b789d0dcb7369e2f0c6db447eb9daa92870..aacd2ccb72df07ac6b31c9bd5b96deca499038e4 100644 --- a/tensorflow/core/grappler/clusters/utils.cc +++ b/tensorflow/core/grappler/clusters/utils.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/mem.h" namespace tensorflow { namespace grappler { @@ -48,6 +49,11 @@ DeviceProperties GetLocalCPUInfo() { device.set_l2_cache_size(Eigen::l2CacheSize()); device.set_l3_cache_size(Eigen::l3CacheSize()); + int64 free_mem = port::AvailableRam(); + if (free_mem < INT64_MAX) { + device.set_memory_size(free_mem); + } + (*device.mutable_environment())["cpu_instruction_set"] = Eigen::SimdInstructionSetsInUse(); diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc index 1c2c1713834a11d0a7c85247e9a7e4cdf779c592..f24192247113bfe91884a9c557f46cc29986ff9a 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc @@ -102,7 +102,7 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) { Costs summary; TF_ASSERT_OK(estimator.PredictCosts(item.graph, &cost_graph, &summary)); - EXPECT_EQ(Costs::NanoSeconds(9150), summary.execution_time); + EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time); // Make this estimate accurate: // TODO(http://b/70031255): Accurate estimator for RandomUniform op needed diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h index 852e69737baa14e0d05de1fdcb6fc24a143f6a2d..9e01ec5ff5b48b9f979695b0a4b7b089245145c0 100644 --- a/tensorflow/core/grappler/costs/cost_estimator.h +++ b/tensorflow/core/grappler/costs/cost_estimator.h @@ -78,6 +78,9 @@ struct Costs { MilliSeconds asMilliSeconds() const { return std::chrono::duration_cast(*this); } + static NanoSeconds infinity() { + return NanoSeconds(std::chrono::nanoseconds::max()); + } }; // We store all our times in nanoseconds. If needs be, we can always switch to @@ -85,10 +88,7 @@ struct Costs { typedef NanoSeconds Duration; // Overall cost of running the graph; latency. - // Mean Duration execution_time; - Duration min_execution_time; - Duration max_execution_time; // Computation cost of running the graph. Duration compute_time; @@ -100,6 +100,8 @@ struct Costs { // requirements of a graph. For example, it might assume that all activations // are live for all of a graph's execution. int64 max_memory; // Maximum main memory requirement in bytes over all ops. + int64 persistent_memory; + int64 temporary_memory; // These fields are used for TPU-related estimations. They are per-op // maximums, so each op is evaluated independently, but we want the maximum of @@ -132,6 +134,8 @@ Costs::Costs() { compute_time = Duration::zero(); memory_time = Duration::zero(); max_memory = kMemoryUnknown; + persistent_memory = kMemoryUnknown; + temporary_memory = kMemoryUnknown; max_per_op_buffers = kMemoryUnknown; max_per_op_streaming = kMemoryUnknown; } @@ -142,6 +146,8 @@ Costs Costs::ZeroCosts() { costs.compute_time = Duration::zero(); costs.memory_time = Duration::zero(); costs.max_memory = kZeroMemory; + costs.persistent_memory = kZeroMemory; + costs.temporary_memory = kZeroMemory; costs.max_per_op_buffers = kZeroMemory; costs.max_per_op_streaming = kZeroMemory; return costs; diff --git a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc index 8fd1801863ad9aadd6e9f1bbde4b90600189d77c..ea4320687af366ccdd82e46cf28adf4ee9c100c0 100644 --- a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc @@ -117,8 +117,6 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph, LOG(ERROR) << "Failed to measure graph performance: " << status.error_message(); costs->execution_time = Costs::Duration::max(); - costs->max_execution_time = Costs::Duration::max(); - costs->min_execution_time = 0; return status; } @@ -126,8 +124,6 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph, // to filter out outliers. RobustStats stats(times); costs->execution_time = Costs::Duration(stats.mean()); - costs->max_execution_time = Costs::Duration(stats.hi()); - costs->min_execution_time = Costs::Duration(stats.lo()); return Status::OK(); } diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 6bc136a3f89c9a1dbfd4be15c143d4c893897494..a57cfdd9891b1d654092f9b896af248fa40eb88f 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -47,6 +47,8 @@ constexpr char kSize[] = "Size"; constexpr char kStopGradient[] = "StopGradient"; constexpr char kPreventGradient[] = "PreventGradient"; +static const Costs::Duration kMinComputeTime(1); + namespace { string GetDataFormat(const OpInfo& op_features) { @@ -163,18 +165,20 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}, - {kPlaceholder, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kRefIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kStopGradient, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kPreventGradient, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kSend, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kConst, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)}, - {kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)}, + + {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)}, + {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)}, + {kRefIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)}, + {kStopGradient, wrap(&OpLevelCostEstimator::PredictIdentity)}, + {kPreventGradient, wrap(&OpLevelCostEstimator::PredictIdentity)}, + {kReshape, wrap(&OpLevelCostEstimator::PredictIdentity)}, + {kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)}, + {kSend, wrap(&OpLevelCostEstimator::PredictIdentity)}, + + {kConst, wrap(&OpLevelCostEstimator::PredictVariable)}, + {kVariable, wrap(&OpLevelCostEstimator::PredictVariable)}, + {kVariableV2, wrap(&OpLevelCostEstimator::PredictVariable)}, {kRank, wrap(&OpLevelCostEstimator::PredictMetadata)}, {kShape, wrap(&OpLevelCostEstimator::PredictMetadata)}, @@ -349,6 +353,9 @@ OpLevelCostEstimator::DeviceInfo OpLevelCostEstimator::GetDeviceInfo( VLOG(1) << "Device: " << device.type() << " gflops: " << gflops << " gb_per_sec: " << gb_per_sec; + DCHECK_LT(0, gflops) << device.DebugString(); + DCHECK_LT(0, gb_per_sec) << device.DebugString(); + return {gflops, gb_per_sec}; } @@ -404,6 +411,12 @@ Costs OpLevelCostEstimator::PredictCostOfAnUnknownOp( Costs OpLevelCostEstimator::PredictOpCountBasedCost( double operations, const OpInfo& op_features) const { DeviceInfo device_perf = GetDeviceInfo(op_features.device()); + if (device_perf.gigaops <= 0 || device_perf.gb_per_sec <= 0) { + VLOG(1) << "BAD DEVICE. Op:" << op_features.op() + << " device type:" << op_features.device().type() + << " device model:" << op_features.device().model(); + } + Costs::NanoSeconds compute_cost(std::ceil(operations / device_perf.gigaops)); VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9 << " Execution Time (ns):" << compute_cost.count(); @@ -429,6 +442,7 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( costs.execution_time = compute_cost + memory_cost; } costs.inaccurate = found_unknown_shapes; + costs.max_memory = total_output_size; return costs; } @@ -443,10 +457,15 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs( const TensorShapeProto& original_image_shape, const TensorShapeProto& original_filter_shape, const OpInfo& op_features, bool* found_unknown_shapes) { + VLOG(2) << "op features: " << op_features.DebugString(); + VLOG(2) << "Original image shape: " << original_image_shape.DebugString(); + VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString(); auto image_shape = MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes); auto filter_shape = MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes); + VLOG(2) << "Image shape: " << image_shape.DebugString(); + VLOG(2) << "Filter shape: " << filter_shape.DebugString(); int x_index, y_index, channel_index; const string& data_format = GetDataFormat(op_features); @@ -705,18 +724,35 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( bool* found_unknown_shapes) const { int64 ops = 0; - if (op_features.op() != kConv2dBackpropInput) { - LOG(ERROR) << "Invalid Operation"; + DCHECK_EQ(kConv2dBackpropInput, op_features.op()); + + if (op_features.inputs_size() < 2) { + *found_unknown_shapes = true; return ops; } - if (op_features.outputs_size() != 1) { - // Need _output_shapes for input shape. - LOG(ERROR) << "No output shape in Conv2DBackpropInput op."; - return ops; + TensorShapeProto input_shape; + if (op_features.inputs(0).has_value()) { + const TensorProto& value = op_features.inputs(0).value(); + if (value.int64_val_size() > 0) { + for (int i = 0; i < value.int64_val_size(); ++i) { + input_shape.add_dim()->set_size(value.int64_val(i)); + } + } else { + for (int i = 0; i < value.int_val_size(); ++i) { + input_shape.add_dim()->set_size(value.int_val(i)); + } + } + } else if (op_features.outputs_size() == 1) { + input_shape = op_features.outputs(0).shape(); + } else { + // Set the minimum filter size that's feasible. + for (int i = 0; i < 4; ++i) { + input_shape.add_dim()->set_size(1); + } + *found_unknown_shapes = true; } - const auto& input_shape = op_features.outputs(0).shape(); ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( input_shape, op_features.inputs(1).shape(), op_features, found_unknown_shapes); @@ -739,18 +775,34 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations( const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, bool* found_unknown_shapes) const { int64 ops = 0; - if (op_features.op() != kConv2dBackpropFilter) { - LOG(ERROR) << "Invalid Operation"; - return ops; + DCHECK_EQ(kConv2dBackpropFilter, op_features.op()); + + TensorShapeProto filter_shape; + if (op_features.inputs_size() >= 2 && op_features.inputs(1).has_value()) { + const TensorProto& value = op_features.inputs(1).value(); + if (value.int64_val_size() > 0) { + for (int i = 0; i < value.int64_val_size(); ++i) { + filter_shape.add_dim()->set_size(value.int64_val(i)); + } + } else { + for (int i = 0; i < value.int_val_size(); ++i) { + filter_shape.add_dim()->set_size(value.int_val(i)); + } + } + } else if (op_features.outputs_size() == 1) { + filter_shape = op_features.outputs(0).shape(); + } else { + // Set the minimum filter size that's feasible. + for (int i = 0; i < 4; ++i) { + filter_shape.add_dim()->set_size(1); + } + *found_unknown_shapes = true; } - if (op_features.outputs_size() != 1) { - // Need _output_shapes for input shape. - LOG(ERROR) << "No output shape in Conv2DBackpropFilter op."; + if (op_features.inputs_size() < 1) { + *found_unknown_shapes = true; return ops; } - - const auto& filter_shape = op_features.outputs(0).shape(); ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( op_features.inputs(0).shape(), filter_shape, op_features, found_unknown_shapes); @@ -885,6 +937,30 @@ Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const { return Costs::ZeroCosts(); } +Costs OpLevelCostEstimator::PredictIdentity(const OpContext& op_context) const { + const auto& op_features = op_context.op_info; + VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)"; + Costs result = Costs::ZeroCosts(); + result.max_memory = CalculateOutputSize(op_features, &result.inaccurate); + // Assign the minimum amount of time we can represent to the identity op since + // it tends to be really cheap. + result.compute_time = kMinComputeTime; + result.execution_time = result.compute_time; + return result; +} + +Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const { + const auto& op_features = op_context.op_info; + VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)"; + Costs result = Costs::ZeroCosts(); + result.persistent_memory = + CalculateOutputSize(op_features, &result.inaccurate); + + result.compute_time = kMinComputeTime; + result.execution_time = result.execution_time; + return result; +} + Costs OpLevelCostEstimator::PredictBatchMatMul( const OpContext& op_context) const { const auto& op_features = op_context.op_info; @@ -898,13 +974,12 @@ Costs OpLevelCostEstimator::PredictBatchMatMul( Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const { const auto& op_features = op_context.op_info; - Costs costs; + Costs costs = Costs::ZeroCosts(); costs.max_memory = CalculateOutputSize(op_features, &costs.inaccurate); // Metadata operations are so cheap we assume they take the minimum amount of // time we can represent (1 ns). - costs.execution_time = 1; - costs.compute_time = 1; - costs.memory_time = 0; + costs.compute_time = kMinComputeTime; + costs.execution_time = costs.compute_time; return costs; } diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index 5f541ccf04dc74eb868d26365a50d2e3542ea7d9..a292e5e97fe52383648d74b08bb7a384b6278446 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -132,6 +132,8 @@ class OpLevelCostEstimator { Costs PredictConv2DBackpropFilter(const OpContext& op_context) const; Costs PredictMatMul(const OpContext& op_context) const; Costs PredictNoOp(const OpContext& op_context) const; + Costs PredictIdentity(const OpContext& op_context) const; + Costs PredictVariable(const OpContext& op_context) const; Costs PredictBatchMatMul(const OpContext& op_context) const; Costs PredictMetadata(const OpContext& op_context) const; diff --git a/tensorflow/core/grappler/costs/op_performance_data.proto b/tensorflow/core/grappler/costs/op_performance_data.proto index 1d623b8db8e5cc3b4e7e6b32d83695ab4ed4c0ec..37f9ebd6a146c8c0089857c7a41ba863b4c2fb1f 100644 --- a/tensorflow/core/grappler/costs/op_performance_data.proto +++ b/tensorflow/core/grappler/costs/op_performance_data.proto @@ -58,11 +58,18 @@ message LogNormalDistribution { double sigma = 2; } +message SessionInfo { + int64 intra_op_parallelism = 1; +} + // Performance data for tensorflow operations message OpPerformance { // The op OpInfo op = 1; + // Information about the session configs. + SessionInfo session_info = 12; + // The node name (optional). Makes it easier to associate the performance data // with a specific graph node. string node = 5; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index d7d07ee7a55665a2d809588f45fbfd166bd2f76a..14b4ed7507f6237ea6255f46e060aa3d0f60b34d 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" @@ -323,8 +324,13 @@ Status VirtualScheduler::Init() { } // Get the nodes that would run to output fetch_nodes. + bool ill_formed = false; std::vector nodes = - ComputeTransitiveFanin(graph, fetch_nodes); + ComputeTransitiveFanin(graph, fetch_nodes, &ill_formed); + if (ill_formed) { + return errors::InvalidArgument( + "Ill formed graph or invalid set of fetch nodes specified"); + } // TODO(dyoon): this is a bit inefficient as name_to_node is already built in // ComputeTransitiveFanin(). @@ -441,13 +447,14 @@ Status VirtualScheduler::Init() { } if (ready_nodes_->Empty()) { - return Status(error::UNAVAILABLE, "No ready nodes in the graph."); + return errors::InvalidArgument("No ready nodes in the graph."); } - if (!feed_nodes.empty()) - LOG(ERROR) << "Some feed nodes were not found in the graph: " - << str_util::Join(feed_nodes, ","); - + if (!feed_nodes.empty()) { + return errors::InvalidArgument( + strings::StrCat("Some feed nodes were not found in the graph: ", + str_util::Join(feed_nodes, ","))); + } initialized_ = true; return Status::OK(); } diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 8ccc51f5451bb2b5052fd04100ba7684b0956cea..5116c8183cb4c51dc833988cbeb75a4a184e4c40 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -139,8 +139,8 @@ class FIFOManager : public ReadyNodeManager { public: FIFOManager() : ReadyNodeManager() {} ~FIFOManager() override {} - virtual void Init( - const std::unordered_map* node_state) {} + void Init(const std::unordered_map* node_state) + override {} void AddNode(const NodeDef* node) override { nodes_.push_back(node); } const NodeDef* GetCurrNode() override { CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; @@ -325,7 +325,7 @@ class VirtualScheduler { // Boolean field for whether the cost is accurate. std::map> op_costs_; - Costs graph_costs_; // Graph cost. + Costs graph_costs_; // Graph cost. std::map op_to_cost_; // Per-op cost. // Auxilliary data structures for constructing NodeState and DeviceState. diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index f4e2de75a60182f3b2bbc366c076052bd0fae118..173ce9c09c2fd98d855a801131ed16a796d9caac 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -46,6 +46,7 @@ class GraphView { }; explicit GraphView(GraphDef* graph); + GraphDef* GetGraph() const { return graph_; } NodeDef* GetNode(const string& node_name) const; // Get the specified input port. Note that the special '-1' port_id can be // used to access the controlling nodes (i.e. the nodes connected to node_name diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 149f6fc7353b3c96e9d780c20697873c15bccaa8..2f8549cf395f6b78154f7a6faf3fea06ea6c56c4 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -134,6 +134,7 @@ std::vector ComputeTransitiveFanin( const NodeDef* node = name_to_node[NodeName(root)]; if (!node) { *ill_formed = true; + VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root; return {}; } queue.push_back(node); @@ -153,6 +154,7 @@ std::vector ComputeTransitiveFanin( for (const string& input : node->input()) { const NodeDef* in = name_to_node[NodeName(input)]; if (!in) { + VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input; *ill_formed = true; return {}; } diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 7a9ad50519c6ace696cb615d1b6c5855589a429f..7ba498dd06409635d7dfc282ab29f1133e299c9b 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -511,14 +511,15 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( } std::unique_ptr GrapplerItemFromFunctionDef( - const string& id, const FunctionDef& func, - const std::unordered_map& func_attr) { - if (id.empty()) { - LOG(ERROR) << "id must be non-empty."; + const FunctionDef& func, + const std::unordered_map& func_attr, + const FunctionDefLibrary& library) { + if (func.signature().name().empty()) { + LOG(ERROR) << "function name must be specified."; return nullptr; } std::unique_ptr new_item(new GrapplerItem()); - new_item->id = id; + new_item->id = func.signature().name(); std::unordered_map port_map; @@ -543,6 +544,8 @@ std::unique_ptr GrapplerItemFromFunctionDef( } // Add the function body to the graph. + FunctionLibraryDefinition func_def(OpRegistry::Global(), library); + for (const NodeDef& node : func.node_def()) { NodeDef* new_node = new_item->graph.add_node(); *new_node = node; @@ -557,15 +560,17 @@ std::unique_ptr GrapplerItemFromFunctionDef( // Functions use a custom format to encode connectivity. Map these custom // strings to regular ones. - const OpDef* op_def = nullptr; - Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + const OpRegistrationData* registration; + Status status = func_def.LookUp(node.op(), ®istration); if (!status.ok()) { LOG(ERROR) << "Op " << node.op() << " not registered: " << status; return nullptr; } + tensorflow::NameRangeMap inputs; tensorflow::NameRangeMap outputs; - status = tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs); + status = tensorflow::NameRangesForNode(node, registration->op_def, &inputs, + &outputs); if (!status.ok()) { LOG(ERROR) << "Op " << node.op() << " invalid: " << status; return nullptr; @@ -587,12 +592,17 @@ std::unique_ptr GrapplerItemFromFunctionDef( // Rewrite the inputs to use the normal naming convention. for (int i = 0; i < node.input_size(); ++i) { const string& input = node.input(i); - auto it = port_map.find(input); - if (it == port_map.end()) { - LOG(ERROR) << "Unknown input: " << input; - return nullptr; + if (IsControlInput(input)) { + // No need to remap control dependencies. + continue; + } else { + auto it = port_map.find(input); + if (it == port_map.end()) { + LOG(ERROR) << "Unknown input: " << input; + return nullptr; + } + node.set_input(i, it->second); } - node.set_input(i, it->second); } } diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index fa6f9faa099cafb6e1fe235bfd36fc8ad0d15c14..e892a3f556f7e9ccba91d5ce672a12d2eac49f5a 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -61,8 +61,9 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( // Factory method for creating a GrapplerItem from a FunctionDef. // Returns nullptr if the given function def cannot be converted. std::unique_ptr GrapplerItemFromFunctionDef( - const string& id, const FunctionDef& func, - const std::unordered_map& func_attr); + const FunctionDef& func, + const std::unordered_map& func_attr, + const FunctionDefLibrary& library); } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index 87377a02583d816ec87900750c54f99c666f24c9..68437b60419f73419bca4467b409818bc0b11650 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -300,9 +300,11 @@ TEST_F(GrapplerItemBuilderTest, FromSimpleFunctionDef) { std::unordered_map func_attr; func_attr["T"].set_type(DT_FLOAT); + FunctionDefLibrary library; std::unique_ptr item = - GrapplerItemFromFunctionDef("test", func, func_attr); + GrapplerItemFromFunctionDef(func, func_attr, library); CHECK(item); + EXPECT_EQ("XTimesTwo", item->id); EXPECT_EQ(4, item->graph.node_size()); EXPECT_EQ(std::vector({"y"}), item->fetch); EXPECT_EQ(1, item->feed.size()); @@ -365,9 +367,11 @@ TEST_F(GrapplerItemBuilderTest, FromFunctionDefWithMultiOutputNodes) { std::unordered_map func_attr; func_attr["T"].set_type(DT_FLOAT); + FunctionDefLibrary library; std::unique_ptr item = - GrapplerItemFromFunctionDef("test", func, func_attr); + GrapplerItemFromFunctionDef(func, func_attr, library); CHECK(item); + EXPECT_EQ("SubGrad", item->id); EXPECT_EQ(12, item->graph.node_size()); EXPECT_EQ(std::vector({"dx", "dy"}), item->fetch); EXPECT_EQ(3, item->feed.size()); @@ -399,6 +403,82 @@ TEST_F(GrapplerItemBuilderTest, FromFunctionDefWithMultiOutputNodes) { } } +TEST_F(GrapplerItemBuilderTest, FromFunctionDefWithNestedFuncs) { + FunctionDefLibrary library; + *library.add_function() = FunctionDefHelper::Define( + // Name + "Swap", + // Args + {"i0: T", "i1: T"}, + // Return values + {"o0: T", "o1: T"}, + // Attr def + {"T: {float, double}"}, + // Nodes + {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}}, + {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}}); + + FunctionDef func = FunctionDefHelper::Create( + // Name + "ManySwapsFirst", + // Args + {"x: float", "y: float"}, + // Return values + {"o: float"}, + // attr def + {}, + // Nodes + // o = x*x + y*y. Furthermore, The 1st swap depends on x2, and + // y2 depends on the 2nd swap. The 2nd swap has data dependency + // on the 1st swap. + {{{"a0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}}, + {{"a1"}, "Swap", {"a0:o0:0", "a0:o1:0"}, {{"T", DT_FLOAT}}}, + {{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}}, + {{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}}, + {{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}}, + {{"o", "o:z:0"}}); + + std::unordered_map func_attr; + func_attr["T"].set_type(DT_FLOAT); + std::unique_ptr item = + GrapplerItemFromFunctionDef(func, func_attr, library); + + for (const NodeDef &node : item->graph.node()) { + if (node.name() == "x" || node.name() == "y") { + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "a0") { + EXPECT_EQ("Swap", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("y", node.input(1)); + EXPECT_EQ("^x2", node.input(2)); + } else if (node.name() == "a1") { + EXPECT_EQ("Swap", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("a0:0", node.input(0)); + EXPECT_EQ("a0:1", node.input(1)); + } else if (node.name() == "x2") { + EXPECT_EQ("Mul", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("x", node.input(1)); + } else if (node.name() == "y2") { + EXPECT_EQ("Mul", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("y", node.input(1)); + EXPECT_EQ("^a1", node.input(2)); + } else if (node.name() == "o") { + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x2:0", node.input(0)); + EXPECT_EQ("y2:0", node.input(1)); + } + } +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/inputs/file_input_yielder.h b/tensorflow/core/grappler/inputs/file_input_yielder.h index a17e1c9ff2a5e1521250e604192d21650732e795..b597319261011e2537848a34167f69cf1e3002f0 100644 --- a/tensorflow/core/grappler/inputs/file_input_yielder.h +++ b/tensorflow/core/grappler/inputs/file_input_yielder.h @@ -18,8 +18,8 @@ limitations under the License. // that may be stored in the checkpoint are not restored in order to speedup the // initialization. -#ifndef LEARNING_BRAIN_EXPERIMENTAL_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ -#define LEARNING_BRAIN_EXPERIMENTAL_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ +#ifndef TENSORFLOW_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ +#define TENSORFLOW_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ #include #include @@ -53,4 +53,4 @@ class FileInputYielder : public InputYielder { } // end namespace grappler } // end namespace tensorflow -#endif // LEARNING_BRAIN_EXPERIMENTAL_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ +#endif // TENSORFLOW_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 791ad34bbed6a4c7d270f3a06ac34ed0f08b9b1a..e839630605a96f1528114f98b88e90a7a20b0a3a 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -125,6 +125,7 @@ tf_cc_test( "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/utils:grappler_test", ], ) @@ -139,6 +140,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", ], ) @@ -285,9 +287,11 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:virtual_cluster", "//tensorflow/core/grappler/costs:graph_memory", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:topological_sort", + "//tensorflow/core/grappler/utils:traversal", ], ) @@ -299,11 +303,13 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/utils:grappler_test", ], ) @@ -365,6 +371,7 @@ cc_library( ":dependency_optimizer", ":graph_optimizer", ":layout_optimizer", + ":loop_optimizer", ":memory_optimizer", ":model_pruner", "//tensorflow/core:framework", @@ -374,3 +381,39 @@ cc_library( "//tensorflow/core/grappler/utils:topological_sort", ], ) + +cc_library( + name = "loop_optimizer", + srcs = ["loop_optimizer.cc"], + hdrs = [ + "loop_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + ], +) + +tf_cc_test( + name = "loop_optimizer_test", + size = "small", + srcs = ["loop_optimizer_test.cc"], + deps = [ + ":loop_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.h b/tensorflow/core/grappler/optimizers/auto_parallel.h index c5d2d47782f0d5515e65e1f99b212315dcc13c0e..8d1098d87755c1257dfebe016a3baf86bfece677 100644 --- a/tensorflow/core/grappler/optimizers/auto_parallel.h +++ b/tensorflow/core/grappler/optimizers/auto_parallel.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ #define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ -#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/framework/variable.pb.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 0aeff6222c291455c04cf3fb68a90298724385dd..b8a21ea5a15ec1556db47b13db43d19bf070c266 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -808,20 +808,26 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name, // Use the packed representation whenever possible to avoid generating large // graphdefs. Moreover, avoid repeating the last values if they're equal. if (tensor->NumElements() > 4) { -#define POPULATE_TENSOR_PROTO(tensor, t, TYPE, NAME) \ - optimized = true; \ - TYPE last = tensor->flat()(0); \ - int last_index = 0; \ - for (int i = 0; i < tensor->NumElements(); ++i) { \ - TYPE cur = tensor->flat()(i); \ - t->add_##NAME##_val(cur); \ - if (cur != last) { \ - last = cur; \ - last_index = i; \ - } \ - } \ - /* Remove all identical trailing values to save memory. */ \ - t->mutable_##NAME##_val()->Truncate(last_index + 1); +#define POPULATE_TENSOR_PROTO(tensor, t, TYPE, NAME) \ + const TYPE* val_ptr = tensor->flat().data(); \ + TYPE last = *val_ptr; \ + int64 last_index = 0; \ + for (int64 i = 0; i < tensor->NumElements(); ++i) { \ + TYPE cur = *val_ptr++; \ + if (cur != last) { \ + last = cur; \ + last_index = i; \ + } \ + } \ + if (last_index < kint32max) { \ + optimized = true; \ + t->mutable_##NAME##_val()->Reserve(last_index + 1); \ + t->mutable_##NAME##_val()->AddNAlreadyReserved(last_index + 1); \ + val_ptr = tensor->flat().data(); \ + for (int64 i = 0; i <= last_index; ++i) { \ + t->set_##NAME##_val(i, *val_ptr++); \ + } \ + } if (tensor->dtype() == DT_FLOAT) { POPULATE_TENSOR_PROTO(tensor, t, float, float) @@ -1369,6 +1375,29 @@ void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward, graph_modified_ = true; } +void ConstantFolding::ReplaceOperationWithSnapshot(int input_to_forward, + NodeDef* node, + GraphDef* graph) { + node->set_op("Snapshot"); + DataType dtype = node->attr().at("T").type(); + node->clear_attr(); + (*node->mutable_attr())["T"].set_type(dtype); + + // Propagate the designated input through the Snapshot. + node->mutable_input()->SwapElements(0, input_to_forward); + // Add all other inputs as control dependencies. + for (int i = 1; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) { + break; + } + const string ctrl_dep = + AddControlDependency(node->input(i), graph, node_map_.get()); + node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep); + node->set_input(i, ctrl_dep); + } + graph_modified_ = true; +} + void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph) { node->set_op("Reciprocal"); @@ -1437,15 +1466,14 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, graph_modified_ = true; continue; } - const bool safe_to_use_shapes = - use_shape_info && (feed_nodes_.empty() || is_aggressive); + const bool is_mul = IsMul(*node); const bool is_matmul = IsMatMul(*node); const bool is_add = IsAdd(*node) || IsBiasAdd(*node); const bool is_sub = IsSub(*node); const bool is_any_div = IsAnyDiv(*node); // Simplify arithmetic operations with ones or zeros. - if (safe_to_use_shapes && + if (use_shape_info && (is_mul || is_matmul || is_add || is_sub || is_any_div) && properties.HasInputProperties(node->name()) && properties.HasOutputProperties(node->name())) { @@ -1469,7 +1497,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, ((is_mul && x_is_one) || (is_add && x_is_zero))) { // TODO(rmlarsen): Handle subtraction 0 - y. // 1 * y = y or 0 + y = y. - ReplaceOperationWithIdentity(1, node, output); + ReplaceOperationWithSnapshot(1, node, output); continue; } @@ -1489,9 +1517,9 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) || - ((is_add || is_sub) && y_is_zero && is_aggressive))) { + ((is_add || is_sub) && y_is_zero))) { // x * 1 = x or x / 1 = x or x +/- 0 = x - ReplaceOperationWithIdentity(0, node, output); + ReplaceOperationWithSnapshot(0, node, output); continue; } @@ -1658,7 +1686,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, // more with the original node name. for (const auto& fetch : item.fetch) { const NodeDef* fetch_node = node_map_->GetNode(fetch); - if (fetch_node && NumOutputs(*fetch_node) == 1) { + if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) { nodes_whitelist_.insert(fetch_node->name()); } } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 18acc91e8a18f4bf2eb77c7e5171eaca4ff5bec5..e4078514af11174788bc5a436125efeb3fa37177 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -79,6 +79,8 @@ class ConstantFolding : public GraphOptimizer { bool IsZeros(const NodeDef& node) const; void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node, GraphDef* graph); + void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node, + GraphDef* graph); Status ReplaceOperationWithConstant(double value, const TensorShapeProto& shape, NodeDef* node, GraphDef* graph); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 849a88770ae6127c6f2e3fac968a976c0a523a0b..d8df19fe6a0a5daafefd11e5ac39c8e3bc50e6e1 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -20,30 +20,15 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/public/session.h" namespace tensorflow { namespace grappler { namespace { -class ConstantFoldingTest : public ::testing::Test { - protected: - std::vector EvaluateNodes(const GraphDef& graph, - const std::vector& fetch) { - SessionOptions options; - std::unique_ptr session(NewSession(options)); - TF_CHECK_OK(session->Create(graph)); - RunOptions run_options; - std::vector output_tensors; - TF_CHECK_OK( - session->Run(run_options, {}, fetch, fetch, &output_tensors, nullptr)); - TF_CHECK_OK(session->Close()); - return output_tensors; - } -}; +class ConstantFoldingTest : public GrapplerTest {}; TEST_F(ConstantFoldingTest, SimpleFolding) { // Build a simple graph with a few trivially prunable ops. @@ -210,8 +195,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch = {"addn", "matmul3", "matmul4"}; - ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, - nullptr /* cpu_device */); + ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -229,11 +213,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("^zeros", node.input(0)); EXPECT_EQ("^y", node.input(1)); } else if (name == "mul3") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^ones", node.input(1)); } else if (name == "mul4") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("^ones", node.input(1)); } else if (name == "mul5") { @@ -245,7 +229,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("^zeros_1d", node.input(0)); EXPECT_EQ("^y", node.input(1)); } else if (name == "div1") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^ones", node.input(1)); } else if (name == "div2") { @@ -281,15 +265,15 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ(2, t.tensor_shape().dim(0).size()); EXPECT_EQ(3, t.tensor_shape().dim(1).size()); } else if (name == "add1") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^zeros", node.input(1)); } else if (name == "add2") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("^zeros", node.input(1)); } else if (name == "bias_add1") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^zeros_1d", node.input(1)); } else if (name == "bias_add2") { @@ -298,7 +282,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("zeros", node.input(0)); EXPECT_EQ("bias", node.input(1)); } else if (name == "sub1") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^zeros", node.input(1)); } else if (name == "sub2") { @@ -337,8 +321,7 @@ TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch = {"div_f", "div_i", "realdiv"}; - ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, - nullptr /* cpu_device */); + ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -428,8 +411,7 @@ TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, - nullptr /* cpu_device */); + ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -483,8 +465,7 @@ TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, - nullptr /* cpu_device */); + ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -1352,7 +1333,7 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */); + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -1413,7 +1394,7 @@ TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch.push_back("reshape"); - ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */); + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index d2da125236ab4f9b386ba2c6dc808e2b030c819c..edb0db65e987318e1e64bf0288b6ef18a7b9d662 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -36,41 +36,77 @@ namespace grappler { namespace { -int RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) { - int num_removed = 0; +bool RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) { + bool removed_input = false; int pos = 0; while (pos < node->input_size()) { if (node->input(pos) == input) { node->mutable_input()->SwapElements(pos, node->input_size() - 1); node->mutable_input()->RemoveLast(); node_map->RemoveOutput(NodeName(input), node->name()); + removed_input = true; } else { ++pos; } - ++num_removed; } - return num_removed; + return removed_input; } -// Remove duplicate control inputs. -void PruneControlInputs(NodeDef* node) { - std::unordered_set inputs; - int pos = 0; - while (pos < node->input_size()) { - const string& input = node->input(pos); - if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) { - VLOG(1) << "**** Removing duplicate control input: " << input - << " from node " << node->DebugString(); - node->mutable_input()->SwapElements(pos, node->input_size() - 1); - node->mutable_input()->RemoveLast(); - } else { - ++pos; - } +void DeleteNodes(const std::set& nodes_to_delete, GraphDef* graph) { + int last = graph->node_size() - 1; + for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) { + const int index = *it; + graph->mutable_node()->SwapElements(index, last); + last--; } + graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size()); } } // namespace +bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) { + if (!IsIdentity(node)) { + return true; + } + + if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { + return false; + } + if (!fetch_nodes_known_) { + // The output values of this node may be needed. + return false; + } + const NodeDef* input = node_map_->GetNode(NodeName(node.input(0))); + CHECK(input != nullptr) << "node = " << node.name() + << " input = " << node.input(0); + // Don't remove Identity nodes corresponding to Variable reads or following + // Recv. + if (IsVariable(*input) || IsRecv(*input)) { + return false; + } else if (IsSwitch(*input)) { + // Don't turn Identity nodes following Switch into NoOp or remove them + // if it requires anchoring a control dependencies the Switch node, which + // is not valid. + if (StringPiece(node.name()).starts_with(kConstantFoldingCtrl)) { + // TODO(rmlarsen): Try to remove this artificial contraint. + return false; + } + } + for (auto consumer : node_map_->GetOutputs(node.name())) { + if (node.input_size() > 1 && IsMerge(*consumer)) { + return false; + } + if (IsSwitch(*input)) { + for (const string& consumer_input : consumer->input()) { + if (consumer_input == AsControlDependency(node.name())) { + return false; + } + } + } + } + return true; +} + bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) { if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { return false; @@ -100,18 +136,8 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) { return false; } - // Don't turn Identity nodes inserted by Grappler after Switch into NoOp, - // since we cannot anchor control dependencies on Switch nodes. - // Don't remove Identity nodes corresponding to Variable reads. - if (IsIdentity(node)) { - const NodeDef* input = node_map_->GetNode(NodeName(node.input(0))); - if (input != nullptr) { - if (IsVariable(*input) || - (StringPiece(node.name()).starts_with(kConstantFoldingCtrl) && - IsSwitch(*input))) { - return false; - } - } + if (!SafeToRemoveIdentity(node)) { + return false; } const std::unordered_set do_not_rewrite_ops{ @@ -125,18 +151,20 @@ void DependencyOptimizer::OptimizeNode(int node_idx, SetVector* nodes_to_simplify, std::set* nodes_to_delete) { NodeDef* node = optimized_graph_->mutable_node(node_idx); - + const bool is_noop = IsNoOp(*node); + const bool is_identity = IsIdentity(*node); + const string node_name = node->name(); // Constant nodes with no input control dependency are always executed early, // so we can prune all their output control dependencies. if (IsConstant(*node) && node->input_size() == 0) { - const std::set output_nodes = node_map_->GetOutputs(node->name()); + const std::set output_nodes = node_map_->GetOutputs(node_name); for (NodeDef* fanout : output_nodes) { bool optimize_fanout = false; bool data_connection = false; for (int i = fanout->input_size() - 1; i >= 0; --i) { int pos; string input_name = ParseNodeName(fanout->input(i), &pos); - if (input_name == node->name()) { + if (input_name == node_name) { if (pos < 0) { fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1); fanout->mutable_input()->RemoveLast(); @@ -149,22 +177,21 @@ void DependencyOptimizer::OptimizeNode(int node_idx, if (optimize_fanout) { nodes_to_simplify->PushBack(node_to_idx_[fanout]); if (!data_connection) { - node_map_->RemoveOutput(node->name(), fanout->name()); + node_map_->RemoveOutput(node_name, fanout->name()); } } } - if (node_map_->GetOutputs(node->name()).empty() && fetch_nodes_known_ && - nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) { + if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ && + nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) { // Mark the node for deletion. nodes_to_delete->insert(node_to_idx_[node]); } - return; } // Change ops that only have control dependencies as outputs to NoOps. - if (node->op() != "NoOp" && SafeToConvertToNoOp(*node)) { - VLOG(1) << "***** Replacing " << node->name() << " (" << node->op() + if (!is_noop && SafeToConvertToNoOp(*node)) { + VLOG(1) << "***** Replacing " << node_name << " (" << node->op() << ") with NoOp."; // The outputs of this node are not consumed. Replace its inputs with // control dependencies and replace the op itself with the NoOp op. @@ -186,7 +213,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx, old_input, optimized_graph_, node_map_.get()); if (ctrl_inputs.insert(ctrl_input).second) { node->set_input(pos, ctrl_input); - node_map_->UpdateInput(node->name(), old_input, ctrl_input); + node_map_->UpdateInput(node_name, old_input, ctrl_input); const NodeDef* old_input_node = node_map_->GetNode(old_input); nodes_to_simplify->PushBack(node_to_idx_[old_input_node]); } @@ -194,6 +221,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx, } node->set_op("NoOp"); node->clear_attr(); + nodes_to_simplify->PushBack(node_to_idx_[node]); + return; } // Remove NoOp nodes if the product of their fan-in and fan-out is less than @@ -222,9 +251,30 @@ void DependencyOptimizer::OptimizeNode(int node_idx, // a and x, respectively, are on the same device. Control edges across device // boundaries require inter-device communication (Send/Recv pairs to be // inserted in the graph), which is very costly. + // + // We also remove identity nodes, subject to the same constraints on number of + // resulting control edges and device boundary crossings: + // + // Case a) + // +----------+ ---> a +---+ ---> a + // x --> | Identity | --^> b ==> | x | --^> b + // | | ... | | ... + // +----------+ --^> c +---+ --^> c + // + // Case b) + // x ---> +----------+ ---> a x ---> +---+ + // y --^> | Identity | ==> y --^> | a | + // ... | | ... | | + // z --^> +----------+ z --^> +---+ + // + // Case c) + // +----------+ x ---> +---+ + // x ---> | Identity | ---> a ==> \--^> | a | + // y --^> | | --^> b /\ +---+ + // +----------+ y --^> b - if (node->op() == "NoOp") { - const auto& output_node_set = node_map_->GetOutputs(node->name()); + if (is_noop || is_identity) { + const auto& output_node_set = node_map_->GetOutputs(node_name); const std::vector output_nodes(output_node_set.begin(), output_node_set.end()); const int num_outputs = output_nodes.size(); @@ -233,15 +283,14 @@ void DependencyOptimizer::OptimizeNode(int node_idx, if (num_inputs * num_outputs > num_inputs + num_outputs) { return; } - VLOG(1) << "***** Rerouting input around " << node->name(); std::vector input_nodes; for (int i = 0; i < num_inputs; ++i) { - NodeDef* tmp = node_map_->GetNode(node->input(i)); - CHECK_NE(tmp, nullptr); - input_nodes.push_back(tmp); + NodeDef* input_node = node_map_->GetNode(node->input(i)); + CHECK_NE(input_node, nullptr); + input_nodes.push_back(input_node); } - // Make sure that we don't increase the number of control edges that cross + // Make sure that we don't increase the number of edges that cross // device boundaries. if ((num_inputs == 1 && num_outputs > 1 && input_nodes[0]->device() != node->device()) || @@ -266,40 +315,75 @@ void DependencyOptimizer::OptimizeNode(int node_idx, if (num_cross_after > num_cross_before) { return; } + // To avoid potentially removing Identity nodes following _Recv nodes, + // we require that no device crossings occur in that case. + // TODO(rmlarsen): See if we can relax this condition. + if (is_identity && (num_cross_after > 0 || num_cross_before > 0)) { + return; + } + } + if (is_identity && !SafeToRemoveIdentity(*node)) { + return; } + + VLOG(1) << "***** Rerouting input around\n" << node->DebugString(); + // Now remove the node and re-wire its inputs to its outputs. for (auto consumer : output_nodes) { bool updated_consumer = false; - VLOG(1) << "***** Considering consumer " << consumer->name() << "\n" - << consumer->DebugString(); + VLOG(1) << "consumer before:\n" << consumer->DebugString(); for (int i = 0; i < num_inputs; ++i) { const NodeDef* input = input_nodes[i]; // Forward dependency from input to consumer if it doesn't already // depend on it. - if (node_map_->GetOutputs(input->name()).count(consumer) == 0) { - consumer->add_input(AsControlDependency(input->name())); + if (is_identity && i == 0) { + // Replace regular input from Identity node. + bool found_input = false; + string new_input; + const string& input_to_forward = node->input(0); + CHECK(!IsControlInput(input_to_forward)); + for (int j = 0; j < consumer->input_size(); ++j) { + const string& old_input = consumer->input(j); + if (old_input == node_name) { + new_input = input_to_forward; + node_map_->UpdateInput(consumer->name(), old_input, new_input); + consumer->set_input(j, new_input); + found_input = true; + } else if (old_input == AsControlDependency(NodeName(node_name))) { + new_input = AsControlDependency(NodeName(input_to_forward)); + node_map_->UpdateInput(consumer->name(), old_input, new_input); + consumer->set_input(j, new_input); + found_input = true; + } + } + CHECK(found_input); updated_consumer = true; - node_map_->AddOutput(input->name(), consumer->name()); - nodes_to_simplify->PushBack(node_to_idx_[input]); + } else { + // Forward dependency from input to consumer if it doesn't already + // depend on it. + if (node_map_->GetOutputs(input->name()).count(consumer) == 0) { + consumer->add_input(AsControlDependency(input->name())); + node_map_->AddOutput(input->name(), consumer->name()); + nodes_to_simplify->PushBack(node_to_idx_[input]); + updated_consumer = true; + } } } // Remove dependency on node from consumer. - updated_consumer |= RemoveInput( - consumer, AsControlDependency(node->name()), node_map_.get()); + updated_consumer |= RemoveInput(consumer, AsControlDependency(node_name), + node_map_.get()); if (updated_consumer) { - VLOG(1) << "***** Updated consumer " << consumer->name() << " (" - << consumer->op() << ")"; nodes_to_simplify->PushBack(node_to_idx_[consumer]); } + VLOG(1) << "consumer after:\n" << consumer->DebugString(); } - - node_map_->RemoveOutputs(node->name()); + node_map_->RemoveOutputs(node_name); if (fetch_nodes_known_ && - nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) { + nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) { // Mark the node for deletion. nodes_to_delete->insert(node_idx); - // Unconnect the node from its inputs to enable further optimizations. - node_map_->RemoveInputs(node->name()); + // Disconnect the node from its inputs to enable further optimizations. + node_map_->RemoveInputs(node_name); node->clear_input(); } } @@ -307,22 +391,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx, void DependencyOptimizer::CleanControlInputs() { for (int i = 0; i < optimized_graph_->node_size(); ++i) { - PruneControlInputs(optimized_graph_->mutable_node(i)); - } -} - -void DependencyOptimizer::DeleteNodes(const std::set& nodes_to_delete) { - int last = optimized_graph_->node_size() - 1; - for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) { - const int index = *it; - optimized_graph_->mutable_node()->SwapElements(index, last); - last--; + DedupControlInputs(optimized_graph_->mutable_node(i)); } - optimized_graph_->mutable_node()->DeleteSubrange(last + 1, - nodes_to_delete.size()); - // Rebuild the NodeMap which was invalidated by the node swapping above. - node_map_.reset(new NodeMap(optimized_graph_)); - BuildNodeToIdx(); } Status DependencyOptimizer::OptimizeDependencies() { @@ -330,19 +400,26 @@ Status DependencyOptimizer::OptimizeDependencies() { std::set nodes_to_delete; for (int i = 0; i < optimized_graph_->node_size(); ++i) { const NodeDef& node = optimized_graph_->node(i); - if (node.op() == "NoOp" || IsConstant(node) || SafeToConvertToNoOp(node)) { + if (IsNoOp(node) || IsIdentity(node) || IsConstant(node) || + SafeToConvertToNoOp(node)) { nodes_to_simplify.PushBack(i); } } while (!nodes_to_simplify.Empty()) { - OptimizeNode(nodes_to_simplify.PopBack(), &nodes_to_simplify, - &nodes_to_delete); + int node_to_simplify = nodes_to_simplify.PopBack(); + // Discard nodes that were marked for deletion already. + while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) { + node_to_simplify = nodes_to_simplify.PopBack(); + } + OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete); } if (fetch_nodes_known_) { VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of " << optimized_graph_->node_size() << " nodes."; - DeleteNodes(nodes_to_delete); + DeleteNodes(nodes_to_delete, optimized_graph_); + node_map_.reset(new NodeMap(optimized_graph_)); + BuildNodeToIdx(); } return Status::OK(); } @@ -431,9 +508,10 @@ Status DependencyOptimizer::TransitiveReduction() { if (longest_distance[target] > 1) { const int input_slot = control_output.second; control_edges_to_remove[target].emplace(input_slot, source); - VLOG(1) << "Removing edge from:\n" - << optimized_graph_->node(source).DebugString() << "\n\nto:\n\n" - << optimized_graph_->node(target).DebugString(); + // VLOG(1) << "Removing edge from:\n" + // << optimized_graph_->node(source).DebugString() << + // "\n\nto:\n\n" + // << optimized_graph_->node(target).DebugString(); } } } @@ -473,14 +551,13 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, *optimized_graph_ = item.graph; nodes_to_preserve_ = item.NodesToPreserve(); fetch_nodes_known_ = !item.fetch.empty(); - CleanControlInputs(); + const int num_iterations = 2; for (int iteration = 0; iteration < num_iterations; ++iteration) { Status topo_sort_status; // Perform topological sort to prepare the graph for transitive reduction. topo_sort_status = TopologicalSort(optimized_graph_); - // Set up index-based graph datastructures to speed up analysis steps below. node_map_.reset(new NodeMap(optimized_graph_)); BuildNodeToIdx(); @@ -491,9 +568,12 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } else { LOG(ERROR) << topo_sort_status.error_message(); } - - // Turn nodes with only control outputs into NoOps, prune NoOps. + // Turn nodes with only control outputs into NoOps, prune NoOp and Identity + // nodes. TF_RETURN_IF_ERROR(OptimizeDependencies()); + + // Dedup control inputs. + CleanControlInputs(); } return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h index 02d8a0f32a9bbe4e49c484ece601e219257908c0..61ed15479370614bc79c15b450039f0cbf30908d 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -43,14 +43,15 @@ class DependencyOptimizer : public GraphOptimizer { const GraphDef& optimized_graph, double result) override; private: + // Returns true if node is not an Identity node or if it is an Identity + // that is safe to remove. + bool SafeToRemoveIdentity(const NodeDef& node); // Returns true if it is safe to convert node to NoOp. bool SafeToConvertToNoOp(const NodeDef& node); // Removes all duplicate control dependencies. void CleanControlInputs(); // Builds a map from the &optimized_graph_->node(i) to i. void BuildNodeToIdx(); - // Removes the given set of nodes from the graph. - void DeleteNodes(const std::set& nodes_to_delete); // Tries to optimize the node with the given index, possibly additional // optimizations by inserting nodes in nodes_to_simplify, and pruning nodes by // inserting them in nodes_to_delete. diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index f5027a4a99e4f28b4b49df914e9247a008036c20..33d6b992d21212fe325c642b87d3c3736185c445 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -167,12 +167,14 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) { ops::Const(scope.WithOpName("c2").WithControlDependencies(ctrl_dep_id), {1.0f, 2.0f}, {1, 2}); Output neg1 = ops::Neg(scope.WithOpName("neg1"), s.output_false); + Output neg2 = ops::Neg(scope.WithOpName("neg2"), ctrl_dep_id); GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); item.fetch.push_back("c1"); item.fetch.push_back("c2"); item.fetch.push_back("neg1"); + item.fetch.push_back("neg2"); DependencyOptimizer optimizer; GraphDef output; @@ -323,25 +325,148 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) { } } +TEST_F(DependencyOptimizerTest, RemoveIdentity) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT); + Output y = ops::RandomUniform(s.WithOpName("y"), {1, 2}, DT_FLOAT); + Output z = ops::RandomUniform(s.WithOpName("z"), {1, 2}, DT_FLOAT); + + // Identity nodes to be removed. + // Case a) with a single input- and multiple outputs. + auto id_a = ops::Identity(s.WithOpName("id_a"), x); + // Case b) with multiple inputs and a single output. + auto id_b = ops::Identity( + s.WithOpName("id_b").WithControlDependencies(y).WithControlDependencies( + z), + x); + // Case c) with two inputs and two outputs. + auto id_c = ops::Identity(s.WithOpName("id_c").WithControlDependencies(y), x); + + // Output for Case a. + Output a_a = ops::Identity(s.WithOpName("a_a"), id_a); + Output a_b = ops::Identity(s.WithOpName("a_b"), id_a); + Output a_c = + ops::Identity(s.WithOpName("a_c").WithControlDependencies(id_a), z); + Output a_d = + ops::Identity(s.WithOpName("a_d").WithControlDependencies(id_a), z); + // Output for Case b. + Output b_a = ops::Identity(s.WithOpName("b_a"), id_b); + // Output for Case c. + Output c_a = ops::Identity(s.WithOpName("c_a"), id_c); + Output c_b = + ops::Identity(s.WithOpName("c_b").WithControlDependencies(id_c), z); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"a_a", "a_b", "a_c", "a_d", "b_a", "c_a", "c_b"}; + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size() - 3, output.node_size()); + for (const NodeDef& node : output.node()) { + EXPECT_NE("id_a", node.name()); + EXPECT_NE("id_b", node.name()); + EXPECT_NE("id_c", node.name()); + if (node.name() == "a_a" || node.name() == "a_b") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("x", node.input(0)); + } + if (node.name() == "a_c" || node.name() == "a_d") { + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("z", node.input(0)); + EXPECT_EQ("^x", node.input(1)); + } + if (node.name() == "b_a") { + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^y", node.input(1)); + EXPECT_EQ("^z", node.input(2)); + } + if (node.name() == "c_a") { + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^y", node.input(1)); + } + if (node.name() == "c_b") { + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("z", node.input(0)); + EXPECT_EQ("^x", node.input(1)); + EXPECT_EQ("^y", node.input(2)); + } + } +} + +TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) { + // Corner cases with repeated inputs. + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + ops::Variable x(scope.WithOpName("x"), {}, DT_BOOL); + ops::Variable y(scope.WithOpName("y"), {}, DT_BOOL); + ops::Switch sw(scope.WithOpName("switch"), x, x); + // id0 should be removed. + Output id0 = ops::Identity(scope.WithOpName("id0"), sw.output_true); + // id1 should not be removed, since it would anchor a control dependency + // on the switch. + Output id1 = ops::Identity(scope.WithOpName("id1"), sw.output_false); + Output or0 = ops::LogicalOr(scope.WithOpName("or0"), id0, id0); + Output or1 = ops::LogicalOr(scope.WithOpName("or1"), id0, y); + Output or2 = ops::LogicalOr( + scope.WithOpName("or2").WithControlDependencies(id1), y, y); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + item.fetch.push_back("or0"); + item.fetch.push_back("or1"); + item.fetch.push_back("or2"); + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size() - 1, output.node_size()); + for (const NodeDef& node : output.node()) { + EXPECT_NE("id0", node.name()); + if (node.name() == "or0") { + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("switch:1", node.input(0)); + EXPECT_EQ("switch:1", node.input(1)); + } + if (node.name() == "or1") { + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("switch:1", node.input(0)); + EXPECT_EQ("y", node.input(1)); + } + if (node.name() == "or2") { + // or1 should be unchanged. + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("y", node.input(1)); + EXPECT_EQ("^id1", node.input(2)); + } + } +} + TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); Output x = ops::Square(s.WithOpName("x"), c); - Output id1 = ops::Identity(s.WithOpName("id1"), x); - Output id2 = - ops::Identity(s.WithOpName("id2").WithControlDependencies({x}), id1); + Output neg1 = ops::Neg(s.WithOpName("neg1"), x); + Output neg2 = + ops::Neg(s.WithOpName("neg2").WithControlDependencies({x}), neg1); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch.push_back("id2"); + item.fetch.push_back("neg2"); DependencyOptimizer optimizer; GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(4, output.node_size()); - EXPECT_EQ("id2", output.node(3).name()); + EXPECT_EQ("neg2", output.node(3).name()); EXPECT_EQ(1, output.node(3).input_size()); - EXPECT_EQ("id1", output.node(3).input(0)); + EXPECT_EQ("neg1", output.node(3).input(0)); } TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) { @@ -356,17 +481,18 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) { Output grappler_added_id = ops::Identity( scope.WithOpName("ConstantFoldingCtrl/switch_1"), s.output_true); Output c1 = ops::Const(scope.WithOpName("c1") - .WithControlDependencies(id0) .WithControlDependencies(id_after_var) .WithControlDependencies(grappler_added_id), {1.0f, 2.0f}, {1, 2}); Output id1 = ops::Identity(scope.WithOpName("id1"), c1); + Output id2 = ops::Identity(scope.WithOpName("id2"), id0); Output fetch = ops::Identity(scope.WithOpName("fetch").WithControlDependencies(id1), c1); GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); item.fetch.push_back("c1"); + item.fetch.push_back("id2"); item.fetch.push_back("fetch"); DependencyOptimizer optimizer; @@ -377,8 +503,8 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) { EXPECT_EQ(item.graph.node_size() - 2, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); - // "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1" nor - // "id_after_var" should be eliminated. + // "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1", + // "id_after_var, nor "id2"" should be eliminated. EXPECT_NE("id0", node.name()); EXPECT_NE("id1", node.name()); if (node.name() == "c1") { diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.cc b/tensorflow/core/grappler/optimizers/graph_rewriter.cc index 2d47ded156048480f243c01e8a706829578438c5..b45ceb12a7972d8e0fb15c0562d0e4ceeeeeef1c 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.cc +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" namespace tensorflow { @@ -61,10 +62,19 @@ void GraphRewriter::ForwardInputs( const NodeDef& original_node, const std::unordered_set& nodes_to_delete, NodeDef* new_node) { - ForwardInputsInternal(original_node, nodes_to_delete, new_node); + ForwardInputsInternal(original_node, nodes_to_delete, false, new_node); if (!new_node->name().empty()) { optimized_nodes_[new_node->name()] = new_node; } + // Reorder inputs such that control inputs come after regular inputs. + int pos = 0; + for (int i = 0; i < new_node->input_size(); ++i) { + if (!IsControlInput(new_node->input(i))) { + new_node->mutable_input()->SwapElements(pos, i); + ++pos; + } + } + DedupControlInputs(new_node); } bool GraphRewriter::DrivesControlDependency(const NodeDef& node) const { @@ -72,6 +82,10 @@ bool GraphRewriter::DrivesControlDependency(const NodeDef& node) const { control_dependency_drivers_.end(); } +bool GraphRewriter::FeedsMerge(const NodeDef& node) const { + return merge_feeders_.find(&node) != merge_feeders_.end(); +} + bool GraphRewriter::IsDrivenByControlDependency(const NodeDef& node) const { for (const auto& input : node.input()) { CHECK(!input.empty()); @@ -94,12 +108,27 @@ bool GraphRewriter::ReceivesRefValue(const NodeDef& node) const { return ref_receivers_.find(&node) != ref_receivers_.end(); } +bool GraphRewriter::IsDrivenBySwitch(const NodeDef& node) const { + return switch_receivers_.find(&node) != switch_receivers_.end(); +} + +bool GraphRewriter::RemovalIncreasesEdgeCount(const NodeDef& node) const { + const int in_degree = node.input_size(); + auto itr = nodes_.find(node.name()); + if (itr == nodes_.end()) { + return true; + } + const int out_degree = itr->second->out_degree; + return in_degree * out_degree > in_degree + out_degree; +} + void GraphRewriter::RecordConnectivity( const NodeDef& node, const std::unordered_set& function_names) { const bool is_function = function_names.find(node.op()) != function_names.end(); bool ref_receiver = false; + bool switch_receiver = false; for (const auto& input : node.input()) { int position = 0; string input_node_name = ParseNodeName(input, &position); @@ -107,8 +136,14 @@ void GraphRewriter::RecordConnectivity( if (itr == nodes_.end()) { continue; } - const NodeInfo* fanin_info = itr->second.get(); + + NodeInfo* fanin_info = itr->second.get(); const NodeDef* fanin = fanin_info->def; + if (IsMerge(node)) { + merge_feeders_.insert(fanin); + } + // Update out_degree of fanin. + ++fanin_info->out_degree; if (position < 0) { // This is a control edge control_dependency_drivers_.insert(fanin); @@ -120,7 +155,9 @@ void GraphRewriter::RecordConnectivity( if (is_function) { function_neighbors_.insert(fanin); } - + if (IsSwitch(*fanin)) { + switch_receiver = true; + } if (position < fanin_info->outputs.size() && IsRefType(fanin_info->outputs[position])) { ref_receiver = true; @@ -134,34 +171,41 @@ void GraphRewriter::RecordConnectivity( if (ref_receiver) { ref_receivers_.insert(&node); } + if (switch_receiver) { + switch_receivers_.insert(&node); + } } void GraphRewriter::ForwardInputsInternal( const NodeDef& node, const std::unordered_set& nodes_to_delete, - NodeDef* new_node) { + bool add_as_control, NodeDef* new_node) { // To speed things up, use the optimized version of the node if // available. auto itr = optimized_nodes_.find(node.name()); if (itr != optimized_nodes_.end()) { for (const string& input : itr->second->input()) { - *new_node->add_input() = input; + *new_node->add_input() = + add_as_control ? AsControlDependency(NodeName(input)) : input; } return; } for (const auto& input : node.input()) { - string input_node_name = NodeName(input); + const string input_node_name = NodeName(input); auto itr = nodes_.find(input_node_name); if (itr == nodes_.end()) { // Invalid input, preserve it as is. - *new_node->add_input() = input; + *new_node->add_input() = + add_as_control ? AsControlDependency(NodeName(input)) : input; continue; } const NodeDef* input_node = itr->second->def; if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) { - ForwardInputsInternal(*input_node, nodes_to_delete, new_node); + ForwardInputsInternal(*input_node, nodes_to_delete, + add_as_control || IsControlInput(input), new_node); } else { - *new_node->add_input() = input; + *new_node->add_input() = + add_as_control ? AsControlDependency(NodeName(input)) : input; } } } diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.h b/tensorflow/core/grappler/optimizers/graph_rewriter.h index 4b9c9feef8f7a4456183a00c8c64f6a0d0991ad4..3d48d628e203e3d1ab6c8ee3bda9575facbd129f 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.h +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.h @@ -58,15 +58,27 @@ class GraphRewriter { // Returns true if the node has input from a stateful op. bool ReceivesRefValue(const NodeDef& node) const; + // Returns true if the node is driven by a Switch node. + bool IsDrivenBySwitch(const NodeDef& node) const; + + // Returns true if the node feeds a Merge node. + bool FeedsMerge(const NodeDef& node) const; + + // Returns true if removal of this degree would increase edge count, i.e. if + // in-degree * out-degree > in-degree + out-degree or if the condition could + // not be verified. + bool RemovalIncreasesEdgeCount(const NodeDef& node) const; + private: void RecordConnectivity(const NodeDef& node, const std::unordered_set& function_names); void ForwardInputsInternal( const NodeDef& original_node, const std::unordered_set& nodes_to_delete, - NodeDef* new_node); + bool add_as_control, NodeDef* new_node); struct NodeInfo { + int out_degree = 0; const NodeDef* def; // These are filled in when the NodeInfo is built, but not that they @@ -80,6 +92,8 @@ class GraphRewriter { std::unordered_set function_neighbors_; std::unordered_set cross_device_receivers_; std::unordered_set ref_receivers_; + std::unordered_set switch_receivers_; + std::unordered_set merge_feeders_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 50e6ba4a6483cf55e32e3d04f1b3af42c48d9f87..826f00209b15705f2a9b8b43f78134498a19d167 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -1400,39 +1400,43 @@ class HistogramSummaryProcessor : public AgnosticNodeProcessor { class IdentityNProcessor : public AgnosticNodeProcessor { public: explicit IdentityNProcessor(const OptimizeContext& opt_cxt) - : AgnosticNodeProcessor(opt_cxt) {} - - protected: - bool ShouldProcess() const override { - return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && - IsOnGPU(); - } - - std::vector GetInputPos() const override { - std::vector input_pos; + : AgnosticNodeProcessor(opt_cxt) { + std::set ops_format_agnostic = GetOpsFormatAgnostic(); for (int i = 0; i < node_->input_size(); i++) { auto input = node_map_->GetNode(node_->input(i)); int port; ParseNodeName(node_->input(i), &port); // Skip control input. if (port != -1) { + bool is_agnostic = + ops_format_agnostic.find(input->op()) != ops_format_agnostic.end(); if (IsPortDimsFour(*input, port) && - (IsNodeAfterNCHWToNHWC(*input) || + ((IsNodeAfterNCHWToNHWC(*input) && is_agnostic) || IsTransposeNCHWToNHWC(input->name()))) { - input_pos.push_back(i); + input_pos_.push_back(i); } } } - return input_pos; } + protected: + bool ShouldProcess() const override { + return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && + IsOnGPU(); + } + + std::vector GetInputPos() const override { return input_pos_; } + std::set GetOutputPos() const override { std::set output_pos{}; - for (const auto& input_pos : GetInputPos()) { + for (const auto& input_pos : input_pos_) { output_pos.insert(input_pos); } return output_pos; } + + private: + std::vector input_pos_; }; class ShapeProcessor : public IdentityNProcessor { @@ -1471,10 +1475,16 @@ class MergeProcessor : public AgnosticNodeProcessor { private: bool IsEveryInputAfterNCHWToNHWC() const { + std::set ops_format_agnostic = GetOpsFormatAgnostic(); for (const auto& input : node_->input()) { auto input_node = node_map_->GetNode(input); - if (IsNodeAfterNCHWToNHWC(*input_node) || - IsTransposeNCHWToNHWC(input_node->name())) { + int port; + ParseNodeName(input, &port); + bool is_agnostic = ops_format_agnostic.find(input_node->op()) != + ops_format_agnostic.end(); + if (IsPortDimsFour(*input_node, port) && + ((IsNodeAfterNCHWToNHWC(*input_node) && is_agnostic) || + IsTransposeNCHWToNHWC(input_node->name()))) { continue; } return false; @@ -1561,6 +1571,16 @@ class SelectProcessor : public AgnosticNodeProcessor { : AgnosticNodeProcessor(opt_cxt) {} protected: + bool ShouldProcess() const override { + auto input0 = node_map_->GetNode(node_->input(0)); + int input0_port; + ParseNodeName(node_->input(0), &input0_port); + bool is_input0_scalar_vector_4d = IsPortDimsN(*input0, input0_port, 0) || + IsPortDimsN(*input0, input0_port, 1) || + IsPortDimsN(*input0, input0_port, 4); + return AgnosticNodeProcessor::ShouldProcess() && is_input0_scalar_vector_4d; + } + std::vector GetInputPos() const override { auto input0 = node_map_->GetNode(node_->input(0)); int input0_port; @@ -1697,13 +1717,28 @@ class SqueezeProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { - return !MustPreserve() && IsPortZeroDimsN(*node_, 2) && HasOutputs() && - IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW() && - IsOnGPU(); + bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) || + (IsPortZeroDimsN(*node_, 1) && IsAlongNHW()); + return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && + IsInputConvertible() && is_dims_supported && IsOnGPU(); } Status AddLayoutTransposeToOutputs() override { return Status::OK(); } + Status CustomizedProcessing() override { + TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims")); + auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list(); + if (list->i_size() == 2) { + list->set_i(0, 2); + list->set_i(1, 3); + } else if (list->i_size() == 3) { + list->set_i(1, 2); + list->set_i(2, 3); + } + return Status::OK(); + } + + private: bool IsInputConvertible() const { int input_port; auto input = node_map_->GetNode(node_->input(0)); @@ -1716,33 +1751,31 @@ class SqueezeProcessor : public AgnosticNodeProcessor { if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) { return true; } + if (shape.dim(0).size() == 1 && shape.dim(1).size() == 1 && + shape.dim(2).size() == 1) { + return true; + } } return false; } - bool IsAlongDimHW() const { + bool IsAlongAxis(const std::vector& axis) const { if (node_->attr().find("squeeze_dims") != node_->attr().end()) { auto list = node_->attr().at("squeeze_dims").list(); // If list is empty, Squeeze op will squeeze all dimensions of size 1. if (list.i_size() == 0) return true; - if (list.i_size() == 2) { - if (list.i(0) == 1 && list.i(1) == 2) { - return true; + if (list.i_size() == axis.size()) { + bool along_axis = true; + for (int i = 0; i < axis.size(); i++) { + along_axis = along_axis && (list.i(i) == axis[i]); } + if (along_axis) return true; } } return false; } - - Status CustomizedProcessing() override { - TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims")); - auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list(); - if (list->i_size() == 2) { - list->set_i(0, 2); - list->set_i(1, 3); - } - return Status::OK(); - } + bool IsAlongHW() const { return IsAlongAxis({1, 2}); } + bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); } }; class ReduceProcessor : public AgnosticNodeProcessor { @@ -1761,7 +1794,7 @@ class ReduceProcessor : public AgnosticNodeProcessor { } Status CustomizedProcessing() override { - if (IsAlongNHW() || IsAlongHW() || IsAlongC()) { + if (IsReduceAxisSupported()) { DataType dtype = node_->attr().at("Tidx").type(); TF_RETURN_IF_ERROR( UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype)); @@ -1769,12 +1802,18 @@ class ReduceProcessor : public AgnosticNodeProcessor { return Status::OK(); } - Status AddLayoutTransposeToOutputs() override { return Status::OK(); } + Status AddLayoutTransposeToOutputs() override { + if (KeepDims()) { + return AddTransformToOutputs("Transpose"); + } + return Status::OK(); + } private: bool IsReduceAxisSupported() const { - return IsAlongAllFourDims() || IsAlongHWC() || - ((IsAlongNHW() || IsAlongHW() || IsAlongC()) && !KeepDims()); + return KeepDims() || ((IsAlongAllFourDims() || IsAlongHWC() || + IsAlongNHW() || IsAlongHW() || IsAlongC()) && + !KeepDims()); } bool IsAlongAxis(const std::vector& axis) const { @@ -2041,17 +2080,6 @@ class DataLayoutOptimizer : GraphProcessor { const LayoutOptimizer::TuningConfig& config_; }; -int GetNumTranspose(const GraphDef& graph) { - int number = 0; - for (const auto& node : graph.node()) { - if (IsTranspose(node)) { - number++; - } - } - VLOG(1) << "Number of Transpose nodes: " << number; - return number; -} - int GetNumGPUs(const Cluster& cluster) { auto devices = cluster.GetDevices(); int num_gpus = 0; @@ -2076,6 +2104,7 @@ Status LayoutOptimizer::Tune(const GrapplerItem& item, const TuningConfig& config, GraphDef* output) { auto status = graph_properties.AnnotateOutputShapes(output); if (!status.ok()) { + VLOG(1) << "Annotate shape return status: " << status.ToString(); *output = item.graph; return status; } @@ -2100,6 +2129,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphProperties graph_properties(item); auto status = graph_properties.InferStatically(false); if (!status.ok()) { + VLOG(1) << "Infer shape return status: " << status.ToString(); *output = item.graph; return status; } diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..102526e22f4742cb90757a1daf55467dd16afc3e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/loop_optimizer.h" + +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace grappler { + +Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + + return Status::OK(); +} + +void LoopOptimizer::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/, + const GraphDef& /*optimized_graph*/, + double /*result*/) { + // Nothing to do for LoopOptimizer. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..106d4628ae68f3c92ab597f903f96a6af8a64b8d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_ + +#include +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +class LoopOptimizer : public GraphOptimizer { + public: + LoopOptimizer() : opt_level_(RewriterConfig::ON) {} + explicit LoopOptimizer(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} + ~LoopOptimizer() override {} + + string name() const override { return "loop_optimizer"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override; + + private: + RewriterConfig::Toggle opt_level_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c09434f60916b9bf269b0f5006b8a3732afaa5fc --- /dev/null +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/loop_optimizer.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class LoopOptimizerTest : public ::testing::Test {}; + +void VerifyGraphsEqual(const GraphDef& original_graph, + const GraphDef& optimized_graph, const string& func) { + EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func; + for (int i = 0; i < original_graph.node_size(); ++i) { + const NodeDef& original = original_graph.node(i); + const NodeDef& optimized = optimized_graph.node(i); + EXPECT_EQ(original.name(), optimized.name()) << func; + EXPECT_EQ(original.op(), optimized.op()) << func; + EXPECT_EQ(original.input_size(), optimized.input_size()) << func; + for (int j = 0; j < original.input_size(); ++j) { + EXPECT_EQ(original.input(j), optimized.input(j)) << func; + } + } +} + +TEST_F(LoopOptimizerTest, NoOp) { + // This trivial graph is so basic there's nothing to optimize. + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + LoopOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + VerifyGraphsEqual(item.graph, output, __FUNCTION__); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index f537ecc41b964fb6c5f2e24891891c9407fcffef..3057ee5fa14bd209ad4bb6a9ad690d57435601f4 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/graph_memory.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/graph_view.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/static_schedule.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/grappler/utils/traversal.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { @@ -488,15 +490,15 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, } bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { - // Look for AddN nodes and record input names. + // Look for AddN nodes (and equivalent) and record input names. GraphView view(&item->graph); std::unordered_map> addn_list; for (NodeDef& node : *item->graph.mutable_node()) { - if (!IsAddN(node)) { + if (!IsAddN(node) && node.op() != "AccumulateNV2") { continue; } - // There is nothing to gain by optimizing nodes with 2 inputs of fewer. + // There is nothing to gain by optimizing nodes with 2 or fewer inputs. if (view.NumFanins(node, false) <= 2) { continue; } @@ -509,6 +511,10 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { } } + if (addn_list.empty()) { + return false; + } + GraphMemory memory(*item); const std::unordered_map& devices = cluster->GetDevices(); @@ -560,6 +566,59 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { } const TensorShapeProto& shape = properties.GetOutputProperties(node->name())[0].shape(); + PartialTensorShape shp(shape); + if (!shp.IsFullyDefined()) { + VLOG(1) << "Shape not fully known for " << node->name(); + continue; + } + + // Compute a topological ordering for the node fanin. + std::unordered_map topo_order; + ReverseDfs(view, {node}, nullptr, + [&topo_order](NodeDef* n) { + int topo_index = topo_order.size(); + topo_order[n] = topo_index; + }, + nullptr); + + std::vector input_topo_index; + + for (int i = 0; i < node->input_size(); ++i) { + const string& input = node->input(i); + const string node_name = NodeName(input); + NodeDef* node = view.GetNode(node_name); + input_topo_index.push_back(topo_order.at(node)); + } + int min_input_topo_index = INT_MAX; + int min_input_id = -1; + for (int i = 0; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) { + // control inputs are always last. + break; + } + const int current = input_topo_index[i]; + if (current < min_input_topo_index) { + min_input_topo_index = current; + min_input_id = i; + } + } + CHECK_LE(0, min_input_id); + std::vector pre_ctrl_deps; + std::vector post_ctrl_deps; + for (int i = node->input_size() - 1; i >= 0; --i) { + if (!IsControlInput(node->input(i))) { + // control inputs are always last. + break; + } + if (input_topo_index[i] < min_input_topo_index) { + // These control dependencies can be executed before the node. + pre_ctrl_deps.push_back(node->input(i)); + } else { + // These control dependencies should be executed after the node. + post_ctrl_deps.push_back(node->input(i)); + } + } + DataType dtype = node->attr().at("T").type(); const string& device = node->device(); @@ -572,19 +631,27 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { *(*tmp_var->mutable_attr())["shape"].mutable_shape() = shape; (*tmp_var->mutable_attr())["var_name"].set_s(tmp_var->name()); + for (const string& ctrl_dep : pre_ctrl_deps) { + *tmp_var->add_input() = ctrl_dep; + } + *tmp_var->add_input() = + AsControlDependency(NodeName(node->input(min_input_id))); + // Initialize it to zero NodeDef* zeros = item->graph.add_node(); zeros->set_name(strings::StrCat(node->name(), "/tmp_var_zeros")); zeros->set_op("ZerosLike"); zeros->set_device(device); (*zeros->mutable_attr())["T"].set_type(dtype); - *zeros->add_input() = node->input(0); + *zeros->add_input() = node->input(min_input_id); NodeDef* initialize = item->graph.add_node(); initialize->set_name(strings::StrCat(node->name(), "/tmp_var_initializer")); initialize->set_op("Assign"); initialize->set_device(device); (*initialize->mutable_attr())["T"].set_type(dtype); + (*initialize->mutable_attr())["use_locking"].set_b(false); + (*initialize->mutable_attr())["validate_shape"].set_b(false); *initialize->add_input() = tmp_var->name(); *initialize->add_input() = zeros->name(); @@ -592,15 +659,14 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { std::vector accumulates; for (int i = 0; i < node->input_size(); ++i) { const string& input = node->input(i); - if (IsControlInput(input)) { - *zeros->add_input() = input; - } else { + if (!IsControlInput(input)) { NodeDef* accumulate = item->graph.add_node(); accumulate->set_name( strings::StrCat(node->name(), "/tmp_var_accum_", i)); accumulate->set_op("AssignAdd"); accumulate->set_device(device); (*accumulate->mutable_attr())["T"].set_type(dtype); + (*accumulate->mutable_attr())["use_locking"].set_b(true); *accumulate->add_input() = initialize->name(); *accumulate->add_input() = input; accumulates.push_back(accumulate); @@ -617,6 +683,10 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { for (const NodeDef* accum : accumulates) { *node->add_input() = AsControlDependency(accum->name()); } + for (const string& ctrl_dep : post_ctrl_deps) { + *node->add_input() = ctrl_dep; + } + updated_graph = true; } @@ -660,6 +730,7 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap, *swap_in_node->add_input() = swap_out_node->name(); // Colocate the swap_in_ node with the node itself. + swap_in_node->set_device(node->device()); string coloc_group = strings::StrCat("loc@", tensor_to_swap); (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); @@ -828,8 +899,7 @@ static NodeDef* FindSwapOutTrigger( const std::unordered_set& fanout = view.GetFanout(generator); NodeDef* trigger = nullptr; - Costs::NanoSeconds earliest_fanout( - static_cast(std::numeric_limits::max() >> 2)); + Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity()); for (const auto& port : fanout) { if (port.node == node) { @@ -861,6 +931,15 @@ static bool IsSwappable(GraphView::InputPort input) { return !IsRefType(dtype); } +struct MemInfo { + GraphView::OutputPort port; + int64 memory_used; + std::vector uses_left; + double fitness; + + bool operator<(const MemInfo& other) const { return fitness < other.fitness; } +}; + static bool IdentifySwappingCandidates( Cluster* cluster, GrapplerItem* item, std::unordered_set* skip_list, std::unordered_map* nodes_to_swap) { @@ -890,31 +969,56 @@ static bool IdentifySwappingCandidates( continue; } int64 required_savings = mem_usage.used_memory - prop.memory_size(); - // TODO(bsteiner): sort the tensors by how long they're live. - std::unordered_map execution_times; + std::unordered_map op_completion_times; { - std::unordered_map - tmp_execution_times; - if (!EstimateEarliestExecutionTimes(*item, cluster, &tmp_execution_times) - .ok()) { + VirtualCluster vcluster(cluster->GetDevices()); + if (!vcluster.Provision().ok()) { return false; } - for (const auto& exec_time : tmp_execution_times) { - execution_times.emplace(exec_time.first->name(), exec_time.second); + if (!vcluster.Initialize(*item).ok()) { + return false; + } + RunMetadata metadata; + Status s = vcluster.Run(item->graph, item->feed, item->fetch, &metadata); + if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) { + return false; + } + + for (const auto& dev_stats : metadata.step_stats().dev_stats()) { + for (const auto& node_stats : dev_stats.node_stats()) { + Costs::NanoSeconds exec_time = + Costs::NanoSeconds(1) + + Costs::MicroSeconds(node_stats.all_start_micros() + + node_stats.op_end_rel_micros()); + op_completion_times.emplace(node_stats.node_name(), exec_time); + } + } + } + + Costs::Duration peak_time = -1; + for (const auto& live_tensor : mem_usage.live_tensors) { + if (live_tensor.allocation_time > peak_time) { + peak_time = live_tensor.allocation_time; } } + std::vector mem_state; + GraphView graph(&item->graph); for (const auto& live_tensor : mem_usage.live_tensors) { + if (live_tensor.memory_used <= 1024) { + // Don't bother with small tensors. + continue; + } if (live_tensor.deallocation_time - live_tensor.allocation_time <= Costs::Duration(1e6)) { // Not enough time to swap. VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node; continue; } - if (live_tensor.memory_used <= 1024) { - // Don't bother with small tensors. + + if (skip_list->find(live_tensor.node) != skip_list->end()) { continue; } GraphView::OutputPort port = @@ -922,56 +1026,77 @@ static bool IdentifySwappingCandidates( if (!IsSwappable(graph, port)) { continue; } - Costs::NanoSeconds execution_time(-1); - GraphView::InputPort fanout_to_swap; + MemInfo mem_info; + mem_info.port = port; + mem_info.memory_used = live_tensor.memory_used; + Costs::Duration allocation_time = live_tensor.allocation_time; + Costs::Duration earliest_use(Costs::Duration::infinity()); + bool valid = true; for (GraphView::InputPort input : graph.GetFanout(port)) { - if (skip_list->find(input.node->name()) != skip_list->end()) { + // Get execution time. + auto it = op_completion_times.find(input.node->name()); + if (it == op_completion_times.end()) { + valid = false; + break; + } + if (it->second <= peak_time) { continue; } + + if (skip_list->find(input.node->name()) != skip_list->end()) { + valid = false; + break; + } string input_name = strings::StrCat(input.node->name(), ":", input.port_id); if (skip_list->find(input_name) != skip_list->end()) { - continue; + valid = false; + break; } if (!IsSwappable(input)) { - continue; - } - auto it = execution_times.find(input.node->name()); - if (it != execution_times.end()) { - if (it->second > execution_time) { - fanout_to_swap = input; - execution_time = it->second; - } + valid = false; + break; } + + // Set earliest use time that's after peak. + mem_info.uses_left.emplace_back(input); + earliest_use = std::min(earliest_use, it->second); } - // Annotate the fanout to request the tensor to be swapped if it's not - // already been done. - bool found = false; - if (!fanout_to_swap.node) { - continue; - } - auto it = fanout_to_swap.node->attr().find("_swap_to_host"); - if (it != fanout_to_swap.node->attr().end()) { - const AttrValue& val = it->second; - for (int port_id : val.list().i()) { - if (port_id == fanout_to_swap.port_id) { - found = true; - break; - } - } + if (valid && !mem_info.uses_left.empty()) { + // Compute the fitness: we need the tensor to be generated way away of + // the time of peak memory usage (to ensure there is enough time to swap + // it out). We also need to ensure it's used way after the peak time, to + // ensure that swapping the tensor back in won't recreate the memory + // bottleneck. Last but not least, we want the tensor to have as few + // remaining uses as possible. + mem_info.fitness = std::pow((earliest_use - peak_time).count(), 2); + mem_info.fitness /= std::pow(mem_info.uses_left.size(), 2); + mem_info.fitness += std::pow((allocation_time - peak_time).count(), 2); + mem_info.fitness = -mem_info.fitness; + mem_state.push_back(mem_info); } - if (!found) { + } + + // Sort by fitness + std::sort(mem_state.begin(), mem_state.end()); + + for (const MemInfo& mem_info : mem_state) { + for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) { + VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":" + << fanout_to_swap.port_id << " of tensor " + << mem_info.port.node->name() << ":" << mem_info.port.port_id + << " of size " << mem_info.memory_used; + (*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back( fanout_to_swap.port_id); - required_savings -= live_tensor.memory_used; - updated_graph = true; - if (required_savings < 0) { - break; - } + } + required_savings -= mem_info.memory_used; + updated_graph = true; + if (required_savings < 0) { + break; } } } - return updated_graph; } @@ -1011,7 +1136,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level, } for (auto& swap : nodes_to_swap) { const NodeDef* node = swap.first; - std::vector props = + const std::vector& props = properties.GetInputProperties(node->name()); SwapInfo& swap_info = swap.second; int64 bytes_to_swap = 0; @@ -1108,7 +1233,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, bool updated_graph = true; for (int i = 0; i < 25 && updated_graph; ++i) { updated_graph = false; - if ((optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS || + if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT || + optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS || optimization_level_ == RewriterConfig::HEURISTICS) && cluster != nullptr) { updated_graph |= SchedulingPass(cluster, &optimized_item); diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc index dd2d20d8d682856a8a94f99e4ca2aa706331d9d4..5d7913e0c018ecf14cc09ab91d3a71125c720aa5 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -19,17 +19,18 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" namespace tensorflow { namespace grappler { namespace { -class RecomputeSubgraphTest : public ::testing::Test {}; +class RecomputeSubgraphTest : public GrapplerTest {}; TEST_F(RecomputeSubgraphTest, SimpleSubgraph) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -193,7 +194,7 @@ TEST_F(RecomputeSubgraphTest, MultiNode) { EXPECT_EQ("^gradients/BN1Grad", recompute_trigger_c->input(0)); } -class MemoryOptimizerTest : public ::testing::Test { +class MemoryOptimizerTest : public GrapplerTest { public: static std::unique_ptr CreateVirtualCluster() { DeviceProperties cpu_device; @@ -201,6 +202,7 @@ class MemoryOptimizerTest : public ::testing::Test { cpu_device.set_frequency(1000); cpu_device.set_num_cores(4); cpu_device.set_bandwidth(32); + cpu_device.set_memory_size(1024 * 1024); DeviceProperties gpu_device; gpu_device.set_type("GPU"); gpu_device.set_frequency(1000); @@ -337,25 +339,27 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) { for (const auto& node : output.node()) { if (node.name() == "e") { // The d node isn't swappable. - EXPECT_EQ(4, node.input_size()); + EXPECT_EQ(5, node.input_size()); EXPECT_EQ("d", node.input(2)); + EXPECT_EQ("^swap_out_d_2", node.input(4)); } } } TEST_F(MemoryOptimizerTest, AccumulationRewrites) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"), - {128, 128, 8}, DT_FLOAT); - Output b = ops::Variable(s.WithOpName("b").WithDevice("/gpu:0"), - {128, 128, 8}, DT_FLOAT); - Output c = ops::Variable(s.WithOpName("c").WithDevice("/gpu:0"), - {128, 128, 8}, DT_FLOAT); - Output d = ops::AddN(s.WithOpName("d").WithDevice("/gpu:0"), {a, b, c}); + Output a = ops::RandomNormal(s.WithOpName("a").WithDevice("/cpu:0"), + {128, 128, 8}, DT_FLOAT); + Output b = ops::RandomNormal(s.WithOpName("b").WithDevice("/cpu:0"), + {128, 128, 8}, DT_FLOAT); + Output c = ops::RandomNormal(s.WithOpName("c").WithDevice("/cpu:0"), + {128, 128, 8}, DT_FLOAT); + Output d = ops::AddN(s.WithOpName("d").WithDevice("/cpu:0"), {a, b, c}); + Output e = ops::Square(s.WithOpName("e").WithDevice("/cpu:0"), d); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"d"}; + item.fetch = {"e"}; std::unique_ptr cluster(CreateVirtualCluster()); MemoryOptimizer optimizer(RewriterConfig::SCHEDULING_HEURISTICS); @@ -374,9 +378,27 @@ TEST_F(MemoryOptimizerTest, AccumulationRewrites) { } else if (node.name() == "d/tmp_var") { EXPECT_EQ("TemporaryVariable", node.op()); count++; + } else if (node.name() == "e") { + EXPECT_EQ("Square", node.op()); + EXPECT_EQ("d", node.input(0)); + count++; + } + } + EXPECT_EQ(4, count); + + std::vector fetch = {"a", "b", "c", "e"}; + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(4, tensors.size()); + + for (int i = 0; i < tensors[0].NumElements(); ++i) { + float actual = tensors[3].flat()(i); + float expected = 0.0f; + for (int j = 0; j < 3; ++j) { + expected += tensors[j].flat()(i); } + expected *= expected; + EXPECT_NEAR(actual, expected, 1e-4); } - EXPECT_EQ(3, count); } } // namespace diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 4228e7baba9741cf9160d4789d6bef04c50a7409..e27b9df6206c652e4503bb064366201a2b90f13a 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" +#include "tensorflow/core/grappler/optimizers/loop_optimizer.h" #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/grappler/utils/topological_sort.h" @@ -75,6 +76,9 @@ std::unique_ptr MetaOptimizer::NewOptimizer( graph_optimizer.reset( new DependencyOptimizer(cfg_.dependency_optimization())); } + if (optimizer == "loop") { + graph_optimizer.reset(new LoopOptimizer(cfg_.loop_optimization())); + } return graph_optimizer; } @@ -97,11 +101,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizers.push_back(std::unique_ptr( new DependencyOptimizer(cfg_.dependency_optimization()))); } - if (cfg_.layout_optimizer() == RewriterConfig::ON) { + if (cfg_.loop_optimization() != RewriterConfig::OFF) { + optimizers.push_back(std::unique_ptr( + new LoopOptimizer(cfg_.loop_optimization()))); + } + if (cfg_.layout_optimizer() != RewriterConfig::OFF) { optimizers.push_back( std::unique_ptr(new LayoutOptimizer())); } - if (cfg_.memory_optimization() > 1) { + if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) { if (cfg_.memory_optimizer_target_node_name_prefix().empty()) { optimizers.push_back(std::unique_ptr( // Use the default target node name prefix "gradients/" @@ -119,8 +127,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } } else { std::set available_optimizers = { - "pruning", "constfold", "layout", "memory", - "autoparallel", "arithmetic", "dependency"}; + "pruning", "constfold", "layout", "memory", + "autoparallel", "arithmetic", "dependency", "loop"}; for (const auto& optimizer : cfg_.optimizers()) { if (available_optimizers.find(optimizer) != available_optimizers.end()) { optimizers.push_back(NewOptimizer(optimizer)); @@ -136,7 +144,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, bool already_optimized = false; for (const auto& optimizer : optimizers) { if (!already_optimized) { - auto status = optimizer->Optimize(cluster, item, optimized_graph); + Status status = optimizer->Optimize(cluster, item, optimized_graph); string result; if (!status.ok()) { VLOG(1) << "Not able to apply optimizer " << optimizer->name() @@ -152,7 +160,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, << " return status: " << result; } else { GrapplerItem optimized_item(item, std::move(*optimized_graph)); - auto status = + Status status = optimizer->Optimize(cluster, optimized_item, optimized_graph); string result; if (!status.ok()) { @@ -201,11 +209,13 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, bool MetaOptimizerEnabled(const RewriterConfig& cfg) { return !cfg.disable_model_pruning() || - cfg.layout_optimizer() == RewriterConfig::ON || + cfg.layout_optimizer() != RewriterConfig::OFF || cfg.constant_folding() != RewriterConfig::OFF || cfg.dependency_optimization() != RewriterConfig::OFF || + cfg.loop_optimization() == RewriterConfig::ON || cfg.arithmetic_optimization() != RewriterConfig::OFF || - cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 || + cfg.auto_parallel().enable() || + cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT || !cfg.optimizers().empty(); } diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index c9bec7890e6af008859d21555fb7ed74451c72c6..f52a2ab86288adacefec6796ceed4cea73d9b632 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -26,12 +26,21 @@ limitations under the License. namespace tensorflow { namespace grappler { -bool IsTrivialOp(const NodeDef& node) { +bool IsTrivialOp(const NodeDef& node, const GraphRewriter& rewriter) { // Remove the stop gradient nodes since they serve no purpose once the graph // is built. Also remove Identity ops. - if (IsStopGradient(node) || IsIdentity(node)) { + if (IsStopGradient(node)) { return true; } + if (IsIdentity(node)) { + if (rewriter.FeedsMerge(node) || rewriter.IsDrivenBySwitch(node) || + rewriter.IsDrivenByControlDependency(node) || + rewriter.DrivesControlDependency(node)) { + return false; + } else { + return true; + } + } if (IsAddN(node) && NumNonControlInputs(node) <= 1) { return true; } @@ -41,7 +50,7 @@ bool IsTrivialOp(const NodeDef& node) { Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* pruned_graph) { - std::unordered_set nodes_to_preserve = item.NodesToPreserve(); + const std::unordered_set& nodes_to_preserve = item.NodesToPreserve(); // Prune all the nodes that won't be executed, ie all the nodes that aren't in // the fanin of a fetch node. If fetch nodes aren't specified, we'll assume @@ -58,7 +67,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, // let's be conservative and preserve the graph as is. return errors::InvalidArgument("Invalid input graph."); } - // Try to keep the nodes ordored somewhat topologically since this helps + // Try to keep the nodes ordered somewhat topologically since this helps // further optimizations perform better. for (int i = keep.size() - 1; i >= 0; --i) { *runnable_item.graph.add_node() = *keep[i]; @@ -72,7 +81,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, // Check if we can further prune the graph, by removing the trivial ops. std::unordered_set nodes_to_delete; for (auto& node : runnable_item.graph.node()) { - if (!IsTrivialOp(node)) { + if (!IsTrivialOp(node, rewriter)) { continue; } @@ -95,8 +104,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, // converting references to non-references. It is important to preserve // these non-references since the partitioner will avoid sending // non-references across partitions more than once. - if (!rewriter.DrivesControlDependency(node) && - !rewriter.IsDrivenByControlDependency(node) && + if (!rewriter.RemovalIncreasesEdgeCount(node) && !rewriter.IsConnectedToFunction(node) && !rewriter.IsDrivenByAnotherDevice(node) && !rewriter.ReceivesRefValue(node)) { @@ -112,13 +120,16 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, return Status::OK(); } + const bool fetches_are_known = !item.fetch.empty(); for (auto& node : runnable_item.graph.node()) { - NodeDef* new_node = pruned_graph->add_node(); - *new_node = node; - new_node->clear_input(); - rewriter.ForwardInputs(node, nodes_to_delete, new_node); + if (!fetches_are_known || + nodes_to_delete.find(&node) == nodes_to_delete.end()) { + NodeDef* new_node = pruned_graph->add_node(); + *new_node = node; + new_node->clear_input(); + rewriter.ForwardInputs(node, nodes_to_delete, new_node); + } } - VLOG(1) << "Pruned " << nodes_to_delete.size() << " nodes from the graph. The graph now contains " << pruned_graph->node_size() << " nodes."; diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc index ee722f311edbb55fbb19044df57cfdfd0b29b1b8..8480a74572883a4657e11606b4cb8dcd5532ea3a 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc @@ -156,47 +156,42 @@ TEST_F(ModelPrunerTest, NoOpPruning) { const NodeDef& new_e = output.node(4); EXPECT_EQ(NodeName(e.name()), new_e.name()); - EXPECT_EQ(1, new_e.input_size()); - EXPECT_EQ(NodeName(d.name()), new_e.input(0)); - EXPECT_EQ(2, new_d.input_size()); - EXPECT_EQ(NodeName(b.name()), new_d.input(0)); - EXPECT_EQ(1, new_c.input_size()); - EXPECT_EQ(NodeName(b.name()), new_c.input(0)); + for (const auto& new_node : output.node()) { + if (new_node.name() != "a") { + EXPECT_EQ(1, new_node.input_size()); + EXPECT_EQ("a", new_node.input(0)); + } + } } -TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { - // Build a simple graph with a few trivially prunable ops. - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); - Output b = ops::Sqrt(s.WithOpName("b"), {a}); - Output c = ops::Identity(s.WithOpName("c"), b); - Output d = ops::Identity(s.WithOpName("d"), c); - Output e = ops::Sqrt(s.WithOpName("e").WithControlDependencies(c), {d}); +TEST_F(ModelPrunerTest, PreserveIdentities) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT); + ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL); + ops::Switch s(scope.WithOpName("switch"), v_in, v_ctrl); + // id0 is preserved because it is fed by a Switch and drives a + // control dependency. + Output id0 = ops::Identity(scope.WithOpName("id0"), s.output_true); + // id1 is preserved because it feeds a Merge. + Output id1 = ops::Identity( + scope.WithOpName("id1").WithControlDependencies(v_ctrl), s.output_false); + Output id2 = ops::Identity(scope.WithOpName("id2"), id0); + Output id3 = + ops::Identity(scope.WithOpName("id3").WithControlDependencies(id0), id1); + auto merge = ops::Merge(scope.WithOpName("merge"), {id0, id1}); GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + item.fetch.push_back("id2"); + item.fetch.push_back("id3"); + item.fetch.push_back("merge"); ModelPruner pruner; GraphDef output; Status status = pruner.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - EXPECT_EQ(5, output.node_size()); - const NodeDef& new_a = output.node(0); - EXPECT_EQ(NodeName(a.name()), new_a.name()); - const NodeDef& new_b = output.node(1); - EXPECT_EQ(NodeName(b.name()), new_b.name()); - const NodeDef& new_c = output.node(2); - EXPECT_EQ(NodeName(c.name()), new_c.name()); - const NodeDef& new_d = output.node(3); - EXPECT_EQ(NodeName(d.name()), new_d.name()); - const NodeDef& new_e = output.node(4); - EXPECT_EQ(NodeName(e.name()), new_e.name()); - - EXPECT_EQ(2, new_e.input_size()); - EXPECT_EQ(NodeName(c.name()), new_e.input(0)); - EXPECT_EQ("^c", new_e.input(1)); + TF_EXPECT_OK(status); + EXPECT_EQ(item.graph.node_size(), output.node_size()); } TEST_F(ModelPrunerTest, PruningSkipsRefOutputs) { @@ -239,55 +234,53 @@ TEST_F(ModelPrunerTest, PruningSkipsRefOutputs) { EXPECT_EQ("b", new_e.input(0)); } -TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) { +// TODO(rmlarsen): Reenable this test when the issues with +// //robotics/learning/sensor_predict:utils_multi_sensor_rnn_test +// have been resolved. +/* +TEST_F(ModelPrunerTest, PruningForwardsCtrlDependencies) { // Build a simple graph with a few trivially prunable ops. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); Output b = ops::Sqrt(s.WithOpName("b"), {a}); Output c = ops::Sqrt(s.WithOpName("c"), {a}); - Output d = ops::Identity(s.WithOpName("d"), c); - Output e = ops::Identity(s.WithOpName("e"), d); - Output f = ops::Sqrt(s.WithOpName("f"), {e}); + Output d = ops::Identity(s.WithOpName("d").WithControlDependencies(b), c); + Output e = ops::Identity(s.WithOpName("e").WithControlDependencies(c), d); + Output f = ops::Sqrt(s.WithOpName("f"), {d}); + Output g = ops::Sqrt(s.WithOpName("g"), {e}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - - // Add a control dependency between b and d and another one between c and e. - // They should be properly forwarded. - EXPECT_EQ("d", item.graph.node(3).name()); - EXPECT_EQ("e", item.graph.node(4).name()); - *item.graph.mutable_node(3)->add_input() = "^b"; - *item.graph.mutable_node(4)->add_input() = "^c"; + item.fetch.push_back("f"); + item.fetch.push_back("g"); ModelPruner pruner; GraphDef output; Status status = pruner.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + LOG(INFO) << "After: " << output.DebugString(); - EXPECT_EQ(6, output.node_size()); - const NodeDef& new_a = output.node(0); - EXPECT_EQ(NodeName(a.name()), new_a.name()); - const NodeDef& new_b = output.node(1); - EXPECT_EQ(NodeName(b.name()), new_b.name()); - const NodeDef& new_c = output.node(2); - EXPECT_EQ(NodeName(c.name()), new_c.name()); - const NodeDef& new_d = output.node(3); - EXPECT_EQ(NodeName(d.name()), new_d.name()); - const NodeDef& new_e = output.node(4); - EXPECT_EQ(NodeName(e.name()), new_e.name()); - const NodeDef& new_f = output.node(5); - EXPECT_EQ(NodeName(f.name()), new_f.name()); - - EXPECT_EQ(1, new_f.input_size()); - EXPECT_EQ(NodeName(e.name()), new_f.input(0)); - EXPECT_EQ(2, new_e.input_size()); - EXPECT_EQ(NodeName(d.name()), new_e.input(0)); - EXPECT_EQ("^c", new_e.input(1)); - EXPECT_EQ(2, new_d.input_size()); - EXPECT_EQ(NodeName(c.name()), new_d.input(0)); - EXPECT_EQ("^b", new_d.input(1)); + EXPECT_EQ(5, output.node_size()); + for (const auto& new_node : output.node()) { + // "d" and "e" should be removed. + EXPECT_NE("d", new_node.name()); + EXPECT_NE("e", new_node.name()); + if (new_node.name() == "g") { + EXPECT_EQ(2, new_node.input_size()); + // The input from switch should be forwarded to id3. + EXPECT_EQ("c", new_node.input(0)); + EXPECT_EQ("^b", new_node.input(1)); + } + if (new_node.name() == "f") { + EXPECT_EQ(2, new_node.input_size()); + // The input from switch should be forwarded to id3. + EXPECT_EQ("c", new_node.input(0)); + EXPECT_EQ("^b", new_node.input(1)); + } + } } +*/ TEST_F(ModelPrunerTest, PruningPerservesFetch) { // Build a simple graph with a few trivially prunable ops. @@ -296,6 +289,7 @@ TEST_F(ModelPrunerTest, PruningPerservesFetch) { Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); Output b = ops::Sqrt(s.WithOpName("b"), {a}); Output c = ops::Identity(s.WithOpName("c"), b); + Output d = ops::Identity(s.WithOpName("d"), c); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 8099214c2bd81e642bbcc8fc913d1ec3307d6251..eb5a2c48dc8b12f7b4090e80c403e238a526e122 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" @@ -131,7 +132,7 @@ string ParseNodeName(const string& name, int* position) { strings::Scanner scan(name); scan.ZeroOrOneLiteral("^") .RestartCapture() - .One(strings::Scanner::LETTER_DIGIT_DOT) + .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); StringPiece capture; StringPiece remaining; @@ -207,7 +208,7 @@ string AsControlDependency(const string& node_name) { : strings::StrCat("^", node_name); } -int NumOutputs(const NodeDef& node) { +int NumOutputs(const NodeDef& node, GraphDef* graph) { int num_outputs = 0; const OpDef* op_def = nullptr; auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); @@ -222,6 +223,12 @@ int NumOutputs(const NodeDef& node) { num_outputs++; } } + } else { + FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library()); + auto status = fdef.LookUpOpDef(node.op(), &op_def); + if (status.ok()) { + num_outputs = op_def->output_arg_size(); + } } return num_outputs; } @@ -305,6 +312,20 @@ void PermuteNodesInPlace(GraphDef* graph, std::vector* permutation, } } +void DedupControlInputs(NodeDef* node) { + std::unordered_set inputs; + int pos = 0; + while (pos < node->input_size()) { + const string& input = node->input(pos); + if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) { + node->mutable_input()->SwapElements(pos, node->input_size() - 1); + node->mutable_input()->RemoveLast(); + } else { + ++pos; + } + } +} + namespace { template inline void STLSortAndRemoveDuplicates(T* v) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index c04a9a666dd68c42f378543bd2fc997a4bde872c..4ecb28f681507f50ad5909f15cf1b408ed6e2979 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -135,7 +135,7 @@ string AsControlDependency(const string& node); // Returns the number of outputs of a node according to its OpDef. Note that // some of the outputs may be unconnected. -int NumOutputs(const NodeDef& node); +int NumOutputs(const NodeDef& node, GraphDef* graph); // Number of connected non-control inputs. int NumNonControlInputs(const NodeDef& node); @@ -143,6 +143,9 @@ int NumNonControlInputs(const NodeDef& node); // Number of connected non-control outputs. int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map); +// Removes redundant control inputs from node. +void DedupControlInputs(NodeDef* node); + // Returns the data type in attribute `attr_name` of `node`. If that attribute // doesn't exist, returns DT_INVALID. DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 534f7a063fe90bf72f8a2afba7ae8f75b8472a36..0a9dbe22cfe3cd01c2c61661adcdd4839a957f03 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -99,3 +99,49 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "traversal", + srcs = ["traversal.cc"], + hdrs = ["traversal.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + ], +) + +tf_cc_test( + name = "traversal_test", + srcs = ["traversal_test.cc"], + deps = [ + ":traversal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "grappler_test", + testonly = 1, + srcs = [ + "grappler_test.cc", + ], + hdrs = ["grappler_test.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:direct_session", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core/grappler:utils", + ], +) diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..813f65f825759ca22dba2bdfd8433d946b7dd852 --- /dev/null +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace grappler { + +std::vector GrapplerTest::EvaluateNodes( + const GraphDef& graph, const std::vector& node_names) { + SessionOptions options; + std::unique_ptr session(NewSession(options)); + TF_CHECK_OK(session->Create(graph)); + RunOptions run_options; + std::vector output_tensors; + TF_CHECK_OK(session->Run(run_options, {}, node_names, node_names, + &output_tensors, nullptr)); + TF_CHECK_OK(session->Close()); + return output_tensors; +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h new file mode 100644 index 0000000000000000000000000000000000000000..46ce47c8c3b6bc18b6eac76bbdb8ec1f8a58fab2 --- /dev/null +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_GRAPPLER_TEST_H_ +#define TENSORFLOW_GRAPPLER_GRAPPLER_TEST_H_ + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +class GrapplerTest : public ::testing::Test { + protected: + std::vector EvaluateNodes(const GraphDef& graph, + const std::vector& node_names); +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_GRAPPLER_TEST_H_ diff --git a/tensorflow/core/grappler/utils/traversal.cc b/tensorflow/core/grappler/utils/traversal.cc new file mode 100644 index 0000000000000000000000000000000000000000..f44f53c4e63805544fa480628e805303064edb3d --- /dev/null +++ b/tensorflow/core/grappler/utils/traversal.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/traversal.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { +namespace grappler { + +void ReverseDfs(const GraphView& graph_view, const std::vector& from, + const std::function& pre_order, + const std::function& post_order, + const std::function& on_back_edge) { + // Stack of work to do. + struct StackElem { + NodeDef* node; + bool children_visited; + NodeDef* src; + }; + std::vector stack; + + stack.reserve(from.size()); + for (NodeDef* node : from) { + stack.push_back(StackElem{node, false}); + } + + enum NodeState { NOT_VISITED = 0, VISITING = 1, DONE = 2 }; + std::unordered_map node_state; + while (!stack.empty()) { + StackElem w = stack.back(); + stack.pop_back(); + + if (w.children_visited) { + // We've processed all the children of this node + node_state[w.node] = DONE; + if (post_order) { + post_order(w.node); + } + continue; + } + + auto& rslt = node_state[w.node]; + if (rslt == DONE) { + continue; + } else if (rslt == VISITING) { + // Loop detected + if (on_back_edge) { + on_back_edge(w.src, w.node); + } + continue; + } + rslt = VISITING; + if (pre_order) { + pre_order(w.node); + } + + // Enqueue the node again with the children_visited flag set to true. + stack.push_back(StackElem{w.node, true, w.src}); + + // Now enqueu the node children. + for (const auto fanin : graph_view.GetFanins(*w.node, true)) { + stack.push_back(StackElem{fanin.node, false, w.node}); + } + } +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/traversal.h b/tensorflow/core/grappler/utils/traversal.h new file mode 100644 index 0000000000000000000000000000000000000000..bb3fa090e8fdaf12ed6dcb18eb1511c55496a125 --- /dev/null +++ b/tensorflow/core/grappler/utils/traversal.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ + +#include +#include "tensorflow/core/grappler/graph_view.h" + +namespace tensorflow { +namespace grappler { + +// Traverse the graph in reverse dfs order, starting from the list of nodes +// specified in the 'from' argument. The pre_order and post_order functors will +// be called on each reachable node (including the 'from' nodes) in pre and post +// order. If loops are found, the on_back_edge functor will be called on the +// corresponding back edges. Moreover, the pre and post order will assume that +// these back edges will be cut. +void ReverseDfs(const GraphView& graph_view, const std::vector& from, + const std::function& pre_order, + const std::function& post_order, + const std::function& on_back_edge); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ diff --git a/tensorflow/core/grappler/utils/traversal_test.cc b/tensorflow/core/grappler/utils/traversal_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc68bd1a9637cb6f61955e8fa5d495a34f19cb09 --- /dev/null +++ b/tensorflow/core/grappler/utils/traversal_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/traversal.h" +//#include "tensorflow/core/framework/node_def.pb.h" +//#include "tensorflow/core/lib/core/status_test_util.h" +//#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class TraversalTest : public ::testing::Test { + protected: + static NodeDef CreateNode(const string& name, + const std::vector& inputs) { + return CreateNode(name, "", inputs); + } + static NodeDef CreateNode(const string& name, const string& op, + const std::vector& inputs) { + NodeDef node; + node.set_name(name); + if (!op.empty()) { + node.set_op(op); + } + for (const string& input : inputs) { + node.add_input(input); + } + return node; + } +}; + +TEST_F(TraversalTest, ReverseDfsNoLoop) { + GraphDef graph; + *graph.add_node() = CreateNode("2", {"5"}); + *graph.add_node() = CreateNode("0", {"5", "4"}); + *graph.add_node() = CreateNode("1", {"4", "3"}); + *graph.add_node() = CreateNode("3", {"2"}); + *graph.add_node() = CreateNode("5", {}); + *graph.add_node() = CreateNode("4", {}); + + std::vector start_nodes = {graph.mutable_node(1), + graph.mutable_node(2)}; + std::vector pre_order; + std::vector post_order; + bool found_back_edge = false; + ReverseDfs( + GraphView(&graph), start_nodes, + [&pre_order](NodeDef* n) { pre_order.push_back(n->name()); }, + [&post_order](NodeDef* n) { post_order.push_back(n->name()); }, + [&found_back_edge](NodeDef*, NodeDef*) { found_back_edge = true; }); + + EXPECT_EQ(std::vector({"1", "4", "3", "2", "5", "0"}), pre_order); + EXPECT_EQ(std::vector({"4", "5", "2", "3", "1", "0"}), post_order); + EXPECT_FALSE(found_back_edge); +} + +TEST_F(TraversalTest, ReverseDfsWithLoop) { + GraphDef graph; + // Create a loop + *graph.add_node() = CreateNode("2", "Merge", {"1", "5"}); + *graph.add_node() = CreateNode("3", "Switch", {"2"}); + *graph.add_node() = CreateNode("4", "Identity", {"3"}); + *graph.add_node() = CreateNode("5", "NextIteration", {"4"}); + *graph.add_node() = CreateNode("1", "Enter", {}); + *graph.add_node() = CreateNode("6", "Exit", {"3"}); + + std::vector start_nodes = {graph.mutable_node(5)}; + std::vector pre_order; + std::vector post_order; + std::vector back_edges; + ReverseDfs( + GraphView(&graph), start_nodes, + [&pre_order](NodeDef* n) { pre_order.push_back(n->name()); }, + [&post_order](NodeDef* n) { post_order.push_back(n->name()); }, + [&back_edges](NodeDef* src, NodeDef* dst) { + back_edges.push_back(strings::StrCat(src->name(), "->", dst->name())); + }); + + EXPECT_EQ(std::vector({"6", "3", "2", "1", "5", "4"}), pre_order); + EXPECT_EQ(std::vector({"1", "4", "5", "2", "3", "6"}), post_order); + EXPECT_EQ(std::vector({"4->3"}), back_edges); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 77371c399e5fc7321f7c2b271aae32ce9655244b..eabce5b5ee7b037b7bc429abfa86ee8735bdbede 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -29,83 +29,84 @@ namespace { class UtilsTest : public ::testing::Test { protected: NodeDef CreateConcatOffsetNode() const { - const string gdef_ascii = R"EOF( -name: "gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/ConcatOffset" -op: "ConcatOffset" -input: "InceptionV3/Mixed_7c/Branch_1/concat_v2/axis" -input: "gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape" -input: "gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape_1" -attr { - key: "N" - value { - i: 2 - } -} - )EOF"; + const string gdef_ascii = + " name: 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/" + "ConcatOffset'" + " op: 'ConcatOffset'" + " input: 'InceptionV3/Mixed_7c/Branch_1/concat_v2/axis'" + " input: 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape'" + " input: " + " 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape_1'" + " attr {" + " key: 'N'" + " value {" + " i: 2" + " }" + " }"; NodeDef node; CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node)); return node; } NodeDef CreateDequeueNode() const { - const string gdef_ascii = R"EOF( -name: "Train/TrainInput/input_producer_Dequeue" -op: "QueueDequeueV2" -input: "Train/TrainInput/input_producer" -attr { - key: "component_types" - value { - list { - type: DT_INT32 - } - } -} -attr { - key: "timeout_ms" - value { - i: -1 - } -} - )EOF"; + const string gdef_ascii = + " name: 'Train/TrainInput/input_producer_Dequeue'" + " op: 'QueueDequeueV2'" + " input: 'Train/TrainInput/input_producer'" + " attr {" + " key: 'component_types'" + " value {" + " list {" + " type: DT_INT32" + " }" + " }" + " }" + " attr {" + " key: 'timeout_ms'" + " value {" + " i: -1" + " }" + " }"; + NodeDef node; CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node)); return node; } NodeDef CreateFusedBatchNormNode() const { - const string gdef_ascii = R"EOF( -name: "InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm" -op: "FusedBatchNorm" -input: "InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm" -input: "InceptionV3/Conv2d_1a_3x3/BatchNorm/gamma/read" -input: "InceptionV3/Conv2d_1a_3x3/BatchNorm/beta/read" -input: "InceptionV3/Conv2d_1a_3x3/BatchNorm/Const" -input: "InceptionV3/Conv2d_1a_3x3/BatchNorm/Const_1" -attr { - key: "T" - value { - type: DT_FLOAT - } -} -attr { - key: "data_format" - value { - s: "NHWC" - } -} -attr { - key: "epsilon" - value { - f: 0.001 - } -} -attr { - key: "is_training" - value { - b: true - } -} - )EOF"; + const string gdef_ascii = + " name: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm'" + " op: 'FusedBatchNorm'" + " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm'" + " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/gamma/read'" + " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/beta/read'" + " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/Const'" + " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/Const_1'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " attr {" + " key: 'data_format'" + " value {" + " s: 'NHWC'" + " }" + " }" + " attr {" + " key: 'epsilon'" + " value {" + " f: 0.001" + " }" + " }" + " attr {" + " key: 'is_training'" + " value {" + " b: true" + " }" + " }"; + NodeDef node; CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node)); return node; @@ -177,9 +178,10 @@ TEST_F(UtilsTest, ExecuteWithTimeout) { } TEST_F(UtilsTest, NumOutputs) { - EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode())); - EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode())); - EXPECT_EQ(1, NumOutputs(CreateDequeueNode())); + GraphDef graph; + EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode(), &graph)); + EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode(), &graph)); + EXPECT_EQ(1, NumOutputs(CreateDequeueNode(), &graph)); } TEST_F(UtilsTest, AsControlDependency) { @@ -249,6 +251,49 @@ TEST_F(UtilsTest, GetTailOfChain) { EXPECT_EQ("noop", tail->name()); } +TEST_F(UtilsTest, DedupControlInputs) { + NodeDef foo; + foo.set_name("foo"); + foo.add_input("bar"); + DedupControlInputs(&foo); + EXPECT_EQ(1, foo.input_size()); + EXPECT_EQ("bar", foo.input(0)); + + foo.set_input(0, "^bar"); + DedupControlInputs(&foo); + EXPECT_EQ(1, foo.input_size()); + EXPECT_EQ("^bar", foo.input(0)); + + foo.set_input(0, "bar"); + foo.add_input("bar"); + DedupControlInputs(&foo); + EXPECT_EQ(2, foo.input_size()); + EXPECT_EQ("bar", foo.input(0)); + EXPECT_EQ("bar", foo.input(1)); + + foo.set_input(1, "^bar"); + DedupControlInputs(&foo); + EXPECT_EQ(1, foo.input_size()); + EXPECT_EQ("bar", foo.input(0)); + + foo.set_input(0, "^bar"); + foo.add_input("^bar"); + DedupControlInputs(&foo); + EXPECT_EQ(1, foo.input_size()); + EXPECT_EQ("^bar", foo.input(0)); + + foo.set_input(0, "bar"); + foo.add_input("gnu"); + foo.add_input("^bar"); + foo.add_input("^gnu"); + DedupControlInputs(&foo); + EXPECT_EQ(2, foo.input_size()); + EXPECT_EQ("bar", foo.input(0)); + EXPECT_EQ("gnu", foo.input(1)); +} + +TEST_F(UtilsTest, DeleteNodes) {} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index fd99409c9b35ae0ee2a3cbd9da9067fdc6434a8f..dc93c76eaee6c3408453a74bac98f5e365364247 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -629,6 +629,7 @@ cc_library( ":transpose_op", ":unique_op", ":unpack_op", + ":unravel_index_op", ":where_op", ], ) @@ -883,6 +884,12 @@ tf_kernel_library( deps = ARRAY_DEPS + [":split_lib"], ) +tf_kernel_library( + name = "unravel_index_op", + prefix = "unravel_index_op", + deps = ARRAY_DEPS, +) + tf_kernel_library( name = "where_op", srcs = ["where_op.cc"], @@ -980,6 +987,7 @@ tf_cuda_cc_test( name = "constant_op_test", size = "small", srcs = ["constant_op_test.cc"], + tags = ["no_cuda_on_cpu_tap"], deps = [ ":constant_op", ":ops_testutil", @@ -1033,7 +1041,7 @@ tf_cc_test( tf_cc_test( name = "conv_ops_test", - size = "small", + size = "medium", srcs = ["conv_ops_test.cc"], deps = [ ":conv_ops", @@ -1935,6 +1943,17 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "functional_ops", + prefix = "functional_ops", + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + cc_library( name = "image", deps = [ @@ -2582,6 +2601,45 @@ tf_cc_tests( ], ) +cc_library( + name = "manip", + deps = [ + ":roll_op", + ], +) + +MANIP_DEPS = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:manip_ops_op_lib", + "//third_party/eigen3", +] + +tf_kernel_library( + name = "roll_op", + prefix = "roll_op", + deps = MANIP_DEPS, +) + +tf_cc_test( + name = "roll_op_test", + size = "small", + srcs = ["roll_op_test.cc"], + deps = [ + ":ops_testutil", + ":ops_util", + ":roll_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + MATH_DEPS = [ ":bounds_check", ":fill_functor", diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc index 37976f71837cb365cd9d232c7c1e102ec5bfe338..72155fd037378fc3d93c02e9b893a6671e9659a6 100644 --- a/tensorflow/core/kernels/adjust_contrast_op.cc +++ b/tensorflow/core/kernels/adjust_contrast_op.cc @@ -40,8 +40,8 @@ typedef Eigen::SyclDevice SYCLDevice; template class AdjustContrastOp : public OpKernel { public: - explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) { - } + explicit AdjustContrastOp(OpKernelConstruction* context) + : OpKernel(context) {} void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); diff --git a/tensorflow/core/kernels/adjust_contrast_op_test.cc b/tensorflow/core/kernels/adjust_contrast_op_test.cc index 0fc03b5a236b2d63fc731f232acebdcbd1ca2532..7522b320400b034aa882efb82efab8d0419d8144 100644 --- a/tensorflow/core/kernels/adjust_contrast_op_test.cc +++ b/tensorflow/core/kernels/adjust_contrast_op_test.cc @@ -29,8 +29,7 @@ limitations under the License. namespace tensorflow { -class AdjustContrastOpTest : public OpsTestBase { -}; +class AdjustContrastOpTest : public OpsTestBase {}; TEST_F(AdjustContrastOpTest, Simple_1113) { TF_EXPECT_OK(NodeDefBuilder("adjust_contrast_op", "AdjustContrastv2") diff --git a/tensorflow/core/kernels/adjust_saturation_op.cc b/tensorflow/core/kernels/adjust_saturation_op.cc index 4643d4e6efda2157458a557819873c8cb7546e1a..f0c6ae499d4c209ef1556890e87f63085de7ea75 100644 --- a/tensorflow/core/kernels/adjust_saturation_op.cc +++ b/tensorflow/core/kernels/adjust_saturation_op.cc @@ -192,8 +192,9 @@ class AdjustSaturationOp : public AdjustSaturationOpBase { const DeviceBase::CpuWorkerThreads& worker_threads = *context->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, channel_count, - kCostPerChannel, [channel_count, &input_data, &output_data, scale_h]( - int64 start_channel, int64 end_channel) { + kCostPerChannel, + [channel_count, &input_data, &output_data, scale_h]( + int64 start_channel, int64 end_channel) { const float* p = input_data.data() + start_channel * kChannelSize; float* q = output_data.data() + start_channel * kChannelSize; for (int i = start_channel; i < end_channel; i++) { diff --git a/tensorflow/core/kernels/aggregate_ops_cpu.h b/tensorflow/core/kernels/aggregate_ops_cpu.h index dfa3fe585e375ada0c5d3d0b3061d05d8a4efabd..aa1cead928aa25e9cf8d9c8d6d43091bf93583ee 100644 --- a/tensorflow/core/kernels/aggregate_ops_cpu.h +++ b/tensorflow/core/kernels/aggregate_ops_cpu.h @@ -25,7 +25,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace tensorflow { @@ -201,7 +201,7 @@ struct Add7Functor { typename TTypes::ConstFlat in6, typename TTypes::ConstFlat in7) { Add7EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, - in7); + in7); } }; @@ -214,7 +214,7 @@ struct Add8Functor { typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { Add8EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, - in7, in8); + in7, in8); } }; @@ -227,7 +227,7 @@ struct Add8pFunctor { typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { Add8pEigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, - in7, in8); + in7, in8); } }; @@ -241,10 +241,10 @@ struct Add9Functor { typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8, typename TTypes::ConstFlat in9) { Add9EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, - in7, in8, in9); + in7, in8, in9); } }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace functor diff --git a/tensorflow/core/kernels/assign_op.h b/tensorflow/core/kernels/assign_op.h index 1d2e1c8c9aaeb8a722646916ea691fa4e5c23771..a312e8e8a420f7f909b20b28f84bf55597a58aba 100644 --- a/tensorflow/core/kernels/assign_op.h +++ b/tensorflow/core/kernels/assign_op.h @@ -109,6 +109,9 @@ class AssignOp : public OpKernel { OP_REQUIRES_OK( context, context->allocate_persistent(old_lhs.dtype(), rhs.shape(), ©, ©Tensor, attr)); + // We track memory of variables in variable ops instead of in this + // assign op. + context->clear_recorded_memory(); context->replace_ref_input(0, *copyTensor, /* lock_held */ true); if (use_exclusive_lock_) { Copy(context, copyTensor, rhs); diff --git a/tensorflow/core/kernels/attention_ops.cc b/tensorflow/core/kernels/attention_ops.cc index cc8f122cab357ed0c8243ba990b3b85dd7ddcb2f..ce2fce92e4ee8cbd7bdc578d92103a5bd5da0629 100644 --- a/tensorflow/core/kernels/attention_ops.cc +++ b/tensorflow/core/kernels/attention_ops.cc @@ -52,8 +52,9 @@ class ExtractGlimpseOp : public OpKernel { const int64 batch_size = input_shape.dim_size(0); const Tensor& window_size = context->input(1); - OP_REQUIRES(context, (window_size.shape().dims() == 1) && - window_size.shape().dim_size(0) == 2, + OP_REQUIRES(context, + (window_size.shape().dims() == 1) && + window_size.shape().dim_size(0) == 2, errors::InvalidArgument( "input must be a vector of size 2 (height, width)", window_size.shape().DebugString())); diff --git a/tensorflow/core/kernels/avgpooling_op.h b/tensorflow/core/kernels/avgpooling_op.h index dea2683184a06308bff7ead2b772aab466b90b34..f5e81dbc0930888ab9258d5d5b5d52fdeb0afc01 100644 --- a/tensorflow/core/kernels/avgpooling_op.h +++ b/tensorflow/core/kernels/avgpooling_op.h @@ -48,9 +48,8 @@ struct SpatialAvgPooling { typedef Eigen::GpuDevice GPUDevice; -// Launch a custom GPU kernels from Yanqing for the avgpooling backward operation -// that works NHWC data formats. -// Arguments: +// Launch a custom GPU kernels from Yanqing for the avgpooling backward +// operation that works NHWC data formats. Arguments: // top_diff: backprop to the output of the pooling layer // num: number of input batches // height: input height diff --git a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc index 2be330d1427b28a01635cc1db5fd10096f2a8abe..6537b42f1ed8856a5f701023eb5fc55ded278ec8 100644 --- a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc @@ -71,8 +71,8 @@ __global__ void AvePoolBackwardNHWC(const int nthreads, hstart = max(hstart, 0); wstart = max(wstart, 0); int pool_size = (hend - hstart) * (wend - wstart); - gradient += - top_diff_slice[(ph * pooled_width + pw) * channels] / dtype(pool_size); + gradient += top_diff_slice[(ph * pooled_width + pw) * channels] / + dtype(pool_size); } } bottom_diff[index] = gradient; @@ -90,11 +90,11 @@ bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num, const GPUDevice& d) { int x_size = num * height * width * channels; CudaLaunchConfig config = GetCudaLaunchConfig(x_size, d); - AvePoolBackwardNHWC< - T><<>>( - config.virtual_thread_count, top_diff, num, height, width, channels, - pooled_height, pooled_width, kernel_h, kernel_w, stride_h, stride_w, - pad_t, pad_t, bottom_diff); + AvePoolBackwardNHWC + <<>>( + config.virtual_thread_count, top_diff, num, height, width, channels, + pooled_height, pooled_width, kernel_h, kernel_w, stride_h, stride_w, + pad_t, pad_t, bottom_diff); return d.ok(); } diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc index d0bbea9fe27856cc0dedb4570d285bd872741099..944564dfba62f257ae45b3c5c25d0de64fa0b773 100644 --- a/tensorflow/core/kernels/barrier_ops.cc +++ b/tensorflow/core/kernels/barrier_ops.cc @@ -111,13 +111,14 @@ class Barrier : public ResourceBase { mutex_lock lock(mu_); if (closed_) { OP_REQUIRES_ASYNC( - ctx, !cancel_pending_enqueues_ && - (num_inserted == 0 || !incomplete_.empty()), + ctx, + !cancel_pending_enqueues_ && + (num_inserted == 0 || !incomplete_.empty()), errors::Cancelled( "Barrier ", name_, " is closed. Pending enqueues cancelled: ", - cancel_pending_enqueues_, ". Number of new insertions: ", - num_inserted, ". Number of incomplete keys: ", - incomplete_.size(), "."), + cancel_pending_enqueues_, + ". Number of new insertions: ", num_inserted, + ". Number of incomplete keys: ", incomplete_.size(), "."), callback); } @@ -128,9 +129,10 @@ class Barrier : public ResourceBase { for (int i = 0; i < num_inserted; ++i) { OP_REQUIRES_OK_ASYNC( - ctx, InsertOneLocked(ctx, keys, values, element_shape, - component_index, i, &ready_tuples, - &new_elements), + ctx, + InsertOneLocked(ctx, keys, values, element_shape, + component_index, i, &ready_tuples, + &new_elements), callback); } @@ -317,8 +319,9 @@ class Barrier : public ResourceBase { return errors::Cancelled( "Barrier ", name_, " is closed, but attempted to insert a brand new key: ", - keys_vec(i), ". Pending enqueues cancelled: ", - cancel_pending_enqueues_, ". Insertion index: ", i, + keys_vec(i), + ". Pending enqueues cancelled: ", cancel_pending_enqueues_, + ". Insertion index: ", i, ". Number of incomplete keys: ", incomplete_.size(), "."); } } else { @@ -532,13 +535,14 @@ class InsertManyOp : public BarrierOpKernel { OP_REQUIRES_ASYNC( ctx, component_index_ < barrier->num_components(), errors::InvalidArgument("The component ID is out of range ", - component_index_, " > num_components", " (= ", - barrier->num_components(), ")"), + component_index_, " > num_components", + " (= ", barrier->num_components(), ")"), callback); OP_REQUIRES_OK_ASYNC( - ctx, ctx->MatchSignature({DT_STRING_REF, DT_STRING, - barrier->component_type(component_index_)}, - {}), + ctx, + ctx->MatchSignature({DT_STRING_REF, DT_STRING, + barrier->component_type(component_index_)}, + {}), callback); const Tensor* keys; diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 5b4e1a809fa4b9e3d5c5e1b877233b31826bd386..546e51be53cee1833e8e1d4a15ea9b5be8a31506 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -13,22 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/periodic_function.h" +#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/split_lib.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/macros.h" - namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -209,7 +207,7 @@ Status Split(OpKernelContext* context, const Tensor& input, class BatchResource : public ResourceBase { public: static Status Create(int32 num_batch_threads, int32 max_batch_size, - int32 batch_timeout_micros, + int32 batch_timeout_micros, int32 max_enqueued_batches, const std::vector& allowed_batch_sizes, std::unique_ptr* resource) { std::unique_ptr new_resource(new BatchResource); @@ -220,6 +218,8 @@ class BatchResource : public ResourceBase { Batcher::Create(batcher_options, &new_resource->batcher_)); new_resource->batcher_queue_options_.max_batch_size = max_batch_size; + new_resource->batcher_queue_options_.max_enqueued_batches = + max_enqueued_batches; new_resource->batcher_queue_options_.batch_timeout_micros = batch_timeout_micros; @@ -515,6 +515,8 @@ class BatchKernel : public AsyncOpKernel { OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_)); OP_REQUIRES_OK(c, c->GetAttr("batch_timeout_micros", &batch_timeout_micros_)); + OP_REQUIRES_OK(c, + c->GetAttr("max_enqueued_batches", &max_enqueued_batches_)); OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_)); OP_REQUIRES_OK(c, ValidateAllowedBatchSizes()); } @@ -526,7 +528,7 @@ class BatchKernel : public AsyncOpKernel { std::unique_ptr new_resource; TF_RETURN_IF_ERROR(BatchResource::Create( num_batch_threads_, max_batch_size_, batch_timeout_micros_, - allowed_batch_sizes_, &new_resource)); + max_enqueued_batches_, allowed_batch_sizes_, &new_resource)); *r = new_resource.release(); return Status::OK(); }; @@ -572,6 +574,7 @@ class BatchKernel : public AsyncOpKernel { int32 num_batch_threads_; int32 max_batch_size_; int32 batch_timeout_micros_; + int32 max_enqueued_batches_; std::vector allowed_batch_sizes_; }; diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index 93c391831982c529fb8e270f6eb0cac8063bffbf..43e716c542ac42835baabde057e45534d5442010 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -41,7 +41,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace { @@ -429,14 +429,13 @@ template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) { - - // Number of matrix multiplies i.e. size of the batch. - const int64 batch_size = in_x.dim_size(0); - ParallelMatMulKernelSYCL::Run(context, in_x, in_y, adj_x, adj_y, out, - 0, batch_size); + // Number of matrix multiplies i.e. size of the batch. + const int64 batch_size = in_x.dim_size(0); + ParallelMatMulKernelSYCL::Run(context, in_x, in_y, adj_x, adj_y, + out, 0, batch_size); } }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class BatchMatMul : public OpKernel { @@ -462,10 +461,10 @@ class BatchMatMul : public OpKernel { TensorShape out_shape; for (int i = 0; i < ndims - 2; ++i) { OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i), - errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(", - i, ") must be the same: ", - in0.shape().DebugString(), " vs ", - in1.shape().DebugString())); + errors::InvalidArgument( + "In[0].dim(", i, ") and In[1].dim(", i, + ") must be the same: ", in0.shape().DebugString(), " vs ", + in1.shape().DebugString())); out_shape.AddDim(in0.dim_size(i)); } auto n = (ndims == 2) ? 1 : out_shape.num_elements(); @@ -507,12 +506,12 @@ class BatchMatMul : public OpKernel { bool adj_y_; }; -#define REGISTER_BATCH_MATMUL_CPU(TYPE) \ +#define REGISTER_BATCH_MATMUL_CPU(TYPE) \ REGISTER_KERNEL_BUILDER( \ Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ BatchMatMul) -#define REGISTER_BATCH_MATMUL_GPU(TYPE) \ +#define REGISTER_BATCH_MATMUL_GPU(TYPE) \ REGISTER_KERNEL_BUILDER( \ Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint("T"), \ BatchMatMul) @@ -522,5 +521,5 @@ class BatchMatMul : public OpKernel { REGISTER_KERNEL_BUILDER( \ Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint("T"), \ BatchMatMul) -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // end namespace tensorflow diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc index 8d155ca62b297a4bf59f62159d6b62b01f777721..7e1e2aa4ec135872993f2e7738c7e863416eee87 100644 --- a/tensorflow/core/kernels/batch_matmul_op_real.cc +++ b/tensorflow/core/kernels/batch_matmul_op_real.cc @@ -35,5 +35,5 @@ TF_CALL_half(REGISTER_BATCH_MATMUL_GPU); #ifdef TENSORFLOW_USE_SYCL TF_CALL_float(REGISTER_BATCH_MATMUL_SYCL); TF_CALL_double(REGISTER_BATCH_MATMUL_SYCL); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_matmul_op_test.cc b/tensorflow/core/kernels/batch_matmul_op_test.cc index 7923f34155b57cb79894936cb4ea0f485f92d99b..c3932cd7b9023482316807c73bfd52da3a4a3f7a 100644 --- a/tensorflow/core/kernels/batch_matmul_op_test.cc +++ b/tensorflow/core/kernels/batch_matmul_op_test.cc @@ -53,9 +53,10 @@ static Graph* BatchMatmul(int b, int m, int k, int n, bool adjoint_a, /* Uncomment to enable benchmarks for double & complex types: */ // BM_BatchMatmulDev(B, M, K, N, TA, TB, std::complex, DT_COMPLEX64, // gpu); -// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \ -// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex, DT_COMPLEX128, cpu); \ -// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \ +// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \ +// BM_BatchMatmulDev(M, K, N, TA, TB, std::complex, DT_COMPLEX128, cpu); +// \ +// BM_BatchMatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \ // BM_BatchMatmulDev(M, K, N, TA, TB, std::complex, DT_COMPLEX128, gpu); // Typical fully connected layers diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc index d3ed617f713094cb94c1a87dc0c36c3d44d97918..c34ea14bf6007f6951733990c0a01999ac838b75 100644 --- a/tensorflow/core/kernels/batch_norm_op.cc +++ b/tensorflow/core/kernels/batch_norm_op.cc @@ -30,7 +30,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class BatchNormOp : public OpKernel { diff --git a/tensorflow/core/kernels/batch_norm_op_test.cc b/tensorflow/core/kernels/batch_norm_op_test.cc index 5e3fcd2114a12709fb306ebadfd21a56b514e0c0..45ddc8532955578b5fca7ea372703f88b6b84f77 100644 --- a/tensorflow/core/kernels/batch_norm_op_test.cc +++ b/tensorflow/core/kernels/batch_norm_op_test.cc @@ -54,7 +54,7 @@ TEST_F(BatchNormOpTest, Simple) { Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2})); test::FillValues( &expected, {-17.86f, -22.00f, -15.87f, -20.59f, -13.87f, -19.18f, -21.86f, - -33.31f, -23.85f, -34.72f, -25.85f, -36.13f }); + -33.31f, -23.85f, -34.72f, -25.85f, -36.13f}); test::ExpectTensorNear(expected, *GetOutput(0), 0.01); } diff --git a/tensorflow/core/kernels/batch_util.cc b/tensorflow/core/kernels/batch_util.cc index 7f2df95e2d55ac93f8a934010244dcbd1dcd28c8..1a45212ad29a7b8a578ce176db20eaf3d2193afd 100644 --- a/tensorflow/core/kernels/batch_util.cc +++ b/tensorflow/core/kernels/batch_util.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" +#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) + namespace tensorflow { namespace batch_util { @@ -61,6 +63,21 @@ Status HandleElementToSlice(Tensor element, Tensor* parent, int64 index, return Status::OK(); } +template <> +Status HandleElementToSlice(Tensor element, Tensor* parent, + int64 index, bool can_move) { + auto parent_as_matrix = parent->flat_outer_dims(); + auto element_flat = element.flat(); + if (can_move) { + for (int64 i = 0; i < element.NumElements(); ++i) { + parent_as_matrix(index, i) = std::move(element_flat(i)); + } + } else { + parent_as_matrix.chip(index, 0) = element_flat; + } + return Status::OK(); +} + // TODO(jsimsa): Add HandleElementToSlice specialization that moves // the data when possible. @@ -87,7 +104,6 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) { switch (element.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); - TF_CALL_variant(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopyElementToSlice Unhandled data type: ", @@ -107,7 +123,6 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) { switch (parent.dtype()) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); - TF_CALL_variant(HANDLE_TYPE); #undef HANDLE_TYPE default: return errors::Unimplemented("CopySliceToElement Unhandled data type: ", @@ -115,5 +130,101 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) { } } +// The following five functions are copied from padding_fifo_queue.cc. +// TODO(mrry): Reconcile these functions with the similar methods in the +// queue implementation. +Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) { + DCHECK_NE(parent->dim_size(0), 0); + if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { + TensorShape chip_shape = parent->shape(); + chip_shape.RemoveDim(0); + return errors::Internal( + "HandleElementToLargerSlice Cannot copy slice: number of entries in " + "element is greater than number of elements in parent slice. ", + "Shapes are: [element]: ", element.shape().DebugString(), + ", [parent slice]: ", chip_shape.DebugString()); + } + return Status::OK(); +} + +template +Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, + int index) { + TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent)); + if (element.NumElements() == 0) { + return Status::OK(); + } + auto element_t = element.tensor(); + auto parent_t = parent->tensor(); + Eigen::DSizes slice_indices; + slice_indices[0] = index; + Eigen::DSizes slice_size; + slice_size[0] = 1; + for (size_t i = 1; i < slice_size.size(); ++i) { + slice_size[i] = element_t.dimension(i - 1); + } + parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); + return Status::OK(); +} + +template +Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, + int index) { +#define HANDLE_TYPE(T) \ + case DataTypeToEnum::value: { \ + return HandleElementToLargerSlice(element, parent, index); \ + } + + switch (element.dtype()) { + TF_CALL_DATASET_TYPES(HANDLE_TYPE); +#undef HANDLE_TYPE + default: + return errors::Unimplemented( + "HandleElementToLargerSliceWithRank Unhandled data type: ", + element.dtype()); + } +} + +Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, + int index) { + if (parent->dims() != element.dims() + 1) { + return errors::Internal( + "Mismatched ranks. Element's rank is: ", element.dims(), + " but element is meant to be a slice in output Tensor having rank: ", + parent->dims(), " (should be: ", element.dims() + 1, ")"); + } + +#define HANDLE_DIMS(NDIMS) \ + case NDIMS: { \ + TF_RETURN_IF_ERROR( \ + HandleElementToLargerSliceWithRank(element, parent, index)); \ + return Status::OK(); \ + } + + switch (element.dims()) { + HANDLE_DIMS(0); + HANDLE_DIMS(1); + HANDLE_DIMS(2); + HANDLE_DIMS(3); + HANDLE_DIMS(4); +#undef HANDLE_DIMS + default: + return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ", + element.dims()); + } +} + +Status SetElementZero(Tensor* element, const Tensor& padding) { +#define HANDLE_TYPE(T) \ + if (element->dtype() == DataTypeToEnum::value) { \ + element->flat().setConstant(padding.scalar()()); \ + return Status::OK(); \ + } + TF_CALL_DATASET_TYPES(HANDLE_TYPE); +#undef HANDLE_TYPE + return errors::Unimplemented("SetElementZero Unhandled data type: ", + element->dtype()); +} + } // namespace batch_util } // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_util.h b/tensorflow/core/kernels/batch_util.h index 0d634ae7b07ee641eb13167d6f9fcb9ed5f0d974..a47bf1935db611417cea1d98ed8aff496efbf689 100644 --- a/tensorflow/core/kernels/batch_util.h +++ b/tensorflow/core/kernels/batch_util.h @@ -32,6 +32,16 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index); // Copies the index^th slice of parent (in the 0th dimension) into element. Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index); +// Zero-initializes the tensor `element` using the scalar stored in `padding`. +// Both `element` and `padding` must have matching `dtype`. +Status SetElementZero(Tensor* element, const Tensor& padding); + +// Copies `element` into a (0th dimension) slice of `parent`, assuming +// the shape of `element` is strictly not larger along any axis than a +// slice. +Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, + int index); + } // namespace batch_util } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/periodic_function.h b/tensorflow/core/kernels/batching_util/periodic_function.h index dbf1733dcc399522a673e5724dfeb62446f72a0f..36a4019002aa55c26fb5419c7a4d17562a367de8 100644 --- a/tensorflow/core/kernels/batching_util/periodic_function.h +++ b/tensorflow/core/kernels/batching_util/periodic_function.h @@ -114,7 +114,7 @@ class PeriodicFunction { void RunLoop(int64 start) LOCKS_EXCLUDED(mutex_); const std::function function_; // Actual client function - const int64 interval_micros_; // Interval between calls. + const int64 interval_micros_; // Interval between calls. const Options options_; // Protects state below. diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc index d73dcf0fa0e1b2b387b3ed53acd63d5c65683fd4..d5ea2b648f35efd03c04d00abc838edadd37570e 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc @@ -55,15 +55,14 @@ Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { // use the clock to be destroyed. std::unique_ptr CreateFakeClockAdvancerThread( test_util::FakeClockEnv* env, Notification* start, Notification* stop) { - return std::unique_ptr( - Env::Default()->StartThread({}, "FakeClockAdvancerThread", - [env, start, stop] { - start->WaitForNotification(); - while (!stop->HasBeenNotified()) { - env->AdvanceByMicroseconds(10); - Env::Default()->SleepForMicroseconds(10); - } - })); + return std::unique_ptr(Env::Default()->StartThread( + {}, "FakeClockAdvancerThread", [env, start, stop] { + start->WaitForNotification(); + while (!stop->HasBeenNotified()) { + env->AdvanceByMicroseconds(10); + Env::Default()->SleepForMicroseconds(10); + } + })); } TEST(SharedBatchSchedulerTest, Basic) { @@ -258,7 +257,7 @@ TEST(SharedBatchSchedulerTest, ObeysTimeout) { TEST(SharedBatchSchedulerTest, ObeysTimeoutWithRealClock) { Notification first_batch_processed, second_batch_processed; auto callback = [&first_batch_processed, &second_batch_processed]( - std::unique_ptr> batch) { + std::unique_ptr> batch) { ASSERT_TRUE(batch->IsClosed()); if (batch->size() == 1) { first_batch_processed.Notify(); @@ -301,7 +300,7 @@ TEST(SharedBatchSchedulerTest, { Notification first_batch_processed, second_batch_processed; auto callback = [&first_batch_processed, &second_batch_processed]( - std::unique_ptr> batch) { + std::unique_ptr> batch) { ASSERT_TRUE(batch->IsClosed()); if (batch->size() == 1) { first_batch_processed.Notify(); @@ -349,7 +348,7 @@ TEST(SharedBatchSchedulerTest, Fairness) { auto queue_0_callback = [&queue_0_first_batch_scheduled, &queue_0_first_batch_proceed, &queue_0_second_batch_scheduled]( - std::unique_ptr> batch) { + std::unique_ptr> batch) { if (!queue_0_first_batch_scheduled.HasBeenNotified()) { queue_0_first_batch_scheduled.Notify(); queue_0_first_batch_proceed.WaitForNotification(); @@ -467,7 +466,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) { TEST(SharedBatchSchedulerTest, OneFullQueueDoesntBlockOtherQueues) { Notification queue_0_processing, queue_0_proceed; auto queue_0_callback = [&queue_0_processing, &queue_0_proceed]( - std::unique_ptr> batch) { + std::unique_ptr> batch) { if (!queue_0_processing.HasBeenNotified()) { queue_0_processing.Notify(); queue_0_proceed.WaitForNotification(); diff --git a/tensorflow/core/kernels/batchtospace_op.cc b/tensorflow/core/kernels/batchtospace_op.cc index c1c0d6d329206088acaa009b3ffe695661527e44..b07c5fd718daea802a08650f97ccff393914e208 100644 --- a/tensorflow/core/kernels/batchtospace_op.cc +++ b/tensorflow/core/kernels/batchtospace_op.cc @@ -56,9 +56,10 @@ static void BatchToSpaceOpCompute(OpKernelContext* context, errors::InvalidArgument("input rank should be >= ", 1 + block_dims, " instead of ", orig_input_tensor.dims())); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(orig_crops.shape()) && - block_dims == orig_crops.dim_size(0) && - 2 == orig_crops.dim_size(1), + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(orig_crops.shape()) && + block_dims == orig_crops.dim_size(0) && + 2 == orig_crops.dim_size(1), errors::InvalidArgument("crops should have shape [", block_dims, ", 2] instead of ", orig_crops.shape().DebugString())); diff --git a/tensorflow/core/kernels/bcast_ops.cc b/tensorflow/core/kernels/bcast_ops.cc index 7fc4b1762d0e56271bef586f0f8db0a2a66ff87d..8e4f08e473060b50d387d53aab89c10d0a26b93a 100644 --- a/tensorflow/core/kernels/bcast_ops.cc +++ b/tensorflow/core/kernels/bcast_ops.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/util/bcast.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/bcast.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc index 42f3db1d79d4e0b0406f8c5c9abb423c03f30ab6..754b93b073a36d0925a0339956b8224878b849e1 100644 --- a/tensorflow/core/kernels/bias_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -77,14 +77,14 @@ void BiasGPU::compute(const GPUDevice& d, const T* input, const T* bias, } CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); if (data_format == FORMAT_NHWC) { - BiasNHWCKernel< - T><<>>( - config.virtual_thread_count, input, bias, output, bias_size); + BiasNHWCKernel + <<>>( + config.virtual_thread_count, input, bias, output, bias_size); } else { - BiasNCHWKernel< - T><<>>( - config.virtual_thread_count, input, bias, output, bias_size, - image_size); + BiasNCHWKernel + <<>>( + config.virtual_thread_count, input, bias, output, bias_size, + image_size); } } @@ -173,19 +173,13 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop, // Accumulate the results in the shared memory into the first element. // No syncthreads is needed since this is only in the same warp. int32 thread_index = threadIdx.x; - if (thread_index < 16) { - s_data[thread_index] += s_data[thread_index + 16]; - __syncwarp(0xFFFF); - if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8]; - __syncwarp(0xFF); - if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4]; - __syncwarp(0xF); - if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2]; - __syncwarp(0x3); + if (thread_index < 32) { + AccT data = s_data[thread_index]; + for (int32 delta = warpSize / 2; delta > 0; delta /= 2) { + data += CudaShuffleXorSync(kCudaWarpAll, data, delta); + } if (thread_index == 0) { - T val = T(s_data[0] + s_data[1]); - // The first thread writes out the accumulated result to global location. - CudaAtomicAdd(bias_backprop + bias_index, val); + CudaAtomicAdd(bias_backprop + bias_index, T(data)); } } } @@ -212,10 +206,10 @@ void BiasGradGPU::compute(const GPUDevice& d, const T* output_backprop, // Check if we have enough shared memory. if (shared_memory_size <= max_shared_memory_size) { if (data_format == FORMAT_NHWC) { - BiasGradNHWC_SharedAtomics< - T><<>>(total_count, output_backprop, bias_backprop, - bias_size); + BiasGradNHWC_SharedAtomics + <<>>(total_count, output_backprop, bias_backprop, + bias_size); } else { // Round up the block count to multiple of bias_size. int group_size = (config.block_count + bias_size - 1) / bias_size; @@ -223,23 +217,24 @@ void BiasGradGPU::compute(const GPUDevice& d, const T* output_backprop, if (config.thread_per_block < kWarpSize) { config.thread_per_block = kWarpSize; } - BiasGradNCHW_SharedAtomics< - T><<>>( - output_backprop, bias_backprop, batch, bias_size, image_size, - group_size); + BiasGradNCHW_SharedAtomics + <<>>( + output_backprop, bias_backprop, batch, bias_size, image_size, + group_size); } } else { // Note that even if we don't have enough shared memory to fit the entire // output block, it is possible to process one group of elements at a time. // But for now, we simply fall back to the naive implementation. if (data_format == FORMAT_NHWC) { - BiasGradNHWC_Naive< - T><<>>( - total_count, output_backprop, bias_backprop, bias_size); + BiasGradNHWC_Naive + <<>>( + total_count, output_backprop, bias_backprop, bias_size); } else { - BiasGradNCHW_Naive< - T><<>>( - total_count, output_backprop, bias_backprop, bias_size, image_size); + BiasGradNCHW_Naive + <<>>( + total_count, output_backprop, bias_backprop, bias_size, + image_size); } } } diff --git a/tensorflow/core/kernels/bounds_check.h b/tensorflow/core/kernels/bounds_check.h index e35f42ad4173348f63445030aef6c6de2b1de9a7..c8c60c55241ab2b1b3a426560959fed7ea893129 100644 --- a/tensorflow/core/kernels/bounds_check.h +++ b/tensorflow/core/kernels/bounds_check.h @@ -48,7 +48,7 @@ EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC const T SubtleMustCopy(const T &x) { auto *to_x = reinterpret_cast(&x); return *to_x; } -} // namespace tensorflow::internal +} // namespace internal } // namespace tensorflow #endif // TENSORFLOW_UTIL_BOUNDS_CHECK_H_ diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc index e937c4f11ba34e16d319b7b4dec317e81b6b8b2c..654d99301af5f528e4360d70edf4cadd4165382d 100644 --- a/tensorflow/core/kernels/candidate_sampler_ops.cc +++ b/tensorflow/core/kernels/candidate_sampler_ops.cc @@ -126,13 +126,13 @@ REGISTER_KERNEL_BUILDER(Name("UniformCandidateSampler").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("LogUniformCandidateSampler").Device(DEVICE_CPU), SimpleCandidateSamplerOp); -REGISTER_KERNEL_BUILDER(Name("LearnedUnigramCandidateSampler") - .Device(DEVICE_CPU), - SimpleCandidateSamplerOp); +REGISTER_KERNEL_BUILDER( + Name("LearnedUnigramCandidateSampler").Device(DEVICE_CPU), + SimpleCandidateSamplerOp); -REGISTER_KERNEL_BUILDER(Name("ThreadUnsafeUnigramCandidateSampler") - .Device(DEVICE_CPU), - SimpleCandidateSamplerOp); +REGISTER_KERNEL_BUILDER( + Name("ThreadUnsafeUnigramCandidateSampler").Device(DEVICE_CPU), + SimpleCandidateSamplerOp); class AllCandidateSamplerOp : public BaseCandidateSamplerOp { public: @@ -197,8 +197,9 @@ class ComputeAccidentalHitsOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& in_true_candidates = context->input(0); const TensorShape& in_true_candidates_shape = in_true_candidates.shape(); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(in_true_candidates_shape) && - in_true_candidates_shape.dim_size(1) == num_true_, + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(in_true_candidates_shape) && + in_true_candidates_shape.dim_size(1) == num_true_, errors::InvalidArgument( "true_candidates must be a batch_size * num_true matrix")); diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index f16abb2b79fe24bfbe2711de03c7dfd0847b3003..626db9131aee28be13391ff9c1c92bf9f2d35dd0 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -36,7 +36,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #define CURRY_TYPES2(FN, arg0) \ FN(arg0, bool); \ @@ -223,11 +223,11 @@ class SyclCastOp : public CastOpBase { } }; -#define REGISTER_CAST_SYCL(srctype, dsttype) \ - REGISTER_KERNEL_BUILDER(Name("Cast") \ - .TypeConstraint("SrcT") \ - .TypeConstraint("DstT") \ - .Device(DEVICE_SYCL), \ +#define REGISTER_CAST_SYCL(srctype, dsttype) \ + REGISTER_KERNEL_BUILDER(Name("Cast") \ + .TypeConstraint("SrcT") \ + .TypeConstraint("DstT") \ + .Device(DEVICE_SYCL), \ SyclCastOp) CURRY_TYPES2(REGISTER_CAST_SYCL, bool); CURRY_TYPES2(REGISTER_CAST_SYCL, int32); @@ -237,7 +237,7 @@ CURRY_TYPES2(REGISTER_CAST_SYCL, double); #undef REGISTER_CAST_SYCL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #undef CURRY_TYPES2 @@ -250,6 +250,5 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER( Name("_HostCast").Device(DEVICE_SYCL).HostMemory("x").HostMemory("y"), CpuCastOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // end namespace tensorflow - diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h index 8fedf2c271c2caf60a83fb1f4146dd94821c4643..fd4e75d26f02dc75e13c8781049c904587d10afd 100644 --- a/tensorflow/core/kernels/cast_op.h +++ b/tensorflow/core/kernels/cast_op.h @@ -131,7 +131,8 @@ struct scalar_cast_op<::tensorflow::bfloat16, float> { p[0] = a.value; p[1] = 0; #else - static_assert(::tensorflow::port::kLittleEndian, "Not a little endian system!"); + static_assert(::tensorflow::port::kLittleEndian, + "Not a little endian system!"); p[0] = 0; p[1] = a.value; #endif diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h index 470e9e08041e808f7459b3c654d55b82fde629a9..3ae9f2ab4d9c102941927215441b4c02625387f0 100644 --- a/tensorflow/core/kernels/cast_op_impl.h +++ b/tensorflow/core/kernels/cast_op_impl.h @@ -41,25 +41,25 @@ struct CastFunctor { o.device(d) = i.template cast(); } }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace functor -#define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ - FN(arg0, arg1, bool); \ - FN(arg0, arg1, uint8); \ - FN(arg0, arg1, int8); \ - FN(arg0, arg1, uint16); \ - FN(arg0, arg1, int16); \ - FN(arg0, arg1, int32); \ - FN(arg0, arg1, int64); \ - FN(arg0, arg1, float); \ - FN(arg0, arg1, double); \ - FN(arg0, arg1, std::complex); \ +#define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ + FN(arg0, arg1, bool); \ + FN(arg0, arg1, uint8); \ + FN(arg0, arg1, int8); \ + FN(arg0, arg1, uint16); \ + FN(arg0, arg1, int16); \ + FN(arg0, arg1, int32); \ + FN(arg0, arg1, int64); \ + FN(arg0, arg1, float); \ + FN(arg0, arg1, double); \ + FN(arg0, arg1, std::complex); \ FN(arg0, arg1, std::complex) -#define CURRY_TYPES3(FN, arg0, arg1) \ - CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ +#define CURRY_TYPES3(FN, arg0, arg1) \ + CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ FN(arg0, arg1, Eigen::half); #define CAST_CASE(DEVICE, IN, OUT) \ diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc index a106f287c1845a108f596b960b65a6392c35b071..057e209a71903ad24e2d4f757e4d2a3bc4357a76 100644 --- a/tensorflow/core/kernels/cast_op_test.cc +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -107,10 +107,10 @@ static void BM_gpu_float_int64(int iters, int num) { testing::UseRealTime(); #if GOOGLE_CUDA test::Benchmark("gpu", Cast(num)).Run(iters); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL test::Benchmark("sycl", Cast(num)).Run(iters); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } BENCHMARK(BM_gpu_float_int64)->Arg(64 << 10)->Arg(32 << 20); @@ -130,10 +130,10 @@ static void BM_gpu_bool_float(int iters, int num) { testing::UseRealTime(); #if GOOGLE_CUDA test::Benchmark("gpu", Cast(num)).Run(iters); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL test::Benchmark("sycl", Cast(num)).Run(iters); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } BENCHMARK(BM_gpu_bool_float)->Arg(64 << 10)->Arg(32 << 20); @@ -180,7 +180,7 @@ static void BM_gpu_float_half(int iters, int num) { testing::UseRealTime(); #if GOOGLE_CUDA test::Benchmark("gpu", Cast(num)).Run(iters); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA } BENCHMARK(BM_gpu_float_half)->Arg(64 << 10)->Arg(32 << 20); @@ -191,7 +191,7 @@ static void BM_gpu_half_float(int iters, int num) { testing::UseRealTime(); #if GOOGLE_CUDA test::Benchmark("gpu", Cast(num)).Run(iters); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA } BENCHMARK(BM_gpu_half_float)->Arg(64 << 10)->Arg(32 << 20); diff --git a/tensorflow/core/kernels/colorspace_op.cc b/tensorflow/core/kernels/colorspace_op.cc index ba100b32e7d8cfcd6a0138a09062910743d6d2eb..f4402a245d6c3848430126b3250731008c954df0 100644 --- a/tensorflow/core/kernels/colorspace_op.cc +++ b/tensorflow/core/kernels/colorspace_op.cc @@ -71,7 +71,7 @@ class RGBToHSVOp : public OpKernel { TensorShape({input_data.dimension(0)}), &trange)); - typename TTypes::Tensor range = trange.tensor(); + typename TTypes::Tensor range(trange.tensor()); functor::RGBToHSV()(context->eigen_device(), input_data, range, output_data); @@ -107,14 +107,14 @@ class HSVToRGBOp : public OpKernel { } }; -#define REGISTER_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - RGBToHSVOp); \ - template class RGBToHSVOp; \ - REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - HSVToRGBOp); \ +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("RGBToHSV").Device(DEVICE_CPU).TypeConstraint("T"), \ + RGBToHSVOp); \ + template class RGBToHSVOp; \ + REGISTER_KERNEL_BUILDER( \ + Name("HSVToRGB").Device(DEVICE_CPU).TypeConstraint("T"), \ + HSVToRGBOp); \ template class HSVToRGBOp; TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); @@ -123,40 +123,39 @@ TF_CALL_double(REGISTER_CPU); // Forward declarations of the function specializations for GPU (to prevent // building the GPU versions here, they will be built compiling _gpu.cu.cc). namespace functor { -#define DECLARE_GPU(T) \ - template <> \ - void RGBToHSV::operator()(const GPUDevice& d, \ - TTypes::ConstTensor input_data, \ - TTypes::Tensor range, \ - TTypes::Tensor output_data); \ - extern template struct RGBToHSV; \ - template <> \ - void HSVToRGB::operator()(const GPUDevice& d, \ - TTypes::ConstTensor input_data, \ - TTypes::Tensor output_data); \ +#define DECLARE_GPU(T) \ + template <> \ + void RGBToHSV::operator()( \ + const GPUDevice& d, TTypes::ConstTensor input_data, \ + TTypes::Tensor range, TTypes::Tensor output_data); \ + extern template struct RGBToHSV; \ + template <> \ + void HSVToRGB::operator()( \ + const GPUDevice& d, TTypes::ConstTensor input_data, \ + TTypes::Tensor output_data); \ extern template struct HSVToRGB; TF_CALL_float(DECLARE_GPU); TF_CALL_double(DECLARE_GPU); } // namespace functor -#define REGISTER_GPU(T) \ - REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_GPU) \ - .TypeConstraint("T"), \ - RGBToHSVOp); \ - REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_GPU) \ - .TypeConstraint("T"), \ - HSVToRGBOp); +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("RGBToHSV").Device(DEVICE_GPU).TypeConstraint("T"), \ + RGBToHSVOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("HSVToRGB").Device(DEVICE_GPU).TypeConstraint("T"), \ + HSVToRGBOp); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #endif #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL(T) \ - REGISTER_KERNEL_BUILDER(Name("RGBToHSV").Device(DEVICE_SYCL) \ - .TypeConstraint("T"), \ - RGBToHSVOp); \ - REGISTER_KERNEL_BUILDER(Name("HSVToRGB").Device(DEVICE_SYCL) \ - .TypeConstraint("T"), \ - HSVToRGBOp); +#define REGISTER_SYCL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("RGBToHSV").Device(DEVICE_SYCL).TypeConstraint("T"), \ + RGBToHSVOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("HSVToRGB").Device(DEVICE_SYCL).TypeConstraint("T"), \ + HSVToRGBOp); TF_CALL_float(REGISTER_SYCL); TF_CALL_double(REGISTER_SYCL); #endif diff --git a/tensorflow/core/kernels/colorspace_op.h b/tensorflow/core/kernels/colorspace_op.h index c5721ef6dd067e9df0b1c23ac471667edee06fb3..90bfce14194bb04a3ebe8418fcc4d1beaab4fc2b 100644 --- a/tensorflow/core/kernels/colorspace_op.h +++ b/tensorflow/core/kernels/colorspace_op.h @@ -54,10 +54,9 @@ struct RGBToHSV { // TODO(wicke): all these assignments are only necessary because a combined // expression is larger than kernel parameter space. A custom kernel is // probably in order. - H.device(d) = (R == V).select(norm * (G - B), - (G == V).select( - norm * (B - R) + T(2) / T(6), - norm * (R - G) + T(4) / T(6))); + H.device(d) = (R == V).select( + norm * (G - B), (G == V).select(norm * (B - R) + T(2) / T(6), + norm * (R - G) + T(4) / T(6))); H.device(d) = (range > T(0)).select(H, H.constant(T(0))); H.device(d) = (H < T(0)).select(H + T(1), H); } diff --git a/tensorflow/core/kernels/colorspace_op_gpu.cu.cc b/tensorflow/core/kernels/colorspace_op_gpu.cu.cc index e19d0b14d5df5c125c3fb071ea6ae6580fba8c6a..61f9ba44c46f1cee87a72349f8e4ebdd6d2e750f 100644 --- a/tensorflow/core/kernels/colorspace_op_gpu.cu.cc +++ b/tensorflow/core/kernels/colorspace_op_gpu.cu.cc @@ -17,8 +17,8 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/colorspace_op.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/colorspace_op.h" namespace tensorflow { @@ -29,6 +29,6 @@ typedef Eigen::GpuDevice GPUDevice; template class functor::HSVToRGB; TF_CALL_float(INSTANTIATE_GPU); TF_CALL_double(INSTANTIATE_GPU); -} +} // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/colorspace_op_test.cc b/tensorflow/core/kernels/colorspace_op_test.cc index 8c6fb732abf40c52c0a9e9a5c338de859c669838..bd82826770f192acd50ca4212a475881fe5c34fc 100644 --- a/tensorflow/core/kernels/colorspace_op_test.cc +++ b/tensorflow/core/kernels/colorspace_op_test.cc @@ -224,34 +224,34 @@ class HSVToRGBOpTest : public OpsTestBase { } }; -#define TEST_COLORSPACE(test, dt) \ - TEST_F(test, CheckBlack) { \ - MakeOp(dt); \ - CheckBlack(dt); \ - } \ - TEST_F(test, CheckGray) { \ - MakeOp(dt); \ - CheckGray(dt); \ - } \ - TEST_F(test, CheckWhite) { \ - MakeOp(dt); \ - CheckWhite(dt); \ - } \ - TEST_F(test, CheckRedMax) { \ - MakeOp(dt); \ - CheckRedMax(dt); \ - } \ - TEST_F(test, CheckGreenMax) { \ - MakeOp(dt); \ - CheckGreenMax(dt); \ - } \ - TEST_F(test, CheckBlueMax) { \ - MakeOp(dt); \ - CheckBlueMax(dt); \ - } \ - TEST_F(test, CheckNegativeDifference) { \ - MakeOp(dt); \ - CheckNegativeDifference(dt); \ +#define TEST_COLORSPACE(test, dt) \ + TEST_F(test, CheckBlack) { \ + MakeOp(dt); \ + CheckBlack(dt); \ + } \ + TEST_F(test, CheckGray) { \ + MakeOp(dt); \ + CheckGray(dt); \ + } \ + TEST_F(test, CheckWhite) { \ + MakeOp(dt); \ + CheckWhite(dt); \ + } \ + TEST_F(test, CheckRedMax) { \ + MakeOp(dt); \ + CheckRedMax(dt); \ + } \ + TEST_F(test, CheckGreenMax) { \ + MakeOp(dt); \ + CheckGreenMax(dt); \ + } \ + TEST_F(test, CheckBlueMax) { \ + MakeOp(dt); \ + CheckBlueMax(dt); \ + } \ + TEST_F(test, CheckNegativeDifference) { \ + MakeOp(dt); \ + CheckNegativeDifference(dt); \ } typedef RGBToHSVOpTest rgb_to_hsv_float; diff --git a/tensorflow/core/kernels/compare_and_bitpack_op.cc b/tensorflow/core/kernels/compare_and_bitpack_op.cc index 9f626a274a4d36b568cc6e25af2e572a35ae3694..224fe534e3392f29e4fab2caa640883d055cb341 100644 --- a/tensorflow/core/kernels/compare_and_bitpack_op.cc +++ b/tensorflow/core/kernels/compare_and_bitpack_op.cc @@ -110,7 +110,19 @@ struct ComputeShard::ConstMatrix input, typename TTypes::Matrix output, bool /*thresh*/, int64 start, int64 limit) { - // NOTE(ebrevdo): This assumes memory is little-endian. +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + for (int64 i = start; i < limit; ++i) { + uint8* out = output.data() + i; + const int64 block = *reinterpret_cast(input.data() + 8 * i); + *out = ((((block & (1LL << (7 * 8))) >> (7 * 8 - 7))) | + (((block & (1LL << (6 * 8))) >> (6 * 8 - 6))) | + (((block & (1LL << (5 * 8))) >> (5 * 8 - 5))) | + (((block & (1LL << (4 * 8))) >> (4 * 8 - 4))) | + (((block & (1LL << (3 * 8))) >> (3 * 8 - 3))) | + (((block & (1LL << (2 * 8))) >> (2 * 8 - 2))) | + (((block & (1LL << 8)) >> (1 * 8 - 1))) | (((block & (1LL))))); + } +#else for (int64 i = start; i < limit; ++i) { uint8* out = output.data() + i; const int64 block = *reinterpret_cast(input.data() + 8 * i); @@ -123,6 +135,7 @@ struct ComputeShard> (2 * 8 - 5))) | (((block & (1LL << 8)) >> (1 * 8 - 6))) | (((block & (1LL)) << 7))); } +#endif } }; diff --git a/tensorflow/core/kernels/concat_lib.h b/tensorflow/core/kernels/concat_lib.h index 526f9420d72fa25ff21bf60b3594649fa1faa0ba..16784c4770eb8626c11dc47104fea3af6c5edc07 100644 --- a/tensorflow/core/kernels/concat_lib.h +++ b/tensorflow/core/kernels/concat_lib.h @@ -41,10 +41,11 @@ namespace tensorflow { // Assumes all inputs are nonempty template -void ConcatCPU(DeviceBase* d, - const std::vector< - std::unique_ptr::ConstMatrix>>& inputs, - typename TTypes::Matrix* output); +void ConcatCPU( + DeviceBase* d, + const std::vector::ConstMatrix>>& + inputs, + typename TTypes::Matrix* output); #if GOOGLE_CUDA template void ConcatGPU( @@ -57,11 +58,12 @@ void ConcatGPU( #ifdef TENSORFLOW_USE_SYCL template -void ConcatSYCL(const Eigen::SyclDevice& d, - const std::vector< - std::unique_ptr::ConstMatrix>>& inputs, - typename TTypes::Matrix* output); -#endif // TENSORFLOW_USE_SYCL +void ConcatSYCL( + const Eigen::SyclDevice& d, + const std::vector::ConstMatrix>>& + inputs, + typename TTypes::Matrix* output); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONCAT_LIB_H_ diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc index 43731114c0b9a87598da19466c0fd9c7e05644bb..547a7b40b9245d4b10c12830a0189b09c9dacc76 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.cc +++ b/tensorflow/core/kernels/concat_lib_cpu.cc @@ -48,10 +48,11 @@ struct MemCpyCopier { } // namespace template -void ConcatCPU(DeviceBase* d, - const std::vector< - std::unique_ptr::ConstMatrix>>& inputs, - typename TTypes::Matrix* output) { +void ConcatCPU( + DeviceBase* d, + const std::vector::ConstMatrix>>& + inputs, + typename TTypes::Matrix* output) { if (std::is_same::value) { // use a large cost here to force strings to be handled by separate threads ConcatCPUImpl(d, inputs, 100000, MemCpyCopier(), output); @@ -72,7 +73,6 @@ REGISTER(qint8) REGISTER(quint16) REGISTER(qint16) REGISTER(qint32) -TF_CALL_variant(REGISTER) #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \ !defined(__ANDROID_TYPES_FULL__) @@ -86,21 +86,22 @@ TF_CALL_variant(REGISTER) #ifdef TENSORFLOW_USE_SYCL template -void ConcatSYCL(const Eigen::SyclDevice& d, - const std::vector< - std::unique_ptr::ConstMatrix>>& inputs, - typename TTypes::Matrix* output) { +void ConcatSYCL( + const Eigen::SyclDevice& d, + const std::vector::ConstMatrix>>& + inputs, + typename TTypes::Matrix* output) { ConcatSYCLImpl(d, inputs, sizeof(T) /* cost_per_unit */, MemCpyCopier(), - output); + output); } -#define REGISTER_SYCL(T) \ - template void ConcatSYCL( \ - const Eigen::SyclDevice&, \ - const std::vector::ConstMatrix>>&, \ - typename TTypes::Matrix* output); +#define REGISTER_SYCL(T) \ + template void ConcatSYCL( \ + const Eigen::SyclDevice&, \ + const std::vector::ConstMatrix>>&, \ + typename TTypes::Matrix* output); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL) #undef REGISTER_SYCL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/concat_lib_cpu.h b/tensorflow/core/kernels/concat_lib_cpu.h index 6a933efde4b6ababf35c83c94d233e4aa2552d84..720b5065377b49859fdecc2634d14fe308432fe3 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.h +++ b/tensorflow/core/kernels/concat_lib_cpu.h @@ -15,9 +15,9 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow/core/kernels/concat_lib.h" #include #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { @@ -73,7 +73,7 @@ void ConcatCPUImpl( // Sharded mode. auto work = [&row_size, &sizes, &inputs, &output, &copier, &num_inputs]( - int64 start, int64 end) { + int64 start, int64 end) { int64 skipped_rows = start / row_size; T* out = output->data() + skipped_rows * row_size; T* out_start = output->data() + start; @@ -160,5 +160,5 @@ void ConcatSYCLImpl( } } } -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index ae1b5da32ea12d94a01ae67563f03dda42d6ead4..7011550f7e161c9727b8d31eff0917964b09044e 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -37,7 +37,7 @@ typedef Eigen::GpuDevice GPUDevice; #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM }; @@ -71,8 +71,9 @@ class ConcatBaseOp : public OpKernel { const TensorShape& input_shape = values[0].shape(); int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; - OP_REQUIRES(c, (0 <= axis && axis < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), + OP_REQUIRES(c, + (0 <= axis && axis < input_dims) || + (allow_legacy_scalars() && concat_dim == 0), errors::InvalidArgument( "ConcatOp : Expected concatenating dimensions in the range " "[", @@ -97,8 +98,8 @@ class ConcatBaseOp : public OpKernel { c, in.dims() == input_dims || (input_is_scalar && in_is_scalar), errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, "] = ", - in.shape().DebugString())); + input_shape.DebugString(), " vs. shape[", i, + "] = ", in.shape().DebugString())); for (int j = 0; j < input_dims; ++j) { if (j == axis) { continue; @@ -107,8 +108,8 @@ class ConcatBaseOp : public OpKernel { c, in.dim_size(j) == input_shape.dim_size(j), errors::InvalidArgument( "ConcatOp : Dimensions of inputs should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, "] = ", - in.shape().DebugString())); + input_shape.DebugString(), " vs. shape[", i, + "] = ", in.shape().DebugString())); } if (in.NumElements() > 0) { int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; @@ -142,7 +143,7 @@ class ConcatBaseOp : public OpKernel { ConcatSYCL(c->eigen_sycl_device(), inputs_flat, &output_flat); return; } -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL ConcatCPU(c->device(), inputs_flat, &output_flat); } } @@ -252,7 +253,7 @@ REGISTER_KERNEL_BUILDER(Name("ConcatV2") ConcatV2Op); #undef REGISTER_SYCL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL class ConcatOffsetOp : public OpKernel { public: @@ -347,5 +348,5 @@ REGISTER_KERNEL_BUILDER(Name("ConcatOffset") .HostMemory("shape") .HostMemory("offset"), ConcatOffsetOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/concat_op_test.cc b/tensorflow/core/kernels/concat_op_test.cc index c5bded9dafcdaf7264649e00db2ea2766db8eea9..e3ba8ae9f691c8ec9be79952d7f97801552b2a56 100644 --- a/tensorflow/core/kernels/concat_op_test.cc +++ b/tensorflow/core/kernels/concat_op_test.cc @@ -157,7 +157,8 @@ BENCHMARK(BM_MemcpyAlternativeDim0)->Arg(1000)->Arg(100000)->Arg(1000000); BENCHMARK(BM_MemcpyAlternativeDim1)->Arg(1000)->Arg(100000)->Arg(1000000); typedef Eigen::TensorMap, - Eigen::Unaligned> EigenMap; + Eigen::Unaligned> + EigenMap; static void MemcpyManyAlternative1(int iters, int dim2) { testing::StopTiming(); diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h index 794ac6fa6de1eb06fcfa614bbfa472814d630d99..c7c7c983691c6f5257622940d183d06304ee74f1 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.h +++ b/tensorflow/core/kernels/conditional_accumulator_base.h @@ -160,7 +160,7 @@ class ConditionalAccumulatorBase : public ResourceBase { * Modifications to convenience macros defined in core/framework/op_kernel.h. * The below macros return a boolean if the test fails, so that the calling * function can get an indication that a failure has occurred. -*/ + */ #define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS) \ do { \ if (!TF_PREDICT_TRUE(EXP)) { \ diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc index fa37916eaba4106fe8067b739e77e7f91631b1e9..e13bf8a4c63ebe86fbf3fcf2fdd50f928298d01b 100644 --- a/tensorflow/core/kernels/conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/conditional_accumulator_op.cc @@ -99,9 +99,10 @@ class AccumulatorTakeGradientOp ConditionalAccumulatorBase* accumulator, DoneCallback callback) override { // Check signature - OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32}, - {accumulator->dtype()}), - callback); + OP_REQUIRES_OK_ASYNC( + ctx, + ctx->MatchSignature({DT_STRING_REF, DT_INT32}, {accumulator->dtype()}), + callback); } private: @@ -111,5 +112,4 @@ class AccumulatorTakeGradientOp REGISTER_KERNEL_BUILDER(Name("AccumulatorTakeGradient").Device(DEVICE_CPU), AccumulatorTakeGradientOp); - } // namespace tensorflow diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 59f9f69315e1a1b8740ee787fa93df686dfa01d8..fdb03a5aae8bd9dbe180e9b722fc5847c740801b 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/kernels/constant_op.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -41,8 +42,33 @@ limitations under the License. namespace tensorflow { +namespace { + +std::unique_ptr StripTensorDataFromNodeDef( + OpKernelConstruction* ctx) { +#ifndef __ANDROID__ + DCHECK_EQ(NodeDef::descriptor()->field_count(), 5) + << "The NodeDef format has changed, and the attr-stripping code may need " + << "to be updated."; +#endif + const NodeDef& original = ctx->def(); + NodeDef* ret = new NodeDef; + ret->set_name(original.name()); + ret->set_op(original.op()); + ret->set_device(original.device()); + // Strip the "value" attr from the returned NodeDef. + // NOTE(mrry): The present implementation of `OpKernel::OpKernel()` only uses + // attrs that affect the cardinality of list-typed inputs and outputs, so it + // is safe to drop other attrs from the NodeDef. + AddNodeAttr("dtype", ctx->output_type(0), ret); + return std::unique_ptr(ret); +} + +} // namespace + ConstantOp::ConstantOp(OpKernelConstruction* ctx) - : OpKernel(ctx), tensor_(ctx->output_type(0)) { + : OpKernel(ctx, StripTensorDataFromNodeDef(ctx)), + tensor_(ctx->output_type(0)) { const TensorProto* proto = nullptr; OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto( @@ -76,6 +102,7 @@ REGISTER_KERNEL(GPU, float); REGISTER_KERNEL(GPU, double); REGISTER_KERNEL(GPU, uint8); REGISTER_KERNEL(GPU, int8); +REGISTER_KERNEL(GPU, qint8); REGISTER_KERNEL(GPU, uint16); REGISTER_KERNEL(GPU, int16); REGISTER_KERNEL(GPU, int64); @@ -146,7 +173,6 @@ typedef Eigen::GpuDevice GPUDevice; typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL - template class FillOp : public OpKernel { public: diff --git a/tensorflow/core/kernels/constant_op_test.cc b/tensorflow/core/kernels/constant_op_test.cc index 7a05d9371d8c19e2cfe943f7a44e458c8baf634a..a6baae73d876d511f1e8d81792fe4cecea160bfd 100644 --- a/tensorflow/core/kernels/constant_op_test.cc +++ b/tensorflow/core/kernels/constant_op_test.cc @@ -77,7 +77,7 @@ void ConstantOpTest::PersistentMemoryTrackingTest(bool on_gpu) { EXPECT_EQ(ctx.persistent_memory_allocated(), 480); } - // Remove memry leak errors. + // Remove memory leak errors. for (auto allocator_pair : ctx.wrapped_allocators()) { allocator_pair.second->GetRecordsAndUnRef(); } diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 8fe82d118a702ec6809d6f4f4385fa3dc0949037..7d5d54e5bece7d448e7c11c6061109e9e8554008 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -113,47 +113,47 @@ REGISTER_GPU_HOST_REF_KERNEL(string); #undef REGISTER_GPU_HOST_REF_KERNEL #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_SWITCH(type) \ - REGISTER_KERNEL_BUILDER(Name("Switch") \ - .Device(DEVICE_SYCL) \ - .HostMemory("pred") \ - .TypeConstraint("T"),\ +#define REGISTER_SYCL_SWITCH(type) \ + REGISTER_KERNEL_BUILDER(Name("Switch") \ + .Device(DEVICE_SYCL) \ + .HostMemory("pred") \ + .TypeConstraint("T"), \ SwitchOp) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH); -#define REGISTER_SYCL_REF_SWITCH(type) \ - REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ - .Device(DEVICE_SYCL) \ - .HostMemory("pred") \ - .TypeConstraint("T"), \ +#define REGISTER_SYCL_REF_SWITCH(type) \ + REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ + .Device(DEVICE_SYCL) \ + .HostMemory("pred") \ + .TypeConstraint("T"), \ SwitchOp) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH); #undef REGISTER_SYCL_SWITCH #undef REGISTER_SYCL_REF_SWITCH -#define REGISTER_SYCL_HOST_KERNEL(type) \ - REGISTER_KERNEL_BUILDER(Name("Switch") \ - .Device(DEVICE_SYCL) \ - .HostMemory("data") \ - .HostMemory("pred") \ - .HostMemory("output_false")\ - .HostMemory("output_true") \ - .TypeConstraint("T"),\ +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Switch") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("pred") \ + .HostMemory("output_false") \ + .HostMemory("output_true") \ + .TypeConstraint("T"), \ SwitchOp) REGISTER_SYCL_HOST_KERNEL(bool); REGISTER_SYCL_HOST_KERNEL(string); REGISTER_SYCL_HOST_KERNEL(int32); -#define REGISTER_SYCL_HOST_REF_KERNEL(type) \ - REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ - .Device(DEVICE_SYCL) \ - .HostMemory("data") \ - .HostMemory("pred") \ - .HostMemory("output_false") \ - .HostMemory("output_true") \ - .TypeConstraint("T"), \ +#define REGISTER_SYCL_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("pred") \ + .HostMemory("output_false") \ + .HostMemory("output_true") \ + .TypeConstraint("T"), \ SwitchOp) REGISTER_SYCL_HOST_REF_KERNEL(int32); @@ -162,7 +162,7 @@ REGISTER_SYCL_HOST_REF_KERNEL(string); #undef REGISTER_SYCL_HOST_KERNEL #undef REGISTER_SYCL_HOST_REF_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL class RefSelectOp : public OpKernel { public: @@ -282,7 +282,7 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); #undef REGISTER_SYCL_KERNEL #undef REGISTER_SYCL_REF_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -331,7 +331,7 @@ REGISTER_SYCL_HOST_KERNEL(string); REGISTER_SYCL_HOST_KERNEL(ResourceHandle); #undef REGISTER_SYCL_HOST_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL void EnterOp::Compute(OpKernelContext* context) { if (IsRefType(context->input_dtype(0))) { @@ -360,14 +360,14 @@ REGISTER_GPU_REF_KERNEL(bool); #undef REGISTER_GPU_REF_KERNEL #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ Name("Enter").Device(DEVICE_SYCL).TypeConstraint("T"), EnterOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); -#define REGISTER_SYCL_REF_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint("T"), EnterOp) REGISTER_SYCL_REF_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); @@ -398,7 +398,7 @@ REGISTER_SYCL_HOST_KERNEL(ResourceHandle); #undef REGISTER_SYCL_HOST_KERNEL #undef REGISTER_SYCL_HOST_REF_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -455,10 +455,10 @@ REGISTER_GPU_REF_KERNEL(bool); #undef REGISTER_GPU_REF_KERNEL #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Exit").Device(DEVICE_SYCL).TypeConstraint("T"), ExitOp); \ - REGISTER_KERNEL_BUILDER( \ +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Exit").Device(DEVICE_SYCL).TypeConstraint("T"), ExitOp); \ + REGISTER_KERNEL_BUILDER( \ Name("RefExit").Device(DEVICE_SYCL).TypeConstraint("T"), ExitOp); REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); @@ -483,7 +483,7 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); REGISTER_SYCL_HOST_KERNEL(int32); REGISTER_SYCL_HOST_KERNEL(string); #undef REGISTER_SYCL_HOST_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Special GPU kernels for int32 and string. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -556,12 +556,12 @@ REGISTER_GPU_HOST_KERNEL(string); #undef REGISTER_GPU_HOST_KERNEL #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint("T"), \ - NextIterationOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint("T"),\ +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint("T"), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint("T"), \ NextIterationOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); @@ -585,7 +585,7 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); REGISTER_SYCL_HOST_KERNEL(int32); REGISTER_SYCL_HOST_KERNEL(string); #undef REGISTER_SYCL_HOST_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // A LoopCond op has one input and one output. The input is a boolean // scalar representing the taken branches of the "pivot" Switch that @@ -619,7 +619,7 @@ REGISTER_KERNEL_BUILDER(Name("LoopCond") .HostMemory("input") .HostMemory("output"), LoopCondOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // ControlTrigger kernels REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU), @@ -631,7 +631,7 @@ REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU), #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_SYCL), ControlTriggerOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // When called, abort op will abort the current process. This can be used to // abort remote PSs when needed. diff --git a/tensorflow/core/kernels/control_flow_ops_test.cc b/tensorflow/core/kernels/control_flow_ops_test.cc index affa0e8ca6b9d053702f8b203321d6ee2954878e..a2f7bd406929ec516d67dfc76767532cf2bac28c 100644 --- a/tensorflow/core/kernels/control_flow_ops_test.cc +++ b/tensorflow/core/kernels/control_flow_ops_test.cc @@ -91,6 +91,7 @@ class KilledBySignal { public: explicit KilledBySignal(int signum) : signum_(signum) {} bool operator()(int exit_status) const { return exit_status == signum_; } + private: const int signum_; }; diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index 2142207b0d89a4b2f02c7f7b5d320c3b4b48462c..6949e5b5fd85f399473095f26314e9d58fa65464 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -54,10 +54,12 @@ struct InflatePadAndShuffle { template void SpatialConvolutionFunc(const Device& d, Output output, Input input, Filter filter, int row_stride, int col_stride, + int row_dilation, int col_dilation, const Eigen::PaddingType& padding) { // Need to swap row/col when calling Eigen. output.device(d) = - Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding); + Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding, + col_dilation, row_dilation); } template @@ -65,9 +67,10 @@ struct SpatialConvolution { void operator()(const Device& d, typename TTypes::Tensor output, typename TTypes::ConstTensor input, typename TTypes::ConstTensor filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding) { + int col_stride, int row_dilation, int col_dilation, + const Eigen::PaddingType& padding) { SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride, - padding); + row_dilation, col_dilation, padding); } }; @@ -77,11 +80,12 @@ struct SpatialConvolution { typename TTypes::Tensor output, typename TTypes::ConstTensor input, typename TTypes::ConstTensor filter, - int row_stride, int col_stride, - const Eigen::PaddingType& padding) { + int row_stride, int col_stride, int row_dilation, + int col_dilation, const Eigen::PaddingType& padding) { output.device(d) = Eigen::SpatialConvolution(input.cast(), filter.cast(), - col_stride, row_stride, padding) + col_stride, row_stride, padding, col_dilation, + row_dilation) .cast(); } }; @@ -91,11 +95,13 @@ struct SpatialConvolutionBackwardInput { void operator()(const Device& d, typename TTypes::Tensor input_backward, typename TTypes::ConstTensor kernel, typename TTypes::ConstTensor output_backward, - int row_stride, int col_stride) { + int row_stride, int col_stride, int row_dilation, + int col_dilation) { // Need to swap row/col when calling Eigen. input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput( kernel, output_backward, input_backward.dimension(2), - input_backward.dimension(1), col_stride, row_stride); + input_backward.dimension(1), col_stride, row_stride, col_dilation, + row_dilation); } }; @@ -105,11 +111,13 @@ struct SpatialConvolutionBackwardFilter { typename TTypes::Tensor kernel_backward, typename TTypes::ConstTensor input, typename TTypes::ConstTensor output_backward, - int row_stride, int col_stride) { + int row_stride, int col_stride, int row_dilation, + int col_dilation) { // Need to swap row/col when calling Eigen. kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel( input, output_backward, kernel_backward.dimension(1), - kernel_backward.dimension(0), col_stride, row_stride); + kernel_backward.dimension(0), col_stride, row_stride, col_dilation, + row_dilation); } }; diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 512bcc6c01bf3eb4aed92f90eebb060abda8a7fc..b8a5ae6a08e5c22fb5d69112b216b3c342b1bb1a 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -101,7 +101,8 @@ struct LaunchConv2DBackpropFilterOp { const CPUDevice& d = ctx->eigen_device(); functor::SpatialConvolutionBackwardFilter()( d, filter_backprop->tensor(), input.tensor(), - out_backprop.tensor(), row_stride, col_stride); + out_backprop.tensor(), row_stride, col_stride, + /*row_dilation=*/1, /*col_dilation=*/1); } }; diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 0356ff4c0f4240ec806d1e337546cfce6771d92f..b87c7899c00ab79c60bdbd85ce28399d103d271d 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -106,7 +106,8 @@ struct LaunchConv2DBackpropInputOp { const CPUDevice& d = ctx->eigen_device(); functor::SpatialConvolutionBackwardInput()( d, in_backprop->tensor(), filter.tensor(), - out_backprop.tensor(), row_stride, col_stride); + out_backprop.tensor(), row_stride, col_stride, + /*row_dilation=*/1, /*col_dilation=*/1); } }; diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 985586d6262b18e89b5fc5246cc00b10ba4924a7..2b81e14f95b1b3f4c04e02d50180a9adda9e51e0 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -60,8 +60,8 @@ template struct LaunchGeneric { void operator()(OpKernelContext* ctx, const Tensor& input, const Tensor& filter, int row_stride, int col_stride, - const Padding& padding, Tensor* output, - TensorFormat data_format) { + int row_dilation, int col_dilation, const Padding& padding, + Tensor* output, TensorFormat data_format) { CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only " "supports NHWC tensor format for now."; if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && @@ -86,7 +86,8 @@ struct LaunchGeneric { filter.shaped({filter.dim_size(2), filter.dim_size(3)}), dim_pair); } else if (filter.dim_size(0) == input.dim_size(1) && - filter.dim_size(1) == input.dim_size(2) && padding == VALID) { + filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 && + col_dilation == 1 && padding == VALID) { // If the input data and filter have the same height/width, // the 2D convolution is reduced to matrix multiplication. const int k = // Length of reduction dimension. @@ -103,7 +104,7 @@ struct LaunchGeneric { functor::SpatialConvolution()( ctx->eigen_device(), output->tensor(), input.tensor(), filter.tensor(), row_stride, col_stride, - BrainPadding2EigenPadding(padding)); + row_dilation, col_dilation, BrainPadding2EigenPadding(padding)); } } }; @@ -122,15 +123,9 @@ struct LaunchConv2DOp { "NHWC tensor format for now.")); return; } - // TODO(yangzihao): Add the CPU implementation of dilated conv 2D. - if (row_dilation > 1 || col_dilation > 1) { - ctx->SetStatus( - errors::Unimplemented("Generic conv implementation only supports " - "dilated rate of 1 for now.")); - return; - } LaunchGeneric()(ctx, input, filter, row_stride, col_stride, - padding, output, data_format); + row_dilation, col_dilation, padding, output, + data_format); } }; @@ -688,7 +683,7 @@ void LaunchConv2DOp::operator()( static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( // default value is in bytes despite the name of the environment variable "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB - ); + ); int device_id = stream->parent()->device_ordinal(); DataType dtype = input.dtype(); @@ -792,7 +787,8 @@ namespace functor { const GPUDevice& d, typename TTypes::Tensor output, \ typename TTypes::ConstTensor input, \ typename TTypes::ConstTensor filter, int row_stride, \ - int col_stride, const Eigen::PaddingType& padding); \ + int col_stride, int row_dilation, int col_dilation, \ + const Eigen::PaddingType& padding); \ extern template struct SpatialConvolution; \ template <> \ void MatMulConvFunctor::operator()( \ diff --git a/tensorflow/core/kernels/conv_ops_fused.cc b/tensorflow/core/kernels/conv_ops_fused.cc index 291ebf2298762d25e2d44aa5b82ffd495ea92c0e..1b40ad81f413a726d14c5496f669923ab9254dce 100644 --- a/tensorflow/core/kernels/conv_ops_fused.cc +++ b/tensorflow/core/kernels/conv_ops_fused.cc @@ -679,8 +679,9 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { const int dims = resized_shape.dims(); OP_REQUIRES( - context, TensorShapeUtils::IsMatrix(paddings.shape()) && - paddings.dim_size(1) == 2, + context, + TensorShapeUtils::IsMatrix(paddings.shape()) && + paddings.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", paddings.shape().DebugString())); const int fixed_dims = @@ -715,20 +716,22 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { const int32 after = paddings_matrix(d, 1); // Pad after existing elements. OP_REQUIRES(context, before >= 0 && after >= 0, - errors::InvalidArgument("paddings must be non-negative: ", - before, " ", after)); + errors::InvalidArgument( + "paddings must be non-negative: ", before, " ", after)); if (offset_ == 0) { // SYMMETRIC mode. OP_REQUIRES( - context, before <= resized_shape.dim_size(d) && - after <= resized_shape.dim_size(d), + context, + before <= resized_shape.dim_size(d) && + after <= resized_shape.dim_size(d), errors::InvalidArgument("paddings must be no greater " "than the dimension size: ", before, ", ", after, " greater than ", resized_shape.dim_size(d))); } else if (offset_ == 1) { // REFLECT mode. OP_REQUIRES( - context, before < resized_shape.dim_size(d) && - after < resized_shape.dim_size(d), + context, + before < resized_shape.dim_size(d) && + after < resized_shape.dim_size(d), errors::InvalidArgument("paddings must be less than" " the dimension size: ", before, ", ", after, " not less than ", @@ -767,18 +770,19 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { // We only check the first three dims, since the depth is accessed as an // int64 below. for (int i = 0; i < 3; i++) { - OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i), - std::numeric_limits::max()), - errors::InvalidArgument("filter too large")); + OP_REQUIRES( + context, + FastBoundsCheck(filter.dim_size(i), std::numeric_limits::max()), + errors::InvalidArgument("filter too large")); } // The last dimension for input is in_depth. It must be the same as the // filter's in_depth. const int64 in_depth = padded_shape.dim_size(3); - OP_REQUIRES( - context, in_depth == filter.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - in_depth, " vs ", filter.dim_size(2))); + OP_REQUIRES(context, in_depth == filter.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, + " vs ", filter.dim_size(2))); // The last dimension for filter is out_depth. const int out_depth = static_cast(filter.dim_size(3)); @@ -786,9 +790,10 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { // The second dimension for input is rows/height. // The first dimension for filter is rows/height. const int64 padded_rows_raw = padded_shape.dim_size(1); - OP_REQUIRES(context, FastBoundsCheck(padded_rows_raw, - std::numeric_limits::max()), - errors::InvalidArgument("Input rows too large")); + OP_REQUIRES( + context, + FastBoundsCheck(padded_rows_raw, std::numeric_limits::max()), + errors::InvalidArgument("Input rows too large")); const int padded_rows = static_cast(padded_rows_raw); const int filter_rows = static_cast(filter.dim_size(0)); const int resized_rows = static_cast(resized_shape.dim_size(1)); @@ -796,9 +801,10 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { // The third dimension for input is columns/width. // The second dimension for filter is columns/width. const int64 padded_cols_raw = padded_shape.dim_size(2); - OP_REQUIRES(context, FastBoundsCheck(padded_cols_raw, - std::numeric_limits::max()), - errors::InvalidArgument("Input cols too large")); + OP_REQUIRES( + context, + FastBoundsCheck(padded_cols_raw, std::numeric_limits::max()), + errors::InvalidArgument("Input cols too large")); const int padded_cols = static_cast(padded_cols_raw); const int filter_cols = static_cast(filter.dim_size(1)); const int resized_cols = static_cast(resized_shape.dim_size(2)); @@ -864,24 +870,26 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp); }; -#define REGISTER_FUSED(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("FusedResizeAndPadConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - FusedResizeConv2DUsingGemmOp< \ - T, FusedResizeAndPadConvFunctor, \ - BILINEAR>, \ +#define REGISTER_FUSED(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("FusedResizeAndPadConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + FusedResizeConv2DUsingGemmOp< \ + T, \ + FusedResizeAndPadConvFunctor, \ + BILINEAR>, \ true>); TF_CALL_float(REGISTER_FUSED); -#define REGISTER_PAD_ONLY_FUSED(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint("T"), \ - FusedResizeConv2DUsingGemmOp< \ - T, FusedResizeAndPadConvFunctor, \ - NEAREST>, \ +#define REGISTER_PAD_ONLY_FUSED(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint("T"), \ + FusedResizeConv2DUsingGemmOp< \ + T, \ + FusedResizeAndPadConvFunctor, \ + NEAREST>, \ false>); TF_CALL_float(REGISTER_PAD_ONLY_FUSED); diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 57e196c67cf067bc716d8253f05fc759eaeeba8d..f0085be3a53b71af85d4c5f4bbcc6b07cd982ca8 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -27,7 +27,6 @@ limitations under the License. namespace tensorflow { - // Get the Cudnn workspace limit from the environment variable, which is in MB. // Return the workspace memory limit in bytes. If no value is set, return the // default value. diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index af6013c9747a717b95138c960abcdcc96f4dac73..a376534badc73065e3ec01972dde85da7bbdb0f8 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -25,9 +25,9 @@ limitations under the License. #include "cuda/include/cuda.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/util/cuda_kernel_helper.h" #include "tensorflow/core/util/tensor_format.h" -#include "tensorflow/core/lib/math/math_util.h" namespace tensorflow { @@ -252,11 +252,14 @@ __global__ void SwapDimension1And2InTensor3UsingTiles( int x = threadIdx.x; Dimension<3> output_dims = { - input_dims[0], input_dims[2], input_dims[1], + input_dims[0], + input_dims[2], + input_dims[1], }; Dimension<3> input_dims_in_tiles = { - input_dims[0], (input_dims[1] + TileSizeI - 1) / TileSizeI, + input_dims[0], + (input_dims[1] + TileSizeI - 1) / TileSizeI, (input_dims[2] + TileSizeJ - 1) / TileSizeJ, }; @@ -264,7 +267,8 @@ __global__ void SwapDimension1And2InTensor3UsingTiles( FlatToTensorIndex(blockIdx.x, input_dims_in_tiles); Index<3> input_tile_origin = { - input_tile_index[0], input_tile_index[1] * TileSizeI, + input_tile_index[0], + input_tile_index[1] * TileSizeI, input_tile_index[2] * TileSizeJ, }; @@ -322,11 +326,14 @@ __global__ void SwapDimension1And2InTensor3UsingTiles( __syncthreads(); Index<3> output_tile_index = { - input_tile_index[0], input_tile_index[2], input_tile_index[1], + input_tile_index[0], + input_tile_index[2], + input_tile_index[1], }; Index<3> output_tile_origin = { - output_tile_index[0], output_tile_index[1] * TileSizeJ, + output_tile_index[0], + output_tile_index[1] * TileSizeJ, output_tile_index[2] * TileSizeI, }; @@ -641,8 +648,9 @@ struct BatchNarrowMatrixTransposeDispatcher { static_assert( (TileLongSide & (TileLongSide - 1)) == 0, "The length of the longer side of the tile is always a power of 2."); - bool request_satisfied = max(tile_size_i, tile_size_j) <= TileLongSide && - min(tile_size_i, tile_size_j) <= TileShortSide; + bool request_satisfied = + std::max(tile_size_i, tile_size_j) <= TileLongSide && + std::min(tile_size_i, tile_size_j) <= TileShortSide; if (request_satisfied) { LaunchBatchNarrowMatrixTransposeKernel( @@ -655,7 +663,7 @@ struct BatchNarrowMatrixTransposeDispatcher { // determine whether it is the long side or the short side that falls short // of the request and increase that parameter accordingly. const bool long_side_request_not_satisfied = - max(tile_size_i, tile_size_j) > TileLongSide; + std::max(tile_size_i, tile_size_j) > TileLongSide; if (long_side_request_not_satisfied) { BatchNarrowMatrixTransposeDispatcher< @@ -683,8 +691,9 @@ struct BatchNarrowMatrixTransposeDispatcher< static_assert( (TileLongSide & (TileLongSide - 1)) == 0, "The length of the longer side of the tile is always a power of 2."); - bool request_satisfied = max(tile_size_i, tile_size_j) <= TileLongSide && - min(tile_size_i, tile_size_j) <= TileShortSide; + bool request_satisfied = + std::max(tile_size_i, tile_size_j) <= TileLongSide && + std::min(tile_size_i, tile_size_j) <= TileShortSide; if (request_satisfied) { LaunchBatchNarrowMatrixTransposeKernel( @@ -799,7 +808,7 @@ struct TransposeElemType<16> { // A helper function to make RunSwapDimension1And2InTensor3 concise. This // helper function looks at the data type and input matrix sizes and decides // the thread numbers and tile sizes to use. -template +template void SwapDimension1And2InTensor3WithNarrowMatrices( const GPUDevice& d, const T* input, const Dimension<3>& input_dims, T* output, const int kMinDimensionToUseTiles) { @@ -809,7 +818,7 @@ void SwapDimension1And2InTensor3WithNarrowMatrices( int tile_long_side_len = 0; int tile_short_side_len = 0; float lowest_cost = std::numeric_limits::max(); - int data_long_side = max(input_dims[1], input_dims[2]); + int data_long_side = std::max(input_dims[1], input_dims[2]); for (auto tile_size_pair : tile_spec) { int proposed_tile_long_side_len = tile_size_pair.first; @@ -854,12 +863,14 @@ void SwapDimension1And2InTensor3WithNarrowMatrices( // Truncate the shorter size requested according to the manual limit set in // tile_spec to make sure that we do not launch configurations violating // hardware limits. - requested_tile_size_i = requested_tile_size_i == tile_long_side_len - ? tile_long_side_len - : min(requested_tile_size_i, tile_short_side_len); - requested_tile_size_j = requested_tile_size_j == tile_long_side_len - ? tile_long_side_len - : min(requested_tile_size_j, tile_short_side_len); + requested_tile_size_i = + requested_tile_size_i == tile_long_side_len + ? tile_long_side_len + : std::min(requested_tile_size_i, tile_short_side_len); + requested_tile_size_j = + requested_tile_size_j == tile_long_side_len + ? tile_long_side_len + : std::min(requested_tile_size_j, tile_short_side_len); Dimension<3> input_dims_in_tiles = { input_dims[0], @@ -902,19 +913,21 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input, constexpr int kNumThreads = 256; Dimension<3> input_dims_in_tiles = { - input_dims[0], MathUtil::CeilOfRatio(input_dims[1], kTileSize), + input_dims[0], + MathUtil::CeilOfRatio(input_dims[1], kTileSize), MathUtil::CeilOfRatio(input_dims[2], kTileSize), }; int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; - SwapDimension1And2InTensor3UsingTiles + SwapDimension1And2InTensor3UsingTiles <<>>(input, input_dims, output); } else if (narrow_matrix) { - SwapDimension1And2InTensor3WithNarrowMatrices(d, input, input_dims, output, - kMinDimensionToUseTiles); + SwapDimension1And2InTensor3WithNarrowMatrices( + d, input, input_dims, output, kMinDimensionToUseTiles); } else { int total_element_count = input_dims[0] * input_dims[1] * input_dims[2]; CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d); diff --git a/tensorflow/core/kernels/conv_ops_using_gemm.cc b/tensorflow/core/kernels/conv_ops_using_gemm.cc index 20da77c36f64173f2dd40fe8e4a608e39c128447..af0a9fa82ee5778fa9e18cea59cf759fa468224f 100644 --- a/tensorflow/core/kernels/conv_ops_using_gemm.cc +++ b/tensorflow/core/kernels/conv_ops_using_gemm.cc @@ -468,18 +468,19 @@ class Conv2DUsingGemmOp : public BinaryOp { filter.shape().DebugString())); for (int i = 0; i < 3; i++) { - OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i), - std::numeric_limits::max()), - errors::InvalidArgument("filter too large")); + OP_REQUIRES( + context, + FastBoundsCheck(filter.dim_size(i), std::numeric_limits::max()), + errors::InvalidArgument("filter too large")); } // The last dimension for input is in_depth. It must be the same as the // filter's in_depth. const int64 in_depth = GetTensorDim(input, data_format_, 'C'); - OP_REQUIRES( - context, in_depth == filter.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - in_depth, " vs ", filter.dim_size(2))); + OP_REQUIRES(context, in_depth == filter.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, + " vs ", filter.dim_size(2))); // The last dimension for filter is out_depth. const int out_depth = static_cast(filter.dim_size(3)); @@ -487,18 +488,20 @@ class Conv2DUsingGemmOp : public BinaryOp { // The second dimension for input is rows/height. // The first dimension for filter is rows/height. const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H'); - OP_REQUIRES(context, FastBoundsCheck(input_rows_raw, - std::numeric_limits::max()), - errors::InvalidArgument("Input rows too large")); + OP_REQUIRES( + context, + FastBoundsCheck(input_rows_raw, std::numeric_limits::max()), + errors::InvalidArgument("Input rows too large")); const int input_rows = static_cast(input_rows_raw); const int filter_rows = static_cast(filter.dim_size(0)); // The third dimension for input is columns/width. // The second dimension for filter is columns/width. const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W'); - OP_REQUIRES(context, FastBoundsCheck(input_cols_raw, - std::numeric_limits::max()), - errors::InvalidArgument("Input cols too large")); + OP_REQUIRES( + context, + FastBoundsCheck(input_cols_raw, std::numeric_limits::max()), + errors::InvalidArgument("Input cols too large")); const int input_cols = static_cast(input_cols_raw); const int filter_cols = static_cast(filter.dim_size(1)); diff --git a/tensorflow/core/kernels/cross_op_gpu.cu.cc b/tensorflow/core/kernels/cross_op_gpu.cu.cc index 7ea0b3be0ca6b8c7df1ba5c311c7949f3672bda1..4a37f6cfbbc4c60e0a2e3cbf280b09acccc0a98c 100644 --- a/tensorflow/core/kernels/cross_op_gpu.cu.cc +++ b/tensorflow/core/kernels/cross_op_gpu.cu.cc @@ -17,8 +17,8 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/cross_op.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/cross_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/ctc_decoder_ops.cc b/tensorflow/core/kernels/ctc_decoder_ops.cc index 73ee3106048f1435f65d435405282574aa0cffda..96bdb6a241b1d88c7b14f22fc618ea9c95fb7642 100644 --- a/tensorflow/core/kernels/ctc_decoder_ops.cc +++ b/tensorflow/core/kernels/ctc_decoder_ops.cc @@ -19,13 +19,13 @@ limitations under the License. #include -#include "tensorflow/core/util/ctc/ctc_beam_search.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/ctc/ctc_beam_search.h" #include "tensorflow/core/util/sparse/sparse_tensor.h" namespace tensorflow { @@ -80,16 +80,17 @@ class CTCDecodeHelper { if (!(batch_size == (*seq_len)->dim_size(0))) { return errors::FailedPrecondition( - "len(sequence_length) != batch_size. ", "len(sequence_length): ", - (*seq_len)->dim_size(0), " batch_size: ", batch_size); + "len(sequence_length) != batch_size. ", + "len(sequence_length): ", (*seq_len)->dim_size(0), + " batch_size: ", batch_size); } auto seq_len_t = (*seq_len)->vec(); for (int b = 0; b < batch_size; ++b) { if (!(seq_len_t(b) <= max_time)) { - return errors::FailedPrecondition("sequence_length(", b, ") <= ", - max_time); + return errors::FailedPrecondition("sequence_length(", b, + ") <= ", max_time); } } diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index fb03adb7a5336919c85c4685f4cc7e7a8180892d..b38d838bf1ebdabad85ee3c70a936844f96f106a 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -113,8 +113,8 @@ class CTCLossOp : public OpKernel { const int64 batch_indices = g.group()[0]; OP_REQUIRES(ctx, FastBoundsCheck(batch_indices, batch_size), errors::InvalidArgument("labels batch index must be between ", - 0, " and ", batch_size, " but saw: ", - batch_indices)); + 0, " and ", batch_size, + " but saw: ", batch_indices)); auto values = g.values(); std::vector* b_values = &labels_t[batch_indices]; diff --git a/tensorflow/core/kernels/cwise_op_abs.cc b/tensorflow/core/kernels/cwise_op_abs.cc index 5fd38d9dc25c13e20766d1fed86c3f7af9912905..1466f24202fea4200f752985d620f1fbea61d35a 100644 --- a/tensorflow/core/kernels/cwise_op_abs.cc +++ b/tensorflow/core/kernels/cwise_op_abs.cc @@ -45,5 +45,5 @@ REGISTER_KERNEL_BUILDER(Name("Abs") .HostMemory("y") .TypeConstraint("T"), UnaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_acos.cc b/tensorflow/core/kernels/cwise_op_acos.cc index 12cc6c8bdd43b64aa1be2860b54e90aaf5e4c05e..4919122607426f719c660b23baf3a8c7cc38e076 100644 --- a/tensorflow/core/kernels/cwise_op_acos.cc +++ b/tensorflow/core/kernels/cwise_op_acos.cc @@ -24,5 +24,5 @@ REGISTER2(UnaryOp, GPU, "Acos", functor::acos, float, double); #if TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Acos", functor::acos, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_acosh.cc b/tensorflow/core/kernels/cwise_op_acosh.cc index 39c8814073382566bc3551fdf6d5afc7f1ef0012..c2b355ab7f4fb11cdc89d8f98a8ca1e293818966 100644 --- a/tensorflow/core/kernels/cwise_op_acosh.cc +++ b/tensorflow/core/kernels/cwise_op_acosh.cc @@ -17,12 +17,12 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_gradients.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Acosh", functor::acosh, float, double, - complex64, complex128); +REGISTER4(UnaryOp, CPU, "Acosh", functor::acosh, float, double, complex64, + complex128); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Acosh", functor::acosh, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA REGISTER2(UnaryOp, GPU, "Acosh", functor::acosh, float, double); diff --git a/tensorflow/core/kernels/cwise_op_add_1.cc b/tensorflow/core/kernels/cwise_op_add_1.cc index 608a6dce3d223d522776c59a3a1b2ad0d0c14147..bf32c8a54b34586e43d34cf8890ed37fe64b8c34 100644 --- a/tensorflow/core/kernels/cwise_op_add_1.cc +++ b/tensorflow/core/kernels/cwise_op_add_1.cc @@ -44,7 +44,6 @@ REGISTER_KERNEL_BUILDER(Name("AddV2") BinaryOp>); #endif - #if TENSORFLOW_USE_SYCL #define REGISTER_KERNEL(type) \ REGISTER(BinaryOp, SYCL, "Add", functor::add, type); \ @@ -66,5 +65,5 @@ REGISTER_KERNEL_BUILDER(Name("AddV2") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_add_2.cc b/tensorflow/core/kernels/cwise_op_add_2.cc index ac21ca06c929662271ad99b3756b8a22fc62a0cf..e8acbac28533ae36a5af8ce527529927f5fe4129 100644 --- a/tensorflow/core/kernels/cwise_op_add_2.cc +++ b/tensorflow/core/kernels/cwise_op_add_2.cc @@ -22,8 +22,8 @@ namespace tensorflow { // sharded files, only make its register calls when not __ANDROID_TYPES_SLIM__. #if !defined(__ANDROID_TYPES_SLIM__) -REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64, - uint8, complex128, string); +REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64, uint8, + complex128, string); // Notice: String is excluded to allow marking AddV2 is_commutative and // is_aggregate. REGISTER5(BinaryOp, CPU, "AddV2", functor::add, int8, int16, complex64, uint8, diff --git a/tensorflow/core/kernels/cwise_op_asin.cc b/tensorflow/core/kernels/cwise_op_asin.cc index c28e27d95ae661bdc02a905bb6efd5bdd79f23e5..fe8dfea1173ca6ec6727f2fb475c011176cacad4 100644 --- a/tensorflow/core/kernels/cwise_op_asin.cc +++ b/tensorflow/core/kernels/cwise_op_asin.cc @@ -24,5 +24,5 @@ REGISTER2(UnaryOp, GPU, "Asin", functor::asin, float, double); #if TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Asin", functor::asin, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_asinh.cc b/tensorflow/core/kernels/cwise_op_asinh.cc index 0aec6aac3442a98309e352cf1431b920a87f62fe..7cf0405f5244a1a5a7e7e09719da25d0e714a7da 100644 --- a/tensorflow/core/kernels/cwise_op_asinh.cc +++ b/tensorflow/core/kernels/cwise_op_asinh.cc @@ -1,10 +1,10 @@ - /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_gradients.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Asinh", functor::asinh, float, double, - complex64, complex128); +REGISTER4(UnaryOp, CPU, "Asinh", functor::asinh, float, double, complex64, + complex128); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Asinh", functor::asinh, float, double); diff --git a/tensorflow/core/kernels/cwise_op_atan.cc b/tensorflow/core/kernels/cwise_op_atan.cc index 7d73de48102189f5c0d92ce811fa639ce6ba2cf4..09f0448874f7dc2bc7140e03cbe38d42246c3087 100644 --- a/tensorflow/core/kernels/cwise_op_atan.cc +++ b/tensorflow/core/kernels/cwise_op_atan.cc @@ -24,5 +24,5 @@ REGISTER2(UnaryOp, GPU, "Atan", functor::atan, float, double); #if TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Atan", functor::atan, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_atanh.cc b/tensorflow/core/kernels/cwise_op_atanh.cc index 7b688db4c585b0f8d92f289cae598a78df7e379c..6170683fa64bdd50c00c8c774d6a1f137e60fa71 100644 --- a/tensorflow/core/kernels/cwise_op_atanh.cc +++ b/tensorflow/core/kernels/cwise_op_atanh.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_gradients.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double, - complex64, complex128); +REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double, complex64, + complex128); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Atanh", functor::atanh, float, double); diff --git a/tensorflow/core/kernels/cwise_op_ceil.cc b/tensorflow/core/kernels/cwise_op_ceil.cc index 0111e9d5fd18f1d94e8d39c5e67d16e04f21e854..816eadc80eb802de46ad4bb22521cbe6a7adf6b2 100644 --- a/tensorflow/core/kernels/cwise_op_ceil.cc +++ b/tensorflow/core/kernels/cwise_op_ceil.cc @@ -24,5 +24,5 @@ REGISTER3(UnaryOp, GPU, "Ceil", functor::ceil, float, Eigen::half, double); #if TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Ceil", functor::ceil, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_cos.cc b/tensorflow/core/kernels/cwise_op_cos.cc index d4b3b0e3935deeded3a0e07bd04056476c4cc29c..71ad0ff0dc2e3031df6177e4d067ad905c23169f 100644 --- a/tensorflow/core/kernels/cwise_op_cos.cc +++ b/tensorflow/core/kernels/cwise_op_cos.cc @@ -25,5 +25,5 @@ REGISTER3(UnaryOp, GPU, "Cos", functor::cos, float, Eigen::half, double); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Cos", functor::cos, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_cosh.cc b/tensorflow/core/kernels/cwise_op_cosh.cc index bca99a4f897d1cc601a082cc17ca6725929942a2..31b4bb3cadd9b2df5d0ae35b2c8ea4a155278a32 100644 --- a/tensorflow/core/kernels/cwise_op_cosh.cc +++ b/tensorflow/core/kernels/cwise_op_cosh.cc @@ -16,20 +16,18 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Cosh", functor::cosh, float, double, - complex64, complex128); +REGISTER4(UnaryOp, CPU, "Cosh", functor::cosh, float, double, complex64, + complex128); #if TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("Cosh") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T"), \ - UnaryOp>); +#define REGISTER_SYCL_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Cosh").Device(DEVICE_SYCL).TypeConstraint("T"), \ + UnaryOp>); REGISTER_SYCL_KERNEL(float); REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA REGISTER2(UnaryOp, GPU, "Cosh", functor::cosh, float, double); diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc index d44c1bf473e2e778a7d31890a25359e782e1dc94..c71c756e4461d4ed36628ea8a4f8a0922896302c 100644 --- a/tensorflow/core/kernels/cwise_op_div.cc +++ b/tensorflow/core/kernels/cwise_op_div.cc @@ -54,5 +54,5 @@ REGISTER_KERNEL_BUILDER(Name("Div") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_exp.cc b/tensorflow/core/kernels/cwise_op_exp.cc index 66d7b7d22ebe63bf42da848aa028fcbafc26864b..8f4ac98016cb252c9c952bbc3c67eb2ea3a92f21 100644 --- a/tensorflow/core/kernels/cwise_op_exp.cc +++ b/tensorflow/core/kernels/cwise_op_exp.cc @@ -26,5 +26,5 @@ REGISTER5(UnaryOp, GPU, "Exp", functor::exp, float, Eigen::half, double, #if TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Exp", functor::exp, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_expm1.cc b/tensorflow/core/kernels/cwise_op_expm1.cc index 4f723080060041f1223dbd86aa95f1cc64f5452c..ce03ad5de6285cfa64b56e3e5357e8c916f8baf3 100644 --- a/tensorflow/core/kernels/cwise_op_expm1.cc +++ b/tensorflow/core/kernels/cwise_op_expm1.cc @@ -23,5 +23,5 @@ REGISTER3(UnaryOp, GPU, "Expm1", functor::expm1, float, Eigen::half, double); #endif #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Expm1", functor::expm1, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_floor.cc b/tensorflow/core/kernels/cwise_op_floor.cc index 5a142b9ce9f8a32fe0569a78452cf710b2317760..d554d41c412bca4a8415852427190fb16f7f8f82 100644 --- a/tensorflow/core/kernels/cwise_op_floor.cc +++ b/tensorflow/core/kernels/cwise_op_floor.cc @@ -23,5 +23,5 @@ REGISTER3(UnaryOp, GPU, "Floor", functor::floor, float, Eigen::half, double); #endif #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Floor", functor::floor, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_floor_div.cc b/tensorflow/core/kernels/cwise_op_floor_div.cc index fa81ef0872d4ed6545c312b865e305ee430fdccb..fecbf859897bd1560da00f54756d4a1ffb7660d4 100644 --- a/tensorflow/core/kernels/cwise_op_floor_div.cc +++ b/tensorflow/core/kernels/cwise_op_floor_div.cc @@ -49,5 +49,5 @@ REGISTER_KERNEL_BUILDER(Name("FloorDiv") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_floor_mod.cc b/tensorflow/core/kernels/cwise_op_floor_mod.cc index 55f8a30461f16ebd52f27792f2d3b4a05fbf6977..29340b88506147eb9535893939cf28842c671cd9 100644 --- a/tensorflow/core/kernels/cwise_op_floor_mod.cc +++ b/tensorflow/core/kernels/cwise_op_floor_mod.cc @@ -40,5 +40,5 @@ REGISTER_KERNEL_BUILDER(Name("FloorMod") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc index e7dff5d0ac521cbe6d80efd1f591a9f23a0c650d..77723b3169fa137f0059ffd80a27e84115cb94ca 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc @@ -19,8 +19,8 @@ limitations under the License. namespace tensorflow { namespace functor { - DEFINE_UNARY1(conj, complex64); - DEFINE_UNARY1(conj, complex128); +DEFINE_UNARY1(conj, complex64); +DEFINE_UNARY1(conj, complex128); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc index 3675398126f3ce13722e41b43f382c7fa1eaf111..26748ef0e724903c95f6665a5d7c00bdbd298a28 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { namespace functor { DEFINE_BINARY10(equal_to, float, Eigen::half, double, uint8, int8, int16, int64, - complex64, complex128, bool); + complex64, complex128, bool); DEFINE_APPROXIMATE_EQUAL2(float, double); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_invert.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_invert.cu.cc index 62f33612db079377729d8d0edde0c37d43fb9cfb..1072ef3aa687ac75dde4d1bacb60897775e74021 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_invert.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_invert.cu.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_UNARY6(invert, int8, int16, int32, int64, uint8, uint16); +DEFINE_UNARY8(invert, int8, int16, int32, int64, uint8, uint16, uint32, uint64); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc index a54dbdfc247dfcbba370852f525f0ca686b6c1b4..627ecc8c802a2bbd428f9cc2160bec379d7b654b 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc @@ -15,8 +15,10 @@ limitations under the License. #if GOOGLE_CUDA -#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" +#define EIGEN_USE_GPU + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" namespace tensorflow { namespace functor { @@ -38,19 +40,17 @@ struct SelectScalarFunctor { typename TTypes::ConstScalar cond, typename TTypes::ConstFlat then_flat, typename TTypes::ConstFlat else_flat) { - #if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::array rank1{1}; + Eigen::array rank1{1}; #else - Eigen::IndexList> rank1; + Eigen::IndexList > rank1; #endif - const int size = then_flat.dimension(0); - Eigen::array broadcast_dims{size}; - - To32Bit(out).device(d) = cond.reshape(rank1) - .broadcast(broadcast_dims) - .select(then_flat, else_flat); + const int size = then_flat.dimension(0); + Eigen::array broadcast_dims{size}; + To32Bit(out).device(d) = cond.reshape(rank1) + .broadcast(broadcast_dims) + .select(then_flat, else_flat); } }; @@ -89,8 +89,8 @@ struct BatchSelectFunctor { } }; -#define SELECT_FUNCTOR(T) \ - template struct SelectFunctor; \ +#define SELECT_FUNCTOR(T) \ + template struct SelectFunctor; \ template struct SelectScalarFunctor; \ template struct BatchSelectFunctor; diff --git a/tensorflow/core/kernels/cwise_op_greater.cc b/tensorflow/core/kernels/cwise_op_greater.cc index ba89899fb323c58f0a0045f3ef32a897f5f2680a..a4ea40883694540903ac80683d3a7151fac4a583 100644 --- a/tensorflow/core/kernels/cwise_op_greater.cc +++ b/tensorflow/core/kernels/cwise_op_greater.cc @@ -43,5 +43,5 @@ REGISTER_KERNEL_BUILDER(Name("Greater") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_greater_equal.cc b/tensorflow/core/kernels/cwise_op_greater_equal.cc index 8f0c483aecd7f84bbb8ac47e4c8b5877b40335d4..3f34d6269ef4a1ab0da3dae1d08da037c5507bdd 100644 --- a/tensorflow/core/kernels/cwise_op_greater_equal.cc +++ b/tensorflow/core/kernels/cwise_op_greater_equal.cc @@ -35,7 +35,8 @@ REGISTER_KERNEL_BUILDER(Name("GreaterEqual") #endif #ifdef TENSORFLOW_USE_SYCL -REGISTER2(BinaryOp, SYCL, "GreaterEqual", functor::greater_equal, float, double); +REGISTER2(BinaryOp, SYCL, "GreaterEqual", functor::greater_equal, float, + double); REGISTER_KERNEL_BUILDER(Name("GreaterEqual") .Device(DEVICE_SYCL) @@ -44,5 +45,5 @@ REGISTER_KERNEL_BUILDER(Name("GreaterEqual") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_invert.cc b/tensorflow/core/kernels/cwise_op_invert.cc index df2c02e42e17f5bbcb74b637adcfb1dbd5cac3c1..98c8d7e9b2e7b727e4662d1fce2efee12a1d7663 100644 --- a/tensorflow/core/kernels/cwise_op_invert.cc +++ b/tensorflow/core/kernels/cwise_op_invert.cc @@ -16,17 +16,17 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER6(UnaryOp, CPU, "Invert", functor::invert, int8, int16, int32, int64, - uint8, uint16); +REGISTER8(UnaryOp, CPU, "Invert", functor::invert, int8, int16, int32, int64, + uint8, uint16, uint32, uint64); #ifdef TENSORFLOW_USE_SYCL REGISTER6(UnaryOp, SYCL, "Invert", functor::invert, int8, int16, int32, int64, - uint8, uint16); + uint8, uint16, uint32, uint64); #endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA -REGISTER6(UnaryOp, GPU, "Invert", functor::invert, int8, int16, int32, int64, - uint8, uint16); +REGISTER8(UnaryOp, GPU, "Invert", functor::invert, int8, int16, int32, int64, + uint8, uint16, uint32, uint64); #endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_isfinite.cc b/tensorflow/core/kernels/cwise_op_isfinite.cc index 53ec1c1c63f17a03218535c974e591b4eec62a72..ae1e590d24290a397096cbdfdf08b7e2d348f362 100644 --- a/tensorflow/core/kernels/cwise_op_isfinite.cc +++ b/tensorflow/core/kernels/cwise_op_isfinite.cc @@ -26,5 +26,5 @@ REGISTER3(UnaryOp, GPU, "IsFinite", functor::isfinite, float, Eigen::half, #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "IsFinite", functor::isfinite, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_isinf.cc b/tensorflow/core/kernels/cwise_op_isinf.cc index 4b34744304f6c856fb98d39fbadc1e1958c84238..f22ca21e1ca425978b23910c27881eed626626e4 100644 --- a/tensorflow/core/kernels/cwise_op_isinf.cc +++ b/tensorflow/core/kernels/cwise_op_isinf.cc @@ -24,5 +24,5 @@ REGISTER3(UnaryOp, GPU, "IsInf", functor::isinf, float, Eigen::half, double); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "IsInf", functor::isinf, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_isnan.cc b/tensorflow/core/kernels/cwise_op_isnan.cc index ad2dd3f722cebba926dd04748ca146c2ecfc0848..aa180c247e7d01ef0f2898b4a50a71c3c3bc6941 100644 --- a/tensorflow/core/kernels/cwise_op_isnan.cc +++ b/tensorflow/core/kernels/cwise_op_isnan.cc @@ -24,5 +24,5 @@ REGISTER3(UnaryOp, GPU, "IsNan", functor::isnan, float, Eigen::half, double); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "IsNan", functor::isnan, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc index 136c3666dfc351fa0485eeff060a6ea3a7d48c08..00cdecdbd184b84b6601eda76dd5dfded5aa1e1b 100644 --- a/tensorflow/core/kernels/cwise_op_less.cc +++ b/tensorflow/core/kernels/cwise_op_less.cc @@ -42,5 +42,5 @@ REGISTER_KERNEL_BUILDER(Name("Less") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc index 97a2508d1290c5afe758db9ff54a22a22b6dcac0..11806c5fc774dc3a37abc733127e4b6660f27f9c 100644 --- a/tensorflow/core/kernels/cwise_op_less_equal.cc +++ b/tensorflow/core/kernels/cwise_op_less_equal.cc @@ -44,5 +44,5 @@ REGISTER_KERNEL_BUILDER(Name("LessEqual") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_log.cc b/tensorflow/core/kernels/cwise_op_log.cc index 7fdfdff0e38ea2bfe18acac86b148a4e1e944117..98936e0f960f1f407c2187746ca80d3db0a93412 100644 --- a/tensorflow/core/kernels/cwise_op_log.cc +++ b/tensorflow/core/kernels/cwise_op_log.cc @@ -25,5 +25,5 @@ REGISTER3(UnaryOp, GPU, "Log", functor::log, float, Eigen::half, double); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Log", functor::log, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_log1p.cc b/tensorflow/core/kernels/cwise_op_log1p.cc index 25ad7b24bb1cee3a09c4ea81cccf79b6a4dabeb9..162ca9e07cdc862e04276aca0dce0ad2f4cfc70e 100644 --- a/tensorflow/core/kernels/cwise_op_log1p.cc +++ b/tensorflow/core/kernels/cwise_op_log1p.cc @@ -25,5 +25,5 @@ REGISTER3(UnaryOp, GPU, "Log1p", functor::log1p, float, Eigen::half, double); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Log1p", functor::log1p, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_maximum.cc b/tensorflow/core/kernels/cwise_op_maximum.cc index 87d54e380b4b923f72aff1eb33d56dd7d8a0dd11..e8a58eea80e611d29886af773be5f1ee061d6f66 100644 --- a/tensorflow/core/kernels/cwise_op_maximum.cc +++ b/tensorflow/core/kernels/cwise_op_maximum.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER5(BinaryOp, CPU, "Maximum", functor::maximum, float, Eigen::half, - double, int32, int64); +REGISTER6(BinaryOp, CPU, "Maximum", functor::maximum, float, Eigen::half, + bfloat16, double, int32, int64); #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "Maximum", functor::maximum, float, Eigen::half, double, int64); @@ -43,5 +43,5 @@ REGISTER_KERNEL_BUILDER(Name("Maximum") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_minimum.cc b/tensorflow/core/kernels/cwise_op_minimum.cc index 442171193bfeb41e8594bf708590fc4d52291685..dff83df828f076a076a8f220d04974344d8ffafc 100644 --- a/tensorflow/core/kernels/cwise_op_minimum.cc +++ b/tensorflow/core/kernels/cwise_op_minimum.cc @@ -43,6 +43,6 @@ REGISTER_KERNEL_BUILDER(Name("Minimum") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_mul_1.cc b/tensorflow/core/kernels/cwise_op_mul_1.cc index 023eb07ca3f52f49c95b5b6450e3417b7cbeabe4..0e8d2e37350dbbb942bd5ed6b16392b6288313fe 100644 --- a/tensorflow/core/kernels/cwise_op_mul_1.cc +++ b/tensorflow/core/kernels/cwise_op_mul_1.cc @@ -17,8 +17,8 @@ limitations under the License. namespace tensorflow { -REGISTER5(BinaryOp, CPU, "Mul", functor::mul, float, Eigen::half, double, - uint8, int32); +REGISTER5(BinaryOp, CPU, "Mul", functor::mul, float, Eigen::half, double, uint8, + int32); #if defined(__ANDROID_TYPES_SLIM__) // We only register the first type when we have multi-argument calls in the // case where we're trying to reduce executable size, but it turns out that the @@ -28,7 +28,7 @@ REGISTER(BinaryOp, CPU, "Mul", functor::mul, int32); #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "Mul", functor::mul, float, Eigen::half, double, - uint8); + uint8); // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. @@ -50,5 +50,5 @@ REGISTER_KERNEL_BUILDER(Name("Mul") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_mul_2.cc b/tensorflow/core/kernels/cwise_op_mul_2.cc index 7be5857cc06d0f6755d3f4cba2ca67f009740d46..6aa8f8836406ab4f350bc7b6cc1e88bd612ad933 100644 --- a/tensorflow/core/kernels/cwise_op_mul_2.cc +++ b/tensorflow/core/kernels/cwise_op_mul_2.cc @@ -22,11 +22,11 @@ namespace tensorflow { // sharded files, only make its register calls when not __ANDROID_TYPES_SLIM__. #if !defined(__ANDROID_TYPES_SLIM__) -REGISTER6(BinaryOp, CPU, "Mul", functor::mul, - int8, uint16, int16, int64, complex64, complex128); +REGISTER6(BinaryOp, CPU, "Mul", functor::mul, int8, uint16, int16, int64, + complex64, complex128); #if GOOGLE_CUDA REGISTER6(BinaryOp, GPU, "Mul", functor::mul, int8, uint16, int16, int64, - complex64, complex128); + complex64, complex128); #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_neg.cc b/tensorflow/core/kernels/cwise_op_neg.cc index 536891b548f043cb25726d70bfdd362ed0294512..a136769b912718a5749273050a2226da3fa9e3cf 100644 --- a/tensorflow/core/kernels/cwise_op_neg.cc +++ b/tensorflow/core/kernels/cwise_op_neg.cc @@ -27,7 +27,7 @@ REGISTER_KERNEL_BUILDER(Name("Neg") .HostMemory("y") .TypeConstraint("T"), UnaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA REGISTER6(UnaryOp, GPU, "Neg", functor::neg, float, Eigen::half, double, int64, diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc index 7bd81ee12719618181a75907ce547815b1076b84..02cd298745795294bfb8117a24ba930a7f471788 100644 --- a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc +++ b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc @@ -17,7 +17,7 @@ limitations under the License. namespace tensorflow { REGISTER6(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half, - double, uint8, int8, int16); + double, uint8, int8, int16); #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half, double, uint8); diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_2.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_2.cc index 7d4ecec59f1564c90c11bb05d6e96c7e1b52a60d..05bdea66367c6d525469dd9cdc28b56d3e4c2adc 100644 --- a/tensorflow/core/kernels/cwise_op_not_equal_to_2.cc +++ b/tensorflow/core/kernels/cwise_op_not_equal_to_2.cc @@ -30,5 +30,5 @@ REGISTER6(BinaryOp, GPU, "NotEqual", functor::not_equal_to, int8, int16, int64, #endif // GOOGLE_CUDA -#endif // !defined(__ANDROID_TYPES_SLIM__) +#endif // !defined(__ANDROID_TYPES_SLIM__) } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_reciprocal.cc b/tensorflow/core/kernels/cwise_op_reciprocal.cc index 8c0e21f9cf3535dd5f62657de165150f9efcae2e..aee25747b866c910a799b76e3b00b699bef41566 100644 --- a/tensorflow/core/kernels/cwise_op_reciprocal.cc +++ b/tensorflow/core/kernels/cwise_op_reciprocal.cc @@ -38,7 +38,7 @@ REGISTER4(UnaryOp, GPU, "Reciprocal", functor::inverse, float, Eigen::half, #endif #ifdef TENSORFLOW_USE_SYCL REGISTER(UnaryOp, SYCL, "Reciprocal", functor::inverse, float); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL REGISTER5(SimpleBinaryOp, CPU, "ReciprocalGrad", functor::inverse_grad, float, Eigen::half, double, complex64, complex128); @@ -48,5 +48,5 @@ REGISTER3(SimpleBinaryOp, GPU, "ReciprocalGrad", functor::inverse_grad, float, #endif #ifdef TENSORFLOW_USE_SYCL REGISTER(SimpleBinaryOp, SYCL, "ReciprocalGrad", functor::inverse_grad, float); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 3dd9de8d897479456c462ea068c5eda6354b199b..e259daaba47e2d0ab434e47b39376f7b723bdc9d 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -30,7 +30,7 @@ typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class SelectOp : public OpKernel { @@ -185,7 +185,7 @@ REGISTER_SELECT_SYCL(double); REGISTER_SELECT_SYCL(int32); REGISTER_SELECT_SYCL(int64); #undef REGISTER_SELECT_SYCL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace functor { @@ -201,13 +201,11 @@ struct SelectFunctorBase { }; template -struct SelectFunctor - : SelectFunctorBase {}; +struct SelectFunctor : SelectFunctorBase {}; #ifdef TENSORFLOW_USE_SYCL template -struct SelectFunctor - : SelectFunctorBase {}; -#endif // TENSORFLOW_USE_SYCL +struct SelectFunctor : SelectFunctorBase {}; +#endif // TENSORFLOW_USE_SYCL template struct SelectScalarFunctorBase { @@ -222,12 +220,12 @@ struct SelectScalarFunctorBase { // CPU Specializations of Select functors with scalar template struct SelectScalarFunctor - : SelectScalarFunctorBase {}; + : SelectScalarFunctorBase {}; #ifdef TENSORFLOW_USE_SYCL template struct SelectScalarFunctor - : SelectScalarFunctorBase {}; -#endif // TENSORFLOW_USE_SYCL + : SelectScalarFunctorBase {}; +#endif // TENSORFLOW_USE_SYCL template struct BatchSelectFunctorBase { @@ -240,8 +238,8 @@ struct BatchSelectFunctorBase { const Eigen::DenseIndex all_but_batch = then_flat_outer_dims.dimension(1); #if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::array broadcast_dims{{ 1, all_but_batch }}; - Eigen::Tensor::Dimensions reshape_dims{{ batch, 1 }}; + Eigen::array broadcast_dims{{1, all_but_batch}}; + Eigen::Tensor::Dimensions reshape_dims{{batch, 1}}; #else Eigen::IndexList, Eigen::DenseIndex> broadcast_dims; broadcast_dims.set(1, all_but_batch); @@ -257,13 +255,13 @@ struct BatchSelectFunctorBase { }; template -struct BatchSelectFunctor - : BatchSelectFunctorBase {}; +struct BatchSelectFunctor : BatchSelectFunctorBase { +}; #ifdef TENSORFLOW_USE_SYCL template struct BatchSelectFunctor - : BatchSelectFunctorBase {}; -#endif // TENSORFLOW_USE_SYCL + : BatchSelectFunctorBase {}; +#endif // TENSORFLOW_USE_SYCL } // namespace functor diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc index a76a088ac8f762a1aa980170ba4617b0c66c6e47..c132fdb63f2b8669294de63ec6cb8567002e9bdd 100644 --- a/tensorflow/core/kernels/cwise_op_sigmoid.cc +++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc @@ -25,7 +25,7 @@ REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half, #endif #ifdef TENSORFLOW_USE_SYCL REGISTER(UnaryOp, SYCL, "Sigmoid", functor::sigmoid, float); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL REGISTER5(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, float, Eigen::half, double, complex64, complex128); @@ -35,6 +35,6 @@ REGISTER3(SimpleBinaryOp, GPU, "SigmoidGrad", functor::sigmoid_grad, float, #endif #ifdef TENSORFLOW_USE_SYCL REGISTER(SimpleBinaryOp, SYCL, "SigmoidGrad", functor::sigmoid_grad, float); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sign.cc b/tensorflow/core/kernels/cwise_op_sign.cc index a4084d5ad1796f5af1ce1a62e76c9dc6b473586d..02915ff4ce4547516e6e12bc250b605135d70521 100644 --- a/tensorflow/core/kernels/cwise_op_sign.cc +++ b/tensorflow/core/kernels/cwise_op_sign.cc @@ -41,6 +41,6 @@ REGISTER_KERNEL_BUILDER(Name("Sign") .HostMemory("y") .TypeConstraint("T"), UnaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sin.cc b/tensorflow/core/kernels/cwise_op_sin.cc index b91ff1ac30ba8e7259223e011aa1e70b0a05f623..16c6057864073596592b62f4463cfd1229d3a415 100644 --- a/tensorflow/core/kernels/cwise_op_sin.cc +++ b/tensorflow/core/kernels/cwise_op_sin.cc @@ -25,5 +25,5 @@ REGISTER3(UnaryOp, GPU, "Sin", functor::sin, float, Eigen::half, double); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Sin", functor::sin, float, double); -#endif // TENSORFLOW_USE_SYC +#endif // TENSORFLOW_USE_SYC } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sinh.cc b/tensorflow/core/kernels/cwise_op_sinh.cc index 055f0b12e14b1e1059600b968584a2ff9924237f..26b7a940aa8dd4fd6ce439eac17b6fd44d0fe3fd 100644 --- a/tensorflow/core/kernels/cwise_op_sinh.cc +++ b/tensorflow/core/kernels/cwise_op_sinh.cc @@ -16,20 +16,18 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER4(UnaryOp, CPU, "Sinh", functor::sinh, float, double, - complex64, complex128); +REGISTER4(UnaryOp, CPU, "Sinh", functor::sinh, float, double, complex64, + complex128); #if TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("Sinh") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T"), \ - UnaryOp>); +#define REGISTER_SYCL_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Sinh").Device(DEVICE_SYCL).TypeConstraint("T"), \ + UnaryOp>); REGISTER_SYCL_KERNEL(float); REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL -#endif // TENSORFLOW_USE_SYC +#endif // TENSORFLOW_USE_SYC #if GOOGLE_CUDA REGISTER2(UnaryOp, GPU, "Sinh", functor::sinh, float, double); diff --git a/tensorflow/core/kernels/cwise_op_sqrt.cc b/tensorflow/core/kernels/cwise_op_sqrt.cc index 00efbb00f1501669b221682c565b4843c0497128..497756133d05249141823481e6ef43b73a84660b 100644 --- a/tensorflow/core/kernels/cwise_op_sqrt.cc +++ b/tensorflow/core/kernels/cwise_op_sqrt.cc @@ -25,7 +25,7 @@ REGISTER3(UnaryOp, GPU, "Sqrt", functor::sqrt, float, Eigen::half, double); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Sqrt", functor::sqrt, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL REGISTER5(SimpleBinaryOp, CPU, "SqrtGrad", functor::sqrt_grad, float, Eigen::half, double, complex64, complex128); @@ -36,5 +36,5 @@ REGISTER3(SimpleBinaryOp, GPU, "SqrtGrad", functor::sqrt_grad, float, #ifdef TENSORFLOW_USE_SYCL REGISTER2(SimpleBinaryOp, SYCL, "SqrtGrad", functor::sqrt_grad, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_square.cc b/tensorflow/core/kernels/cwise_op_square.cc index 07a4b0b084d804c46a8a4a0bc272f78b22d7e845..7fc2f6bf08b2c825f471123e1ab58bd060f6070a 100644 --- a/tensorflow/core/kernels/cwise_op_square.cc +++ b/tensorflow/core/kernels/cwise_op_square.cc @@ -42,5 +42,5 @@ REGISTER_KERNEL_BUILDER(Name("Square") .HostMemory("y") .TypeConstraint("T"), UnaryOp>); -#endif // TENSORFLOW_USE_SYC +#endif // TENSORFLOW_USE_SYC } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc index 6adaecba04bfcf1b42a760d712eece493131ade2..025041946ac71f0e8f4724f9432d5e2901e348cc 100644 --- a/tensorflow/core/kernels/cwise_op_sub.cc +++ b/tensorflow/core/kernels/cwise_op_sub.cc @@ -53,5 +53,5 @@ REGISTER_KERNEL_BUILDER(Name("Sub") .HostMemory("z") .TypeConstraint("T"), BinaryOp>); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_tan.cc b/tensorflow/core/kernels/cwise_op_tan.cc index 7891b1183dd56b9809ef7f5dc76c3f04fe605b02..c1a25767d3146abc43442cc25b48378c74f8e984 100644 --- a/tensorflow/core/kernels/cwise_op_tan.cc +++ b/tensorflow/core/kernels/cwise_op_tan.cc @@ -24,5 +24,5 @@ REGISTER2(UnaryOp, GPU, "Tan", functor::tan, float, double); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Tan", functor::tan, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_tanh.cc b/tensorflow/core/kernels/cwise_op_tanh.cc index 8b3900892c300ee266b1a7fb066ef79c88c3d087..c5005f5ea8aa3e0b392bd038983d1658c8c56520 100644 --- a/tensorflow/core/kernels/cwise_op_tanh.cc +++ b/tensorflow/core/kernels/cwise_op_tanh.cc @@ -26,7 +26,7 @@ REGISTER3(UnaryOp, GPU, "Tanh", functor::tanh, float, Eigen::half, double); #ifdef TENSORFLOW_USE_SYCL REGISTER2(UnaryOp, SYCL, "Tanh", functor::tanh, float, double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL REGISTER5(SimpleBinaryOp, CPU, "TanhGrad", functor::tanh_grad, float, Eigen::half, double, complex64, complex128); diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc index e561e59cf5a23d6d4881c7c5fcf289ccff4c21cb..980edffceb35ee3f3d7f3557093baec1487a9b5a 100644 --- a/tensorflow/core/kernels/cwise_ops_common.cc +++ b/tensorflow/core/kernels/cwise_ops_common.cc @@ -57,9 +57,9 @@ BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx) in1(ctx->input(1)), bcast(BCast::FromShape(in0.shape()), BCast::FromShape(in1.shape())) { if (!bcast.IsValid()) { - ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ", - in0.shape().DebugString(), " vs. ", - in1.shape().DebugString())); + ctx->SetStatus(errors::InvalidArgument( + "Incompatible shapes: ", in0.shape().DebugString(), " vs. ", + in1.shape().DebugString())); return; } const TensorShape output_shape = BCast::ToShape(bcast.output_shape()); diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h index 6dd108f7226ab5a64b8c074afa9ab219f045158a..965e42dcce1b24460d28e24cd33c520598ecfc41 100644 --- a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h +++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h @@ -136,6 +136,9 @@ struct ApproximateEqual { #define DEFINE_UNARY7(F, T0, T1, T2, T3, T4, T5, T6) \ DEFINE_UNARY2(F, T0, T1); \ DEFINE_UNARY5(F, T2, T3, T4, T5, T6) +#define DEFINE_UNARY8(F, T0, T1, T2, T3, T4, T5, T6, T7) \ + DEFINE_UNARY4(F, T0, T1, T2, T3); \ + DEFINE_UNARY4(F, T4, T5, T6, T7) // Macros to explicitly instantiate kernels on GPU for multiple types // (T0, T1, etc.) for BinaryFunctor. diff --git a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h index 439477070893d37a9fcb7b662e379cce2955b07a..e81b840a509ada73e62a763b203763d9e4e65363 100644 --- a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h +++ b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h @@ -50,16 +50,16 @@ struct SimpleBinaryFunctor { // Macros to explicitly instantiate kernels on GPU for multiple types // (T0, T1, etc.) for SimpleBiaryFunctor (e.g., functor::tanh_grad). -#define DEFINE_SIMPLE_BINARY1(F, T) \ +#define DEFINE_SIMPLE_BINARY1(F, T) \ template struct SimpleBinaryFunctor > -#define DEFINE_SIMPLE_BINARY2(F, T0, T1) \ - DEFINE_SIMPLE_BINARY1(F, T0); \ +#define DEFINE_SIMPLE_BINARY2(F, T0, T1) \ + DEFINE_SIMPLE_BINARY1(F, T0); \ DEFINE_SIMPLE_BINARY1(F, T1) -#define DEFINE_SIMPLE_BINARY3(F, T0, T1, T2) \ - DEFINE_SIMPLE_BINARY2(F, T0, T1); \ +#define DEFINE_SIMPLE_BINARY3(F, T0, T1, T2) \ + DEFINE_SIMPLE_BINARY2(F, T0, T1); \ DEFINE_SIMPLE_BINARY1(F, T2) -#define DEFINE_SIMPLE_BINARY4(F, T0, T1, T2, T3) \ - DEFINE_SIMPLE_BINARY2(F, T0, T1); \ +#define DEFINE_SIMPLE_BINARY4(F, T0, T1, T2, T3) \ + DEFINE_SIMPLE_BINARY2(F, T0, T1); \ DEFINE_SIMPLE_BINARY2(F, T2, T3) #define DEFINE_SIMPLE_BINARY5(F, T0, T1, T2, T3, T4) \ DEFINE_SIMPLE_BINARY2(F, T0, T1); \ diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h index 77b330f5899815d5784659515e43ee497bdca58e..82cdae9a348aaf3625e1e4cf9f80ea7768694062 100644 --- a/tensorflow/core/kernels/cwise_ops_gradients.h +++ b/tensorflow/core/kernels/cwise_ops_gradients.h @@ -171,7 +171,6 @@ struct SimpleBinaryFunctor { } }; - #ifdef TENSORFLOW_USE_SYCL // Partial specialization of BinaryFunctor for SYCL devices typedef Eigen::SyclDevice SYCLDevice; @@ -184,7 +183,7 @@ struct SimpleBinaryFunctor { } }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template struct tanh_grad : base> {}; diff --git a/tensorflow/core/kernels/cwise_ops_sycl_common.h b/tensorflow/core/kernels/cwise_ops_sycl_common.h index 3f6ff7303d627ca64abd0f93658bf1b40ce4d71e..3e107cee04c787d71326bbe4799565f8609f6f4e 100644 --- a/tensorflow/core/kernels/cwise_ops_sycl_common.h +++ b/tensorflow/core/kernels/cwise_ops_sycl_common.h @@ -51,7 +51,8 @@ struct BinaryFunctor { void operator()(const SYCLDevice& d, typename Functor::tout_type out, typename Functor::tin_type in0, typename Functor::tin_type in1, bool* error) { - To32Bit(out).device(d) = To32Bit(in0).binaryExpr(To32Bit(in1), typename Functor::func()); + To32Bit(out).device(d) = + To32Bit(in0).binaryExpr(To32Bit(in1), typename Functor::func()); } void Left(const SYCLDevice& d, typename Functor::tout_type out, @@ -61,7 +62,9 @@ struct BinaryFunctor { constexpr int NumDims = Functor::tin_type::NumDimensions; static_assert(NumDims == 1, "Unexpected size"); Eigen::Sizes<1> scalar_dim; - out.device(d) = scalar.reshape(scalar_dim).broadcast(in.dimensions()).binaryExpr(in, Binary()); + out.device(d) = scalar.reshape(scalar_dim) + .broadcast(in.dimensions()) + .binaryExpr(in, Binary()); } void Right(const SYCLDevice& d, typename Functor::tout_type out, @@ -71,7 +74,8 @@ struct BinaryFunctor { constexpr int NumDims = Functor::tin_type::NumDimensions; static_assert(NumDims == 1, "Unexpected size"); Eigen::Sizes<1> scalar_dim; - out.device(d) = in.binaryExpr(scalar.reshape(scalar_dim).broadcast(in.dimensions()), Binary()); + out.device(d) = in.binaryExpr( + scalar.reshape(scalar_dim).broadcast(in.dimensions()), Binary()); } void BCast(const SYCLDevice& d, diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc index bca0f1004d5f41fd3c8fd8b4eebd44c981053520..39f497e71612fc08a085e410edae73669fc9993a 100644 --- a/tensorflow/core/kernels/cwise_ops_test.cc +++ b/tensorflow/core/kernels/cwise_ops_test.cc @@ -54,36 +54,36 @@ int ColsFromArg(int arg) { return (arg % kRows); } BM_UNARY(cpu, Floor, float, DT_FLOAT); #if GOOGLE_CUDA BM_UNARY(gpu, Floor, float, DT_FLOAT); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL BM_UNARY(sycl, Floor, float, DT_FLOAT); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL BM_UNARY(cpu, Floor, double, DT_DOUBLE); #if GOOGLE_CUDA BM_UNARY(gpu, Floor, double, DT_DOUBLE); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL BM_UNARY(sycl, Floor, double, DT_DOUBLE); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL BM_UNARY(cpu, Conj, std::complex, DT_COMPLEX64); #if GOOGLE_CUDA BM_UNARY(gpu, Conj, std::complex, DT_COMPLEX64); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA BM_UNARY(cpu, Conj, std::complex, DT_COMPLEX128); #if GOOGLE_CUDA BM_UNARY(gpu, Conj, std::complex, DT_COMPLEX128); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA BM_UNARY(cpu, Rint, double, DT_DOUBLE); #if GOOGLE_CUDA BM_UNARY(gpu, Rint, double, DT_DOUBLE); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA BM_UNARY(cpu, Rint, float, DT_FLOAT); #if GOOGLE_CUDA BM_UNARY(gpu, Rint, float, DT_FLOAT); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA // data func scalar. Graph* BinaryScalar(int num, const string& func) { @@ -113,18 +113,18 @@ Graph* BinaryScalar(int num, const string& func) { BM_BINARY_SCALAR(cpu, Less); #if GOOGLE_CUDA BM_BINARY_SCALAR(gpu, Less); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL BM_BINARY_SCALAR(sycl, Less); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL BM_BINARY_SCALAR(cpu, Add); #if GOOGLE_CUDA BM_BINARY_SCALAR(gpu, Add); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL BM_BINARY_SCALAR(sycl, Add); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #undef BM_BINARY_SCALAR template @@ -163,11 +163,11 @@ using Eigen::half; BM_BIAS_ADD_ALL(cpu, float, DT_FLOAT); #if GOOGLE_CUDA BM_BIAS_ADD_ALL(gpu, float, DT_FLOAT); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA BM_BIAS_ADD_ALL(cpu, half, DT_HALF); #if GOOGLE_CUDA BM_BIAS_ADD_ALL(gpu, half, DT_HALF); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #undef BM_BIAS_ADD_ALL #undef BM_BIAS_ADD @@ -217,15 +217,15 @@ using Eigen::half; #if GOOGLE_CUDA BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, float, DT_FLOAT); BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, half, DT_HALF); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, float, DT_FLOAT); #if GOOGLE_CUDA BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, float, DT_FLOAT); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, half, DT_HALF); #if GOOGLE_CUDA BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, half, DT_HALF); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #undef BM_BIAS_ADD_GRAD_ALL #undef BM_BIAS_ADD_GRAD @@ -265,10 +265,10 @@ Graph* BcastAdd(int rows, int cols, int dim) { BM_BCAST_ADD_ROW_ALL(cpu); #if GOOGLE_CUDA BM_BCAST_ADD_ROW_ALL(gpu); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL BM_BCAST_ADD_ROW_ALL(sycl); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #undef BM_BCAST_ADD_ROW_ALL #undef BM_BCAST_ADD_ROW @@ -291,10 +291,10 @@ BM_BCAST_ADD_ROW_ALL(sycl); BM_BCAST_ADD_COL_ALL(cpu); #if GOOGLE_CUDA BM_BCAST_ADD_COL_ALL(gpu); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL BM_BCAST_ADD_COL_ALL(sycl); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #undef BM_BCAST_ADD_COL_ALL #undef BM_BCAST_ADD_COL diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 500ee7b43f2fbd730ae38c3820ed28ec67b9036c..1e3b0c231f35c12d2e9e23d8d503b3a7492ab676 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -44,11 +44,13 @@ tf_kernel_library( ], ) +# TODO(mrry): Remove this empty forwarding library. cc_library( name = "dataset", - srcs = ["dataset.cc"], + srcs = [], hdrs = ["dataset.h"], deps = [ + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -81,9 +83,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:session_options", "//tensorflow/core/kernels:variable_ops", ], ) @@ -122,6 +122,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:batch_util", ], ) @@ -317,18 +318,6 @@ tf_kernel_library( ], ) -tf_kernel_library( - name = "ignore_errors_dataset_op", - srcs = ["ignore_errors_dataset_op.cc"], - deps = [ - ":dataset", - "//tensorflow/core:dataset_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - tf_kernel_library( name = "stats_dataset_ops", srcs = ["stats_dataset_ops.cc"], @@ -402,6 +391,19 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "tensor_queue_dataset_op", + srcs = ["tensor_queue_dataset_op.cc"], + deps = [ + ":dataset", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:batch_util", + ], +) + tf_kernel_library( name = "tensor_slice_dataset_op", srcs = ["tensor_slice_dataset_op.cc"], @@ -518,7 +520,6 @@ tf_kernel_library( ":filter_dataset_op", ":flat_map_dataset_op", ":group_by_window_dataset_op", - ":ignore_errors_dataset_op", ":interleave_dataset_op", ":iterator_ops", ":map_and_batch_dataset_op", @@ -540,6 +541,7 @@ tf_kernel_library( ":stats_dataset_ops", ":take_dataset_op", ":tensor_dataset_op", + ":tensor_queue_dataset_op", ":tensor_slice_dataset_op", ":unique_dataset_op", ":zip_dataset_op", diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 2d6e06398f66c0b07ae17d4fd25d7ba6b5cfef03..7fa67efb9e22e6877b97524150b9024521619dbc 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -92,7 +92,6 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { } private: - class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) @@ -145,7 +144,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { const Tensor& first_element = batch_elements[0][component_index]; TensorShape batch_component_shape({num_batch_elements}); batch_component_shape.AppendShape(first_element.shape()); - Tensor batch_component(cpu_allocator(), first_element.dtype(), + Tensor batch_component(ctx->allocator({}), first_element.dtype(), batch_component_shape); // Build the output tuple component by copying one slice // from each input element in the batch. diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 1f6d32f8df39948a4529bdf53091ff742ba88edb..c4aa9ec26545a2792c1e741af69f61a292fcc216 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/notification.h" - namespace tensorflow { /* static */ @@ -33,7 +32,11 @@ Status CapturedFunction::Create( return Status::OK(); } -CapturedFunction::~CapturedFunction() {} +CapturedFunction::~CapturedFunction() { + if (lib_ != nullptr && f_handle_ != kInvalidHandle) { + lib_->ReleaseHandle(f_handle_).IgnoreError(); + } +} namespace { class CallFrameBase : public CallFrameInterface { @@ -185,8 +188,7 @@ Status CapturedFunction::MaybeInstantiate( return Status::OK(); } -Status CapturedFunction::Run(IteratorContext* ctx, - std::vector&& args, +Status CapturedFunction::Run(IteratorContext* ctx, std::vector&& args, std::vector* rets) { FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); diff --git a/tensorflow/core/kernels/data/dataset.h b/tensorflow/core/kernels/data/dataset.h index 2ef31ddfaaa2fd1bd6a4898726d788d1ceece82e..2c6fc8d5b4f607c026e683b3086ef0cf5e9e8e76 100644 --- a/tensorflow/core/kernels/data/dataset.h +++ b/tensorflow/core/kernels/data/dataset.h @@ -15,595 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ -#include - -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/framework/variant_encode_decode.h" -#include "tensorflow/core/framework/variant_tensor_data.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/tracing.h" - -// Polymorphic datasets should support all primitive TensorFlow -// types. Use this macro to expand `m(T)` once for each primitive type -// `T`, e.g. to build a `switch` statement. -#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) - -namespace tensorflow { - -// Interface for reading values from a key-value store. -// Used for restoring iterator state. -class IteratorStateReader { - public: - virtual Status ReadScalar(StringPiece key, int64* val) = 0; - virtual Status ReadScalar(StringPiece key, string* val) = 0; - virtual Status ReadTensor(StringPiece key, Tensor* val) = 0; - virtual bool Contains(StringPiece key) = 0; - - virtual ~IteratorStateReader() {} -}; - -// Interface for writing values to a key-value store. -// Used for saving iterator state. -class IteratorStateWriter { - public: - virtual Status WriteScalar(StringPiece key, const int64 val) = 0; - virtual Status WriteScalar(StringPiece key, const string& val) = 0; - virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0; - - virtual ~IteratorStateWriter() {} -}; - -// Forward declarations to avoid introducing a dependency on headers in -// "tensorflow/core/graph/...". -class GraphDefBuilder; -class GraphDatasetBase; -class Node; - -// Wrapper around GraphDefBuilder. Used to serialize Dataset graph. -class GraphDefBuilderWrapper { - public: - explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} - - // Adds a Const node with scalar value to the Graph. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - template - Status AddScalar(const T& val, Node** output) { - Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); - val_t.scalar()() = val; - AddTensorInternal(val_t, output); - if (*output == nullptr) { - return errors::Internal("AddScalar: Failed to build Const op."); - } - return Status::OK(); - } - - // Adds a Const node with vector value to the Graph. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? - template - Status AddVector(const std::vector& val, Node** output) { - Tensor val_t = Tensor(DataTypeToEnum::v(), - TensorShape({static_cast(val.size())})); - for (int i = 0; i < val.size(); i++) { - val_t.flat()(i) = val[i]; - } - AddTensorInternal(val_t, output); - if (*output == nullptr) { - return errors::Internal("AddVector: Failed to build Const op."); - } - return Status::OK(); - } - - // Adds a Const node with Tensor value to the Graph. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - Status AddTensor(const Tensor& val, Node** output) { - AddTensorInternal(val, output); - if (*output == nullptr) { - return errors::Internal("AddTensor: Failed to build Const op."); - } - return Status::OK(); - } - - Status AddDataset(const GraphDatasetBase* dataset, - const std::vector& inputs, Node** output) { - return AddDataset(dataset, inputs, {}, output); - } - - // Adds a node corresponding to the `DatasetType` to the Graph. - // Return value of `DatasetType::op_name()` is used as the op type for the - // node. - // Values for the output_types and output_shapes node attributes are also - // written if those attributes are defined in the OpDef. - // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. - Status AddDataset(const GraphDatasetBase* dataset, - const std::vector& inputs, - const std::vector>& attrs, - Node** output) { - std::vector> enumerated_inputs(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - enumerated_inputs[i] = std::make_pair(i, inputs[i]); - } - return AddDataset(dataset, enumerated_inputs, {}, attrs, output); - } - - Status AddDataset( - const GraphDatasetBase* dataset, - const std::vector>& inputs, - const std::vector>>& list_inputs, - const std::vector>& attrs, - Node** output); - - // Adds a user-defined function with name `function_name` to the graph and - // recursively adds all functions it references. If a function with a matching - // name has already been added, returns with OK status. If a user-defined with - // name `function_name` is not found in the FunctionLibraryDefinition, returns - // an InvalidArgumentError. If the function with name `function_name` or any - // of its dependent functions are stateful, returns an InvalidArgument error. - Status AddFunction(OpKernelContext* ctx, const string& function_name); - - template - void BuildAttrValue(const T& value, AttrValue* attr) { - SetAttrValue(value, attr); - } - - private: - void AddTensorInternal(const Tensor& val, Node** output); - - Status EnsureFunctionIsStateless(OpKernelContext* ctx, - const string& function_name) const { - const FunctionLibraryDefinition* lib_def = - ctx->function_library()->GetFunctionLibraryDefinition(); - const FunctionDef* function_def = lib_def->Find(function_name); - if (!function_def) { - return errors::InvalidArgument("Unable to find FunctionDef for ", - function_name, " in registry."); - } - for (const NodeDef& node_def : function_def->node_def()) { - const OpDef* op_def; - TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def)); - // TODO(b/65524810): Hack to allow functions to capture Dataset op - // nodes needed for FlatMap. Currently, source datasets nodes have been - // marked stateful to avoid constant folding since we do not have a - // good way of serializing them. - if (IsOpWhitelisted(op_def)) { - continue; - } - if (op_def->is_stateful()) { - return errors::InvalidArgument( - "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ", - "in function ", function_name, " is stateful. ", - "Saving stateful functions is not supported yet."); - } - } - return Status::OK(); - } - - // Returns whether an op has been whitelisted for use inside map_fns. - // Uses a heuristic to whitelist source dataset ops which have been - // marked stateful due to b/65524810. - // Also looks up the `op_def->name` in the global - // `WhitelistedStatefulOpRegistry`. - bool IsOpWhitelisted(const OpDef* op_def) const { - return (StringPiece(op_def->name()).ends_with("Dataset") && - op_def->output_arg_size() == 1 && - op_def->output_arg(0).type() == DT_VARIANT) || - dataset::WhitelistedStatefulOpRegistry::Global()->Contains( - op_def->name()); - } - - bool HasAttr(const string& op_type_name, const string& attr_name) const; - - bool HasAttr(const OpDef* op_def, const string& attr_name) const { - for (auto attr : op_def->attr()) { - if (attr.name() == attr_name) { - return true; - } - } - return false; - } - - Status AddAttrFunctions(const AttrValue& attr_value, OpKernelContext* ctx) { - if (attr_value.has_func()) { - TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); - } else if (attr_value.has_list()) { - for (const NameAttrList& name_attr_list : attr_value.list().func()) { - TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); - } - } - return Status::OK(); - } - - GraphDefBuilder* b_; -}; - -class StatsAggregator; - -// A cut-down version of OpKernelContext for running computations in -// iterators. Note that we cannot simply use OpKernelContext here -// because we might run computation in an iterator whose lifetime is -// not nested within the lifetime of a single OpKernelContext -// (e.g. asynchronous prefetching). -// -// TODO(mrry): We will probably need to support more of -// OpKernelContext here. For example, should allocation be handled by -// the IteratorContext? -// TODO(mrry): We're making some daring assumptions about the lifetime -// of the runner passed in here. A runner will be deleted when the original -// step ends, but all existing runners only close over session-lifetime (or -// longer-lived) state, so we can make a copy of the function. There's nothing -// in the definition of the API from which we took the runner to guarantee that -// what we are doing is safe. We should formalize the properties here. -class IteratorContext { - public: - struct Params { - // Interface to operating system functionality. - Env* env; - - // Function call support. - std::function)> runner = nullptr; - - // A function that returns the current `StatsAggregator` instance to be - // used when recording statistics about the iterator. - // - // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator` - // is a property of the `IteratorResource` (which this class does not know - // about), and (ii) it can change after the `IteratorContext` has been - // created. Better suggestions are welcome! - std::function()> stats_aggregator_getter = - nullptr; - - // The FunctionLibraryRuntime object to be used to make function calls. - FunctionLibraryRuntime* lib = nullptr; - std::shared_ptr function_library = nullptr; - }; - - explicit IteratorContext(Params params) : params_(std::move(params)) {} - - Env* env() const { return params_.env; } - - std::function)>* runner() { - return ¶ms_.runner; - } - - std::shared_ptr stats_aggregator() { - if (params_.stats_aggregator_getter) { - return params_.stats_aggregator_getter(); - } else { - return nullptr; - } - } - - std::shared_ptr function_library() { - return params_.function_library; - } - - FunctionLibraryRuntime* lib() { return params_.lib; } - - void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; } - - private: - Params params_; -}; - -// Represents the current position in a range of outputs, where the -// range of outputs is typically represented by an `DatasetBase`, -// defined below. -class IteratorBase { - public: - virtual ~IteratorBase() {} - - // Gets the next output from the range that this iterator is traversing. - // - // If at least one output remains in this iterator's range, that - // output will be stored in `*out_tensors` and `false` will be - // stored in `*end_of_sequence`. - // - // If no more outputs remain in this iterator's range, `true` will - // be stored in `*end_of_sequence`, and the content of - // `*out_tensors` will be undefined. - // - // This method is thread-safe. - // - // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and - // potentially remove this method. - virtual Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) = 0; - - // Returns a vector of DataType values, representing the respective - // element types of each tuple component in the outputs of this - // iterator. - virtual const DataTypeVector& output_dtypes() const = 0; - - // Returns a vector of tensor shapes, representing the respective - // (and possibly partially defined) shapes of each tuple component - // in the outputs of this iterator. - virtual const std::vector& output_shapes() const = 0; - - // Saves the state of this iterator. - virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { - return SaveInternal(writer); - } - - // Restores the state of this iterator. - virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) { - return RestoreInternal(ctx, reader); - } - - protected: - // This is needed so that sub-classes of IteratorBase can call - // `SaveInternal` on their parent iterators, e.g., in - // `RepeatDataasetOp::Dataset`. - Status SaveParent(IteratorStateWriter* writer, - const std::unique_ptr& parent) { - return parent->SaveInternal(writer); - } - - // This is needed so that sub-classes of IteratorBase can call - // `RestoreInternal` on their parent iterators, e.g., in - // `RepeatDataasetOp::Dataset`. - Status RestoreParent(IteratorContext* ctx, IteratorStateReader* reader, - const std::unique_ptr& parent) { - return parent->RestoreInternal(ctx, reader); - } - - // Saves the state of this iterator recursively. - virtual Status SaveInternal(IteratorStateWriter* writer) { - return errors::Unimplemented("SaveInternal"); - } - - // Restores the state of this iterator recursively. - virtual Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) { - return errors::Unimplemented("RestoreInternal"); - } -}; - -// Represents a (potentially infinite) range of outputs, where each -// output is a tuple of tensors. -class DatasetBase : public core::RefCounted { - public: - // Returns a new iterator for iterating over the range of elements in - // this dataset. - // - // This method may be called multiple times on the same instance, - // and the resulting iterators will have distinct state. Each - // iterator will traverse all elements in this dataset from the - // start. - // - // Ownership of the created iterator will be transferred to the caller. - // - // The prefix identifies the sequence of iterators leading up to the newly - // created iterator. - virtual std::unique_ptr MakeIterator( - const string& prefix) const = 0; - - // Returns a vector of DataType values, representing the respective - // element types of each tuple component in the outputs of this - // dataset. - virtual const DataTypeVector& output_dtypes() const = 0; - - // Returns a vector of tensor shapes, representing the respective - // (and possibly partially defined) shapes of each tuple component - // in the outputs of this dataset. - virtual const std::vector& output_shapes() const = 0; - - // A human-readable debug string for this dataset. - virtual string DebugString() = 0; - - // Serializes the dataset and writes it to the `writer`. - virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { - return errors::Unimplemented("DatasetBase::Save"); - } - - protected: - // TODO(srbs): Ideally all graph related logic should reside in - // GraphDatasetBase. However, that would require Datasets defined in all ops - // to derive from GraphDatasetBase. Once that is done we can move - // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase. - class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { - public: - DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} - Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset, - Node** output) { - return dataset->AsGraphDefInternal(ctx, this, output); - } - }; - - virtual Status AsGraphDefInternal(OpKernelContext* ctx, - DatasetGraphDefBuilder* b, - Node** node) const { - return AsGraphDefInternal(b, node); - } - - virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, - Node** node) const { - return errors::Unimplemented("AsGraphDefInternal"); - } -}; - -// Base-class for datasets that are built by ops. -class GraphDatasetBase : public DatasetBase { - public: - GraphDatasetBase(OpKernelContext* ctx) - : op_name_(ctx->op_kernel().type_string()) {} - - const string op_name() const { return op_name_; } - - Status Save(OpKernelContext* ctx, - IteratorStateWriter* writer) const override { - string serialized_graph_def; - string output_node; - TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); - return Status::OK(); - } - - // Key for storing the Dataset graph in the serialized format. - static const char kDatasetGraphKey[]; - - // Key for storing the output node of the Dataset graph in the serialized - // format. - static const char kDatasetGraphOutputNodeKey[]; - - private: - Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, - string* output_node) const; - - const string op_name_; -}; - -// Represents an iterator that is associated with a particular parent dataset. -template -class DatasetIterator : public IteratorBase { - public: - struct Params { - // Owns one reference on the shared dataset resource. - const DatasetType* dataset; - - // Identifies the sequence of iterators leading up to this iterator. - const string prefix; - }; - - explicit DatasetIterator(const Params& params) : params_(params) { - params_.dataset->Ref(); - } - - ~DatasetIterator() override { params_.dataset->Unref(); } - - // The dataset from which this iterator was created. - const DatasetType* dataset() const { return params_.dataset; } - - // The sequence of iterators leading up to this iterator. - const string prefix() const { return params_.prefix; } - - const DataTypeVector& output_dtypes() const override { - return params_.dataset->output_dtypes(); - } - - const std::vector& output_shapes() const override { - return params_.dataset->output_shapes(); - } - - Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) final { - port::Tracing::TraceMe activity(params_.prefix); - Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); - if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { - s = errors::Internal( - "Iterator \"", params_.prefix, - "\" returned OutOfRange without setting `*end_of_sequence`. This " - "indicates that an error may have occurred. Original message: ", - s.error_message()); - LOG(ERROR) << s; - } - return s; - } - - Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final { - TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer)); - return IteratorBase::Save(ctx, writer); - } - - protected: - // Internal implementation of GetNext that is wrapped in tracing logic. - virtual Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) = 0; - - string full_name(const string& name) const { - return strings::StrCat(prefix(), ":", name); - } - - private: - Params params_; -}; - -// Encapsulates the work required to plug a DatasetBase into the core TensorFlow -// graph execution engine. -class DatasetOpKernel : public OpKernel { - public: - DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} - void Compute(OpKernelContext* ctx) final; - - protected: - // Subclasses should implement this method. It will be called during Compute - // execution. - virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; - - template - Status ParseScalarArgument(OpKernelContext* ctx, - const StringPiece& argument_name, T* output) { - const Tensor* argument_t; - TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); - if (!TensorShapeUtils::IsScalar(argument_t->shape())) { - return errors::InvalidArgument(argument_name, " must be a scalar"); - } - *output = argument_t->scalar()(); - return Status::OK(); - } -}; - -// Encapsulates the work required to plug unary Datasets into the core -// TensorFlow graph execution engine. -class UnaryDatasetOpKernel : public DatasetOpKernel { - public: - UnaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} - - protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; - virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) = 0; -}; - -// Encapsulates the work required to plug binary Datasets into the core -// TensorFlow graph execution engine. -class BinaryDatasetOpKernel : public DatasetOpKernel { - public: - BinaryDatasetOpKernel(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} - - protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; - virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase* another_input, - DatasetBase** output) = 0; -}; - -// Validates and extracts a `DatasetBase` object from `tensor`. -// -// `tensor` must have been written by a call to SetVariantTensorToDataset(). -// -// The retrieved pointer is a borrowed reference to the dataset, which is owned -// by the tensor. The consumer must either acquire its own reference to the -// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not -// destroyed or mutated while the retrieved pointer is in use. -Status GetDatasetFromVariantTensor(const Tensor& tensor, - DatasetBase** out_dataset); - -// Stores a `DatasetBase` object in `tensor`. -// -// The ownership of `dataset` is transferred to `tensor`. -Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); - -} // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_ diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index e7224bb547f60f943c7c91c37edfbbf561f5351a..132808a5f140a31fc3c1852cb83e5cd8579b6d95 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -155,7 +155,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { // Determine the size of the output tensors: // * dense_shape will be [`row_shape + 1`]. - Tensor dense_shape(cpu_allocator(), DT_INT64, {row_ndims + 1}); + Tensor dense_shape(ctx->allocator({}), DT_INT64, {row_ndims + 1}); auto dense_shape_vec = dense_shape.vec(); for (size_t i = 0; i < row_ndims; ++i) { if (row_shape.dim_size(i) == -1) { @@ -215,10 +215,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { // * indices will be [`total_elements`, `row_shape + 1`]. // * values will be [`total_elements`]. - Tensor indices(cpu_allocator(), DT_INT64, + Tensor indices(ctx->allocator({}), DT_INT64, {total_elements, row_ndims + 1}); Tensor values( - cpu_allocator(), + ctx->allocator({}), DatasetIterator>::dataset()->input_->output_dtypes()[0], {total_elements}); auto indices_matrix = indices.matrix(); diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index eb047e10ecf738c90c18b9fea25f1b49fdf441c4..834c06bb930d1c723c5b3f880dcc13a892bb44f7 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -510,10 +509,6 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - // A resource name for the temporary window dataset that is - // created as the input to the reduce function. - static constexpr const char* kWindowResourceName = "__window_dataset"; - const DatasetBase* const input_; const NameAttrList key_func_; const NameAttrList reduce_func_; @@ -537,5 +532,4 @@ REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU), GroupByWindowDatasetOp); } // namespace - } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 56044a3d41a9f8f2af3c3a72344845e3a59151af..d7d4ad5cf7f6d5a3386be524c7a227006da0b3f4 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/common_runtime/threadpool_device.h" #include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -82,7 +83,7 @@ class IteratorResource : public ResourceBase { public: IteratorResource(const DataTypeVector& output_dtypes, const std::vector& output_shapes, - const int graph_def_version, + const int /*unused: graph_def_version*/, std::unique_ptr device_mgr, std::unique_ptr flib_def, std::unique_ptr pflr, @@ -93,8 +94,7 @@ class IteratorResource : public ResourceBase { lib_(lib), iterator_(nullptr), output_dtypes_(output_dtypes), - output_shapes_(output_shapes), - graph_def_version_(graph_def_version) {} + output_shapes_(output_shapes) {} Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) { @@ -160,6 +160,10 @@ class IteratorResource : public ResourceBase { params.runner = *(ctx->runner()); params.function_library = flib_def; params.lib = lib_; + DeviceBase* device = lib_->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; IteratorContext iter_ctx(std::move(params)); TF_RETURN_IF_ERROR(captured_iterator->Restore(&iter_ctx, reader)); @@ -223,7 +227,6 @@ class IteratorResource : public ResourceBase { std::shared_ptr lib_def_ GUARDED_BY(mu_); const DataTypeVector output_dtypes_; const std::vector output_shapes_; - const int graph_def_version_; }; // Helper class for reading data from a VariantTensorData object. @@ -430,13 +433,10 @@ class IteratorStateVariant { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, kIteratorVariantTypeName); -// TODO(mrry): Can we simply use the template kernel here? class IteratorHandleOp : public OpKernel { public: explicit IteratorHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { - OP_REQUIRES_OK(ctx, ctx->allocate_persistent(DT_STRING, TensorShape({2}), - &handle_, nullptr)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); @@ -460,56 +460,54 @@ class IteratorHandleOp : public OpKernel { } void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - FunctionLibraryRuntime* lib = context->function_library(); - std::unique_ptr device_mgr(nullptr); - std::unique_ptr flib_def(nullptr); - std::unique_ptr pflr(nullptr); - // If the iterator is shared then we construct a new FLR, and pass that in. - // NOTE(mrry,rohanj): In this case it is not possible to call remote - // functions from the iterator. We may add this functionality if there - // is sufficient demand, but it will require a significant refactoring. - if (!name_.empty()) { - lib = CreateFLR(context, &device_mgr, &flib_def, &pflr); - } + { + mutex_lock l(mu_); + if (resource_ == nullptr) { + FunctionLibraryRuntime* lib; + std::unique_ptr device_mgr(nullptr); + std::unique_ptr flib_def(nullptr); + std::unique_ptr pflr(nullptr); + // If the iterator is shared then we construct a new FLR, and pass that + // in. NOTE(mrry,rohanj): In this case it is not possible to call remote + // functions from the iterator. We may add this functionality if there + // is sufficient demand, but it will require a significant refactoring. + if (!name_.empty()) { + lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr); + } else { + OP_REQUIRES_OK(context, context->function_library()->Clone( + &flib_def, &pflr, &lib)); + } - if (resource_ == nullptr) { - ResourceMgr* mgr = context->resource_manager(); - OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + + IteratorResource* resource; + OP_REQUIRES_OK( + context, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [lib, &device_mgr, &flib_def, &pflr, + this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new IteratorResource( + output_dtypes_, output_shapes_, graph_def_version_, + std::move(device_mgr), std::move(flib_def), + std::move(pflr), lib); + return Status::OK(); + })); + + Status s = VerifyResource(resource); + if (TF_PREDICT_FALSE(!s.ok())) { + resource->Unref(); + context->SetStatus(s); + return; + } - IteratorResource* resource; - OP_REQUIRES_OK( - context, - mgr->LookupOrCreate( - cinfo_.container(), cinfo_.name(), &resource, - [lib, &device_mgr, &flib_def, &pflr, this](IteratorResource** ret) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new IteratorResource( - output_dtypes_, output_shapes_, graph_def_version_, - std::move(device_mgr), std::move(flib_def), - std::move(pflr), lib); - return Status::OK(); - })); - - Status s = VerifyResource(resource); - if (TF_PREDICT_FALSE(!s.ok())) { - resource->Unref(); - context->SetStatus(s); - return; + resource_ = resource; } - - auto h = handle_.AccessTensor(context)->template flat(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); - resource_ = resource; - } - if (context->expected_output_dtype(0) == DT_RESOURCE) { - OP_REQUIRES_OK(context, MakeResourceHandleToOutput( - context, 0, cinfo_.container(), cinfo_.name(), - MakeTypeIndex())); - } else { - context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); } + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); } private: @@ -526,15 +524,32 @@ class IteratorHandleOp : public OpKernel { return Status::OK(); } - FunctionLibraryRuntime* CreateFLR( + template // use like this: down_cast(foo); + static inline To down_cast(From* f) { // so we only accept pointers + static_assert( + (std::is_base_of::type>::value), + "target type not derived from source type"); + + // We skip the assert and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. + assert(f == nullptr || dynamic_cast(f) != nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + return static_cast(f); + } + + FunctionLibraryRuntime* CreatePrivateFLR( OpKernelContext* ctx, std::unique_ptr* device_mgr, std::unique_ptr* flib_def, std::unique_ptr* pflr) { - Device* device = new ThreadPoolDevice( - SessionOptions(), ctx->device()->attributes().name(), Bytes(256 << 20), - DeviceLocality(), cpu_allocator()); - - device_mgr->reset(new DeviceMgr({device})); + // Wrap the existing device in order to see any captured resources + // in its resource manager. The existing device will outlive the + // IteratorResource, because we are storing the IteratorResource + // in that device's resource manager. + Device* wrapped_device = RenamedDevice::NewRenamedDevice( + ctx->device()->name(), down_cast(ctx->device()), + false /* owns_underlying */, false /* isolate_session_state */); + device_mgr->reset(new DeviceMgr({wrapped_device})); flib_def->reset(new FunctionLibraryDefinition( *ctx->function_library()->GetFunctionLibraryDefinition())); pflr->reset(new ProcessFunctionLibraryRuntime( @@ -542,13 +557,12 @@ class IteratorHandleOp : public OpKernel { {} /* TODO(mrry): OptimizerOptions? */, nullptr /* TODO(mrry): ClusterFLR */)); - return (*pflr)->GetFLR(device->name()); + return (*pflr)->GetFLR(ctx->device()->name()); } mutex mu_; - ContainerInfo cinfo_ GUARDED_BY(mu_); + ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. IteratorResource* resource_ GUARDED_BY(mu_) = nullptr; - PersistentTensor handle_ GUARDED_BY(mu_); DataTypeVector output_dtypes_; std::vector output_shapes_; const int graph_def_version_; @@ -595,6 +609,11 @@ class ToSingleElementOp : public AsyncOpKernel { params.env = ctx->env(); params.runner = *(ctx->runner()); params.lib = ctx->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + IteratorContext iter_ctx(std::move(params)); std::vector components; @@ -715,18 +734,23 @@ class OneShotIteratorOp : public AsyncOpKernel { Status TryInit(OpKernelContext* ctx, IteratorResource** iterator, ContainerInfo* cinfo) { TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def())); - FunctionLibraryRuntime* lib = ctx->function_library(); + + FunctionLibraryRuntime* lib; + std::unique_ptr flib_def(nullptr); + std::unique_ptr pflr(nullptr); + TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def, &pflr, &lib)); // Create an IteratorResource that will hold the iterator for this op. TF_RETURN_IF_ERROR( ctx->resource_manager()->LookupOrCreate( cinfo->container(), cinfo->name(), iterator, - [lib, this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new IteratorResource(output_dtypes_, output_shapes_, - graph_def_version_, nullptr, nullptr, - nullptr, lib); - return Status::OK(); - })); + [lib, this, &flib_def, &pflr](IteratorResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new IteratorResource( + output_dtypes_, output_shapes_, graph_def_version_, + nullptr, std::move(flib_def), std::move(pflr), lib); + return Status::OK(); + })); core::ScopedUnref unref_iterator(*iterator); @@ -848,6 +872,10 @@ class IteratorGetNextOp : public AsyncOpKernel { }; params.runner = *(ctx->runner()); params.function_library = iterator->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; IteratorContext iter_ctx(std::move(params)); OP_REQUIRES_OK_ASYNC( @@ -890,6 +918,10 @@ class IteratorGetNextSyncOp : public OpKernel { }; params.runner = *(ctx->runner()); params.function_library = iterator->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; IteratorContext iter_ctx(std::move(params)); OP_REQUIRES_OK(ctx, diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index c529f671f2bb7fd3eb5277c23867e25ba70fd046..9ce263732f6e6c907dfdc89692455daa5dca86d1 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -183,7 +183,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { TensorShape component_shape( batch_results_[current_batch_index_].output[i].shape()); component_shape.set_dim(0, num_elements); - Tensor component(cpu_allocator(), output[i].dtype(), + Tensor component(ctx->allocator({}), output[i].dtype(), component_shape); TF_RETURN_IF_ERROR( CopyPartialBatch(&component, output[i], num_elements)); @@ -244,7 +244,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - void EnsureOutputAllocated(BatchResult* batch_result, + void EnsureOutputAllocated(IteratorContext* ctx, + BatchResult* batch_result, const std::vector& return_values) { mutex_lock l(batch_result->mu); if (batch_result->output_allocated) { @@ -254,7 +255,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { for (size_t i = 0; i < num_components; ++i) { TensorShape component_shape({dataset()->batch_size_}); component_shape.AppendShape(return_values[i].shape()); - Tensor component(cpu_allocator(), return_values[i].dtype(), + Tensor component(ctx->allocator({}), return_values[i].dtype(), component_shape); batch_result->output.emplace_back(std::move(component)); } @@ -285,10 +286,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { dataset()->captured_func_->RunAsync( ctx, std::move(input_element), &result->return_values, [this, ctx, result, batch_result, offset](Status ret_status) { - delete ctx; result->status.Update(ret_status); if (ret_status.ok()) { - EnsureOutputAllocated(batch_result, + EnsureOutputAllocated(ctx, batch_result, result->return_values); const size_t num_components = result->return_values.size(); @@ -318,6 +318,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } } } + delete ctx; // NOTE(mrry): We clear the return values here to release // any memory associated with them and to paralellize the // destruction of the tensors (which can be surprisingly diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index 346eca0bb2ab1c7a82ddba98063c0ccb71b4e58f..cfb4efda9a56fde04994201f509cf3d9fb45ea82 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/kernels/batch_util.h" #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { @@ -24,102 +25,6 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -// The following five functions are copied from padding_fifo_queue.cc. -// TODO(mrry): Reconcile these functions with the similar methods in the -// queue implementation. -Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) { - DCHECK_NE(parent->dim_size(0), 0); - if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { - TensorShape chip_shape = parent->shape(); - chip_shape.RemoveDim(0); - return errors::Internal( - "HandleElementToLargerSlice Cannot copy slice: number of entries in " - "element is greater than number of elements in parent slice. ", - "Shapes are: [element]: ", element.shape().DebugString(), - ", [parent slice]: ", chip_shape.DebugString()); - } - return Status::OK(); -} - -template -Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, - int index) { - TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent)); - if (element.NumElements() == 0) { - return Status::OK(); - } - auto element_t = element.tensor(); - auto parent_t = parent->tensor(); - Eigen::DSizes slice_indices; - slice_indices[0] = index; - Eigen::DSizes slice_size; - slice_size[0] = 1; - for (size_t i = 1; i < slice_size.size(); ++i) { - slice_size[i] = element_t.dimension(i - 1); - } - parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); - return Status::OK(); -} - -template -Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, - int index) { -#define HANDLE_TYPE(T) \ - case DataTypeToEnum::value: { \ - return HandleElementToLargerSlice(element, parent, index); \ - } - - switch (element.dtype()) { - TF_CALL_DATASET_TYPES(HANDLE_TYPE); -#undef HANDLE_TYPE - default: - return errors::Unimplemented( - "HandleElementToLargerSliceWithRank Unhandled data type: ", - element.dtype()); - } -} - -Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, - int index) { - if (parent->dims() != element.dims() + 1) { - return errors::Internal( - "Mismatched ranks. Element's rank is: ", element.dims(), - " but element is meant to be a slice in output Tensor having rank: ", - parent->dims(), " (should be: ", element.dims() + 1, ")"); - } - -#define HANDLE_DIMS(NDIMS) \ - case NDIMS: { \ - TF_RETURN_IF_ERROR( \ - HandleElementToLargerSliceWithRank(element, parent, index)); \ - return Status::OK(); \ - } - - switch (element.dims()) { - HANDLE_DIMS(0); - HANDLE_DIMS(1); - HANDLE_DIMS(2); - HANDLE_DIMS(3); - HANDLE_DIMS(4); -#undef HANDLE_DIMS - default: - return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ", - element.dims()); - } -} - -Status SetElementZero(Tensor* element, const Tensor& padding) { -#define HANDLE_TYPE(T) \ - if (element->dtype() == DataTypeToEnum::value) { \ - element->flat().setConstant(padding.scalar()()); \ - return Status::OK(); \ - } - TF_CALL_DATASET_TYPES(HANDLE_TYPE); -#undef HANDLE_TYPE - return errors::Unimplemented("SetElementZero Unhandled data type: ", - element->dtype()); -} - class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx) @@ -376,20 +281,27 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { // 2. Copy each batch element to the appropriate location in // the output component tensor. - Tensor batch_component(cpu_allocator(), + Tensor batch_component(ctx->allocator({}), output_dtypes()[component_index], batch_component_shape); - TF_RETURN_IF_ERROR(SetElementZero( + TF_RETURN_IF_ERROR(batch_util::SetElementZero( &batch_component, dataset()->padding_values_[component_index])); // Build the output tuple component by copying one slice // from each input element in the batch. + TensorShape component_shape({}); + for (int i = 1; i < batch_component_shape.dims(); ++i) { + component_shape.AddDim(batch_component_shape.dim_size(i)); + } for (int64 i = 0; i < num_batch_elements; ++i) { - TF_RETURN_IF_ERROR(ValidateElementToLargerSlice( - batch_elements[i][component_index], &batch_component)); - - TF_RETURN_IF_ERROR(CopyElementToLargerSlice( - batch_elements[i][component_index], &batch_component, i)); + // Take the fast path if possible. + if (batch_elements[i][component_index].shape() == component_shape) { + TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( + batch_elements[i][component_index], &batch_component, i)); + } else { + TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice( + batch_elements[i][component_index], &batch_component, i)); + } } out_tensors->push_back(std::move(batch_component)); } diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc index bc638864b0147f4d71b3382ea320453e972ba8d7..210b9ad1b84eeb0c106b0ee538b4957aba7ce1b2 100644 --- a/tensorflow/core/kernels/data/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/random_dataset_op.cc @@ -99,7 +99,7 @@ class RandomDatasetOp : public DatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - Tensor value_tensor(cpu_allocator(), DT_INT64, {}); + Tensor value_tensor(ctx->allocator({}), DT_INT64, {}); value_tensor.scalar()() = Random(); out_tensors->emplace_back(std::move(value_tensor)); *end_of_sequence = false; diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index d0bc61acd99afae14ddc8a3e678acb4197fcea71..b57518e678ed185a183e0413d6e90f2a9f85e9fc 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -100,7 +100,7 @@ class RangeDatasetOp : public DatasetOpKernel { *end_of_sequence = true; return Status::OK(); } - Tensor value_tensor(cpu_allocator(), DT_INT64, {}); + Tensor value_tensor(ctx->allocator({}), DT_INT64, {}); value_tensor.scalar()() = next_; out_tensors->emplace_back(std::move(value_tensor)); *end_of_sequence = false; diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc index aa39fffc2e344db8143b700cbba4c29bdb134964..34d7d9f914d7a726135febabb1fbe35b0146977c 100644 --- a/tensorflow/core/kernels/data/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc @@ -141,7 +141,7 @@ class TextLineDatasetOp : public DatasetOpKernel { if (s.ok()) { // Produce the line as output. - Tensor line_tensor(cpu_allocator(), DT_STRING, {}); + Tensor line_tensor(ctx->allocator({}), DT_STRING, {}); line_tensor.scalar()() = line_contents; out_tensors->emplace_back(std::move(line_tensor)); *end_of_sequence = false; @@ -384,7 +384,7 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR( input_buffer_->ReadNBytes(dataset()->record_bytes_, &record)); // Produce the record as output. - Tensor record_tensor(cpu_allocator(), DT_STRING, {}); + Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); record_tensor.scalar()() = record; out_tensors->emplace_back(std::move(record_tensor)); *end_of_sequence = false; @@ -589,7 +589,7 @@ class TFRecordDatasetOp : public DatasetOpKernel { do { // We are currently processing a file, so try to read the next record. if (reader_) { - Tensor result_tensor(cpu_allocator(), DT_STRING, {}); + Tensor result_tensor(ctx->allocator({}), DT_STRING, {}); Status s = reader_->ReadRecord(&result_tensor.scalar()()); if (s.ok()) { out_tensors->emplace_back(std::move(result_tensor)); diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 1cb533158bb5b8bd4b950192ce67e17c0f9d5447..d37086541dc4714162e00cc6d022b3bd300e3a1c 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -187,12 +187,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { } else { input_impl_.reset(); if (first_call) { - // If the first call to GetNext() fails because the end of - // sequence has been reached, we return an OutOfRange error to - // terminate the iteration. (Otherwise, this iterator would loop - // infinitely and never produce a value.) - return errors::OutOfRange( - "Attempted to repeat an empty dataset infinitely."); + // If the first call to GetNext() fails because the end + // of sequence has been reached, we terminate the + // iteration immediately. (Otherwise, this iterator + // would loop infinitely and never produce a value.) + return Status::OK(); } } } while (true); diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 1dde236c1711afd794ff397859631a48984b5ba8..2f6bf83da5d4f1d4b431e6849fd6571f56539dfe 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -104,13 +104,12 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { break; } if (first_call && dataset()->count_ == -1) { - // If the first call to GetNext() fails because the end of - // sequence has been reached, we return an OutOfRange error to - // terminate the iteration. (Otherwise, this iterator may loop - // infinitely and never produce a value.) + // If the first call to GetNext() fails because the end + // of sequence has been reached, we terminate the + // iteration immediately. (Otherwise, this iterator + // would loop infinitely and never produce a value.) *end_of_sequence = true; - return errors::OutOfRange( - "Attempted to repeat an empty dataset infinitely."); + return Status::OK(); } epoch_++; int64 n = slices_.back()->end; diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index 13c2501bbbd43bdb6c3c521db4c3830934ee91db..d636c37afe2aa0566df7d4a38a8d393c34fd0195 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -128,8 +128,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { while (i_ < dataset()->count_) { // Fetch and throw away Tensors. std::vector dummy_out_tensors; - TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &dummy_out_tensors, - end_of_sequence)); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &dummy_out_tensors, end_of_sequence)); if (*end_of_sequence) { // We reached the end before the count was reached. input_impl_.reset(); @@ -140,8 +140,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { } // Return GetNext() on the underlying iterator. - TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, - end_of_sequence)); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); if (*end_of_sequence) { input_impl_.reset(); } @@ -184,8 +184,7 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { }; }; -REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), - SkipDatasetOp); +REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp); } // namespace diff --git a/tensorflow/core/kernels/data/sql/BUILD b/tensorflow/core/kernels/data/sql/BUILD index 0286825af3ef7c04fff6911ddf7daec76479a715..f4698bdaf7ae9767e068e49dad61d2a3d9f739a8 100644 --- a/tensorflow/core/kernels/data/sql/BUILD +++ b/tensorflow/core/kernels/data/sql/BUILD @@ -33,6 +33,7 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/kernels/data:dataset", "//tensorflow/core/lib/db:sqlite", ], ) diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h index f31017bd1981c3809d9b7daaa2dc56256d19d914..e9ffca202ff32f0c0130427c2699ce0449a0903a 100644 --- a/tensorflow/core/kernels/data/sql/query_connection.h +++ b/tensorflow/core/kernels/data/sql/query_connection.h @@ -19,6 +19,8 @@ limitations under the License. namespace tensorflow { +class IteratorContext; + namespace sql { // This interface allows a user to connect to a database, execute a query, and // iterate over the result set, putting the results into an output tensor. @@ -56,7 +58,7 @@ class QueryConnection { // If there are no more rows in the result set, then instead `true` will be // stored in `*end_of_sequence`, and the content of `*out_tensors` will be // undefined. - virtual Status GetNext(std::vector* out_tensors, + virtual Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) = 0; }; diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc index 029a0aab97290e30783e415274323a1e43f9740b..7cd07bd8eca160bfc62e15adc568742c84711779 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { @@ -48,14 +49,16 @@ Status SqliteQueryConnection::Close() { return Status::OK(); } -Status SqliteQueryConnection::GetNext(std::vector* out_tensors, +Status SqliteQueryConnection::GetNext(IteratorContext* ctx, + std::vector* out_tensors, bool* end_of_sequence) { if (!stmt_) TF_RETURN_IF_ERROR(PrepareQuery()); TF_RETURN_IF_ERROR(stmt_.Step(end_of_sequence)); if (!*end_of_sequence) { for (int i = 0; i < column_count_; i++) { DataType dt = output_types_[i]; - Tensor tensor(cpu_allocator(), dt, {}); + // TODO(mrry): Pass in the `IteratorContext::allocator()`. + Tensor tensor(ctx->allocator({}), dt, {}); FillTensorWithResultSetEntry(dt, i, &tensor); out_tensors->emplace_back(std::move(tensor)); } diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h index 787c17d6c00d99afad3d7814c3c2daaf4295b1b3..81b19530b7d5964e17bde996de9fa7766af318b7 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h @@ -32,7 +32,7 @@ class SqliteQueryConnection : public QueryConnection { Status Open(const string& data_source_name, const string& query, const DataTypeVector& output_types) override; Status Close() override; - Status GetNext(std::vector* out_tensors, + Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override; private: diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc index 72302190802d17f2cb1ed5471017180238aedff3..d50e9c9cf9739044379c7bbe753fc4acc2de311e 100644 --- a/tensorflow/core/kernels/data/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc @@ -116,7 +116,7 @@ class SqlDatasetOp : public DatasetOpKernel { } } - Status GetNextInternal(IteratorContext* /*ctx*/, + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); @@ -132,7 +132,7 @@ class SqlDatasetOp : public DatasetOpKernel { return s; } } - return query_connection_->GetNext(out_tensors, end_of_sequence); + return query_connection_->GetNext(ctx, out_tensors, end_of_sequence); } private: diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff412a4671bd0307e4975027ebd1e098353de238 --- /dev/null +++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc @@ -0,0 +1,646 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/kernels/batch_util.h" +#include "tensorflow/core/kernels/data/dataset.h" + +namespace tensorflow { + +namespace { + +bool IsGreaterEqualToOrCompatibleWith(const PartialTensorShape& a, + const PartialTensorShape& b) { + // Returns true if dims[a] >= dims[b], or are compatible. + if (a.unknown_rank()) return true; + if (a.dims() != b.dims()) return false; + for (int d = 0; d < a.dims(); ++d) { + if (a.dim_size(d) == -1 || b.dim_size(d) == -1) continue; + if (a.dim_size(d) < b.dim_size(d)) return false; + } + return true; +} + +DataTypeVector PrependQueueType(const DataTypeVector& dtypes) { + DataTypeVector out; + out.reserve(dtypes.size() + 1); + out.push_back(DT_VARIANT); // The queue component. + for (const DataType& d : dtypes) out.push_back(d); + return out; +} + +std::vector PrependQueueShapeWithBatch( + const std::vector& shapes) { + std::vector out; + out.reserve(shapes.size() + 1); + out.emplace_back(PartialTensorShape({-1})); // The queue component. + for (PartialTensorShape s : shapes) { + s.InsertDim(0, -1); // Unknown batch size. + out.push_back(std::move(s)); + } + return out; +} + +class EnqueueInQueueDatasetOp; + +class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase { + public: + PrependFromQueueAndPaddedBatchDataset( + OpKernelContext* ctx, const int64 batch_size, const DatasetBase* input, + const DataTypeVector& dtypes, + const std::vector& shapes, + std::vector padding_values) + : GraphDatasetBase(ctx), + batch_size_(batch_size), + input_(input), + dtypes_(dtypes), + shapes_(shapes), + padding_values_(std::move(padding_values)), + dtypes_with_queue_(PrependQueueType(dtypes)), + batched_shapes_with_queue_(PrependQueueShapeWithBatch(shapes)) { + input_->Ref(); + } + + ~PrependFromQueueAndPaddedBatchDataset() override { input_->Unref(); } + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::PrependFromQueueAndPaddedBatch")})); + } + + const DataTypeVector& output_dtypes() const override { + return dtypes_with_queue_; + } + const std::vector& output_shapes() const override { + return batched_shapes_with_queue_; + } + + string DebugString() override { + return "PrependFromQueueAndPaddedBatchDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph)); + Node* batch_size = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); + + std::vector padded_shapes; + padded_shapes.reserve(shapes_.size()); + for (int i = 0; i < shapes_.size(); i++) { + Node* node; + Tensor t(DT_INT64, TensorShape({shapes_[i].dims()})); + for (int j = 0; j < shapes_[i].dims(); j++) { + t.vec()(j) = shapes_[i].dim_size(j); + } + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + padded_shapes.emplace_back(node); + } + + std::vector padding_values; + padding_values.reserve(padding_values_.size()); + for (const Tensor& t : padding_values_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + padding_values.emplace_back(node); + } + + AttrValue output_types; + b->BuildAttrValue(dtypes_, &output_types); + + AttrValue output_shapes; + b->BuildAttrValue(batched_shapes_with_queue_, &output_shapes); + + AttrValue N; + b->BuildAttrValue(shapes_.size(), &N); + + TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, input_graph}, {1, batch_size}}, + {{2, padded_shapes}, {3, padding_values}}, + {{"Toutput_types", output_types}, + {"output_shapes", output_shapes}, + {"N", N}}, + output)); + + return Status::OK(); + } + + private: + friend class EnqueueInQueueDatasetOp; + + class Iterator + : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + queue_(new TensorQueue(/*input_impl*/ + params.dataset->input_->MakeIterator( + params.prefix), + params.dataset->dtypes_, + params.dataset->shapes_)) {} + + ~Iterator() override { queue_->Unref(); } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + std::vector> batch; + TF_RETURN_IF_ERROR(queue_->GetNext(ctx, dataset()->batch_size_, &batch, + end_of_sequence)); + const auto& dtypes = dataset()->dtypes_; + const auto& shapes = dataset()->shapes_; + const auto& input_shapes = dataset()->input_->output_shapes(); + const auto& padding_values = dataset()->padding_values_; + const int64 batch_size = batch.size(); + out_tensors->reserve(dtypes.size()); + + std::vector max_shapes; // Of non-queue components. + for (int i = 0; i < dtypes.size(); ++i) { + const PartialTensorShape& shape = shapes[i]; + TensorShape out_shape({batch_size}); + for (int r = 0; r < shape.dims(); ++r) { + if (shape.dim_size(r) >= 0) { + // padded_shape[r] is known. + out_shape.AddDim(shape.dim_size(r)); + } else { + // padded_shape[r] is unknown, find the maximum across + // the batch. + int64 dim = 0; + for (int b = 0; b < batch.size(); ++b) { + dim = std::max(dim, batch[b][i].dim_size(r)); + } + out_shape.AddDim(dim); + } + } + max_shapes.push_back(std::move(out_shape)); + } + + Tensor queues_t(cpu_allocator(), DT_VARIANT, TensorShape({batch_size})); + if (!batch.empty()) { + auto queues = queues_t.flat(); + Variant& queue_inserter = queues(0); + queue_inserter = TensorQueueInserter(); + queue_inserter.get()->set_queue(queue_); + for (int b = 1; b < batch.size(); ++b) { + // Copy the TensorQueueInserter. Each copy increments the + // Ref on the queue_. + queues(b) = queues(0); + } + } + out_tensors->push_back(std::move(queues_t)); + + for (int i = 0; i < max_shapes.size(); ++i) { + Tensor component(cpu_allocator(), dtypes[i], max_shapes[i]); + // Try hard to take the fast path. + if (shapes[i].IsFullyDefined() && + shapes[i].IsIdenticalTo(input_shapes[i])) { + // Take the fast path if we know all the shapes statically. + for (int64 b = 0; b < batch.size(); ++b) { + TF_RETURN_IF_ERROR( + batch_util::CopyElementToSlice(batch[b][i], &component, b)); + } + } else { + TF_RETURN_IF_ERROR( + batch_util::SetElementZero(&component, padding_values[i])); + for (int64 b = 0; b < batch.size(); ++b) { + if (batch[b][i].shape() == max_shapes[i]) { + TF_RETURN_IF_ERROR( + batch_util::CopyElementToSlice(batch[b][i], &component, b)); + } else { + TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice( + batch[b][i], &component, b)); + } + } + } + out_tensors->push_back(std::move(component)); + } + + // end_of_sequence was set before we populated out_tensors, so + // it's ok to return now. + return Status::OK(); + } + + protected: + // Work around bug in MSVC that disallows access to protected + // members of Iterator from within TensorQueue. + class TensorQueue; + friend class TensorQueue; + + class TensorQueue : public core::RefCounted { + public: + TensorQueue(std::unique_ptr input_impl, + const DataTypeVector& dtypes, + const std::vector& shapes) + : dtypes_(dtypes), + shapes_(shapes), + input_impl_(std::move(input_impl)) {} + + void MaybeWaitForNotificationLocked(mutex_lock* lock) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // This essentially just releases the lock and immediately relocks. + cv_.wait_for(*lock, std::chrono::milliseconds(0)); + } + + void NotifyLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { cv_.notify_all(); } + + Status GetNext(IteratorContext* ctx, const int64 batch_size, + std::vector>* batch, + bool* end_of_sequence) { + mutex_lock lock(mu_); + + *end_of_sequence = false; + + for (int64 b = 0; b < batch_size;) { + if (!entries_.empty()) { + batch->push_back(std::move(entries_.front())); + entries_.pop_front(); + ++b; + continue; + } else { + if (input_impl_) { + // There's still input coming in. + std::vector tensors; + bool input_end; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &tensors, &input_end)); + if (!input_end) { + batch->push_back(std::move(tensors)); + ++b; + continue; + } else { + input_impl_.reset(); + } + } + if (!input_impl_) { + // There's no more input coming in. + if (RefCountIsOne()) { + // No TensorQueueInserters in the wild. + if (batch->empty()) { + *end_of_sequence = true; + } + break; + } else { + MaybeWaitForNotificationLocked(&lock); + // If there's data available, try to add entries again. + // Otherwise return a smaller batch and hope the next + // iterator request has a non-empty or unused queue_. + if (entries_.empty()) { + break; + } + } + } + } + } // for (int64 b = ... batch_size) + return Status::OK(); + } + + Status Insert(const std::vector& tensors) { + if (tensors.size() != dtypes_.size()) { + return errors::InvalidArgument( + "TensorQueue::Insert: mismatched number of tensors. Queue " + "expects ", + dtypes_.size(), " tensors but tried to insert ", tensors.size()); + } + for (int i = 0; i < tensors.size(); ++i) { + if (tensors[i].dtype() != dtypes_[i]) { + return errors::InvalidArgument( + "TensorQueue::Insert: mismatched dtypes at component ", i, + ". Attempted " + "to insert tensor of type ", + DataTypeString(tensors[i].dtype()), + " but queue expected type: ", DataTypeString(dtypes_[i])); + } + if (!shapes_[i].IsCompatibleWith(tensors[i].shape())) { + return errors::InvalidArgument( + "TensorQueue::Insert: mismatched shapes at component ", i, + ". Attempted " + "to insert tensor with shape ", + tensors[i].shape().DebugString(), + " but queue expected shape: ", shapes_[i].DebugString()); + } + } + mutex_lock lock(mu_); + entries_.push_back(tensors); + NotifyLocked(); + return Status::OK(); + } + + Status Save(Iterator* iter, IteratorStateWriter* writer) { + mutex_lock lock(mu_); + if (input_impl_) { + TF_RETURN_IF_ERROR(iter->SaveParent(writer, input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(iter->full_name("input_exhausted"), "")); + } + TF_RETURN_IF_ERROR(writer->WriteScalar(iter->full_name("entries_size"), + entries_.size())); + for (int64 b = 0; b < entries_.size(); ++b) { + for (int i = 0; i < dtypes_.size(); ++i) { + TF_RETURN_IF_ERROR( + writer->WriteTensor(strings::StrCat(iter->full_name("entries"), + "[", b, "][", i, "]"), + entries_[b][i])); + } + } + return Status::OK(); + } + + Status Restore(Iterator* iter, IteratorContext* ctx, + IteratorStateReader* reader) { + mutex_lock l(mu_); + if (reader->Contains(iter->full_name("input_exhausted"))) { + input_impl_.reset(); + } else { + input_impl_ = iter->dataset_input()->MakeIterator(iter->prefix()); + TF_RETURN_IF_ERROR(iter->RestoreParent(ctx, reader, input_impl_)); + } + entries_.clear(); + int64 entries_size = -1; + TF_RETURN_IF_ERROR( + reader->ReadScalar(iter->full_name("entries_size"), &entries_size)); + if (entries_size < 0) { + return errors::DataLoss( + "Expected entries_size key '", iter->full_name("entries_size"), + "' to have nonnegative value, but saw: ", entries_size); + } + for (int64 b = 0; b < entries_size; ++b) { + std::vector entry; + for (int i = 0; i < dtypes_.size(); ++i) { + Tensor value; + TF_RETURN_IF_ERROR( + reader->ReadTensor(strings::StrCat(iter->full_name("entries"), + "[", b, "][", i, "]"), + &value)); + entry.push_back(std::move(value)); + } + entries_.push_back(std::move(entry)); + } + return Status::OK(); + } + + mutex* mu() { return &mu_; } + + private: + DataTypeVector dtypes_; + std::vector shapes_; + + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + std::deque> entries_ GUARDED_BY(mu_); + condition_variable cv_ GUARDED_BY(mu_); + }; + + const DatasetBase* dataset_input() const { return dataset()->input_; } + + Status SaveInternal(IteratorStateWriter* writer) override { + return queue_->Save(this, writer); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return queue_->Restore(this, ctx, reader); + } + + public: + class TensorQueueInserter { + public: + TensorQueueInserter() : queue_(nullptr) {} + + void set_queue(TensorQueue* queue) { + queue_ = queue; + queue_->Ref(); + } + + TensorQueueInserter(const TensorQueueInserter& rhs) { + queue_ = rhs.queue_; + queue_->Ref(); + }; + + TensorQueueInserter(TensorQueueInserter&& rhs) { + queue_ = rhs.queue_; + rhs.queue_ = nullptr; + } + + TensorQueueInserter& operator=(const TensorQueueInserter& rhs) = delete; + + string TypeName() const { return "tensorflow::TensorQueueInserter"; } + string DebugString() const { return TypeName(); } + + void Encode(VariantTensorData*) const {} + bool Decode(const VariantTensorData&) { return false; } + + ~TensorQueueInserter() { + if (queue_) { + mutex_lock lock(*queue_->mu()); + queue_->Unref(); + queue_->NotifyLocked(); + queue_ = nullptr; + } + } + + Status Insert(const std::vector& tensors) const { + CHECK(queue_); + return queue_->Insert(tensors); + } + + private: + mutable TensorQueue* queue_; + }; + + private: + TensorQueue* const queue_; + }; + + private: + const int64 batch_size_; + const DatasetBase* input_; + const DataTypeVector dtypes_; + const std::vector shapes_; + const std::vector padding_values_; + const DataTypeVector dtypes_with_queue_; + const std::vector batched_shapes_with_queue_; +}; + +class PrependFromQueueAndPaddedBatchDatasetOp : public UnaryDatasetOpKernel { + public: + explicit PrependFromQueueAndPaddedBatchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutput_types", &output_types_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 batch_size = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "batch_size", &batch_size)); + OP_REQUIRES( + ctx, batch_size > 0, + errors::InvalidArgument("Batch size must be greater than zero.")); + + OpInputList padded_shape_tensors; + OP_REQUIRES_OK(ctx, + ctx->input_list("padded_shapes", &padded_shape_tensors)); + std::vector padded_shapes; + padded_shapes.reserve(padded_shape_tensors.size()); + OP_REQUIRES(ctx, + padded_shape_tensors.size() == input->output_shapes().size(), + errors::InvalidArgument("Number of padded shapes (", + padded_shape_tensors.size(), + ") must match the number of components " + "in the input dataset's elements (", + input->output_shapes().size(), ")")); + for (const Tensor& padded_shape_t : padded_shape_tensors) { + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(padded_shape_t.shape()), + errors::InvalidArgument("All padded shapes must be vectors")); + PartialTensorShape padded_shape; + OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape( + padded_shape_t.vec().data(), + padded_shape_t.NumElements(), &padded_shape)); + padded_shapes.push_back(std::move(padded_shape)); + } + + OP_REQUIRES( + ctx, input->output_dtypes() == output_types_, + errors::InvalidArgument("Input dataset and this dataset " + "have different output_types: ", + DataTypeVectorString(input->output_dtypes()), + " and ", DataTypeVectorString(output_types_))); + + for (int i = 0; i < input->output_shapes().size(); ++i) { + // Exclude the queue from the tensor_shapes calculation. + const PartialTensorShape& tensor_shape = padded_shapes[i]; + OP_REQUIRES( + ctx, + IsGreaterEqualToOrCompatibleWith(tensor_shape, + input->output_shapes()[i]), + errors::InvalidArgument("Incompatible input shapes at component ", i, + " between input dataset this dataset: ", + input->output_shapes()[i].DebugString(), + " vs. ", tensor_shape.DebugString())); + } + + OpInputList padding_values_list; + OP_REQUIRES_OK(ctx, + ctx->input_list("padding_values", &padding_values_list)); + std::vector padding_values; + OP_REQUIRES(ctx, + padding_values_list.size() == input->output_shapes().size(), + errors::InvalidArgument( + "Number of padding values (", padding_values_list.size(), + ") must match the number of components in the input " + "dataset's elements (", + input->output_shapes().size(), ")")); + for (int i = 0; i < padding_values_list.size(); ++i) { + const Tensor& padding_value_t = padding_values_list[i]; + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(padding_value_t.shape()), + errors::InvalidArgument( + "All padding values must be scalars; but at component ", i, + " saw shape: ", padding_value_t.shape().DebugString())); + OP_REQUIRES(ctx, padding_value_t.dtype() == input->output_dtypes()[i], + errors::InvalidArgument( + "Mismatched type between padding value ", i, + " and input dataset's component ", i, ": ", + DataTypeString(padding_value_t.dtype()), " vs. ", + DataTypeString(input->output_dtypes()[i]))); + padding_values.push_back(padding_value_t); + } + + *output = new PrependFromQueueAndPaddedBatchDataset( + ctx, batch_size, input, output_types_, padded_shapes, + std::move(padding_values)); + } + + private: + DataTypeVector output_types_; +}; + +REGISTER_KERNEL_BUILDER( + Name("PrependFromQueueAndPaddedBatchDataset").Device(DEVICE_CPU), + PrependFromQueueAndPaddedBatchDatasetOp); + +class EnqueueInQueueDatasetOp : public OpKernel { + public: + explicit EnqueueInQueueDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + using TensorQueueInserter = + PrependFromQueueAndPaddedBatchDataset::Iterator::TensorQueueInserter; + + // TODO(ebrevdo): accept list of sequence lengths to do proper + // sub-slicing of tensors for placement into the queue? + const Tensor& tensor_queue_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(tensor_queue_t.shape()), + errors::InvalidArgument("queue must be a vector, saw shape: ", + tensor_queue_t.shape().DebugString())); + std::vector inserters; + const int64 batch_size = tensor_queue_t.NumElements(); + inserters.reserve(batch_size); + const Variant* variants = tensor_queue_t.flat().data(); + for (int i = 0; i < batch_size; ++i) { + const auto* inserter = variants[i].get(); + OP_REQUIRES(ctx, inserter != nullptr, + errors::InvalidArgument( + "Could not access TensorQueueInserter from queue[", i, + "]. Received variant: ", variants[i].DebugString())); + inserters.push_back(inserter); + } + + OpInputList components; + OP_REQUIRES_OK(ctx, ctx->input_list("components", &components)); + for (int i = 0; i < components.size(); ++i) { + OP_REQUIRES( + ctx, + components[i].dims() > 0 && components[i].dim_size(0) == batch_size, + errors::InvalidArgument( + "Expected component ", i, " to have batched shape [", batch_size, + ",...], but saw shape: ", components[i].shape().DebugString())); + } + std::vector element_shapes; + for (int i = 0; i < components.size(); ++i) { + TensorShape element_shape = components[i].shape(); + element_shape.RemoveDim(0); + element_shapes.push_back(std::move(element_shape)); + } + for (int64 b = 0; b < batch_size; ++b) { + std::vector tensors; + tensors.reserve(components.size()); + for (int i = 0; i < components.size(); ++i) { + Tensor t(components[i].dtype(), element_shapes[i]); + OP_REQUIRES_OK(ctx, + batch_util::CopySliceToElement(components[i], &t, b)); + tensors.push_back(std::move(t)); + } + // TODO(ebrevdo): Acquire the lock once for all inserters with + // the same underlying queue? Add InsertLocked? + OP_REQUIRES_OK(ctx, inserters[b]->Insert(tensors)); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("EnqueueInQueueDataset").Device(DEVICE_CPU), + EnqueueInQueueDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index 18adae1ea32316ffd995a95fb25198309fda3361..d5be4c778074e406122dc3a1a9c23681fca491d0 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -117,7 +117,7 @@ class TensorSliceDatasetOp : public DatasetOpKernel { out_tensors->reserve(dataset()->tensors_.size()); for (int i = 0; i < dataset()->tensors_.size(); ++i) { const Tensor& t = dataset()->tensors_[i]; - Tensor t_slice(cpu_allocator(), t.dtype(), + Tensor t_slice(ctx->allocator({}), t.dtype(), TensorShape(dataset()->shapes_[i].dim_sizes())); TF_RETURN_IF_ERROR(batch_util::CopySliceToElement(t, &t_slice, i_)); out_tensors->emplace_back(std::move(t_slice)); diff --git a/tensorflow/core/kernels/debug_ops.cc b/tensorflow/core/kernels/debug_ops.cc index 965a60c7e05297d7aa7125bfcb7eed062af7a058..1b94ea05440516ff458c1785edd27589d18ffe61 100644 --- a/tensorflow/core/kernels/debug_ops.cc +++ b/tensorflow/core/kernels/debug_ops.cc @@ -46,7 +46,7 @@ REGISTER_KERNEL_BUILDER(Name("CopyHost") .HostMemory("input") .HostMemory("output"), CopyOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Register debug identity (non-ref and ref) ops. REGISTER_KERNEL_BUILDER(Name("DebugIdentity").Device(DEVICE_CPU), @@ -66,7 +66,7 @@ REGISTER_KERNEL_BUILDER(Name("DebugIdentity") .HostMemory("input") .HostMemory("output"), DebugIdentityOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Register debug NaN-counter (non-ref and ref) ops. #define REGISTER_DEBUG_NAN_COUNT(type) \ @@ -98,7 +98,7 @@ REGISTER_GPU_DEBUG_NAN_COUNT(double); DebugNanCountOp); REGISTER_GPU_DEBUG_NAN_COUNT(float); REGISTER_GPU_DEBUG_NAN_COUNT(double); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Register debug numeric summary ops. #define REGISTER_DEBUG_NUMERIC_SUMMARY_COUNT(type) \ diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index 381add3fb3bd57ebf068212cdd32a640bf60dd9b..53a23b130609f8b1f4d2dd9f7665d02154f47364 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -21,7 +21,7 @@ limitations under the License. #endif #ifdef TENSORFLOW_USE_SYCL #include "tensorflow/core/common_runtime/sycl/sycl_util.h" -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #include "tensorflow/core/debug/debug_io_utils.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" @@ -91,7 +91,7 @@ class CopyOp : public OpKernel { Device* device = static_cast(context->device()); // Determine if the input tensor is not on CPU (e.g., on GPU). const bool off_host_input = device->device_type() == DEVICE_SYCL && - !context->input_alloc_attr(0).on_host(); + !context->input_alloc_attr(0).on_host(); if (off_host_input) { SYCLmemcpy(context->eigen_sycl_device(), src_tensor, copied_tensor); diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc index c778278e8fbbec67a0255ea7d257c19da4f3612f..b4dcf0a74b336e6173843be233c370d624a9a8e2 100644 --- a/tensorflow/core/kernels/decode_bmp_op.cc +++ b/tensorflow/core/kernels/decode_bmp_op.cc @@ -39,6 +39,13 @@ class DecodeBmpOp : public OpKernel { errors::InvalidArgument("channels must be 0, 1, 3 or 4, got ", channels_)); } + inline int32 ByteSwapInt32ForBigEndian(int32 x) { +#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) + return le32toh(x); +#else + return x; +#endif + } void Compute(OpKernelContext* context) override { const Tensor& contents = context->input(0); @@ -56,14 +63,18 @@ class DecodeBmpOp : public OpKernel { input.size(), " bytes")); const uint8* img_bytes = reinterpret_cast(input.data()); - const int32 header_size = internal::SubtleMustCopy( + int32 header_size_ = internal::SubtleMustCopy( *(reinterpret_cast(img_bytes + 10))); - const int32 width = internal::SubtleMustCopy( + const int32 header_size = ByteSwapInt32ForBigEndian(header_size_); + int32 width_ = internal::SubtleMustCopy( *(reinterpret_cast(img_bytes + 18))); - const int32 height = internal::SubtleMustCopy( + const int32 width = ByteSwapInt32ForBigEndian(width_); + int32 height_ = internal::SubtleMustCopy( *(reinterpret_cast(img_bytes + 22))); - const int32 bpp = internal::SubtleMustCopy( + const int32 height = ByteSwapInt32ForBigEndian(height_); + int32 bpp_ = internal::SubtleMustCopy( *(reinterpret_cast(img_bytes + 28))); + const int32 bpp = ByteSwapInt32ForBigEndian(bpp_); if (channels_) { OP_REQUIRES(context, (channels_ == bpp / 8), @@ -80,15 +91,32 @@ class DecodeBmpOp : public OpKernel { errors::InvalidArgument( "Number of channels must be 1, 3 or 4, was ", channels_)); + OP_REQUIRES(context, width > 0 && header_size >= 0, + errors::InvalidArgument("Width must be positive")); + OP_REQUIRES(context, header_size >= 0, + errors::InvalidArgument("header size must be nonnegative")); + + // The real requirement is < 2^31 minus some headers and channel data, + // so rounding down to something that's still ridiculously big. + OP_REQUIRES( + context, + (static_cast(width) * std::abs(static_cast(height))) < + static_cast(std::numeric_limits::max() / 8), + errors::InvalidArgument( + "Total possible pixel bytes must be less than 2^30")); + + const int32 abs_height = abs(height); + // there may be padding bytes when the width is not a multiple of 4 bytes // 8 * channels == bits per pixel const int row_size = (8 * channels_ * width + 31) / 32 * 4; - const int last_pixel_offset = - header_size + (abs(height) - 1) * row_size + (width - 1) * channels_; + const int64 last_pixel_offset = static_cast(header_size) + + (abs_height - 1) * row_size + + (width - 1) * channels_; // [expected file size] = [last pixel offset] + [last pixel size=channels] - const int expected_file_size = last_pixel_offset + channels_; + const int64 expected_file_size = last_pixel_offset + channels_; OP_REQUIRES( context, (expected_file_size <= input.size()), @@ -104,12 +132,12 @@ class DecodeBmpOp : public OpKernel { Tensor* output = nullptr; OP_REQUIRES_OK( context, context->allocate_output( - 0, TensorShape({abs(height), width, channels_}), &output)); + 0, TensorShape({abs_height, width, channels_}), &output)); const uint8* bmp_pixels = &img_bytes[header_size]; Decode(bmp_pixels, row_size, output->flat().data(), width, - abs(height), channels_, top_down); + abs_height, channels_, top_down); } uint8* Decode(const uint8* input, const int row_size, uint8* const output, diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc index c4555db453ba1549601cbf9a4bbf096fc3db22b2..0c42f632521dd86760e791626c8978c0b1e82709 100644 --- a/tensorflow/core/kernels/decode_csv_op.cc +++ b/tensorflow/core/kernels/decode_csv_op.cc @@ -91,9 +91,9 @@ class DecodeCSVOp : public OpKernel { } else { int32 value; OP_REQUIRES(ctx, strings::safe_strto32(fields[f], &value), - errors::InvalidArgument("Field ", f, " in record ", i, - " is not a valid int32: ", - fields[f])); + errors::InvalidArgument( + "Field ", f, " in record ", i, + " is not a valid int32: ", fields[f])); output[f]->flat()(i) = value; } break; @@ -111,9 +111,9 @@ class DecodeCSVOp : public OpKernel { } else { int64 value; OP_REQUIRES(ctx, strings::safe_strto64(fields[f], &value), - errors::InvalidArgument("Field ", f, " in record ", i, - " is not a valid int64: ", - fields[f])); + errors::InvalidArgument( + "Field ", f, " in record ", i, + " is not a valid int64: ", fields[f])); output[f]->flat()(i) = value; } break; @@ -130,9 +130,9 @@ class DecodeCSVOp : public OpKernel { } else { float value; OP_REQUIRES(ctx, strings::safe_strtof(fields[f].c_str(), &value), - errors::InvalidArgument("Field ", f, " in record ", i, - " is not a valid float: ", - fields[f])); + errors::InvalidArgument( + "Field ", f, " in record ", i, + " is not a valid float: ", fields[f])); output[f]->flat()(i) = value; } break; @@ -150,9 +150,9 @@ class DecodeCSVOp : public OpKernel { } else { double value; OP_REQUIRES(ctx, strings::safe_strtod(fields[f].c_str(), &value), - errors::InvalidArgument("Field ", f, " in record ", i, - " is not a valid double: ", - fields[f])); + errors::InvalidArgument( + "Field ", f, " in record ", i, + " is not a valid double: ", fields[f])); output[f]->flat()(i) = value; } break; @@ -208,9 +208,10 @@ class DecodeCSVOp : public OpKernel { if (!quoted) { while (static_cast(current_idx) < input.size() && input[current_idx] != delim_) { - OP_REQUIRES(ctx, (!use_quote_delim_ || input[current_idx] != '"') && - input[current_idx] != '\n' && - input[current_idx] != '\r', + OP_REQUIRES(ctx, + (!use_quote_delim_ || input[current_idx] != '"') && + input[current_idx] != '\n' && + input[current_idx] != '\r', errors::InvalidArgument( "Unquoted fields cannot have quotes/CRLFs inside")); field += input[current_idx]; @@ -238,10 +239,11 @@ class DecodeCSVOp : public OpKernel { } OP_REQUIRES( - ctx, (static_cast(current_idx) < input.size() && - input[current_idx] == '"' && - (static_cast(current_idx) == input.size() - 1 || - input[current_idx + 1] == delim_)), + ctx, + (static_cast(current_idx) < input.size() && + input[current_idx] == '"' && + (static_cast(current_idx) == input.size() - 1 || + input[current_idx + 1] == delim_)), errors::InvalidArgument("Quoted field has to end with quote " "followed by delim or end")); diff --git a/tensorflow/core/kernels/decode_image_op.cc b/tensorflow/core/kernels/decode_image_op.cc index 44dcbf834ce838e3b25957f88bfcded645104957..912d04c1536600348e8263f03709f2305607d11f 100644 --- a/tensorflow/core/kernels/decode_image_op.cc +++ b/tensorflow/core/kernels/decode_image_op.cc @@ -87,10 +87,11 @@ class DecodeImageOp : public OpKernel { channels_ = 3; } else { OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_)); - OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3 || - channels_ == 4, - errors::InvalidArgument( - "channels must be 0, 1, 3, or 4, got ", channels_)); + OP_REQUIRES( + context, + channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4, + errors::InvalidArgument("channels must be 0, 1, 3, or 4, got ", + channels_)); } flags_.components = channels_; @@ -114,8 +115,9 @@ class DecodeImageOp : public OpKernel { if (format_ == kJpgFormat) { OP_REQUIRES_OK(context, context->GetAttr("ratio", &flags_.ratio)); - OP_REQUIRES(context, flags_.ratio == 1 || flags_.ratio == 2 || - flags_.ratio == 4 || flags_.ratio == 8, + OP_REQUIRES(context, + flags_.ratio == 1 || flags_.ratio == 2 || flags_.ratio == 4 || + flags_.ratio == 8, errors::InvalidArgument("ratio must be 1, 2, 4, or 8, got ", flags_.ratio)); OP_REQUIRES_OK(context, context->GetAttr("fancy_upscaling", @@ -130,8 +132,9 @@ class DecodeImageOp : public OpKernel { string dct_method; OP_REQUIRES_OK(context, context->GetAttr("dct_method", &dct_method)); OP_REQUIRES( - context, (dct_method.empty() || dct_method == "INTEGER_FAST" || - dct_method == "INTEGER_ACCURATE"), + context, + (dct_method.empty() || dct_method == "INTEGER_FAST" || + dct_method == "INTEGER_ACCURATE"), errors::InvalidArgument("dct_method must be one of " "{'', 'INTEGER_FAST', 'INTEGER_ACCURATE'}")); if (dct_method == "INTEGER_FAST") { @@ -157,9 +160,9 @@ class DecodeImageOp : public OpKernel { errors::InvalidArgument("Expected image (JPEG, PNG, or GIF), got ", FileFormatString(magic, input))); OP_REQUIRES(context, input.size() <= std::numeric_limits::max(), - errors::InvalidArgument(FileFormatString(magic, input), - " contents are too large for int: ", - input.size())); + errors::InvalidArgument( + FileFormatString(magic, input), + " contents are too large for int: ", input.size())); OP_REQUIRES(context, magic == kPngFormat || channel_bits_ == 8, errors::InvalidArgument(FileFormatString(magic, input), " does not support uint16 output")); @@ -212,9 +215,10 @@ class DecodeImageOp : public OpKernel { input.data(), input.size(), flags, nullptr /* nwarn */, [=, &output](int width, int height, int channels) -> uint8* { Status status(context->allocate_output( - 0, format_ == kGifFormat - ? TensorShape({1, height, width, channels}) - : TensorShape({height, width, channels}), + 0, + format_ == kGifFormat + ? TensorShape({1, height, width, channels}) + : TensorShape({height, width, channels}), &output)); if (!status.ok()) { VLOG(1) << status; diff --git a/tensorflow/core/kernels/deep_conv2d.cc b/tensorflow/core/kernels/deep_conv2d.cc index 8e9b8a7e2e7be8e55deeacd4de3f77033499387f..829155fb313bd354d28432be6212af0760630c44 100644 --- a/tensorflow/core/kernels/deep_conv2d.cc +++ b/tensorflow/core/kernels/deep_conv2d.cc @@ -120,9 +120,9 @@ bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows, VLOG(2) << "CanUseDeepConv2D" << " deep_conv_cost: " << deep_conv_cost - << " direct_conv_cost: " << direct_conv_cost - << " deep_direct_ratio: " << (static_cast(deep_conv_cost) / - static_cast(direct_conv_cost)) + << " direct_conv_cost: " << direct_conv_cost << " deep_direct_ratio: " + << (static_cast(deep_conv_cost) / + static_cast(direct_conv_cost)) << " use_deep_conv: " << (deep_conv_cost < direct_conv_cost); return deep_conv_cost < direct_conv_cost; } diff --git a/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc b/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc index c9c97dc072c93e3ab840a8a9c9d81eadd2adaa3c..9a3b2303a3bf6718009b5055c4ef25464ec01136 100644 --- a/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc @@ -57,6 +57,7 @@ struct DenseUpdate { template struct functor::DenseUpdate; \ template struct functor::DenseUpdate; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); +TF_CALL_int64(DEFINE_GPU_KERNELS); #undef DEFINE_GPU_KERNELS #define DEFINE_GPU_KERNELS(T) \ diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc index 6d44a92fa3c2d22ade6293d30b4f008a62eb8e0f..0de97de20523ad54c08aa7b4190438c1da6ebde7 100644 --- a/tensorflow/core/kernels/dense_update_ops.cc +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -89,7 +89,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #define REGISTER_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ @@ -109,18 +109,19 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); AssignOpT); TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNELS(type) \ -REGISTER_KERNEL_BUILDER( \ - Name("Assign").Device(DEVICE_SYCL).TypeConstraint("T"), \ - AssignOpT); +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Assign").Device(DEVICE_SYCL).TypeConstraint("T"), \ + AssignOpT); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); #undef REGISTER_SYCL_KERNELS -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #define REGISTER_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ @@ -142,11 +143,12 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); Name("AssignSub").Device(DEVICE_GPU).TypeConstraint("T"), \ DenseUpdateOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // end GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNELS(type) \ +#define REGISTER_SYCL_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint("T"), \ DenseUpdateOp); \ @@ -156,5 +158,5 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); #undef REGISTER_SYCL_KERNELS -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc index 9347978d515b9244dde2b50b2fcfaa3c91ab9c94..91a9587174be4c047f8a21ea9222219def42d5f1 100644 --- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc @@ -400,7 +400,7 @@ struct LaunchDepthwiseConvBackpropInputOp { // Computes one shard of depthwise conv2d backprop input. auto shard = [&ctx, &args, &out_backprop, &filter_data, &in_backprop]( - int64 start, int64 limit) { + int64 start, int64 limit) { static const int64 kPacketSize = (sizeof(Packet) / sizeof(T)); const int64 input_image_size = @@ -750,7 +750,7 @@ struct LaunchDepthwiseConvBackpropFilterOp { // Computes one shard of depthwise conv2d backprop filter. auto shard = [&ctx, &args, &out_backprop, &input, &output_buffer_data]( - int64 start, int64 limit) { + int64 start, int64 limit) { static const int64 kPacketSize = (sizeof(Packet) / sizeof(T)); const int64 filter_spatial_size = args.filter_rows * args.filter_cols; const int64 padded_out_depth_size = diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index a5fd07fbe177f2206ef9b6b3252556211b9e3905..c060b2e14d2f03f990af5267260bd88fa01a2c81 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -308,10 +308,10 @@ class DepthwiseConv2dNativeOp : public BinaryOp { // in_depth for input and filter must match. const int64 in_depth = GetTensorDim(input, data_format_, 'C'); - OP_REQUIRES( - context, in_depth == filter.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - in_depth, " vs ", filter.dim_size(2))); + OP_REQUIRES(context, in_depth == filter.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, + " vs ", filter.dim_size(2))); // The last dimension for filter is depth multiplier. const int32 depth_multiplier = filter.dim_size(3); @@ -430,9 +430,10 @@ TF_CALL_double(REGISTER_CPU_KERNEL); #endif #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER( - Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint("T"), - DepthwiseConv2dNativeOp); +REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") + .Device(DEVICE_GPU) + .TypeConstraint("T"), + DepthwiseConv2dNativeOp); REGISTER_KERNEL_BUILDER( Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint("T"), diff --git a/tensorflow/core/kernels/depthwise_conv_op.h b/tensorflow/core/kernels/depthwise_conv_op.h index ba262d56eef62eed3abf23b34da2a3c4727795d4..b2d5898891370321f7e97f19f2382eb1d55985f7 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.h +++ b/tensorflow/core/kernels/depthwise_conv_op.h @@ -83,7 +83,7 @@ struct LaunchDepthwiseConvBackpropFilterOp { #if GOOGLE_CUDA template struct LaunchDepthwiseConvOp { - void operator()(OpKernelContext* ctx, const DepthwiseArgs args, + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, const T* input, const T* filter, T* output, TensorFormat data_format); }; diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 903aac5d68baeb8c37b009a54863a084dcb75147..94989089ec9cdf9314860b43f67691f39f33c31f 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -17,19 +17,19 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "external/cub_archive/cub/util_ptx.cuh" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/depthwise_conv_op.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/cuda_kernel_helper.h" #include "tensorflow/core/util/tensor_format.h" -#include "external/cub_archive/cub/util_ptx.cuh" -#if !defined(_MSC_VER) -#define UNROLL _Pragma("unroll") -#define NOUNROLL _Pragma("nounroll") -#else +#if defined(_MSC_VER) && !defined(__clang__) #define UNROLL #define NOUNROLL +#else +#define UNROLL _Pragma("unroll") +#define NOUNROLL _Pragma("nounroll") #endif namespace tensorflow { @@ -39,7 +39,7 @@ using Eigen::GpuDevice; // Returns whether depthwise convolution forward or backward input pass can be // performed using the faster ('Small') variant of the kernel. EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall( - const DepthwiseArgs args) { + const DepthwiseArgs& args) { return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 && args.in_cols <= 32 && args.in_rows == args.out_rows && args.in_cols == args.out_cols && args.pad_rows >= 0 && @@ -52,13 +52,13 @@ EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall( // Returns whether depthwise convolution backward filter pass can be performed // using the faster ('Small') variant of the kernel. EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const DepthwiseArgs args, const int block_rows) { + const DepthwiseArgs& args, const int block_height) { return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 && args.in_cols <= 32 && args.in_rows == args.out_rows && args.in_cols == args.out_cols && args.pad_rows >= 0 && args.pad_rows < args.filter_rows && args.pad_cols >= 0 && - args.pad_cols < args.filter_cols && block_rows <= args.in_rows && - args.filter_rows * args.filter_cols <= args.in_cols * block_rows; + args.pad_cols < args.filter_cols && block_height <= args.in_rows && + args.filter_rows * args.filter_cols <= args.in_cols * block_height; } // The DepthwiseConv2dGPUKernels perform either forward or backprop input @@ -72,72 +72,81 @@ template (0); - const int input_offset_temp = in_rows * OB; + const int input_offset_temp = in_height * batch; if (input_row_start >= 0 && input_col_start >= 0 && - input_row_end < in_rows && input_col_end < in_cols) { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = input_row_start + f_r; - const int filter_offset_temp = filter_cols * f_r; - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = input_col_start + f_c; + input_row_end < in_height && input_col_end < in_width) { + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; const int input_offset = - in_d + in_depth * (in_c + in_cols * (in_r + input_offset_temp)); + in_channel + + in_depth * (in_col + in_width * (in_row + input_offset_temp)); const int filter_offset = multiplier + - depth_multiplier * (in_d + in_depth * (f_c + filter_offset_temp)); + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); sum += ldg(input + input_offset) * ldg(filter + filter_offset); } } } else { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = input_row_start + f_r; - const int filter_offset_temp = filter_cols * f_r; - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = input_col_start + f_c; - if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) { - const int in_c = input_col_start + f_c; + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { + const int in_col = input_col_start + filter_col; const int input_offset = - in_d + in_depth * (in_c + in_cols * (in_r + input_offset_temp)); + in_channel + + in_depth * (in_col + in_width * (in_row + input_offset_temp)); const int filter_offset = - multiplier + depth_multiplier * - (in_d + in_depth * (f_c + filter_offset_temp)); + multiplier + + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); sum += ldg(input + input_offset) * ldg(filter + filter_offset); } } @@ -157,8 +166,8 @@ __global__ void __launch_bounds__(1024, 2) // Backprop input direction is the same as forward direction with the filter // rotated by 180°. template + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( const DepthwiseArgs args, const T* input, const T* filter, T* output) { assert(CanLaunchDepthwiseConv2dGPUSmall(args)); @@ -166,45 +175,47 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; T* const shared_data = reinterpret_cast(shared_memory); - const int batches = args.batch; - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; - const int block_rows = blockDim.z; + assert(blockDim.x == kBlockDepth); + assert(blockDim.y == args.in_cols); + const int block_height = blockDim.z; // These values are the same for all threads and could // be precomputed on the CPU. - const int block_size = block_rows * in_cols * kBlockSlices; - const int in_row_size = in_cols * in_depth; - const int in_size = in_rows * in_row_size; - const int in_increment = (in_cols - 1) * kBlockSlices; - const int filter_pixels = filter_rows * filter_cols; - const int tile_cols = in_cols + filter_cols - 1; - const int even_rows = kKnownEvenRows || (1 & ~in_rows); - const int tile_rows = in_rows + filter_rows - even_rows; - const int tile_row_size = tile_cols * kBlockSlices; - const int tile_size = tile_rows * tile_row_size; - const int tile_offset = block_rows * tile_row_size; - const int pad_offset = pad_rows * tile_cols + pad_cols; - const int batch_blocks = (in_depth + kBlockSlices - 1) / kBlockSlices; - const int in_blocks = batch_blocks * batches; + const int block_size = block_height * in_width * kBlockDepth; + const int in_row_size = in_width * in_depth; + const int in_size = in_height * in_row_size; + const int in_increment = (in_width - 1) * kBlockDepth; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int even_height = kKnownEvenHeight || (1 & ~in_height); + const int tile_height = in_height + filter_height - even_height; + const int tile_row_size = tile_width * kBlockDepth; + const int tile_size = tile_height * tile_row_size; + const int tile_offset = block_height * tile_row_size; + const int pad_offset = pad_height * tile_width + pad_width; + const int batch_blocks = (in_depth + kBlockDepth - 1) / kBlockDepth; + const int in_blocks = batch_blocks * num_batches; const int tensor_offset = - kKnownEvenRows ? in_size / 2 : block_rows * in_row_size; + kKnownEvenHeight ? in_size / 2 : block_height * in_row_size; const int thread_depth = threadIdx.x; const int thread_col = threadIdx.y; const int thread_row = threadIdx.z; // Position in block. - const int thread_pix = thread_row * in_cols + thread_col; - const int thread_idx = thread_pix * kBlockSlices + thread_depth; + const int thread_pix = thread_row * in_width + thread_col; + const int thread_idx = thread_pix * kBlockDepth + thread_depth; // Initialize tile, in particular the padding. for (int i = thread_idx; i < tile_size; i += block_size) { @@ -216,32 +227,32 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( const int tensor_idx = thread_pix * in_depth + thread_depth; // Position in (padded) shared memory. - const int data_pix = thread_row * tile_cols + thread_col; - const int data_idx = data_pix * kBlockSlices + thread_depth; + const int data_pix = thread_row * tile_width + thread_col; + const int data_idx = data_pix * kBlockDepth + thread_depth; - // Position in shared memory, offset by pad_rows / pad_cols. + // Position in shared memory, offset by pad_height / pad_width. const int tile_pix = data_pix + pad_offset; - const int tile_idx = tile_pix * kBlockSlices + thread_depth; + const int tile_idx = tile_pix * kBlockDepth + thread_depth; - const int max_depth = in_depth - thread_depth; + const int max_channel = in_depth - thread_depth; const int filter_write_offset = thread_pix < filter_pixels ? tile_size + thread_idx : 0; const int filter_read_offset = tile_size + thread_depth + - (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockSlices); + (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockDepth); const bool skip_second = - !kKnownEvenRows && thread_row + (in_rows & 1) == block_rows; + !kKnownEvenHeight && thread_row + (in_height & 1) == block_height; for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { const int batch = b / batch_blocks; - const int stack = b - batch * batch_blocks; + const int block = b - batch * batch_blocks; - const int start_depth = stack * kBlockSlices; - const int filter_offset = tensor_idx + start_depth; + const int start_channel = block * kBlockDepth; + const int filter_offset = tensor_idx + start_channel; const int inout_offset = batch * in_size + filter_offset; - const bool depth_in_range = start_depth < max_depth; + const bool channel_in_range = start_channel < max_channel; - if (depth_in_range) { + if (channel_in_range) { const T* const in_ptr = inout_offset + input; T* const tile_ptr = tile_idx + shared_data; tile_ptr[0] = ldg(in_ptr); @@ -257,23 +268,23 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - if (depth_in_range) { + if (channel_in_range) { T sum1 = static_cast(0); T sum2 = static_cast(0); int shared_offset = data_idx; const T* filter_ptr = filter_read_offset + shared_data; - UNROLL for (int r = 0; r < filter_rows; ++r) { - UNROLL for (int c = 0; c < filter_cols; ++c) { + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { if (kDirection == DIRECTION_BACKWARD) { - filter_ptr -= kBlockSlices; + filter_ptr -= kBlockDepth; } const T filter_value = *filter_ptr; const T* const tile_ptr = shared_offset + shared_data; sum1 += filter_value * tile_ptr[0]; sum2 += filter_value * tile_ptr[tile_offset]; - shared_offset += kBlockSlices; + shared_offset += kBlockDepth; if (kDirection == DIRECTION_FORWARD) { - filter_ptr += kBlockSlices; + filter_ptr += kBlockDepth; } } shared_offset += in_increment; @@ -297,20 +308,20 @@ template (0); if (input_row_start >= 0 && input_col_start >= 0 && - input_row_end < in_rows && input_col_end < in_cols) { + input_row_end < in_height && input_col_end < in_width) { // Loop that doesn't need to check for boundary conditions. - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = input_row_start + f_r; - const int filter_offset_temp = filter_cols * f_r; - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = input_col_start + f_c; + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; const int input_offset = - (input_offset_temp) + (in_r * in_cols) + in_c; + (input_offset_temp) + (in_row * in_width) + in_col; const int filter_offset = multiplier + - depth_multiplier * (in_d + in_depth * (f_c + filter_offset_temp)); + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); sum += ldg(input + input_offset) * ldg(filter + filter_offset); } } } else { // Loop that needs to check for boundary conditions. - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = input_row_start + f_r; - const int filter_offset_temp = filter_cols * f_r; - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = input_col_start + f_c; - // TODO(vrv): the in_r check can be done outside of this loop; + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; + // TODO(vrv): the in_row check can be done outside of this loop; // benchmark both methods to determine the better decision. - if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) { - const int in_c = input_col_start + f_c; + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { + const int in_col = input_col_start + filter_col; // input_offset_temp indexes into the start of memory // where the spatial data starts. const int input_offset = - (input_offset_temp) + (in_r * in_cols) + in_c; + (input_offset_temp) + (in_row * in_width) + in_col; const int filter_offset = - multiplier + depth_multiplier * - (in_d + in_depth * (f_c + filter_offset_temp)); + multiplier + + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); sum += ldg(input + input_offset) * ldg(filter + filter_offset); } } @@ -427,8 +446,8 @@ __global__ void __launch_bounds__(1024, 2) // Backprop input direction is the same as forward direction with the filter // rotated by 180°. template + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( const DepthwiseArgs args, const T* input, const T* filter, T* output) { assert(CanLaunchDepthwiseConv2dGPUSmall(args)); @@ -436,43 +455,45 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; T* const shared_data = reinterpret_cast(shared_memory); - const int batches = args.batch; - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; // Fixed blockDim.z, tailored for maximum grid size for images of size 16x16. - const int block_rows = blockDim.y; + assert(blockDim.x == args.in_cols); + assert(blockDim.z == kBlockDepth); + const int block_height = blockDim.y; // These values are the same for all threads and could // be precomputed on the CPU. - const int block_pixels = in_cols * block_rows; - const int block_size = block_pixels * kBlockSlices; - const int in_pixels = in_cols * in_rows; - const int in_increment = in_cols - 1; - const int filter_pixels = filter_rows * filter_cols; - const int tile_cols = in_cols + filter_cols - 1; - const int even_rows = kKnownEvenRows || (1 & ~in_rows); - const int tile_rows = in_rows + filter_rows - even_rows; - const int tile_pixels = tile_cols * tile_rows; - const int tile_size = tile_pixels * kBlockSlices; - const int tile_offset = block_rows * tile_cols; - const int pad_offset = pad_rows * tile_cols + pad_cols; - const int in_slices = in_depth * batches; - const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices; + const int block_pixels = in_width * block_height; + const int block_size = block_pixels * kBlockDepth; + const int in_pixels = in_width * in_height; + const int in_increment = in_width - 1; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int even_height = kKnownEvenHeight || (1 & ~in_height); + const int tile_height = in_height + filter_height - even_height; + const int tile_pixels = tile_width * tile_height; + const int tile_size = tile_pixels * kBlockDepth; + const int tile_offset = block_height * tile_width; + const int pad_offset = pad_height * tile_width + pad_width; + const int in_total_depth = in_depth * num_batches; + const int in_blocks = (in_total_depth + kBlockDepth - 1) / kBlockDepth; const int thread_col = threadIdx.x; const int thread_row = threadIdx.y; const int thread_depth = threadIdx.z; // Position in block. - const int thread_pix = thread_row * in_cols + thread_col; + const int thread_pix = thread_row * in_width + thread_col; const int thread_idx = thread_depth * block_pixels + thread_pix; // Initialize tile, in particular the padding. @@ -485,33 +506,33 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( const int tensor_idx = thread_depth * in_pixels + thread_pix; // Position in (padded) shared memory. - const int data_pix = thread_row * tile_cols + thread_col; + const int data_pix = thread_row * tile_width + thread_col; const int data_idx = thread_depth * tile_pixels + data_pix; - // Position in shared memory, offset by pad_rows / pad_cols. + // Position in shared memory, offset by pad_height / pad_width. const int tile_idx = data_idx + pad_offset; // Filter is always in HWCK format, irrespective of the input/output format. - const int filter_pix = thread_idx / kBlockSlices; - const int filter_depth = thread_idx % kBlockSlices; + const int filter_pix = thread_idx / kBlockDepth; + const int filter_channel = thread_idx % kBlockDepth; const int filter_idx = filter_pix * in_depth; - const int max_slice = in_slices - thread_depth; + const int max_channel = in_total_depth - thread_depth; const int filter_write_offset = filter_pix < filter_pixels ? tile_size + thread_idx : 0; const int filter_read_offset = tile_size + thread_depth + - (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockSlices); + (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockDepth); const bool skip_second = - !kKnownEvenRows && thread_row + (in_rows & 1) == block_rows; + !kKnownEvenHeight && thread_row + (in_height & 1) == block_height; for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { - const int slice = b * kBlockSlices; + const int channel = b * kBlockDepth; - const int inout_offset = slice * in_pixels + tensor_idx; - const bool slice_in_range = slice < max_slice; + const int inout_offset = channel * in_pixels + tensor_idx; + const bool channel_in_range = channel < max_channel; - if (slice_in_range) { + if (channel_in_range) { const T* const in_ptr = inout_offset + input; T* const tile_ptr = tile_idx + shared_data; tile_ptr[0] = ldg(in_ptr); @@ -521,22 +542,23 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( } if (filter_write_offset != 0) { - const int filter_offset = filter_idx + (slice + filter_depth) % in_depth; + const int filter_offset = + filter_idx + (channel + filter_channel) % in_depth; shared_data[filter_write_offset] = ldg(filter_offset + filter); } // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - if (slice_in_range) { + if (channel_in_range) { T sum1 = static_cast(0); T sum2 = static_cast(0); int shared_offset = data_idx; const T* filter_ptr = filter_read_offset + shared_data; - UNROLL for (int r = 0; r < filter_rows; ++r) { - UNROLL for (int c = 0; c < filter_cols; ++c) { + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { if (kDirection == DIRECTION_BACKWARD) { - filter_ptr -= kBlockSlices; + filter_ptr -= kBlockDepth; } const T filter_value = *filter_ptr; const T* const tile_ptr = shared_offset + shared_data; @@ -544,7 +566,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( sum2 += filter_value * tile_ptr[tile_offset]; ++shared_offset; if (kDirection == DIRECTION_FORWARD) { - filter_ptr += kBlockSlices; + filter_ptr += kBlockDepth; } } shared_offset += in_increment; @@ -562,146 +584,164 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( } template -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, - const T* input, const T* filter, T* output, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> +void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, TensorFormat data_format) { - const int block_rows = (args.in_rows + 1) / 2; + const int block_height = (args.in_rows + 1) / 2; dim3 block_dim; + int block_count; void (*kernel)(const DepthwiseArgs, const T*, const T*, T*); - if (data_format == FORMAT_NHWC) { - block_dim = dim3(kBlockSlices, args.in_cols, block_rows); - kernel = DepthwiseConv2dGPUKernelNHWCSmall; - } else if (data_format == FORMAT_NCHW) { - block_dim = dim3(args.in_cols, block_rows, kBlockSlices); - kernel = DepthwiseConv2dGPUKernelNCHWSmall; - } else { - assert(false && "Incorrect data format"); - return; + switch (data_format) { + case FORMAT_NHWC: + block_dim = dim3(kBlockDepth, args.in_cols, block_height); + block_count = + args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth; + kernel = + DepthwiseConv2dGPUKernelNHWCSmall; + break; + case FORMAT_NCHW: + block_dim = dim3(args.in_cols, block_height, kBlockDepth); + block_count = + DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth; + kernel = + DepthwiseConv2dGPUKernelNCHWSmall; + break; + case FORMAT_NCHW_VECT_C: + LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + return; } - const int tile_cols = args.in_cols + args.filter_cols - 1; - const int tile_rows = block_rows * 2 + args.filter_rows - 1; - const int tile_pixels = tile_rows * tile_cols; + const int tile_width = args.in_cols + args.filter_cols - 1; + const int tile_height = block_height * 2 + args.filter_rows - 1; + const int tile_pixels = tile_height * tile_width; const int filter_pixels = args.filter_rows * args.filter_cols; const int shared_memory_size = - kBlockSlices * (tile_pixels + filter_pixels) * sizeof(T); - const int num_outputs = - args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = - GetCudaLaunchConfig(num_outputs, d, kernel, shared_memory_size, - block_dim.x * block_dim.y * block_dim.z); - kernel<<>>( - args, input, filter, output); + kBlockDepth * (tile_pixels + filter_pixels) * sizeof(T); + const int num_outputs = args.out_rows * args.out_cols * block_count; + CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( + num_outputs, device, kernel, shared_memory_size, + block_dim.x * block_dim.y * block_dim.z); + kernel<<>>(args, input, filter, output); } template -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, - const T* input, const T* filter, T* output, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth> +void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, TensorFormat data_format) { if (args.in_rows & 1) { LaunchDepthwiseConv2dGPUSmall( - d, args, input, filter, output, data_format); + kKnownFilterHeight, kBlockDepth, false>( + device, args, input, filter, output, data_format); } else { LaunchDepthwiseConv2dGPUSmall( - d, args, input, filter, output, data_format); + kKnownFilterHeight, kBlockDepth, true>( + device, args, input, filter, output, data_format); } } template -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, - const T* input, const T* filter, T* output, +void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, TensorFormat data_format) { - // Maximize (power of two) kBlockSlices while keeping a block within 1024 + // Maximize (power of two) kBlockDepth while keeping a block within 1024 // threads (2 pixels per thread). const int block_pixels = (args.in_rows + 1) / 2 * args.in_cols; if (block_pixels > 256) { LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, - output, data_format); + kKnownFilterHeight, 2>( + device, args, input, filter, output, data_format); } else if (block_pixels > 128) { LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, - output, data_format); + kKnownFilterHeight, 4>( + device, args, input, filter, output, data_format); } else { LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, - output, data_format); + kKnownFilterHeight, 8>( + device, args, input, filter, output, data_format); } } template -void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, - const T* input, const T* filter, T* output, +void LaunchDepthwiseConv2dGPU(const GpuDevice& device, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, TensorFormat data_format) { void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); - if (data_format == FORMAT_NHWC) { - kernel = - DepthwiseConv2dGPUKernelNHWC; - } else if (data_format == FORMAT_NCHW) { - kernel = - DepthwiseConv2dGPUKernelNCHW; - } else { - assert(false && "Incorrect data format"); - return; + switch (data_format) { + case FORMAT_NHWC: + kernel = + DepthwiseConv2dGPUKernelNHWC; + break; + case FORMAT_NCHW: + kernel = + DepthwiseConv2dGPUKernelNCHW; + break; + case FORMAT_NCHW_VECT_C: + LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + return; } const int num_outputs = args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d, kernel, 0, 0); + CudaLaunchConfig config = + GetCudaLaunchConfig(num_outputs, device, kernel, 0, 0); // The compile-time constant version runs faster with a single block. const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 || kKnownDepthMultiplier < 0 ? std::numeric_limits::max() - : d.getNumCudaMultiProcessors(); + : device.getNumCudaMultiProcessors(); kernel<<>>(args, input, filter, - output, num_outputs); + config.thread_per_block, 0, device.stream()>>>(args, input, filter, + output, num_outputs); } template -void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, - const T* input, const T* filter, T* output, +void LaunchDepthwiseConv2dGPU(const GpuDevice& device, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, TensorFormat data_format) { if (args.depth_multiplier == 1) { if (CanLaunchDepthwiseConv2dGPUSmall(args)) { LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, - output, data_format); + kKnownFilterHeight>( + device, args, input, filter, output, data_format); return; } LaunchDepthwiseConv2dGPU( - d, args, input, filter, output, data_format); + device, args, input, filter, output, data_format); } else { LaunchDepthwiseConv2dGPU( - d, args, input, filter, output, data_format); + device, args, input, filter, output, data_format); } } // A simple launch pad to launch the Cuda kernel for depthwise convolution. template -void LaunchDepthwiseConvOp::operator()(OpKernelContext* ctx, - const DepthwiseArgs args, +void LaunchDepthwiseConvOp::operator()(OpKernelContext* ctx, + const DepthwiseArgs& args, const T* input, const T* filter, T* output, TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device(); + const GpuDevice& device = ctx->eigen_device(); if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dGPU(d, args, input, filter, output, + LaunchDepthwiseConv2dGPU(device, args, input, filter, output, data_format); } else { - LaunchDepthwiseConv2dGPU(d, args, input, filter, output, + LaunchDepthwiseConv2dGPU(device, args, input, filter, output, data_format); } auto stream = ctx->op_device_context()->stream(); @@ -710,9 +750,9 @@ void LaunchDepthwiseConvOp::operator()(OpKernelContext* ctx, "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed")); } -template struct LaunchDepthwiseConvOp; -template struct LaunchDepthwiseConvOp; -template struct LaunchDepthwiseConvOp; +template struct LaunchDepthwiseConvOp; +template struct LaunchDepthwiseConvOp; +template struct LaunchDepthwiseConvOp; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. input. template (0); - const int out_r_start = - tf_max(0, (in_r - filter_rows + pad_rows + stride) / stride); - const int out_r_end = tf_min(out_rows - 1, (in_r + pad_rows) / stride); - const int out_c_start = - tf_max(0, (in_c - filter_cols + pad_cols + stride) / stride); - const int out_c_end = tf_min(out_cols - 1, (in_c + pad_cols) / stride); - - NOUNROLL for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) { - const int f_r = in_r + pad_rows - out_r * stride; + const int out_row_start = + tf_max(0, (in_row - filter_height + pad_height + stride) / stride); + const int out_row_end = + tf_min(out_height - 1, (in_row + pad_height) / stride); + const int out_col_start = + tf_max(0, (in_col - filter_width + pad_width + stride) / stride); + const int out_col_end = + tf_min(out_width - 1, (in_col + pad_width) / stride); + + NOUNROLL for (int out_row = out_row_start; out_row <= out_row_end; + ++out_row) { + const int filter_row = in_row + pad_height - out_row * stride; const int temp_out_backprop_offset = - out_depth * out_cols * (out_r + out_rows * b); - const int temp_filter_offset = filter_cols * f_r; - NOUNROLL for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) { - const int f_c = in_c + pad_cols - out_c * stride; + out_depth * out_width * (out_row + out_height * batch); + const int temp_filter_offset = filter_width * filter_row; + NOUNROLL for (int out_col = out_col_start; out_col <= out_col_end; + ++out_col) { + const int filter_col = in_col + pad_width - out_col * stride; int filter_offset = - depth_multiplier * (in_d + in_depth * (f_c + temp_filter_offset)); + depth_multiplier * + (in_channel + in_depth * (filter_col + temp_filter_offset)); const int out_backprop_offset = - out_depth * out_c + temp_out_backprop_offset; + out_depth * out_col + temp_out_backprop_offset; #pragma unroll 6 for (int i = 0; i < depth_multiplier; ++i) { sum += ldg(out_backprop + out_backprop_offset + - in_d * depth_multiplier + i) * + in_channel * depth_multiplier + i) * ldg(filter + filter_offset + i); } } } const int in_backprop_offset = - in_d + in_depth * (in_c + in_cols * (in_r + in_rows * b)); + in_channel + + in_depth * (in_col + in_width * (in_row + in_height * batch)); in_backprop[in_backprop_offset] = sum; } } @@ -786,99 +832,108 @@ __global__ void __launch_bounds__(640, 2) const T* out_backprop, const T* filter, T* in_backprop, int num_in_backprop) { - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int in_height = args.in_rows; + const int in_width = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; const int depth_multiplier = kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; - const int out_rows = args.out_rows; - const int out_cols = args.out_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_height = args.out_rows; + const int out_width = args.out_cols; const int out_depth = args.out_depth; // TODO(vrv): Consider assigning threads to output and using // atomics for accumulation, similar to the filter case. CUDA_1D_KERNEL_LOOP(thread_id, num_in_backprop) { // Compute the indexes of this thread in the input. - const int in_c = thread_id % in_cols; - const int in_r = (thread_id / in_cols) % in_rows; - const int in_d = (thread_id / in_cols / in_rows) % in_depth; - const int b = thread_id / in_depth / in_cols / in_rows; + const int in_col = thread_id % in_width; + const int in_row = (thread_id / in_width) % in_height; + const int in_channel = (thread_id / in_width / in_height) % in_depth; + const int batch = thread_id / in_depth / in_width / in_height; T sum = static_cast(0); - const int out_d_start = in_d * depth_multiplier; - const int out_d_end = out_d_start + depth_multiplier; - - const int out_r_start = - tf_max(0, (in_r - filter_rows + pad_rows + stride) / stride); - const int out_r_end = tf_min(out_rows - 1, (in_r + pad_rows) / stride); - const int out_c_start = - tf_max(0, (in_c - filter_cols + pad_cols + stride) / stride); - const int out_c_end = tf_min(out_cols - 1, (in_c + pad_cols) / stride); - - UNROLL for (int out_d = out_d_start; out_d < out_d_end; ++out_d) { - UNROLL for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) { - const int f_r = in_r + pad_rows - out_r * stride; - const int filter_dm = out_d - out_d_start; - - const int temp_filter_offset = filter_cols * f_r; - for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) { - const int f_c = in_c + pad_cols - out_c * stride; + const int out_channel_start = in_channel * depth_multiplier; + const int out_channel_end = out_channel_start + depth_multiplier; + + const int out_row_start = + tf_max(0, (in_row - filter_height + pad_height + stride) / stride); + const int out_row_end = + tf_min(out_height - 1, (in_row + pad_height) / stride); + const int out_col_start = + tf_max(0, (in_col - filter_width + pad_width + stride) / stride); + const int out_col_end = + tf_min(out_width - 1, (in_col + pad_width) / stride); + + UNROLL for (int out_channel = out_channel_start; + out_channel < out_channel_end; ++out_channel) { + UNROLL for (int out_row = out_row_start; out_row <= out_row_end; + ++out_row) { + const int filter_row = in_row + pad_height - out_row * stride; + const int filter_dm = out_channel - out_channel_start; + + const int temp_filter_offset = filter_width * filter_row; + for (int out_col = out_col_start; out_col <= out_col_end; ++out_col) { + const int filter_col = in_col + pad_width - out_col * stride; const int filter_offset = - filter_dm + args.depth_multiplier * - (in_d + in_depth * (f_c + temp_filter_offset)); + filter_dm + + args.depth_multiplier * + (in_channel + in_depth * (filter_col + temp_filter_offset)); const int out_backprop_offset = - (b * out_depth * out_rows * out_cols) + - (out_d * out_rows * out_cols) + (out_r * out_cols) + (out_c); + (batch * out_depth * out_height * out_width) + + (out_channel * out_height * out_width) + (out_row * out_width) + + (out_col); sum += ldg(out_backprop + out_backprop_offset) * ldg(filter + filter_offset); } } } - const int in_backprop_offset = (b * in_rows * in_cols * in_depth) + - (in_d * in_rows * in_cols) + - (in_r * in_cols) + (in_c); + const int in_backprop_offset = (batch * in_height * in_width * in_depth) + + (in_channel * in_height * in_width) + + (in_row * in_width) + (in_col); in_backprop[in_backprop_offset] = sum; } } template -void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, - const DepthwiseArgs args, +void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device, + const DepthwiseArgs& args, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); - if (data_format == FORMAT_NHWC) { - kernel = DepthwiseConv2dBackpropInputGPUKernelNHWC< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; - } else if (data_format == FORMAT_NCHW) { - kernel = DepthwiseConv2dBackpropInputGPUKernelNCHW< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; - } else { - assert(false && "Incorrect data format"); - return; + switch (data_format) { + case FORMAT_NHWC: + kernel = DepthwiseConv2dBackpropInputGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + break; + case FORMAT_NCHW: + kernel = DepthwiseConv2dBackpropInputGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + break; + case FORMAT_NCHW_VECT_C: + LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + return; } const int num_in_backprop = args.batch * args.in_rows * args.in_cols * args.in_depth; CudaLaunchConfig config = - GetCudaLaunchConfig(num_in_backprop, d, kernel, 0, 0); - kernel<<>>( + GetCudaLaunchConfig(num_in_backprop, device, kernel, 0, 0); + kernel<<>>( args, out_backprop, filter, in_backprop, num_in_backprop); } template -void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, - const DepthwiseArgs args, +void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device, + const DepthwiseArgs& args, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { @@ -886,32 +941,32 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, if (CanLaunchDepthwiseConv2dGPUSmall(args)) { LaunchDepthwiseConv2dGPUSmall( - d, args, out_backprop, filter, in_backprop, data_format); + device, args, out_backprop, filter, in_backprop, data_format); return; } LaunchDepthwiseConv2dBackpropInputGPU( - d, args, out_backprop, filter, in_backprop, data_format); + device, args, out_backprop, filter, in_backprop, data_format); } else { LaunchDepthwiseConv2dBackpropInputGPU( - d, args, out_backprop, filter, in_backprop, data_format); + device, args, out_backprop, filter, in_backprop, data_format); } } // A simple launch pad to launch the Cuda kernel for depthwise convolution. template -void LaunchDepthwiseConvBackpropInputOp::operator()( +void LaunchDepthwiseConvBackpropInputOp::operator()( OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device(); + const GpuDevice& device = ctx->eigen_device(); if (args.filter_rows == 3 && args.filter_cols == 3) { LaunchDepthwiseConv2dBackpropInputGPU( - d, args, out_backprop, filter, in_backprop, data_format); + device, args, out_backprop, filter, in_backprop, data_format); } else { LaunchDepthwiseConv2dBackpropInputGPU( - d, args, out_backprop, filter, in_backprop, data_format); + device, args, out_backprop, filter, in_backprop, data_format); } auto stream = ctx->op_device_context()->stream(); OP_REQUIRES(ctx, stream->ok(), @@ -920,9 +975,9 @@ void LaunchDepthwiseConvBackpropInputOp::operator()( "utGPULaunch failed")); } -template struct LaunchDepthwiseConvBackpropInputOp; -template struct LaunchDepthwiseConvBackpropInputOp; -template struct LaunchDepthwiseConvBackpropInputOp; +template struct LaunchDepthwiseConvBackpropInputOp; +template struct LaunchDepthwiseConvBackpropInputOp; +template struct LaunchDepthwiseConvBackpropInputOp; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. template = 0 && in_c_start >= 0 && in_r_end < in_rows && - in_c_end < in_cols) { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = in_r_start + f_r; + if (in_row_start >= 0 && in_col_start >= 0 && in_row_end < in_height && + in_col_end < in_width) { + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = in_row_start + filter_row; // Avoid repeated computation. - const int input_offset_temp = in_cols * (in_r + in_rows * b); - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = in_c_start + f_c; + const int input_offset_temp = in_width * (in_row + in_height * batch); + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = in_col_start + filter_col; - const int input_offset = in_d + in_depth * (in_c + input_offset_temp); + const int input_offset = + in_channel + in_depth * (in_col + input_offset_temp); T partial_sum = ldg(input + input_offset) * out_bp; - T* addr = filter_backprop + - (dm + depth_multiplier * - (in_d + in_depth * (f_c + filter_cols * f_r))); + T* addr = + filter_backprop + + (dm + depth_multiplier * + (in_channel + + in_depth * (filter_col + filter_width * filter_row))); CudaAtomicAdd(addr, partial_sum); } } } else { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = in_r_start + f_r; + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = in_row_start + filter_row; // Avoid repeated computation. - const int input_offset_temp = in_cols * (in_r + in_rows * b); - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = in_c_start + f_c; - const int addr_temp = filter_cols * f_r; - - if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) { + const int input_offset_temp = in_width * (in_row + in_height * batch); + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = in_col_start + filter_col; + const int addr_temp = filter_width * filter_row; + + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { const int input_offset = - in_d + in_depth * (in_c + input_offset_temp); + in_channel + in_depth * (in_col + input_offset_temp); T partial_sum = ldg(input + input_offset) * out_bp; T* addr = filter_backprop + - (dm + depth_multiplier * (in_d + in_depth * (f_c + addr_temp))); + (dm + depth_multiplier * + (in_channel + in_depth * (filter_col + addr_temp))); // Potentially many threads can add to the same address so we have // to use atomic add here. // TODO(jmchen): If atomic add turns out to be slow, we can: @@ -1020,7 +1085,7 @@ __global__ void __launch_bounds__(640, 2) // Device function to compute sub-warp sum reduction for a power-of-two group of // neighboring threads. -template +template __device__ __forceinline__ T WarpSumReduce(T val) { // support only power-of-two widths. assert(__popc(kWidth) == 1); @@ -1028,7 +1093,7 @@ __device__ __forceinline__ T WarpSumReduce(T val) { int zeros = sub_warp * kWidth; unsigned mask = ((1UL << kWidth) - 1) << zeros; for (int delta = kWidth / 2; delta > 0; delta /= 2) { - val += CudaShuffleXor(mask, val, delta); + val += CudaShuffleXorSync(mask, val, delta); } return val; } @@ -1045,9 +1110,9 @@ __device__ __forceinline__ T WarpSumReduce(T val) { // memory are warp-accumulated (in chunks of kAccumPixels elements) and summed // up in global memory using atomics. // Requirements: threads per block must be multiple of 32 and <= launch_bounds, -// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockSlices. +// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth. template + int kBlockDepth, int kAccumPixels> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { @@ -1056,40 +1121,42 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; T* const shared_data = reinterpret_cast(shared_memory); - const int batches = args.batch; - const int in_rows = args.in_rows; - const int in_cols = blockDim.y; // slower (see b/62280718): args.in_cols; + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = blockDim.y; // slower (see b/62280718): args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; - const int block_rows = blockDim.z; + assert(blockDim.x == kBlockDepth); + assert(blockDim.y == args.in_cols); + const int block_height = blockDim.z; // These values are the same for all threads and could // be precomputed on the CPU. - const int block_size = block_rows * in_cols * kBlockSlices; + const int block_size = block_height * in_width * kBlockDepth; assert((block_size & 31) == 0); - const int in_row_size = in_cols * in_depth; - const int in_size = in_rows * in_row_size; - const int in_increment = (in_cols - 1) * kBlockSlices; - const int filter_pixels = filter_rows * filter_cols; - const int tile_cols = in_cols + filter_cols - 1; - const int tile_rows = 2 * block_rows + filter_rows - 1; - const int tile_row_size = tile_cols * kBlockSlices; - const int tile_size = tile_rows * tile_row_size; - const int tile_offset = block_rows * tile_row_size; - const int pad_offset = pad_rows * tile_cols + pad_cols; - const int batch_blocks = (in_depth + kBlockSlices - 1) / kBlockSlices; - const int in_blocks = batch_blocks * batches; - const int tensor_offset = block_rows * in_row_size; + const int in_row_size = in_width * in_depth; + const int in_size = in_height * in_row_size; + const int in_increment = (in_width - 1) * kBlockDepth; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int tile_height = 2 * block_height + filter_height - 1; + const int tile_row_size = tile_width * kBlockDepth; + const int tile_size = tile_height * tile_row_size; + const int tile_offset = block_height * tile_row_size; + const int pad_offset = pad_height * tile_width + pad_width; + const int batch_blocks = (in_depth + kBlockDepth - 1) / kBlockDepth; + const int in_blocks = batch_blocks * num_batches; + const int tensor_offset = block_height * in_row_size; // The accumulator has a fixed number of pixels that can be reduced by one - // warp. Pixels beyond ceil(in_pixels * kBlockSlices / 64) are never written. - assert(kAccumPixels * 64 >= in_rows * in_cols * kBlockSlices); - const int accum_increment = kAccumPixels * kBlockSlices; + // warp. Pixels beyond ceil(in_pixels * kBlockDepth / 64) are never written. + assert(kAccumPixels * 64 >= in_height * in_width * kBlockDepth); + const int accum_increment = kAccumPixels * kBlockDepth; const int accum_size = filter_pixels * accum_increment; const int thread_depth = threadIdx.x; @@ -1097,8 +1164,8 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const int thread_row = threadIdx.z; // Position in block. - const int thread_pix = thread_row * in_cols + thread_col; - const int thread_idx = thread_pix * kBlockSlices + thread_depth; + const int thread_pix = thread_row * in_width + thread_col; + const int thread_idx = thread_pix * kBlockDepth + thread_depth; // Initialize tile, in particular the padding and accumulator. for (int i = thread_idx; i < tile_size + accum_size; i += block_size) { @@ -1110,31 +1177,31 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const int tensor_idx = thread_pix * in_depth + thread_depth; // Position in (padded) shared memory. - const int data_pix = thread_row * tile_cols + thread_col; - const int data_idx = data_pix * kBlockSlices + thread_depth; + const int data_pix = thread_row * tile_width + thread_col; + const int data_idx = data_pix * kBlockDepth + thread_depth; - // Position in shared memory, offset by pad_rows / pad_cols. + // Position in shared memory, offset by pad_height / pad_width. const int tile_pix = data_pix + pad_offset; - const int tile_idx = tile_pix * kBlockSlices + thread_depth; + const int tile_idx = tile_pix * kBlockDepth + thread_depth; - // Position in accumulator (kBlockSlices per warp, depth major). - const int accum_pix = thread_pix / (32 / kBlockSlices); + // Position in accumulator (kBlockDepth per warp, depth major). + const int accum_pix = thread_pix / (32 / kBlockDepth); const int accum_idx = thread_depth * kAccumPixels + accum_pix; - const int max_depth = in_depth - thread_depth; + const int max_channel = in_depth - thread_depth; const int accum_offset = tile_size + accum_idx; - const bool skip_second = block_rows + thread_row >= in_rows; + const bool skip_second = block_height + thread_row >= in_height; for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { const int batch = b / batch_blocks; - const int stack = b - batch * batch_blocks; + const int block = b - batch * batch_blocks; - const int start_depth = stack * kBlockSlices; - const int filter_offset = tensor_idx + start_depth; + const int start_channel = block * kBlockDepth; + const int filter_offset = tensor_idx + start_channel; const int inout_offset = batch * in_size + filter_offset; - const bool depth_in_range = start_depth < max_depth; + const bool channel_in_range = start_channel < max_channel; - if (depth_in_range) { + if (channel_in_range) { const T* const in_ptr = inout_offset + input; T* const tile_ptr = tile_idx + shared_data; tile_ptr[0] = ldg(in_ptr); @@ -1145,26 +1212,26 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range); + unsigned active_threads = CudaBallotSync(kCudaWarpAll, channel_in_range); - if (depth_in_range) { + if (channel_in_range) { const T* const out_ptr = inout_offset + output; const T out1 = ldg(out_ptr); const T out2 = skip_second ? T(0) : ldg(tensor_offset + out_ptr); int shared_offset = data_idx; T* accum_ptr = accum_offset + shared_data; - UNROLL for (int r = 0; r < filter_rows; ++r) { - UNROLL for (int c = 0; c < filter_cols; ++c) { + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { const T* const tile_ptr = shared_offset + shared_data; T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; // Warp-accumulate pixels of the same depth and write to accumulator. - for (int delta = 16; delta >= kBlockSlices; delta /= 2) { - val += CudaShuffleDown(active_threads, val, delta); + for (int delta = 16; delta >= kBlockDepth; delta /= 2) { + val += CudaShuffleXorSync(active_threads, val, delta); } - if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) { + if (!(thread_idx & 32 - kBlockDepth) /* lane_idx < kBlockDepth */) { *accum_ptr = val; } - shared_offset += kBlockSlices; + shared_offset += kBlockDepth; accum_ptr += accum_increment; } shared_offset += in_increment; @@ -1177,10 +1244,10 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const T* const accum_data = tile_size + shared_data; for (int i = thread_idx; i < accum_size; i += block_size) { const int filter_idx = i / kAccumPixels; - const int filter_pix = filter_idx / kBlockSlices; - const int filter_depth = filter_idx % kBlockSlices + start_depth; - const int filter_offset = filter_pix * in_depth + filter_depth; - if (filter_depth < in_depth) { + const int filter_pix = filter_idx / kBlockDepth; + const int filter_channel = filter_idx % kBlockDepth + start_channel; + const int filter_offset = filter_pix * in_depth + filter_channel; + if (filter_channel < in_depth) { T val = accum_data[i]; // Warp-accumulate the pixels of the same depth from the accumulator. val = WarpSumReduce(val); @@ -1201,81 +1268,90 @@ __global__ void __launch_bounds__(640, 2) const T* input, T* filter_backprop, int num_out_backprop) { - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int in_height = args.in_rows; + const int in_width = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; const int depth_multiplier = kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; - const int out_rows = args.out_rows; - const int out_cols = args.out_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_height = args.out_rows; + const int out_width = args.out_cols; const int out_depth = args.out_depth; CUDA_1D_KERNEL_LOOP(thread_id, num_out_backprop) { // Compute the indexes of this thread in the output. - const int out_c = thread_id % out_cols; - const int out_r = (thread_id / out_cols) % out_rows; - const int out_d = (thread_id / out_cols / out_rows) % out_depth; + const int out_col = thread_id % out_width; + const int out_row = (thread_id / out_width) % out_height; + const int out_channel = (thread_id / out_width / out_height) % out_depth; - const int b = thread_id / out_depth / out_cols / out_rows; + const int batch = thread_id / out_depth / out_width / out_height; // Compute the input depth and the index of depth multiplier. - const int in_d = out_d / depth_multiplier; - const int dm = out_d % depth_multiplier; + const int in_channel = out_channel / depth_multiplier; + const int dm = out_channel % depth_multiplier; // Decide if all input is valid, if yes, we can skip the boundary checks // for each input. - const int in_r_start = out_r * stride - pad_rows; - const int in_c_start = out_c * stride - pad_cols; - const int in_r_end = in_r_start + filter_rows; - const int in_c_end = in_c_start + filter_cols; + const int in_row_start = out_row * stride - pad_height; + const int in_col_start = out_col * stride - pad_width; + const int in_row_end = in_row_start + filter_height; + const int in_col_end = in_col_start + filter_width; - const int out_backprop_offset = (b * out_depth * out_rows * out_cols) + - (out_d * out_rows * out_cols) + - (out_r * out_cols) + (out_c); + const int out_backprop_offset = + (batch * out_depth * out_height * out_width) + + (out_channel * out_height * out_width) + (out_row * out_width) + + (out_col); const T out_bp = ldg(out_backprop + out_backprop_offset); - if (in_r_start >= 0 && in_c_start >= 0 && in_r_end < in_rows && - in_c_end < in_cols) { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = in_r_start + f_r; + if (in_row_start >= 0 && in_col_start >= 0 && in_row_end < in_height && + in_col_end < in_width) { + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = in_row_start + filter_row; // Avoid repeated computation. - const int input_offset_temp = (b * in_depth * in_rows * in_cols) + - (in_d * in_rows * in_cols) + - (in_r * in_cols); - - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = in_c_start + f_c; - const int input_offset = input_offset_temp + in_c; + const int input_offset_temp = + (batch * in_depth * in_height * in_width) + + (in_channel * in_height * in_width) + (in_row * in_width); + + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = in_col_start + filter_col; + const int input_offset = input_offset_temp + in_col; T partial_sum = ldg(input + input_offset) * out_bp; - T* addr = filter_backprop + - (dm + depth_multiplier * - (in_d + in_depth * (f_c + filter_cols * f_r))); + T* addr = + filter_backprop + + (dm + depth_multiplier * + (in_channel + + in_depth * (filter_col + filter_width * filter_row))); CudaAtomicAdd(addr, partial_sum); } } } else { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = in_r_start + f_r; + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = in_row_start + filter_row; // Avoid repeated computation. - const int input_offset_temp = (b * in_depth * in_rows * in_cols) + - (in_d * in_rows * in_cols) + - (in_r * in_cols); - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = in_c_start + f_c; - const int addr_temp = filter_cols * f_r; - - if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) { - const int input_offset = input_offset_temp + in_c; + const int input_offset_temp = + (batch * in_depth * in_height * in_width) + + (in_channel * in_height * in_width) + (in_row * in_width); + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = in_col_start + filter_col; + const int addr_temp = filter_width * filter_row; + + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { + const int input_offset = input_offset_temp + in_col; T partial_sum = ldg(input + input_offset) * out_bp; T* addr = filter_backprop + - (dm + depth_multiplier * (in_d + in_depth * (f_c + addr_temp))); + (dm + depth_multiplier * + (in_channel + in_depth * (filter_col + addr_temp))); // Potentially many threads can add to the same address so we have // to use atomic add here. // TODO(jmchen): If atomic add turns out to be slow, we can: @@ -1304,9 +1380,9 @@ __global__ void __launch_bounds__(640, 2) // memory are warp-accumulated (in chunks of kAccumPixels elements) and summed // up in global memory using atomics. // Requirements: threads per block must be multiple of 32 and <= launch_bounds, -// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockSlices. +// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth. template + int kBlockDepth, int kAccumPixels> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { @@ -1315,39 +1391,41 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; T* const shared_data = reinterpret_cast(shared_memory); - const int batches = args.batch; - const int in_rows = args.in_rows; - const int in_cols = blockDim.x; // slower (see b/62280718): args.in_cols; + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = blockDim.x; // slower (see b/62280718): args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; - const int block_rows = blockDim.y; + assert(blockDim.x == args.in_cols); + assert(blockDim.z == kBlockDepth); + const int block_height = blockDim.y; // These values are the same for all threads and could // be precomputed on the CPU. - const int block_pixels = in_cols * block_rows; - const int block_size = block_pixels * kBlockSlices; + const int block_pixels = in_width * block_height; + const int block_size = block_pixels * kBlockDepth; assert((block_size & 31) == 0); - const int in_pixels = in_cols * in_rows; - const int in_increment = in_cols - 1; - const int filter_pixels = filter_rows * filter_cols; - const int tile_cols = in_cols + filter_cols - 1; - const int tile_rows = 2 * block_rows + filter_rows - 1; - const int tile_pixels = tile_cols * tile_rows; - const int tile_size = tile_pixels * kBlockSlices; - const int tile_offset = block_rows * tile_cols; - const int pad_offset = pad_rows * tile_cols + pad_cols; - const int in_slices = in_depth * batches; - const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices; + const int in_pixels = in_width * in_height; + const int in_increment = in_width - 1; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int tile_height = 2 * block_height + filter_height - 1; + const int tile_pixels = tile_width * tile_height; + const int tile_size = tile_pixels * kBlockDepth; + const int tile_offset = block_height * tile_width; + const int pad_offset = pad_height * tile_width + pad_width; + const int in_total_depth = in_depth * num_batches; + const int in_blocks = (in_total_depth + kBlockDepth - 1) / kBlockDepth; // The accumulator has a fixed number of pixels that can be reduced by one - // warp. Pixels beyond ceil(in_pixels * kBlockSlices / 64) are never written. - assert(kAccumPixels * 64 >= in_rows * in_cols * kBlockSlices); - const int accum_increment = kAccumPixels * kBlockSlices; + // warp. Pixels beyond ceil(in_pixels * kBlockDepth / 64) are never written. + assert(kAccumPixels * 64 >= in_height * in_width * kBlockDepth); + const int accum_increment = kAccumPixels * kBlockDepth; const int accum_size = filter_pixels * accum_increment; const int thread_col = threadIdx.x; @@ -1355,7 +1433,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const int thread_depth = threadIdx.z; // Position in block. - const int thread_pix = thread_row * in_cols + thread_col; + const int thread_pix = thread_row * in_width + thread_col; const int thread_idx = thread_depth * block_pixels + thread_pix; // Initialize tile, in particular the padding and accumulator. @@ -1368,27 +1446,27 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const int tensor_idx = thread_depth * in_pixels + thread_pix; // Position in (padded) shared memory. - const int data_pix = thread_row * tile_cols + thread_col; + const int data_pix = thread_row * tile_width + thread_col; const int data_idx = thread_depth * tile_pixels + data_pix; - // Position in shared memory, offset by pad_rows / pad_cols. + // Position in shared memory, offset by pad_height / pad_width. const int tile_idx = data_idx + pad_offset; - // Position in accumulator (kBlockSlices per warp, depth major). - const int accum_pix = thread_pix / (32 / kBlockSlices); + // Position in accumulator (kBlockDepth per warp, depth major). + const int accum_pix = thread_pix / (32 / kBlockDepth); const int accum_idx = thread_depth * kAccumPixels + accum_pix; - const int max_slice = in_slices - thread_depth; + const int max_channel = in_total_depth - thread_depth; const int accum_offset = tile_size + accum_idx; - const bool skip_second = block_rows + thread_row >= in_rows; + const bool skip_second = block_height + thread_row >= in_height; for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { - const int slice = b * kBlockSlices; + const int channel = b * kBlockDepth; - const int inout_offset = slice * in_pixels + tensor_idx; - const bool slice_in_range = slice < max_slice; + const int inout_offset = channel * in_pixels + tensor_idx; + const bool channel_in_range = channel < max_channel; - if (slice_in_range) { + if (channel_in_range) { const T* const in_ptr = inout_offset + input; T* const tile_ptr = tile_idx + shared_data; tile_ptr[0] = ldg(in_ptr); @@ -1399,24 +1477,24 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range); + unsigned active_threads = CudaBallotSync(kCudaWarpAll, channel_in_range); - if (slice_in_range) { + if (channel_in_range) { const T* const out_ptr = inout_offset + output; const T out1 = ldg(out_ptr); const T out2 = skip_second ? T(0) : ldg(block_pixels + out_ptr); int shared_offset = data_idx; T* accum_ptr = accum_offset + shared_data; - UNROLL for (int r = 0; r < filter_rows; ++r) { - UNROLL for (int c = 0; c < filter_cols; ++c) { + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { const T* const tile_ptr = shared_offset + shared_data; T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; // Warp-accumulate pixels of the same depth and write to accumulator. - for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) { - val += CudaShuffleDown(active_threads, val, delta); + for (int delta = 16 / kBlockDepth; delta > 0; delta /= 2) { + val += CudaShuffleXorSync(active_threads, val, delta); } - if (!(thread_idx & 32 / kBlockSlices - 1)) { - *accum_ptr = val; + if (!(thread_idx & 32 / kBlockDepth - 1)) { + *accum_ptr = val; // kBlockDepth threads per warp. } ++shared_offset; accum_ptr += accum_increment; @@ -1431,10 +1509,11 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const T* const accum_data = tile_size + shared_data; for (int i = thread_idx; i < accum_size; i += block_size) { const int filter_idx = i / kAccumPixels; - const int filter_pix = filter_idx / kBlockSlices; - const int filter_depth = (slice + filter_idx % kBlockSlices) % in_depth; - const int filter_offset = filter_pix * in_depth + filter_depth; - if (filter_depth < in_depth) { + const int filter_pix = filter_idx / kBlockDepth; + const int filter_channel = + (channel + filter_idx % kBlockDepth) % in_depth; + const int filter_offset = filter_pix * in_depth + filter_channel; + if (filter_channel < in_depth) { T val = accum_data[i]; // Warp-accumulate pixels of the same depth from the accumulator. val = WarpSumReduce(val); @@ -1447,109 +1526,119 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( } template + int kBlockDepth, int kAccumPixels> bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs args, const int block_rows, + const GpuDevice& device, const DepthwiseArgs& args, const int block_height, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - const int tile_cols = args.in_cols + args.filter_cols - 1; - const int tile_rows = block_rows * 2 + args.filter_rows - 1; - const int tile_pixels = tile_rows * tile_cols; + const int tile_width = args.in_cols + args.filter_cols - 1; + const int tile_height = block_height * 2 + args.filter_rows - 1; + const int tile_pixels = tile_height * tile_width; const int filter_pixels = args.filter_rows * args.filter_cols; const int shared_memory_size = - kBlockSlices * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(T); - if (shared_memory_size > d.sharedMemPerBlock()) { + kBlockDepth * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(T); + if (shared_memory_size > device.sharedMemPerBlock()) { return false; } dim3 block_dim; + int block_count; void (*kernel)(const DepthwiseArgs, const T*, const T*, T*); - if (data_format == FORMAT_NHWC) { - block_dim = dim3(kBlockSlices, args.in_cols, block_rows); - kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kAccumPixels>; - } else if (data_format == FORMAT_NCHW) { - block_dim = dim3(args.in_cols, block_rows, kBlockSlices); - kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kAccumPixels>; - } else { - assert(false && "Incorrect data format"); - return false; + switch (data_format) { + case FORMAT_NHWC: + block_dim = dim3(kBlockDepth, args.in_cols, block_height); + block_count = + args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth; + kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; + break; + case FORMAT_NCHW: + block_dim = dim3(args.in_cols, block_height, kBlockDepth); + block_count = + DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth; + kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; + break; + case FORMAT_NCHW_VECT_C: + LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + return false; } - const int num_out_backprop = - args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = - GetCudaLaunchConfig(num_out_backprop, d, kernel, shared_memory_size, - block_dim.x * block_dim.y * block_dim.z); - kernel<<>>( - args, out_backprop, input, filter_backprop); + const int num_out_backprop = args.out_rows * args.out_cols * block_count; + CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( + num_out_backprop, device, kernel, shared_memory_size, + block_dim.x * block_dim.y * block_dim.z); + kernel<<>>(args, out_backprop, input, filter_backprop); return true; } template + int kBlockDepth> bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs args, const int block_rows, + const GpuDevice& device, const DepthwiseArgs& args, const int block_height, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { // Minimize (power of two) kAccumPixels, while satisfying - // kAccumPixels * 32 >= block_rows * in_cols * kBlockSlices. - const int block_pixels = block_rows * args.in_cols * kBlockSlices; + // kAccumPixels * 32 >= block_height * in_width * kBlockDepth. + const int block_pixels = block_height * args.in_cols * kBlockDepth; if (block_pixels > 512) { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 32>( - d, args, block_rows, out_backprop, input, filter_backprop, data_format); + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 32>( + device, args, block_height, out_backprop, input, filter_backprop, + data_format); } else if (block_pixels > 256) { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 16>( - d, args, block_rows, out_backprop, input, filter_backprop, data_format); + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 16>( + device, args, block_height, out_backprop, input, filter_backprop, + data_format); } else { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 8>( - d, args, block_rows, out_backprop, input, filter_backprop, data_format); + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 8>( + device, args, block_height, out_backprop, input, filter_backprop, + data_format); } } template bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs args, const T* out_backprop, + const GpuDevice& device, const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - // Maximize (power of two) kBlockSlices while keeping a block within 1024 + // Maximize (power of two) kBlockDepth while keeping a block within 1024 // threads (2 pixels per thread). - int block_slices = 8; - int block_rows = (args.in_rows + 1) / 2; + int block_depth = 8; + int block_height = (args.in_rows + 1) / 2; int round_mask = 1; - for (; block_slices > 1; block_slices /= 2) { - // args.in_cols * block_rows * kBlockSlices must be multiple of 32. - for (; block_rows * args.in_cols * block_slices & 31; + for (; block_depth > 1; block_depth /= 2) { + // args.in_cols * block_height * kBlockDepth must be multiple of 32. + for (; block_height * args.in_cols * block_depth & 31; round_mask = round_mask * 2 + 1) { - block_rows = block_rows + round_mask & ~round_mask; + block_height = block_height + round_mask & ~round_mask; } - int block_size = block_rows * args.in_cols * block_slices; + int block_size = block_height * args.in_cols * block_depth; if (block_size <= 1024) { break; } } - if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_rows)) { + if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_height)) { return false; } - switch (block_slices) { + switch (block_depth) { case 8: return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, 8>( - d, args, block_rows, out_backprop, input, filter_backprop, + device, args, block_height, out_backprop, input, filter_backprop, data_format); case 4: return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, 4>( - d, args, block_rows, out_backprop, input, filter_backprop, + device, args, block_height, out_backprop, input, filter_backprop, data_format); case 2: return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, 2>( - d, args, block_rows, out_backprop, input, filter_backprop, + device, args, block_height, out_backprop, input, filter_backprop, data_format); default: return false; @@ -1558,59 +1647,62 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( template -void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, - const DepthwiseArgs args, +void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device, + const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); - if (data_format == FORMAT_NHWC) { - kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWC< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; - } else if (data_format == FORMAT_NCHW) { - kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHW< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; - } else { - assert(false && "Incorrect data format"); - return; + switch (data_format) { + case FORMAT_NHWC: + kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + break; + case FORMAT_NCHW: + kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + break; + case FORMAT_NCHW_VECT_C: + LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + return; } const int num_out_backprop = args.batch * args.out_rows * args.out_cols * args.out_depth; CudaLaunchConfig config = - GetCudaLaunchConfig(num_out_backprop, d, kernel, 0, 0); - kernel<<>>( + GetCudaLaunchConfig(num_out_backprop, device, kernel, 0, 0); + kernel<<>>( args, out_backprop, input, filter_backprop, num_out_backprop); } template -void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, - const DepthwiseArgs args, +void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device, + const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { if (args.depth_multiplier == 1) { if (TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - d, args, out_backprop, input, filter_backprop, data_format)) { + device, args, out_backprop, input, filter_backprop, data_format)) { return; } LaunchDepthwiseConv2dBackpropFilterGPU( - d, args, out_backprop, input, filter_backprop, data_format); + device, args, out_backprop, input, filter_backprop, data_format); } else { LaunchDepthwiseConv2dBackpropFilterGPU( - d, args, out_backprop, input, filter_backprop, data_format); + device, args, out_backprop, input, filter_backprop, data_format); } } // A simple launch pad to launch the Cuda kernel for depthwise convolution. template -void LaunchDepthwiseConvBackpropFilterOp::operator()( +void LaunchDepthwiseConvBackpropFilterOp::operator()( OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device(); + const GpuDevice& device = ctx->eigen_device(); auto stream = ctx->op_device_context()->stream(); // Initialize the results to 0. @@ -1622,10 +1714,10 @@ void LaunchDepthwiseConvBackpropFilterOp::operator()( if (args.filter_rows == 3 && args.filter_cols == 3) { LaunchDepthwiseConv2dBackpropFilterGPU( - d, args, out_backprop, input, filter_backprop, data_format); + device, args, out_backprop, input, filter_backprop, data_format); } else { LaunchDepthwiseConv2dBackpropFilterGPU( - d, args, out_backprop, input, filter_backprop, data_format); + device, args, out_backprop, input, filter_backprop, data_format); } OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for " @@ -1633,8 +1725,8 @@ void LaunchDepthwiseConvBackpropFilterOp::operator()( "terGPULaunch failed")); } -template struct LaunchDepthwiseConvBackpropFilterOp; -template struct LaunchDepthwiseConvBackpropFilterOp; -template struct LaunchDepthwiseConvBackpropFilterOp; +template struct LaunchDepthwiseConvBackpropFilterOp; +template struct LaunchDepthwiseConvBackpropFilterOp; +template struct LaunchDepthwiseConvBackpropFilterOp; } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/diag_op.cc b/tensorflow/core/kernels/diag_op.cc index 86fa7dce36afff121dc6ff0642f45c809bc63a3d..d228153d4c76dedd74a4b1db1059bc25ff0a6f77 100644 --- a/tensorflow/core/kernels/diag_op.cc +++ b/tensorflow/core/kernels/diag_op.cc @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { @@ -47,8 +47,9 @@ class DiagOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& diagonal = context->input(0); const int num_dims = diagonal.dims(); - OP_REQUIRES(context, 0 != num_dims, errors::InvalidArgument( - "Input must be at least rank 1, got 0")); + OP_REQUIRES( + context, 0 != num_dims, + errors::InvalidArgument("Input must be at least rank 1, got 0")); TensorShape out_shape; for (int i = 0; i < num_dims; ++i) { out_shape.AddDim(diagonal.dim_size(i)); @@ -60,10 +61,9 @@ class DiagOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output_tensor)); functor::DiagFunctor diagFunc; - Status s = diagFunc(context, - diagonal.NumElements(), - diagonal.flat().data(), - output_tensor->flat().data()); + Status s = + diagFunc(context, diagonal.NumElements(), diagonal.flat().data(), + output_tensor->flat().data()); OP_REQUIRES_OK(context, s); } }; @@ -82,12 +82,12 @@ class DiagPartOp : public OpKernel { errors::InvalidArgument("The rank of the tensor should be \ even and positive, got shape ", tensor.shape().DebugString())); - for (int i = 0; i < out_dims; i++){ - OP_REQUIRES(context, tensor.dim_size(i) == tensor.dim_size(i + out_dims), - errors::InvalidArgument( - "Invalid shape ", tensor.shape().DebugString(), - ": dimensions ", i, " and ", i + out_dims, " do not match.") - ); + for (int i = 0; i < out_dims; i++) { + OP_REQUIRES( + context, tensor.dim_size(i) == tensor.dim_size(i + out_dims), + errors::InvalidArgument("Invalid shape ", + tensor.shape().DebugString(), ": dimensions ", + i, " and ", i + out_dims, " do not match.")); } TensorShape out_shape; @@ -96,13 +96,10 @@ class DiagPartOp : public OpKernel { } Tensor* output = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, out_shape, &output)); + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); functor::DiagPartFunctor diagPartFunc; - Status s = diagPartFunc(context, - out_shape.num_elements(), - tensor.flat().data(), - output->flat().data()); + Status s = diagPartFunc(context, out_shape.num_elements(), + tensor.flat().data(), output->flat().data()); OP_REQUIRES_OK(context, s); } }; @@ -129,9 +126,8 @@ class DiagPartOp : public OpKernel { namespace functor { template struct DiagFunctor { - EIGEN_ALWAYS_INLINE Status - operator() (OpKernelContext* context, const int64 size, - const T* in, T* out) { + EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, + const int64 size, const T* in, T* out) { // This subprocess is responsible for writing values in index range // [start*size, limit*size) auto subDiag = [in, out, size](int64 start, int64 limit) { @@ -143,17 +139,16 @@ struct DiagFunctor { // Here, 5 is a empirical factor of cost_per_unit. auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); - Shard(worker_threads.num_threads, worker_threads.workers, size, - 5 * size, subDiag); + Shard(worker_threads.num_threads, worker_threads.workers, size, 5 * size, + subDiag); return Status::OK(); } }; template struct DiagPartFunctor { - EIGEN_ALWAYS_INLINE Status - operator() (OpKernelContext* context, const int64 size, - const T* in, T* out) { + EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, + const int64 size, const T* in, T* out) { // This subprocess is responsible for extracting values in index range // [start, limit) auto subDiagPart = [in, out, size](int64 start, int64 limit) { @@ -164,14 +159,13 @@ struct DiagPartFunctor { // Here, 5 is a empirical factor of cost_per_unit. auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); - Shard(worker_threads.num_threads, worker_threads.workers, size, - 5, subDiagPart); + Shard(worker_threads.num_threads, worker_threads.workers, size, 5, + subDiagPart); return Status::OK(); } }; } // namespace functor - // Register the CPU kernels. #define REGISTER_DIAGOP(T) \ REGISTER_KERNEL_BUILDER( \ @@ -250,6 +244,4 @@ TF_CALL_complex128(REGISTER_DIAGPARTOP_GPU); #endif // GOOGLE_CUDA - } // namespace tensorflow - diff --git a/tensorflow/core/kernels/diag_op.h b/tensorflow/core/kernels/diag_op.h index c6ca6a2047455649b5197da27a58cb068476e928..baf16ddb4b987fa09de113c0316ec0014c884980 100644 --- a/tensorflow/core/kernels/diag_op.h +++ b/tensorflow/core/kernels/diag_op.h @@ -26,14 +26,14 @@ namespace functor { template struct DiagFunctor { - Status operator() (OpKernelContext* context, const int64 size, - const T* in, T* out); + Status operator()(OpKernelContext* context, const int64 size, const T* in, + T* out); }; template struct DiagPartFunctor { - Status operator() (OpKernelContext* context, const int64 size, - const T* in, T* out); + Status operator()(OpKernelContext* context, const int64 size, const T* in, + T* out); }; } // namespace functor diff --git a/tensorflow/core/kernels/diag_op_gpu.cu.cc b/tensorflow/core/kernels/diag_op_gpu.cu.cc index d3c529d784e3a9ba4a793cd98cff9eb5e74d6090..910f3093b2307526e36bdfad9ac6746dd861d4fd 100644 --- a/tensorflow/core/kernels/diag_op_gpu.cu.cc +++ b/tensorflow/core/kernels/diag_op_gpu.cu.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" #include "tensorflow/core/kernels/diag_op.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { namespace functor { @@ -28,10 +28,8 @@ namespace functor { typedef Eigen::GpuDevice GPUDevice; template -__global__ void DiagCudaKernel(const int num_threads, - const int64 size, - const T* in, - T* out) { +__global__ void DiagCudaKernel(const int num_threads, const int64 size, + const T* in, T* out) { CUDA_1D_KERNEL_LOOP(index, num_threads) { // Fill the diagonal elements or set to zero in other place. if (index % (1 + size) == 0) { @@ -44,9 +42,8 @@ __global__ void DiagCudaKernel(const int num_threads, template struct DiagFunctor { - EIGEN_ALWAYS_INLINE Status - operator() (OpKernelContext* context, const int64 size, - const T* in, T* out) { + EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, + const int64 size, const T* in, T* out) { // Empty tensor couldn't launch the kernel. if (size == 0) { return Status::OK(); @@ -56,25 +53,22 @@ struct DiagFunctor { // so this may overflow for `size*size` in extreme cases, // here is checking the multiplication overflow for integer. if (size && (int(size * size) / size) != size) { - return errors::Internal( - "DiagOp got input size too large."); + return errors::Internal("DiagOp got input size too large."); } int virtual_thread_count = int(size * size); // Launch the GPU kernel. const GPUDevice& device = context->eigen_device(); - CudaLaunchConfig diag_config = GetCudaLaunchConfig( - virtual_thread_count, device); - DiagCudaKernel<<>>( - diag_config.virtual_thread_count, size, in, out); + CudaLaunchConfig diag_config = + GetCudaLaunchConfig(virtual_thread_count, device); + DiagCudaKernel<<>>(diag_config.virtual_thread_count, size, + in, out); auto err = cudaGetLastError(); if (err != cudaSuccess) { return errors::Internal( - "Could not launch DiagOp kernel: ", - cudaGetErrorString(err), "."); + "Could not launch DiagOp kernel: ", cudaGetErrorString(err), "."); } return Status::OK(); } @@ -87,12 +81,9 @@ template struct DiagFunctor; template struct DiagFunctor; template struct DiagFunctor; - template -__global__ void DiagPartCudaKernel(const int num_threads, - const int64 size, - const T* in, - T* out) { +__global__ void DiagPartCudaKernel(const int num_threads, const int64 size, + const T* in, T* out) { CUDA_1D_KERNEL_LOOP(index, num_threads) { out[index] = in[(1 + size) * index]; } @@ -100,9 +91,8 @@ __global__ void DiagPartCudaKernel(const int num_threads, template struct DiagPartFunctor { - EIGEN_ALWAYS_INLINE Status - operator() (OpKernelContext* context, const int64 size, - const T* in, T* out) { + EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, + const int64 size, const T* in, T* out) { // Empty tensor couldn't launch the kernel. if (size == 0) { return Status::OK(); @@ -111,16 +101,14 @@ struct DiagPartFunctor { // Extract the diagonal elements. CudaLaunchConfig diag_config = GetCudaLaunchConfig(size, device); - DiagPartCudaKernel<<>>( - diag_config.virtual_thread_count, size, in, out); + DiagPartCudaKernel<<>>(diag_config.virtual_thread_count, + size, in, out); auto err = cudaGetLastError(); if (err != cudaSuccess) { return errors::Internal( - "Could not launch DiagPartOp kernel: ", - cudaGetErrorString(err), "."); + "Could not launch DiagPartOp kernel: ", cudaGetErrorString(err), "."); } return Status::OK(); } diff --git a/tensorflow/core/kernels/diag_op_test.cc b/tensorflow/core/kernels/diag_op_test.cc index 2d1417854cc06a138a803169495196ac70e70e5d..a708e53dd016d9a004a0cd2ddcdc285b0e6ad6fd 100644 --- a/tensorflow/core/kernels/diag_op_test.cc +++ b/tensorflow/core/kernels/diag_op_test.cc @@ -30,8 +30,8 @@ static Graph* Diag(int n, DataType type) { return g; } -#define BM_DiagDev(N, T, TFTYPE, DEVICE) \ - static void BM_Diag##_##N##_##TFTYPE##_##DEVICE(int iters) { \ +#define BM_DiagDev(N, T, TFTYPE, DEVICE) \ + static void BM_Diag##_##N##_##TFTYPE##_##DEVICE(int iters) { \ testing::UseRealTime(); \ testing::ItemsProcessed(static_cast(iters) * N * N); \ test::Benchmark(#DEVICE, Diag(N, TFTYPE)).Run(iters); \ @@ -51,4 +51,3 @@ BM_Diag(128); BM_Diag(512); } // end namespace tensorflow - diff --git a/tensorflow/core/kernels/dilation_ops.cc b/tensorflow/core/kernels/dilation_ops.cc index 6f5c0e91569eb5d44069a452632ad108e5df7d0d..441a63465c8246e09a8e70535f4b95a94d7acdb3 100644 --- a/tensorflow/core/kernels/dilation_ops.cc +++ b/tensorflow/core/kernels/dilation_ops.cc @@ -91,10 +91,10 @@ void ParseSizes(OpKernelContext* context, const std::vector& strides, filter.shape().DebugString())); const int filter_rows = filter.dim_size(0); const int filter_cols = filter.dim_size(1); - OP_REQUIRES( - context, depth == filter.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - depth, " vs ", filter.dim_size(2))); + OP_REQUIRES(context, depth == filter.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", depth, " vs ", + filter.dim_size(2))); // Effective filter size, after introducing rate - 1 zeros between each // non-zero filter element. @@ -234,10 +234,11 @@ class DilationBackpropInputOp : public OpKernel { // [ batch, out_rows, out_cols, depth ] const int batch = input.dim_size(0); const int depth = input.dim_size(3); - OP_REQUIRES(context, batch == out_backprop.dim_size(0) && - out_rows == out_backprop.dim_size(1) && - out_cols == out_backprop.dim_size(2) && - depth == out_backprop.dim_size(3), + OP_REQUIRES(context, + batch == out_backprop.dim_size(0) && + out_rows == out_backprop.dim_size(1) && + out_cols == out_backprop.dim_size(2) && + depth == out_backprop.dim_size(3), errors::InvalidArgument("out_backprop has incompatible size.")); // The computed in_backprop has the same dimensions as the input: @@ -353,10 +354,11 @@ class DilationBackpropFilterOp : public OpKernel { // [ batch, out_rows, out_cols, depth ] const int batch = input.dim_size(0); const int depth = input.dim_size(3); - OP_REQUIRES(context, batch == out_backprop.dim_size(0) && - out_rows == out_backprop.dim_size(1) && - out_cols == out_backprop.dim_size(2) && - depth == out_backprop.dim_size(3), + OP_REQUIRES(context, + batch == out_backprop.dim_size(0) && + out_rows == out_backprop.dim_size(1) && + out_cols == out_backprop.dim_size(2) && + depth == out_backprop.dim_size(3), errors::InvalidArgument("out_backprop has incompatible size.")); // The computed filter_backprop has the same dimensions as the filter: diff --git a/tensorflow/core/kernels/dilation_ops_gpu.cu.cc b/tensorflow/core/kernels/dilation_ops_gpu.cu.cc index ac0775fbefe601e53aaa6c67529cf9a67a0562c2..c63806a7f68c6981dd0e83373c6bfd598788e338 100644 --- a/tensorflow/core/kernels/dilation_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/dilation_ops_gpu.cu.cc @@ -61,9 +61,8 @@ __global__ void DilationKernel(const int32 nthreads, const T* input_ptr, const int w_in = w_beg + w * rate_cols; if (w_in >= 0 && w_in < input_cols) { const T val = - input_ptr[d + - depth * - (w_in + input_cols * (h_in + input_rows * b))] + + input_ptr[d + depth * (w_in + + input_cols * (h_in + input_rows * b))] + filter_ptr[d + depth * (w + filter_cols * h)]; if (val > cur_val) { cur_val = val; @@ -106,9 +105,8 @@ __global__ void DilationBackpropInputKernel( const int w_in = w_beg + w * rate_cols; if (w_in >= 0 && w_in < input_cols) { const T val = - input_ptr[d + - depth * - (w_in + input_cols * (h_in + input_rows * b))] + + input_ptr[d + depth * (w_in + + input_cols * (h_in + input_rows * b))] + filter_ptr[d + depth * (w + filter_cols * h)]; if (val > cur_val) { cur_val = val; @@ -156,9 +154,8 @@ __global__ void DilationBackpropFilterKernel( const int w_in = w_beg + w * rate_cols; if (w_in >= 0 && w_in < input_cols) { const T val = - input_ptr[d + - depth * - (w_in + input_cols * (h_in + input_rows * b))] + + input_ptr[d + depth * (w_in + + input_cols * (h_in + input_rows * b))] + filter_ptr[d + depth * (w + filter_cols * h)]; if (val > cur_val) { cur_val = val; diff --git a/tensorflow/core/kernels/draw_bounding_box_op.cc b/tensorflow/core/kernels/draw_bounding_box_op.cc index a8818b7385d9d5253588ec40f425b85180c79006..b5d5b880bbbacab07c51fc395b86b4fbbb343d36 100644 --- a/tensorflow/core/kernels/draw_bounding_box_op.cc +++ b/tensorflow/core/kernels/draw_bounding_box_op.cc @@ -29,8 +29,7 @@ template class DrawBoundingBoxesOp : public OpKernel { public: explicit DrawBoundingBoxesOp(OpKernelConstruction* context) - : OpKernel(context) { - } + : OpKernel(context) {} void Compute(OpKernelContext* context) override { const Tensor& images = context->input(0); @@ -94,35 +93,28 @@ class DrawBoundingBoxesOp : public OpKernel { int64 color_index = bb % color_table_length; const int64 min_box_row = static_cast(tboxes(b, bb, 0)) * (height - 1); - const int64 min_box_row_clamp = - std::max(min_box_row, 0); + const int64 min_box_row_clamp = std::max(min_box_row, 0); const int64 max_box_row = static_cast(tboxes(b, bb, 2)) * (height - 1); const int64 max_box_row_clamp = std::min(max_box_row, height - 1); const int64 min_box_col = static_cast(tboxes(b, bb, 1)) * (width - 1); - const int64 min_box_col_clamp = - std::max(min_box_col, 0); + const int64 min_box_col_clamp = std::max(min_box_col, 0); const int64 max_box_col = static_cast(tboxes(b, bb, 3)) * (width - 1); - const int64 max_box_col_clamp = - std::min(max_box_col, width - 1); + const int64 max_box_col_clamp = std::min(max_box_col, width - 1); if (min_box_row > max_box_row || min_box_col > max_box_col) { - LOG(WARNING) << "Bounding box (" << min_box_row - << "," << min_box_col - << "," << max_box_row - << "," << max_box_col + LOG(WARNING) << "Bounding box (" << min_box_row << "," << min_box_col + << "," << max_box_row << "," << max_box_col << ") is inverted and will not be drawn."; continue; } - if (min_box_row >= height || max_box_row < 0 || - min_box_col >= width || max_box_col < 0) { - LOG(WARNING) << "Bounding box (" << min_box_row - << "," << min_box_col - << "," << max_box_row - << "," << max_box_col + if (min_box_row >= height || max_box_row < 0 || min_box_col >= width || + max_box_col < 0) { + LOG(WARNING) << "Bounding box (" << min_box_row << "," << min_box_col + << "," << max_box_row << "," << max_box_col << ") is completely outside the image" << " and will not be drawn."; continue; diff --git a/tensorflow/core/kernels/dynamic_partition_op.cc b/tensorflow/core/kernels/dynamic_partition_op.cc index 861e16b2fd02001e913f548a5b48ca6b7497a8f2..3c988db5e618b976b5b2d45a9bfc386485249826 100644 --- a/tensorflow/core/kernels/dynamic_partition_op.cc +++ b/tensorflow/core/kernels/dynamic_partition_op.cc @@ -103,7 +103,8 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared { // Walk through data and copy the data to the appropriate output tensor const auto data_flat = data->flat(); std::vector, - Eigen::Aligned> > out_vec; + Eigen::Aligned> > + out_vec; out_vec.reserve(num_partitions_); for (int p = 0; p < num_partitions_; p++) { out_vec.push_back(outputs[p]->vec()); @@ -124,7 +125,8 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared { } else { // If data has extra dimensions, use Eigen slices std::vector, - Eigen::Aligned> > out_flat; + Eigen::Aligned> > + out_flat; out_flat.reserve(num_partitions_); for (int p = 0; p < num_partitions_; p++) { out_flat.push_back(outputs[p]->flat_outer_dims()); diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc index 9bb58b13f382970c60b551f448243a2b75e30df3..9dfeccff0e8d2488fec5a1dc7b93f83d2cfedca5 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc @@ -79,9 +79,9 @@ template void RangeInit(const GPUDevice& d, const T start, const T delta, const int32 size, typename TTypes::Flat out) { CudaLaunchConfig config = GetCudaLaunchConfig(size, d); - RangeInitKernel< - T><<>>( - start, delta, size, out.data()); + RangeInitKernel + <<>>( + start, delta, size, out.data()); } // Given *num_runs pairs (key, value), this function moves the value @@ -103,11 +103,10 @@ void CallGatherKernel(const GPUDevice& d, const T* params, const int32* indices, T* out, int64 gather_dim_size, int64 indices_size, int64 slice_size, int64 out_size) { CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d); - GatherOpKernel< - T, int32, - true><<>>( - params, indices, out, gather_dim_size, indices_size, slice_size, - out_size); + GatherOpKernel + <<>>( + params, indices, out, gather_dim_size, indices_size, slice_size, + out_size); } struct IdentityOp { @@ -231,10 +230,10 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { OP_REQUIRES_ASYNC( c, TensorShapeUtils::StartsWith(data.shape(), partitions.shape()), - errors::InvalidArgument("data.shape must start with partitions.shape, ", - "got data.shape = ", data.shape().DebugString(), - ", partitions.shape = ", - partitions.shape().DebugString()), + errors::InvalidArgument( + "data.shape must start with partitions.shape, ", + "got data.shape = ", data.shape().DebugString(), + ", partitions.shape = ", partitions.shape().DebugString()), done); Tensor partition_count; @@ -245,8 +244,9 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { AllocatorAttributes alloc_attr; alloc_attr.set_on_host(true); OP_REQUIRES_OK_ASYNC( - c, c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), - &partition_count, alloc_attr), + c, + c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), + &partition_count, alloc_attr), done); auto e_part_count = partition_count.flat(); for (int i = 0; i < num_partitions_; i++) e_part_count(i) = 0; @@ -259,8 +259,9 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { // Prepare for counting. OP_REQUIRES_OK_ASYNC( - c, c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), - &partition_count), + c, + c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), + &partition_count), done); Tensor indices_out; // Count how many times each partition index occurs. @@ -280,8 +281,9 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { alloc_attr.set_on_host(true); alloc_attr.set_gpu_compatible(true); OP_REQUIRES_OK_ASYNC( - c, c->allocate_temp(partition_count.dtype(), partition_count.shape(), - &cpu_tensor, alloc_attr), + c, + c->allocate_temp(partition_count.dtype(), partition_count.shape(), + &cpu_tensor, alloc_attr), done); perftools::gputools::DeviceMemoryBase wrapped( partition_count.flat().data(), num_partitions_ * sizeof(int32)); @@ -340,9 +342,10 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { indices_in_ptr, indices_out_ptr, N, 0, sizeof(int32) * 8, cu_stream); // Allocate temporary storage. OP_REQUIRES_OK_ASYNC( - c, c->allocate_temp( - DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), - &cub_temp_storage), + c, + c->allocate_temp(DT_INT8, + TensorShape({static_cast(temp_storage_bytes)}), + &cub_temp_storage), done); // Radix-sort the partition information. cub::DeviceRadixSort::SortPairs( @@ -376,8 +379,9 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { zero_functor(device, partition_count->flat()); // Allocate memory for aggregates_out. OP_REQUIRES_OK_ASYNC( - c, c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), - &aggregates_out), + c, + c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), + &aggregates_out), done); // Obtain the pointers to inner buffers. int32* keys_in_ptr = partitions_out.flat().data(); @@ -408,9 +412,10 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { num_runs_ptr, reduction_op, N, cu_stream); // Allocate temporary storage. OP_REQUIRES_OK_ASYNC( - c, c->allocate_temp( - DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), - &cub_temp_storage), + c, + c->allocate_temp(DT_INT8, + TensorShape({static_cast(temp_storage_bytes)}), + &cub_temp_storage), done); // Run reduce-by-key. The effect is that we count how many times // each index appears in partitions. The distinct indices are stored diff --git a/tensorflow/core/kernels/eigen_activations.h b/tensorflow/core/kernels/eigen_activations.h index 99b4b2abe66d9f372f99af1ef6164774e7ebfabc..302033e47c59db2d87483a8e2f1e70d0572b21f9 100644 --- a/tensorflow/core/kernels/eigen_activations.h +++ b/tensorflow/core/kernels/eigen_activations.h @@ -21,13 +21,13 @@ limitations under the License. namespace Eigen { /** scalar_sigmoid_fast_derivative_op - * \ingroup CXX11_NeuralNetworks_Module - * \brief Template functor to compute the fast derivative of a sigmoid - * - * Input should be the backpropagated gradient. - * - * \sa class CwiseUnaryOp, Cwise::sigmoid_fast_derivative() - */ + * \ingroup CXX11_NeuralNetworks_Module + * \brief Template functor to compute the fast derivative of a sigmoid + * + * Input should be the backpropagated gradient. + * + * \sa class CwiseUnaryOp, Cwise::sigmoid_fast_derivative() + */ template struct scalar_sigmoid_fast_derivative_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_fast_derivative_op) @@ -55,13 +55,13 @@ struct functor_traits > { } // namespace internal /** scalar_tanh_fast_derivative_op - * \ingroup CXX11_NeuralNetworks_Module - * \brief Template functor to compute the fast derivative of a tanh - * - * Input should be the backpropagated gradient. - * - * \sa class CwiseUnaryOp, Cwise::tanh_fast_derivative() - */ + * \ingroup CXX11_NeuralNetworks_Module + * \brief Template functor to compute the fast derivative of a tanh + * + * Input should be the backpropagated gradient. + * + * \sa class CwiseUnaryOp, Cwise::tanh_fast_derivative() + */ template struct scalar_tanh_fast_derivative_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_fast_derivative_op) @@ -89,11 +89,11 @@ struct functor_traits > { } // namespace internal /** - * \ingroup CXX11_NeuralNetworks_Module - * \brief Template functor to clip the magnitude of the first scalar. - * - * \sa class CwiseBinaryOp, MatrixBase::Clip - */ + * \ingroup CXX11_NeuralNetworks_Module + * \brief Template functor to clip the magnitude of the first scalar. + * + * \sa class CwiseBinaryOp, MatrixBase::Clip + */ template struct scalar_clip_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_clip_op) diff --git a/tensorflow/core/kernels/eigen_activations_test.cc b/tensorflow/core/kernels/eigen_activations_test.cc index 907233103d8244749410c3198f0ca92ad44769b8..34952f5abb8526f0317ba8a674948fada4dc0ce7 100644 --- a/tensorflow/core/kernels/eigen_activations_test.cc +++ b/tensorflow/core/kernels/eigen_activations_test.cc @@ -23,7 +23,7 @@ namespace { void EigenApprox(float a, float b) { ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3); } -} +} // namespace TEST(EigenBackwardSpatialConvolutionsTest, SigmoidFastDerivative) { const ptrdiff_t depth = 3; diff --git a/tensorflow/core/kernels/eigen_attention.h b/tensorflow/core/kernels/eigen_attention.h index 3a94b8c9933ddbf262552044c73206e1deb9828d..4d86f9deb9902a64764e29ca0371bb68ad4f3370 100644 --- a/tensorflow/core/kernels/eigen_attention.h +++ b/tensorflow/core/kernels/eigen_attention.h @@ -21,35 +21,47 @@ limitations under the License. namespace Eigen { /** ExtractGlimpses - * \ingroup CXX11_NeuralNetworks_Module - * - * \brief Extract glimpses from an input tensor. - * - * The input parameter is expected to be a col-major tensor with a rank of 4 (depth, x, y, and batch). - * The width and height parameters specify the extension of the returned glimpses. - * The offsets parameter specifies the x, y locations of the center of the glimpses relative to the center of the input image. The vector is expected to contain one IndexPair for each image in the batch dimension. - * The normalized boolean indicates if incoming coordinates are normalized so that 0.0 and 1.0 correspond to the minimum and maximum of each height and width dimension. - * The centered boolean indicates if incoming coordinates are centered relative to the image, in which case -1.0 and 1.0 correspond to minimum and maximum of each dimension while 0.0 corresponds to the center. - * - * The result can be assigned to a tensor of rank equal to that of the input. The result will be laid out in col-major order (depth, x, y, batch). - * The dimensions of the result will be equal to the dimensions of the input except for width and height which will be equal to the requested glimpse size. - */ + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Extract glimpses from an input tensor. + * + * The input parameter is expected to be a col-major tensor with a rank of 4 + * (depth, x, y, and batch). The width and height parameters specify the + * extension of the returned glimpses. The offsets parameter specifies the x, y + * locations of the center of the glimpses relative to the center of the input + * image. The vector is expected to contain one IndexPair for each image in the + * batch dimension. The normalized boolean indicates if incoming coordinates are + * normalized so that 0.0 and 1.0 correspond to the minimum and maximum of each + * height and width dimension. The centered boolean indicates if incoming + * coordinates are centered relative to the image, in which case -1.0 and 1.0 + * correspond to minimum and maximum of each dimension while 0.0 corresponds to + * the center. + * + * The result can be assigned to a tensor of rank equal to that of the input. + * The result will be laid out in col-major order (depth, x, y, batch). The + * dimensions of the result will be equal to the dimensions of the input except + * for width and height which will be equal to the requested glimpse size. + */ namespace { template struct GlimpseExtractionOp { GlimpseExtractionOp(const Index width, const Index height, const std::vector >& offsets, - const bool normalized, - const bool centered, - const bool uniform_noise) : - width_(width), height_(height), offsets_(offsets), - normalized_(normalized), centered_(centered), uniform_noise_(uniform_noise) { } + const bool normalized, const bool centered, + const bool uniform_noise) + : width_(width), + height_(height), + offsets_(offsets), + normalized_(normalized), + centered_(centered), + uniform_noise_(uniform_noise) {} template DSizes dimensions(const Input& input) const { typedef typename internal::traits::Index IndexType; typedef TensorRef::Scalar, 4, - internal::traits::Layout, IndexType> > Ref; + internal::traits::Layout, IndexType> > + Ref; Ref in(input); DSizes dims = in.dimensions(); @@ -62,12 +74,12 @@ struct GlimpseExtractionOp { } template - EIGEN_DEVICE_FUNC - void eval(const Input& input, Output& output, const Device& device) const - { + EIGEN_DEVICE_FUNC void eval(const Input& input, Output& output, + const Device& device) const { typedef typename internal::traits::Index IndexType; typedef TensorRef::Scalar, 4, - internal::traits::Layout, IndexType> > Ref; + internal::traits::Layout, IndexType> > + Ref; Ref in(input); const Index num_channels = in.dimension(0); const Index input_width = in.dimension(1); @@ -97,8 +109,8 @@ struct GlimpseExtractionOp { x -= width_ / 2.0f; y -= height_ / 2.0f; - const Index offset_x = (Index) x; - const Index offset_y = (Index) y; + const Index offset_x = (Index)x; + const Index offset_y = (Index)y; Index glimpse_width = width_; Index glimpse_height = height_; bool partial_overlap = false; @@ -135,7 +147,7 @@ struct GlimpseExtractionOp { if (uniform_noise_) { // Initialize the glimpse with uniform noise. typedef typename internal::remove_const< - typename internal::traits::Scalar>::type Scalar; + typename internal::traits::Scalar>::type Scalar; TensorFixedSize > mini; mini.device(device) = input.template chip<3>(i).minimum(); TensorFixedSize > range; @@ -215,21 +227,22 @@ struct GlimpseExtractionOp { const bool centered_; const bool uniform_noise_; }; -} - +} // namespace template -EIGEN_ALWAYS_INLINE -static const TensorCustomUnaryOp::Index>, const Input> +EIGEN_ALWAYS_INLINE static const TensorCustomUnaryOp< + const GlimpseExtractionOp::Index>, + const Input> ExtractGlimpses(const Input& input, const typename internal::traits::Index width, const typename internal::traits::Index height, const std::vector >& offsets, const bool normalized = true, const bool centered = true, - const bool uniform_noise = true) -{ - EIGEN_STATIC_ASSERT(internal::traits::Layout == ColMajor, YOU_MADE_A_PROGRAMMING_MISTAKE); - EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == 4, YOU_MADE_A_PROGRAMMING_MISTAKE); + const bool uniform_noise = true) { + EIGEN_STATIC_ASSERT(internal::traits::Layout == ColMajor, + YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == 4, + YOU_MADE_A_PROGRAMMING_MISTAKE); typedef typename internal::traits::Index Index; const GlimpseExtractionOp op(width, height, offsets, normalized, @@ -237,6 +250,6 @@ ExtractGlimpses(const Input& input, return input.customOp(op); } -} // end namespace Eigen +} // end namespace Eigen #endif // TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ diff --git a/tensorflow/core/kernels/eigen_attention_test.cc b/tensorflow/core/kernels/eigen_attention_test.cc index 3a2eeb05959e8844903eb3b910a893760bb02e74..08f61877182cce36316752b7dd17dee3bd2efaac 100644 --- a/tensorflow/core/kernels/eigen_attention_test.cc +++ b/tensorflow/core/kernels/eigen_attention_test.cc @@ -23,7 +23,7 @@ namespace { void EigenApprox(float a, float b) { ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3); } -} +} // namespace TEST(EigenAttentionTest, Simple) { const ptrdiff_t depth = 3; diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h index aec76978102ed4d5e8d0cca18f1ae4422acc1515..099696105b61c19b7fcc9694fe1d7a3021cb97dc 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h @@ -21,29 +21,29 @@ limitations under the License. namespace Eigen { /** SpatialConvolutionBackwardInput - * \ingroup CXX11_NeuralNetworks_Module - * - * \brief Computes the backprop for the input of a 2D convolution. - * - * The output_backward parameter is expected to be a tensor with a rank of 3 or + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Computes the backprop for the input of a 2D convolution. + * + * The output_backward parameter is expected to be a tensor with a rank of 3 or * more (channels, height, width, and optionally others) - * The kernel parameter is expected to be a 4D tensor (filters, channels, + * The kernel parameter is expected to be a 4D tensor (filters, channels, * kernel_height, kernel_width) - * The output_backward and the kernel must both be in col-major layout. The + * The output_backward and the kernel must both be in col-major layout. The * result will also be in col-major layout. - * - * If row_in_stride, col_in_stride > 1, then applies convolution with holes + * + * If row_in_stride, col_in_stride > 1, then applies convolution with holes * (aka atrous convolution), sampling every row_in_stride, col_in_stride input * pixels. - * - * The result can be assigned to a tensor of rank equal to the rank of the + * + * The result can be assigned to a tensor of rank equal to the rank of the * output_backward. The dimensions of the result will be filters, height, width * (and others if applicable). - * - * It is possible to swap the order of the width and height dimensions provided + * + * It is possible to swap the order of the width and height dimensions provided * that the same order is used in the input, the kernel, and the output. - * - */ + * + */ #ifdef EIGEN_HAS_INDEX_LIST typedef IndexList, type2index<0>, type2index<1>, type2index<1> > ReverseColMajor; @@ -293,29 +293,29 @@ SpatialConvolutionBackwardInput( } /** SpatialConvolutionBackwardKernel - * \ingroup CXX11_NeuralNetworks_Module - * - * \brief Computes the backprop for the filter of a 2D convolution. - * - * The output_backward parameter is expected to be a tensor with a rank of 3 or + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Computes the backprop for the filter of a 2D convolution. + * + * The output_backward parameter is expected to be a tensor with a rank of 3 or * more (channels, height, width, and optionally others) - * The kernel parameter is expected to be a 4D tensor (filters, channels, + * The kernel parameter is expected to be a 4D tensor (filters, channels, * kernel_height, kernel_width) - * The output_backward and the kernel must both be in col-major layout. The + * The output_backward and the kernel must both be in col-major layout. The * result will also be in col-major layout. - * - * If row_in_stride, col_stride > 1, then applies convolution with holes (aka + * + * If row_in_stride, col_stride > 1, then applies convolution with holes (aka * atrous convolution), sampling every row_in_stride, col_in_stride input * pixels. - * - * The result can be assigned to a tensor of rank equal to the rank of the + * + * The result can be assigned to a tensor of rank equal to the rank of the * output_backward. The dimensions of the result will be filters, height, width * (and others if applicable). - * - * It is possible to swap the order of the width and height dimensions provided + * + * It is possible to swap the order of the width and height dimensions provided * that the same order is used in the input, the kernel, and the output. - * - */ + * + */ template EIGEN_ALWAYS_INLINE static const typename internal::conditional< diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc index 1758067829e5b577477c1d86f9cdb4396b46e047..2229ec9659472daee3158c593252907f288d829f 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc @@ -25,7 +25,7 @@ void EigenApprox(float a, float b) { ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3); } static int ceil_div(int a, int b) { return (a + b - 1) / b; } -} +} // namespace TEST(EigenBackwardSpatialConvolutionsTest, test_simple_spatial_convolution_backward_input_valid) { diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h index 972036833fff6753031e97216d524a014bb81cbb..896c9957616037da4ead2dbda8cb2393eaea226f 100644 --- a/tensorflow/core/kernels/eigen_pooling.h +++ b/tensorflow/core/kernels/eigen_pooling.h @@ -309,10 +309,10 @@ struct AvgPoolMeanReducer { _mm512_castsi512_ps( \ _mm512_maskz_set1_epi32(_mm512_cmp_ps_mask(a, b, _CMP_EQ_UQ), -1)) -// The ternarylogic function immediate determines the values in the result -// In the case below, 0xd8 implies (false_mask) ? (b) : (a) -// For details, refer to the vpternlogd instruction table at -// http://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-software-developer-vol-2c-manual.pdf + // The ternarylogic function immediate determines the values in the result + // In the case below, 0xd8 implies (false_mask) ? (b) : (a) + // For details, refer to the vpternlogd instruction table at + // http://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-software-developer-vol-2c-manual.pdf #define psel(a, b, false_mask) \ _mm512_castsi512_ps(_mm512_ternarylogic_epi32( \ diff --git a/tensorflow/core/kernels/eigen_pooling_test.cc b/tensorflow/core/kernels/eigen_pooling_test.cc index 9383972b9fff39deb130d5cecac6f0c7abec5566..47b6665e680268793df18d50395d0b6c6aca0ad0 100644 --- a/tensorflow/core/kernels/eigen_pooling_test.cc +++ b/tensorflow/core/kernels/eigen_pooling_test.cc @@ -23,7 +23,7 @@ namespace { void EigenApprox(float a, float b) { ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3); } -} +} // namespace TEST(EigenPoolingTest, Simple) { const int depth = 10; diff --git a/tensorflow/core/kernels/eigen_softmax.h b/tensorflow/core/kernels/eigen_softmax.h index a2930a726f908ac4862a47104e379e6d30e88477..12148c54b364bbc5ef1dff9b9645303534e7ea12 100644 --- a/tensorflow/core/kernels/eigen_softmax.h +++ b/tensorflow/core/kernels/eigen_softmax.h @@ -21,19 +21,21 @@ limitations under the License. namespace Eigen { /** SoftMax - * \ingroup CXX11_NeuralNetworks_Module - * - * \brief Applies a softmax - * - * The input parameter is expected to be a col-major tensor with a rank of 2 (depth and other). - * - * The result can be assigned to a tensor of rank and dimensions equal to that of the input. The result will be laid out in col-major order. - * -*/ + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies a softmax + * + * The input parameter is expected to be a col-major tensor with a rank of 2 + * (depth and other). + * + * The result can be assigned to a tensor of rank and dimensions equal to that + * of the input. The result will be laid out in col-major order. + * + */ namespace { struct SoftmaxOp { - SoftmaxOp(const float beta) : beta_(beta) { } + SoftmaxOp(const float beta) : beta_(beta) {} template typename Input::Dimensions dimensions(const Input& input) const { @@ -41,8 +43,7 @@ struct SoftmaxOp { } template - void eval(const Input& input, Output& output, const Device& device) const - { + void eval(const Input& input, Output& output, const Device& device) const { #if !defined(EIGEN_HAS_INDEX_LIST) // nvcc doesn't support cxx11 Eigen::array::Index, 1> depth_dim; @@ -56,35 +57,43 @@ struct SoftmaxOp { #else // Take advantage of cxx11 to give the compiler information it can use to // optimize the code. - Eigen::IndexList> depth_dim; - Eigen::IndexList> bcast; + Eigen::IndexList > depth_dim; + Eigen::IndexList > bcast; bcast.set(0, dimensions(input)[0]); - Eigen::IndexList, typename internal::traits::Index> dims2d; + Eigen::IndexList, + typename internal::traits::Index> + dims2d; dims2d.set(1, dimensions(input)[1]); #endif - output.device(device) = ((input - input.maximum(depth_dim).eval().reshape(dims2d).broadcast(bcast)) * beta_).exp(); - output.device(device) = output / (output.sum(depth_dim).eval().reshape(dims2d).broadcast(bcast)); + output.device(device) = + ((input - + input.maximum(depth_dim).eval().reshape(dims2d).broadcast(bcast)) * + beta_) + .exp(); + output.device(device) = + output / + (output.sum(depth_dim).eval().reshape(dims2d).broadcast(bcast)); } private: const float beta_; }; -} - +} // namespace template -EIGEN_ALWAYS_INLINE -static const TensorCustomUnaryOp -SoftMax(const Input& input, const float beta) -{ - EIGEN_STATIC_ASSERT(internal::traits::Layout == ColMajor, YOU_MADE_A_PROGRAMMING_MISTAKE); - EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == 2, YOU_MADE_A_PROGRAMMING_MISTAKE); +EIGEN_ALWAYS_INLINE static const TensorCustomUnaryOp +SoftMax(const Input& input, const float beta) { + EIGEN_STATIC_ASSERT(internal::traits::Layout == ColMajor, + YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == 2, + YOU_MADE_A_PROGRAMMING_MISTAKE); const SoftmaxOp op(beta); return input.customOp(op); } -} // end namespace Eigen +} // end namespace Eigen #endif // TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_ diff --git a/tensorflow/core/kernels/eigen_softmax_test.cc b/tensorflow/core/kernels/eigen_softmax_test.cc index ba681d68ab0d416cd2c7bae9065df9b95638a3e8..7f985d71366487e0426e25e064764c196979b114 100644 --- a/tensorflow/core/kernels/eigen_softmax_test.cc +++ b/tensorflow/core/kernels/eigen_softmax_test.cc @@ -23,7 +23,7 @@ namespace { void EigenApprox(float a, float b) { ASSERT_TRUE(std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * 1e-3); } -} +} // namespace TEST(EigenSoftmaxTest, Simple) { const int depth = 1024; diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h index 2fe64cd72ac06e86cccea31145079451d0b28f88..1acbe3a658070222e99ff874815db9a6b07d4565 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h @@ -877,29 +877,29 @@ struct gemm_pack_rhs< } // end namespace internal /** SpatialConvolution - * \ingroup CXX11_NeuralNetworks_Module - * - * \brief Applies a 2D convolution over a multichannel input image. - * - * The input parameter is expected to be a tensor with a rank of 3 or more + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies a 2D convolution over a multichannel input image. + * + * The input parameter is expected to be a tensor with a rank of 3 or more * (channels, height, width, and optionally others) - * The kernel parameter is expected to be a 4D tensor (filters, channels, + * The kernel parameter is expected to be a 4D tensor (filters, channels, * kernel_height, kernel_width) - * The input and the kernel must both be in col-major layout. The result will + * The input and the kernel must both be in col-major layout. The result will * also be in col-major layout. - * - * If col_in_stride, row_in_stride > 1, then applies convolution with holes + * + * If col_in_stride, row_in_stride > 1, then applies convolution with holes * (aka atrous convolution), sampling every col_in_stride, row_in_stride input * pixels. - * - * The result can be assigned to a tensor of rank equal to the rank of the + * + * The result can be assigned to a tensor of rank equal to the rank of the * input. The dimensions of the result will be filters, height, width (and * others if applicable). - * - * It is possible to swap the order of the width and height dimensions provided + * + * It is possible to swap the order of the width and height dimensions provided * that the same order is used in the input, the kernel, and the output. - * - */ + * + */ template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static const typename internal::conditional< @@ -993,7 +993,7 @@ EIGEN_DEVICE_FUNC default: // Initialize unused variables to avoid a compiler warning out_height = 0; - out_width = 0; + out_width = 0; eigen_assert(false && "unexpected padding"); } diff --git a/tensorflow/core/kernels/encode_jpeg_op.cc b/tensorflow/core/kernels/encode_jpeg_op.cc index 4fcae25aa6eac8b31f78e1d5ae964aed427fc0f4..1a5b0f2b675a85ba2c1dbf0356c3e42b03db22b4 100644 --- a/tensorflow/core/kernels/encode_jpeg_op.cc +++ b/tensorflow/core/kernels/encode_jpeg_op.cc @@ -80,10 +80,11 @@ class EncodeJpegOp : public OpKernel { errors::InvalidArgument("image must be 3-dimensional", image.shape().DebugString())); - OP_REQUIRES(context, FastBoundsCheck(image.NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument( - "Cannot encode images with >= max int32 elements")); + OP_REQUIRES( + context, + FastBoundsCheck(image.NumElements(), std::numeric_limits::max()), + errors::InvalidArgument( + "Cannot encode images with >= max int32 elements")); const int32 dim_size0 = static_cast(image.dim_size(0)); const int32 dim_size1 = static_cast(image.dim_size(1)); @@ -100,9 +101,10 @@ class EncodeJpegOp : public OpKernel { } else if (channels == 3) { adjusted_flags.format = jpeg::FORMAT_RGB; } else { - OP_REQUIRES(context, false, errors::InvalidArgument( - "image must have 1 or 3 channels, got ", - image.shape().DebugString())); + OP_REQUIRES( + context, false, + errors::InvalidArgument("image must have 1 or 3 channels, got ", + image.shape().DebugString())); } } else { if (flags_.format == jpeg::FORMAT_GRAYSCALE) { diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index 268a059275acc62432d59df239abd5869f546064..83cd0e9b47e5480cd562452213aa81c7a4a64a95 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -346,8 +346,9 @@ class SingleSequenceExampleParserOp : public OpKernel { feature_list_sparse_keys[di].scalar()(); } OP_REQUIRES( - ctx, TensorShapeUtils::IsVector( - feature_list_dense_missing_assumed_empty->shape()), + ctx, + TensorShapeUtils::IsVector( + feature_list_dense_missing_assumed_empty->shape()), errors::InvalidArgument( "Expected feature_list_dense_missing_assumed_empty ", "to be a vector, got shape: ", @@ -386,12 +387,12 @@ class SingleSequenceExampleParserOp : public OpKernel { required[d] = (def_value.NumElements() == 0); // No default provided. if (def_value.NumElements() > 0) { - OP_REQUIRES( - ctx, def_value.shape() == attrs_.context_dense_shapes[d], - errors::InvalidArgument( - "def_value[", d, "].shape() == ", - def_value.shape().DebugString(), " != context_dense_shapes_[", - d, "] == ", attrs_.context_dense_shapes[d].DebugString())); + OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d], + errors::InvalidArgument( + "def_value[", d, + "].shape() == ", def_value.shape().DebugString(), + " != context_dense_shapes_[", d, + "] == ", attrs_.context_dense_shapes[d].DebugString())); OP_REQUIRES( ctx, def_value.dtype() == attrs_.context_dense_types[d], errors::InvalidArgument( @@ -576,12 +577,12 @@ class SingleSequenceExampleParserOp : public OpKernel { const Feature& f = fl.feature(t); bool types_match; OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); - OP_REQUIRES( - ctx, types_match, - errors::InvalidArgument( - "Name: ", name, ", Feature list: ", key, ", Index: ", t, - ". Data types don't match. ", "Expected type: ", - DataTypeString(dtype), " Feature is: ", ProtoDebugString(f))); + OP_REQUIRES(ctx, types_match, + errors::InvalidArgument( + "Name: ", name, ", Feature list: ", key, ", Index: ", t, + ". Data types don't match. ", + "Expected type: ", DataTypeString(dtype), + " Feature is: ", ProtoDebugString(f))); OP_REQUIRES_OK(ctx, FeatureDenseCopy(t, name, key, dtype, shape, f, feature_list_dense_values[d])); } diff --git a/tensorflow/core/kernels/fact_op.cc b/tensorflow/core/kernels/fact_op.cc index 4fbf76d2d0d0470c0529353003eb7e086451d57f..4a1aa433bc94e5f190ce75c1b991eaf91210eedf 100644 --- a/tensorflow/core/kernels/fact_op.cc +++ b/tensorflow/core/kernels/fact_op.cc @@ -122,13 +122,9 @@ static string D(const char* s) { return ret; } -REGISTER_KERNEL_BUILDER(Name("Fact") - .Device(DEVICE_CPU) - .Label(D("Yoxmos").c_str()), - FactOpKernel2); -REGISTER_KERNEL_BUILDER(Name("Fact") - .Device(DEVICE_CPU) - .Label(D("yoxmos").c_str()), - FactOpKernel2); +REGISTER_KERNEL_BUILDER( + Name("Fact").Device(DEVICE_CPU).Label(D("Yoxmos").c_str()), FactOpKernel2); +REGISTER_KERNEL_BUILDER( + Name("Fact").Device(DEVICE_CPU).Label(D("yoxmos").c_str()), FactOpKernel2); } // namespace tensorflow diff --git a/tensorflow/core/kernels/fake_quant_ops.cc b/tensorflow/core/kernels/fake_quant_ops.cc index 68762af8cf1e76211c0229163d9dce44fc0ad153..f5e279eca4c6d3492419a507c7d070613e169b64 100644 --- a/tensorflow/core/kernels/fake_quant_ops.cc +++ b/tensorflow/core/kernels/fake_quant_ops.cc @@ -45,7 +45,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; namespace { -bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 8; } +bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 16; } } // namespace // ----------------------------------------------------------------------------- @@ -65,8 +65,9 @@ class FakeQuantWithMinMaxArgsOp " >= ", max_)); int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; @@ -104,8 +105,9 @@ class FakeQuantWithMinMaxArgsGradientOp " >= ", max_)); int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; @@ -175,8 +177,9 @@ class FakeQuantWithMinMaxVarsOp : public OpKernel { : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; @@ -213,8 +216,9 @@ class FakeQuantWithMinMaxVarsGradientOp : public OpKernel { : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; @@ -302,8 +306,9 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel { : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; @@ -348,8 +353,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel { : OpKernel::OpKernel(context) { int num_bits; OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(context, IsNumBitsValid(num_bits), - InvalidArgument("num_bits must be between 2 and 8, inclusive")); + OP_REQUIRES( + context, IsNumBitsValid(num_bits), + InvalidArgument("num_bits must be between 2 and 16, inclusive")); bool narrow_range; OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range)); quant_min_ = narrow_range ? 1 : 0; diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h index 81189866c34819306231edc2073fbdc23fbb9baf..d51acc38ef7e5a865f51ac319a3ad16198714dd9 100644 --- a/tensorflow/core/kernels/fake_quant_ops_functor.h +++ b/tensorflow/core/kernels/fake_quant_ops_functor.h @@ -45,16 +45,16 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void Nudge( const float quant_max_float = static_cast(quant_max); *scale = (max - min) / (quant_max_float - quant_min_float); const float zero_point_from_min = quant_min_float - min / *scale; - const uint8 nudged_zero_point = [zero_point_from_min, quant_min, - quant_min_float, quant_max, - quant_max_float] { + const uint16 nudged_zero_point = [zero_point_from_min, quant_min, + quant_min_float, quant_max, + quant_max_float] { if (zero_point_from_min < quant_min_float) { - return static_cast(quant_min); + return static_cast(quant_min); } if (zero_point_from_min > quant_max_float) { - return static_cast(quant_max); + return static_cast(quant_max); } - return static_cast(StdRound(zero_point_from_min)); + return static_cast(StdRound(zero_point_from_min)); }(); *nudged_min = (quant_min_float - nudged_zero_point) * (*scale); *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); diff --git a/tensorflow/core/kernels/fake_quant_ops_test.cc b/tensorflow/core/kernels/fake_quant_ops_test.cc index 5953db14768fd4e8d6c8537a2bea91c2ca211b17..af3a42135d1fe99da87c1cfafbc2b8eb932a7d2c 100644 --- a/tensorflow/core/kernels/fake_quant_ops_test.cc +++ b/tensorflow/core/kernels/fake_quant_ops_test.cc @@ -378,9 +378,8 @@ TEST_F(QuantOpsTest, WithArgsGradient_RegularRange) { Tensor* output = GetOutput(0); auto input_flat = GetInput(0).flat(); Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); - FillValues(&expected, - {0.0f, input_flat(1), input_flat(2), - input_flat(3), input_flat(4), 0.0f}); + FillValues(&expected, {0.0f, input_flat(1), input_flat(2), + input_flat(3), input_flat(4), 0.0f}); ExpectClose(expected, *output); } @@ -2167,21 +2166,19 @@ TEST_F(QuantOpsTest, Tensor* output_bprop_wrt_input = GetOutput(0); Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3})); auto grad_flat = GetInput(0).flat(); - FillValues(&expected_bprop_wrt_input, - {0.0f, grad_flat(1), grad_flat(2), - grad_flat(3), grad_flat(4), 0.0f}); + FillValues( + &expected_bprop_wrt_input, + {0.0f, grad_flat(1), grad_flat(2), grad_flat(3), grad_flat(4), 0.0f}); ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input); Tensor* output_bprop_wrt_min = GetOutput(1); Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({3})); - FillValues(&expected_bprop_wrt_min, - {grad_flat(0), 0.0f, 0.0f}); + FillValues(&expected_bprop_wrt_min, {grad_flat(0), 0.0f, 0.0f}); ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min); Tensor* output_bprop_wrt_max = GetOutput(2); Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({3})); - FillValues(&expected_bprop_wrt_max, - {0.0f, 0.0f, grad_flat(5)}); + FillValues(&expected_bprop_wrt_max, {0.0f, 0.0f, grad_flat(5)}); ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max); } @@ -2215,21 +2212,19 @@ TEST_F(QuantOpsTest, WithVarsPerChannelDim2GradientNudgedUp_4Bits_NarrowRange) { Tensor* output_bprop_wrt_input = GetOutput(0); Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3})); auto grad_flat = GetInput(0).flat(); - FillValues(&expected_bprop_wrt_input, - {0.0f, grad_flat(1), grad_flat(2), - grad_flat(3), grad_flat(4), 0.0f}); + FillValues( + &expected_bprop_wrt_input, + {0.0f, grad_flat(1), grad_flat(2), grad_flat(3), grad_flat(4), 0.0f}); ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input); Tensor* output_bprop_wrt_min = GetOutput(1); Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({3})); - FillValues(&expected_bprop_wrt_min, - {grad_flat(0), 0.0f, 0.0f}); + FillValues(&expected_bprop_wrt_min, {grad_flat(0), 0.0f, 0.0f}); ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min); Tensor* output_bprop_wrt_max = GetOutput(2); Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({3})); - FillValues(&expected_bprop_wrt_max, - {0.0f, 0.0f, grad_flat(5)}); + FillValues(&expected_bprop_wrt_max, {0.0f, 0.0f, grad_flat(5)}); ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max); } @@ -2270,14 +2265,13 @@ TEST_F(QuantOpsTest, Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4})); auto grad_flat = GetInput(0).flat(); - FillValues( - &expected_bprop_wrt_input, - {0.0f, grad_flat(1), grad_flat(2), 0.0f, - 0.0f, grad_flat(5), grad_flat(6), 0.0f, - 0.0f, grad_flat(9), grad_flat(10), 0.0f, - 0.0f, grad_flat(13), grad_flat(14), 0.0f, - 0.0f, grad_flat(17), grad_flat(18), 0.0f, - 0.0f, grad_flat(21), grad_flat(22), 0.0f}); + FillValues(&expected_bprop_wrt_input, + {0.0f, grad_flat(1), grad_flat(2), 0.0f, + 0.0f, grad_flat(5), grad_flat(6), 0.0f, + 0.0f, grad_flat(9), grad_flat(10), 0.0f, + 0.0f, grad_flat(13), grad_flat(14), 0.0f, + 0.0f, grad_flat(17), grad_flat(18), 0.0f, + 0.0f, grad_flat(21), grad_flat(22), 0.0f}); ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input); Tensor* output_bprop_wrt_min = GetOutput(1); diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc index 82ec87911985abe714490ad74fa19105f850b536..479f7be4b506e4f8721216fb00ea0eff7e0394c2 100644 --- a/tensorflow/core/kernels/fifo_queue.cc +++ b/tensorflow/core/kernels/fifo_queue.cc @@ -255,97 +255,96 @@ void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, // TODO(josh11b): This makes two copies of callback, avoid this if possible. dequeue_attempts_.emplace_back( num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token, - [callback, allow_small_batch, this](Attempt* attempt) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - int64 queue_size = queues_[0].size(); + [callback, allow_small_batch, + this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 queue_size = queues_[0].size(); - if (closed_ && queue_size < attempt->elements_requested) { - // If we don't have enough for a full dequeue, we have - // to reset the attempt tuple. - if (!attempt->tuple.empty()) { - // Restore already-dequeued elements to the front of the - // queue. - for (int64 i = attempt->tuple[0].dim_size(0) - - attempt->elements_requested - 1; - i >= 0; --i) { - for (int j = 0; j < num_components(); ++j) { - PersistentTensor element; - Status s = GetElementComponentFromBatch( - attempt->tuple, i, j, attempt->context, &element); - if (!s.ok()) { - attempt->context->SetStatus( - errors::DataLoss("Failed to restore element from " - "partially-dequeued batch " - "to FIFOQueue: ", - s.error_message())); - } - queues_[j].push_front(element); - } - } - } - if (allow_small_batch && !queues_[0].empty()) { - // Request all remaining elements in the queue. - queue_size = queues_[0].size(); - attempt->tuple.clear(); - attempt->elements_requested = queue_size; - } else { - if (allow_small_batch) { - // There may be some other attempts containing - // values. If so, we'll yield and wait for them - // to add elements to the queue. - if (!enqueue_attempts_.empty()) return kProgress; - } - if (attempt->context->status().ok()) { - attempt->context->SetStatus(errors::OutOfRange( - "FIFOQueue '", name_, "' is closed and has ", - "insufficient elements (requested ", - attempt->elements_requested, ", current size ", - queue_size, ")")); + if (closed_ && queue_size < attempt->elements_requested) { + // If we don't have enough for a full dequeue, we have + // to reset the attempt tuple. + if (!attempt->tuple.empty()) { + // Restore already-dequeued elements to the front of the + // queue. + for (int64 i = attempt->tuple[0].dim_size(0) - + attempt->elements_requested - 1; + i >= 0; --i) { + for (int j = 0; j < num_components(); ++j) { + PersistentTensor element; + Status s = GetElementComponentFromBatch( + attempt->tuple, i, j, attempt->context, &element); + if (!s.ok()) { + attempt->context->SetStatus( + errors::DataLoss("Failed to restore element from " + "partially-dequeued batch " + "to FIFOQueue: ", + s.error_message())); } - return kComplete; + queues_[j].push_front(element); } } + } + if (allow_small_batch && !queues_[0].empty()) { + // Request all remaining elements in the queue. + queue_size = queues_[0].size(); + attempt->tuple.clear(); + attempt->elements_requested = queue_size; + } else { + if (allow_small_batch) { + // There may be some other attempts containing + // values. If so, we'll yield and wait for them + // to add elements to the queue. + if (!enqueue_attempts_.empty()) return kProgress; + } + if (attempt->context->status().ok()) { + attempt->context->SetStatus(errors::OutOfRange( + "FIFOQueue '", name_, "' is closed and has ", + "insufficient elements (requested ", + attempt->elements_requested, ", current size ", + queue_size, ")")); + } + return kComplete; + } + } - RunResult result = kNoProgress; - for (; queue_size > 0; --queue_size) { - if (attempt->tuple.empty()) { - // Only allocate tuple when we have something to dequeue - // so we don't use excessive memory when there are many - // blocked dequeue attempts waiting. - attempt->tuple.reserve(num_components()); - for (int i = 0; i < num_components(); ++i) { - const TensorShape shape = - ManyOutShape(i, attempt->elements_requested); - Tensor element; - attempt->context->SetStatus( - attempt->context->allocate_temp(component_dtypes_[i], - shape, &element)); - if (!attempt->context->status().ok()) return kComplete; - attempt->tuple.emplace_back(element); - } - } - result = kProgress; - Tuple tuple; - DequeueLocked(attempt->context, &tuple); - const int64 index = attempt->tuple[0].dim_size(0) - - attempt->elements_requested; - for (int i = 0; i < num_components(); ++i) { - attempt->context->SetStatus(batch_util::CopyElementToSlice( - std::move(tuple[i]), &attempt->tuple[i], index)); - if (!attempt->context->status().ok()) return kComplete; - } - tuple.clear(); - --attempt->elements_requested; - if (attempt->elements_requested == 0) { - tuple = attempt->tuple; - attempt->done_callback = [callback, tuple]() { - callback(tuple); - }; - return kComplete; - } + RunResult result = kNoProgress; + for (; queue_size > 0; --queue_size) { + if (attempt->tuple.empty()) { + // Only allocate tuple when we have something to dequeue + // so we don't use excessive memory when there are many + // blocked dequeue attempts waiting. + attempt->tuple.reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + const TensorShape shape = + ManyOutShape(i, attempt->elements_requested); + Tensor element; + attempt->context->SetStatus(attempt->context->allocate_temp( + component_dtypes_[i], shape, &element)); + if (!attempt->context->status().ok()) return kComplete; + attempt->tuple.emplace_back(element); } - return result; - }); + } + result = kProgress; + Tuple tuple; + DequeueLocked(attempt->context, &tuple); + const int64 index = + attempt->tuple[0].dim_size(0) - attempt->elements_requested; + for (int i = 0; i < num_components(); ++i) { + attempt->context->SetStatus(batch_util::CopyElementToSlice( + std::move(tuple[i]), &attempt->tuple[i], index)); + if (!attempt->context->status().ok()) return kComplete; + } + tuple.clear(); + --attempt->elements_requested; + if (attempt->elements_requested == 0) { + tuple = attempt->tuple; + attempt->done_callback = [callback, tuple]() { + callback(tuple); + }; + return kComplete; + } + } + return result; + }); } } if (!already_cancelled) { diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc index bde39770dee0a3e66746bb47562f799ab8bb1224..7090417dfdb2d7e433025b1a0f1cdeb5eece10a8 100644 --- a/tensorflow/core/kernels/fill_functor.cc +++ b/tensorflow/core/kernels/fill_functor.cc @@ -18,8 +18,8 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant_encode_decode.h" @@ -60,7 +60,7 @@ DEFINE_SETZERO_CPU(Variant); template void SetZeroFunctor::operator()( const Eigen::SyclDevice& d, typename TTypes::Flat out) { - To32Bit(out).device(d) = To32Bit(out).constant(T(0)); + To32Bit(out).device(d) = To32Bit(out).constant(T(0)); } #define DEFINE_SETZERO_SYCL(T) \ @@ -118,7 +118,8 @@ DEFINE_SETONE_SYCL(double); template struct FillFunctor { - void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes::Flat out, + void operator()(const Eigen::ThreadPoolDevice& d, + typename TTypes::Flat out, typename TTypes::ConstScalar in) { out.device(d) = out.constant(in()); } @@ -150,8 +151,7 @@ struct FillFunctor { } }; -#define DEFINE_FILL_SYCL(T) \ - template struct FillFunctor; +#define DEFINE_FILL_SYCL(T) template struct FillFunctor; DEFINE_FILL_SYCL(float); DEFINE_FILL_SYCL(double); TF_CALL_INTEGRAL_TYPES(DEFINE_FILL_SYCL) diff --git a/tensorflow/core/kernels/fractional_avg_pool_op.cc b/tensorflow/core/kernels/fractional_avg_pool_op.cc index 47f4189c30f10644ca7b040677ebadf439a9dc75..135d0023458b1ef393ab0bc296dc07310347e7ff 100644 --- a/tensorflow/core/kernels/fractional_avg_pool_op.cc +++ b/tensorflow/core/kernels/fractional_avg_pool_op.cc @@ -232,8 +232,9 @@ class FractionalAvgPoolGradOp : public OpKernel { // Grab the inputs. const Tensor& orig_input_tensor_shape = context->input(0); - OP_REQUIRES(context, orig_input_tensor_shape.dims() == 1 && - orig_input_tensor_shape.NumElements() == 4, + OP_REQUIRES(context, + orig_input_tensor_shape.dims() == 1 && + orig_input_tensor_shape.NumElements() == 4, errors::InvalidArgument("original input tensor shape must be" "1-dimensional and 4 elements")); const Tensor& out_backprop = context->input(1); diff --git a/tensorflow/core/kernels/fractional_pool_common.h b/tensorflow/core/kernels/fractional_pool_common.h index df0bbbfa066bca4705ff371d1823f789a1c4e9ef..2d7a230fc00613d91d147d4927403ba270a4d562 100644 --- a/tensorflow/core/kernels/fractional_pool_common.h +++ b/tensorflow/core/kernels/fractional_pool_common.h @@ -57,7 +57,7 @@ static inline void RandomShuffle(Iter first, Iter last, const Random& uniform) { // * sum(generated_diff_pooling_sequence) = input_length // * Let's define floor(input_length / output_length) = K, then // K <= generated_diff_pooling_sequence[i] <= K+1 -// For example, when input_length = 10, output_length = 6, the followings are +// For example, when input_length = 10, output_length = 6, the following are // valid pooling sequence: // * [1, 2, 2, 1, 2, 2] // * [1, 1, 2, 2, 2, 2] diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index ef9e8484132d25e517367862364518ca0baf38af..9d4bc35ba890c251b0800f266e7845e411e7a835 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -253,22 +253,21 @@ class SymbolicGradientOp : public AsyncOpKernel { args.push_back(ctx->input(i)); } std::vector* rets = new std::vector; - lib->Run( - opts, handle, args, rets, [ctx, done, rets](const Status& status) { - if (!status.ok()) { - ctx->SetStatus(status); - } else if (rets->size() != ctx->num_outputs()) { - ctx->SetStatus(errors::InvalidArgument( - "SymGrad expects to return ", ctx->num_outputs(), - " tensor(s), but get ", rets->size(), " tensor(s) instead.")); - } else { - for (size_t i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); - } - } - delete rets; - done(); - }); + lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) { + if (!status.ok()) { + ctx->SetStatus(status); + } else if (rets->size() != ctx->num_outputs()) { + ctx->SetStatus(errors::InvalidArgument( + "SymGrad expects to return ", ctx->num_outputs(), + " tensor(s), but get ", rets->size(), " tensor(s) instead.")); + } else { + for (size_t i = 0; i < rets->size(); ++i) { + ctx->set_output(i, (*rets)[i]); + } + } + delete rets; + done(); + }); } private: diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..b687088db16a31d8ecb74a7a483c35d2c65a74f9 --- /dev/null +++ b/tensorflow/core/kernels/functional_ops.cc @@ -0,0 +1,322 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef FunctionLibraryRuntime::Handle FHandle; +typedef std::vector TensorVec; + +namespace { + +// Helper to instantiate function "func" in the library "lib". +Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func, + FunctionLibraryRuntime::Handle* handle) { + return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle); +} + +// If "t" is a scalar of a supported type, returns t != 0 in "*v". +Status ToBool(gtl::ArraySlice t, bool* v) { + if (t.size() != 1) { + return errors::InvalidArgument( + "Expected a single scalar which can be converted to a boolean, got ", + t.size(), " tensors."); + } + if (TensorShapeUtils::IsScalar(t[0].shape())) { + switch (t[0].dtype()) { +#define CASE(T) \ + case DataTypeToEnum::value: \ + *v = t[0].scalar()() != 0; \ + break; + + CASE(float); + CASE(double); + CASE(int32); + CASE(uint8); + CASE(int16); + CASE(int8); + CASE(int64); +#undef CASE + case DT_BOOL: + *v = t[0].scalar()(); + break; + case DT_STRING: + *v = !t[0].scalar()().empty(); + break; + default: + return errors::InvalidArgument(DataTypeString(t[0].dtype()), + " cannot be converted to a boolean"); + } + } else { + *v = t[0].NumElements() > 0; + } + return Status::OK(); +} + +// Sets "rets" to be the output of "ctx". Validates rets' types based +// on "kernel". +Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx, + gtl::ArraySlice rets) { + if (rets.size() != ctx->num_outputs()) { + return errors::Internal("Expect to produce ", ctx->num_outputs(), + " tensors, but only get ", rets.size()); + } + for (int i = 0; i < rets.size(); ++i) { + if (rets[i].dtype() != kernel->output_type(i)) { + return errors::Internal("Expect ", i, "-th output is of type ", + DataTypeString(kernel->output_type(i)), + " but get ", DataTypeString(rets[i].dtype())); + } + ctx->set_output(i, rets[i]); + } + return Status::OK(); +} + +void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, + bool always_collect_stats) { + opts->step_id = ctx->step_id(); + opts->rendezvous = ctx->rendezvous(); + opts->cancellation_manager = ctx->cancellation_manager(); + if (always_collect_stats) { + opts->stats_collector = ctx->stats_collector(); + } + opts->runner = ctx->runner(); +} + +} // end namespace + +class FunctionalIf : public AsyncOpKernel { + public: + explicit FunctionalIf(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + auto lib = ctx->function_library(); + OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library")); + const NameAttrList* func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &func)); + OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &then_handle_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &func)); + OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_)); + } + + ~FunctionalIf() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + bool cond; + OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond)); + (new State(this, ctx, cond, done))->Start(); + } + + private: + FHandle then_handle_; + FHandle else_handle_; + + class State { + public: + State(FunctionalIf* kernel, OpKernelContext* ctx, bool cond, + DoneCallback done) + : kernel_(kernel), + ctx_(ctx), + cond_(cond), + done_(done), + lib_(CHECK_NOTNULL(ctx_->function_library())) { + SetRunOptions(ctx_, &opts_, true /* always_collect_stats */); + for (int i = 1; i < ctx_->num_inputs(); ++i) { + args_.push_back(ctx_->input(i)); + } + } + + ~State() {} + + void Start() { + FHandle handle = cond_ ? kernel_->then_handle_ : kernel_->else_handle_; + rets_.clear(); + lib_->Run( + // Evaluate one of the branch. + opts_, handle, args_, &rets_, + // Done callback + [this](Status s) { + if (s.ok()) { + s = SetOutputs(kernel_, ctx_, rets_); + } + ctx_->SetStatus(s); + auto done = done_; + delete this; + done(); + }); + } + + private: + FunctionalIf* const kernel_; + OpKernelContext* const ctx_; + const bool cond_; + const DoneCallback done_; + FunctionLibraryRuntime* const lib_; + FunctionLibraryRuntime::Options opts_; + TensorVec args_; + TensorVec rets_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), FunctionalIf); +REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"), + FunctionalIf); + +class FunctionalWhile : public AsyncOpKernel { + public: + explicit FunctionalWhile(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_)); + } + + ~FunctionalWhile() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + auto lib = ctx->function_library(); + OP_REQUIRES_ASYNC(ctx, lib != nullptr, + errors::Internal("No function library"), done); + + // TODO(b/37549631): Because this op has `SetIsStateful()` in its + // op registration, this kernel may be shared by multiple + // subgraphs, which have different associated + // `FunctionLibraryRuntime` objects and hence different `FHandle` + // namespaces. We currently work around this by caching the map + // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two + // functions this op uses. + FHandle cond_handle; + FHandle body_handle; + { + mutex_lock l(mu_); + const auto iter = handles_.find(lib); + if (iter == handles_.end()) { + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), + done); + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), + done); + handles_[lib] = {cond_handle, body_handle}; + } else { + cond_handle = iter->second.first; + body_handle = iter->second.second; + } + } + + (new State(this, ctx, cond_handle, body_handle, done))->Start(); + } + + private: + NameAttrList cond_func_; + NameAttrList body_func_; + + mutex mu_; + std::unordered_map> + handles_ GUARDED_BY(mu_); + + class State { + public: + State(FunctionalWhile* kernel, OpKernelContext* ctx, FHandle cond_handle, + FHandle body_handle, DoneCallback done) + : kernel_(kernel), + ctx_(ctx), + cond_handle_(cond_handle), + body_handle_(body_handle), + done_(done), + lib_(CHECK_NOTNULL(ctx_->function_library())) { + SetRunOptions(ctx_, &opts_, false /* always_collect_stats */); + for (int i = 0; i < ctx_->num_inputs(); ++i) { + args_.push_back(ctx_->input(i)); + } + } + + ~State() {} + + void Start() { EvalCond(); } + + private: + FunctionalWhile* const kernel_; + OpKernelContext* const ctx_; + const FHandle cond_handle_; + const FHandle body_handle_; + const DoneCallback done_; + FunctionLibraryRuntime* const lib_; + FunctionLibraryRuntime::Options opts_; + TensorVec args_; + TensorVec rets_; + + void EvalCond() { + lib_->Run( + // Evaluate the condition. + opts_, cond_handle_, args_, &rets_, + // Done cb. + [this](const Status& s) { + if (!s.ok()) { + return Finish(s); + } + StartBody(); + }); + } + + void StartBody() { + bool cond; + Status s = ToBool(rets_, &cond); + if (!s.ok()) { + return Finish(s); + } + if (!cond) { + return Finish(Status::OK()); + } + rets_.clear(); + lib_->Run( + // Evaluate the body. + opts_, body_handle_, args_, &rets_, + // Done callback + [this](const Status& s) { + if (!s.ok()) { + return Finish(s); + } + if (args_.size() != rets_.size()) { + return Finish(errors::InvalidArgument( + "While loop body returned ", rets_.size(), + " arguments. Expected: ", args_.size())); + } + args_.clear(); + using std::swap; + swap(args_, rets_); + EvalCond(); + }); + } + + void Finish(Status s) { + if (s.ok()) { + s = SetOutputs(kernel_, ctx_, args_); + } + ctx_->SetStatus(s); + done_(); + delete this; + } + }; +}; +REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), FunctionalWhile); +REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), FunctionalWhile); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cu.cc b/tensorflow/core/kernels/fused_batch_norm_op.cu.cc index a8484390b928105cb51216e18e419957f12ad2ac..4a67b2b3a30463448ac97aff96402f6500eeb19a 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cu.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cu.cc @@ -68,7 +68,8 @@ void InvVarianceToVariance::operator()(const Eigen::GpuDevice& d, template void SetNanFunctor::operator()(const Eigen::GpuDevice& d, typename TTypes::Flat out) { - To32Bit(out).device(d) = To32Bit(out).constant(Eigen::NumTraits::quiet_NaN()); + To32Bit(out).device(d) = + To32Bit(out).constant(Eigen::NumTraits::quiet_NaN()); } template class VarianceToInvVariance; diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD index 41af950d7deca5a9ce1e2ca6496ccf40fd72dd87..9a7eca03ce276d26321f01f80ad7f1a0a254e4db 100644 --- a/tensorflow/core/kernels/fuzzing/BUILD +++ b/tensorflow/core/kernels/fuzzing/BUILD @@ -43,6 +43,8 @@ tf_ops_fuzz_target_lib("decode_base64") tf_ops_fuzz_target_lib("encode_jpeg") +tf_ops_fuzz_target_lib("decode_bmp") + tf_ops_fuzz_target_lib("decode_png") tf_ops_fuzz_target_lib("decode_jpeg") diff --git a/tensorflow/core/kernels/fuzzing/decode_base64_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_base64_fuzz.cc index 6d4a9dfdef4609a45d3a38e49a32492408043617..37edd1ce0f95d7f6d6a366f5b0d83bac7f6159d5 100644 --- a/tensorflow/core/kernels/fuzzing/decode_base64_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/decode_base64_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/decode_bmp_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_bmp_fuzz.cc new file mode 100644 index 0000000000000000000000000000000000000000..01c56ac6f67223108768c3e6922d7f193f93d52a --- /dev/null +++ b/tensorflow/core/kernels/fuzzing/decode_bmp_fuzz.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" + +namespace tensorflow { +namespace fuzzing { + +class FuzzDecodeBmp : public FuzzStringInputOp { + SINGLE_INPUT_OP_BUILDER(DT_STRING, DecodeBmp); +}; + +STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeBmp); + +} // end namespace fuzzing +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc index b084a972049cc2b1997df64a2f43a6d79b6b4e6d..f3b24b2341e590adfbeac1a18b6a65fbfd34f598 100644 --- a/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/decode_json_example_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_json_example_fuzz.cc index 9dd795b94e82c48ad037df67f3218ed62feb722e..e9ffad178616a7b0872d461653cb01c40b292d88 100644 --- a/tensorflow/core/kernels/fuzzing/decode_json_example_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/decode_json_example_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/decode_png_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_png_fuzz.cc index 4a68a5b5803f363ab93bf280df54fa8f14206a84..020f18b1895c480748cafbfb8f7f267887db1fba 100644 --- a/tensorflow/core/kernels/fuzzing/decode_png_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/decode_png_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc b/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc index 2d6c82826cf9dad1ca67d6e5ee1d13a059f9c8ea..a8f07f4bad3a7e7ccff4ebefd4c56c695d0b2573 100644 --- a/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/encode_jpeg_fuzz.cc b/tensorflow/core/kernels/fuzzing/encode_jpeg_fuzz.cc index 81b6e491248fda37f602c0365c1e90d4b08f7c2a..f5dd47a052cd098937d66394ed04c66831ee5972 100644 --- a/tensorflow/core/kernels/fuzzing/encode_jpeg_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/encode_jpeg_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc b/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc index d91a351c5969e71385348b76376202c14e86daac..4d736a21602b34b560ea1c8d9ede4645d806ca29 100644 --- a/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/example_proto_fast_parsing_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/fuzz_session.h b/tensorflow/core/kernels/fuzzing/fuzz_session.h index 0c0e548a909a0c87c622449c8ac6f66db29b5b8d..f1f3f199df137b83193c4d1e974dfb401d9ec9ff 100644 --- a/tensorflow/core/kernels/fuzzing/fuzz_session.h +++ b/tensorflow/core/kernels/fuzzing/fuzz_session.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef LEARNING_BRAIN_KERNELS_FUZZING_FUZZ_SESSION_H_ -#define LEARNING_BRAIN_KERNELS_FUZZING_FUZZ_SESSION_H_ +#ifndef TENSORFLOW_CORE_KERNELS_FUZZING_FUZZ_SESSION_H_ +#define TENSORFLOW_CORE_KERNELS_FUZZING_FUZZ_SESSION_H_ #include "tensorflow/cc/framework/scope.h" #include "tensorflow/core/graph/graph.h" @@ -153,4 +153,4 @@ class FuzzStringInputOp : public FuzzSession { } // end namespace fuzzing } // end namespace tensorflow -#endif // LEARNING_BRAIN_KERNELS_FUZZING_FUZZ_SESSION_H_ +#endif // TENSORFLOW_CORE_KERNELS_FUZZING_FUZZ_SESSION_H_ diff --git a/tensorflow/core/kernels/fuzzing/identity_fuzz.cc b/tensorflow/core/kernels/fuzzing/identity_fuzz.cc index ac3a12aa399a3efe532c71c49a092b6cecd6059b..5c3fc4a2795430d1f8f269f42131e882106db7b0 100644 --- a/tensorflow/core/kernels/fuzzing/identity_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/identity_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc index 978fcd102822a6a2690478eaca473eabc6ae83ab..c90ad2cfeb7222f4c75e718fcaea6955567f3a4a 100644 --- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc b/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc index 7d1aa1fbf3a149d25e82b454543a5add522145af..738d78e99a0081a2b9f0f59c94433372acec19e2 100644 --- a/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/fuzzing/string_to_number_fuzz.cc b/tensorflow/core/kernels/fuzzing/string_to_number_fuzz.cc index 94255d215e5292bf77ab1104eb1d36c0cc1d661c..e98363ffbf166782649f3fa12dc2ab70024908cf 100644 --- a/tensorflow/core/kernels/fuzzing/string_to_number_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/string_to_number_fuzz.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/kernels/gather_functor.cc b/tensorflow/core/kernels/gather_functor.cc index dde08b37eacb9edada92f98c5115f694015aad34..e6fefe643b72bd5a169f0c152ac2fee2568462aa 100644 --- a/tensorflow/core/kernels/gather_functor.cc +++ b/tensorflow/core/kernels/gather_functor.cc @@ -25,12 +25,12 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { // Forward declarations of the functor specializations for GPU. -#define DECLARE_GPU_SPECS_INDEX(T, Index) \ - template <> \ - int64 GatherFunctor::operator()( \ +#define DECLARE_GPU_SPECS_INDEX(T, Index) \ + template <> \ + int64 GatherFunctor::operator()( \ OpKernelContext* ctx, typename TTypes::ConstTensor Tparams, \ - typename TTypes::ConstFlat Tindices, \ - typename TTypes::Tensor Tout); \ + typename TTypes::ConstFlat Tindices, \ + typename TTypes::Tensor Tout); \ extern template struct GatherFunctor; #define DECLARE_GPU_SPECS(T) \ diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h index 1e429a037e8b16f5e01766125e1d10ec7567d78d..16ccb03b8502dd626c0dc4f0c10fcfe50224c7b8 100644 --- a/tensorflow/core/kernels/gather_functor.h +++ b/tensorflow/core/kernels/gather_functor.h @@ -18,12 +18,12 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/platform/prefetch.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { @@ -52,21 +52,23 @@ SliceIndex HandleCopies(OpKernelContext* ctx, const size_t slice_bytes = slice_elems * sizeof(T); auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); mutex mu; - // Store the value of invalidate index for printing error information, it's a shared variable. + // Store the value of invalidate index for printing error information, it's a + // shared variable. SliceIndex result = -1; - auto work = [&] (int64 start, int64 end) { + auto work = [&](int64 start, int64 end) { SliceIndex batch_idx = static_cast(start / indices_size); SliceIndex indices_idx = static_cast(start % indices_size); SliceIndex batch_idx_end = static_cast(end / indices_size); SliceIndex indices_idx_end = static_cast(end % indices_size); while ((batch_idx < batch_idx_end) || - (batch_idx == batch_idx_end && indices_idx < indices_idx_end)) { + (batch_idx == batch_idx_end && indices_idx < indices_idx_end)) { SliceIndex i_next = indices_idx + 1; SliceIndex b_next = batch_idx + 1; if ((batch_idx == batch_idx_end && i_next < indices_idx_end) || - (i_next < indices_size)) { - port::prefetch(¶ms(batch_idx, indices(i_next), 0)); + (i_next < indices_size)) { + port::prefetch( + ¶ms(batch_idx, indices(i_next), 0)); port::prefetch(&out(batch_idx, i_next, 0)); b_next = batch_idx; } else if (b_next <= batch_idx_end) { @@ -85,11 +87,12 @@ SliceIndex HandleCopies(OpKernelContext* ctx, // ahead-of-time compilation binary size). if (is_simple_type::value) { // Avoid auto-promotion to Index from SliceIndex by casting. - memcpy(out_base + (batch_idx * indices_size + indices_idx) * slice_elems, - params_base + (batch_idx * static_cast(limit) + - static_cast(index)) * - slice_elems, - slice_bytes); + memcpy( + out_base + (batch_idx * indices_size + indices_idx) * slice_elems, + params_base + (batch_idx * static_cast(limit) + + static_cast(index)) * + slice_elems, + slice_bytes); } else { // For non-"simple" types (e.g. strings). out.template chip<1>(indices_idx) = params.template chip<1>(index); @@ -99,8 +102,8 @@ SliceIndex HandleCopies(OpKernelContext* ctx, } }; - Shard(worker_threads->num_threads, worker_threads->workers, batch_size*indices_size, - slice_elems * sizeof(T), work); + Shard(worker_threads->num_threads, worker_threads->workers, + batch_size * indices_size, slice_elems * sizeof(T), work); return result; } @@ -117,16 +120,16 @@ struct GatherFunctorCPU { bool use_large = (slice_size > std::numeric_limits::max() || params.size() > std::numeric_limits::max() || N > std::numeric_limits::max()); -#define CALL(elems) \ - do { \ - if (use_large) { \ - bad_i = HandleCopies(ctx, params, indices, \ - slice_size, out); \ - } else { \ - const int32 small_slice = static_cast(slice_size); \ - bad_i = HandleCopies(ctx, params, indices, \ - small_slice, out); \ - } \ +#define CALL(elems) \ + do { \ + if (use_large) { \ + bad_i = HandleCopies(ctx, params, indices, \ + slice_size, out); \ + } else { \ + const int32 small_slice = static_cast(slice_size); \ + bad_i = HandleCopies(ctx, params, indices, \ + small_slice, out); \ + } \ } while (0) if (slice_size == 10) @@ -143,7 +146,8 @@ struct GatherFunctorCPU { template struct GatherFunctor { - int64 operator()(OpKernelContext* ctx, typename TTypes::ConstTensor params, + int64 operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, typename TTypes::ConstFlat indices, typename TTypes::Tensor out); }; diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index 239d5d2e990a88bbc8ca5949a07a2aa2a75de2ba..08adf4badbcd9c8baf664b13098f23dfb0584e24 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/mem.h" @@ -106,8 +108,7 @@ class GatherOp : public OpKernel { auto out_flat = out->shaped({outer_size, N, inner_size}); functor::GatherFunctor functor; - int64 bad_i = functor(c, params_flat, - indices_flat, out_flat); + int64 bad_i = functor(c, params_flat, indices_flat, out_flat); OP_REQUIRES( c, bad_i < 0, diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc index f0d7c670a62bf0a520cb37f01beda530d157d5c7..4040bf52bffe638d601f954f9a81d9eda78346a6 100644 --- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc +++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc @@ -46,7 +46,7 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data, GetTopNFloatResults(data, labels, element_count); LOG(INFO) << "=== Dump ranking ==="; for (int i = 0; i < top_n; ++i) { - const std::tuple &entry = queue.top(); + const std::tuple& entry = queue.top(); LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry) << ", " << std::get<0>(entry); queue.pop(); diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h index a360d188cc2246b87af348db9958152418742822..0d43d028cdbea02b820d8ac0c48378524e875e78 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.h +++ b/tensorflow/core/kernels/hexagon/graph_transferer.h @@ -181,8 +181,8 @@ class GraphTransferer { void AppendNodeInputParams(const int id, const Node& node, const std::vector& extra_inputs); - void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, - const int id, const Node& node); + void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const int id, + const Node& node); static std::array BuildShapeArray( const shape_inference::ShapeHandle& shape_handle, diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc index 536d295506c9669b0434059e26094cb70a4f1e87..20b09f144bab5482f2cf1bfa86cf22f0b7ff815e 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc @@ -42,8 +42,7 @@ constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f; class GraphTransfererTest : public ::testing::Test { protected: - void SetUp() final { - } + void SetUp() final {} GraphTransferer gt_; }; @@ -61,7 +60,7 @@ class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions { } } return -1; -} + } private: const std::vector op_types_{"INPUT", "OUTPUT", "Conv2D", diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc index 71bc4187b74cd6501d203aa3779c6d01e01f0d38..3f794dfb1a04cfdd6f7c114e0b2c7c0aac319a61 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc @@ -420,7 +420,7 @@ TEST(GraphTransferer, false, // is_text_proto false, // shape_inference_for_unknown_shape true // dry_run_for_unknown_shape - ); + ); ASSERT_TRUE(status.ok()) << status; prof.Stop(); prof.DumpStatistics("LoadGraphFromProtoFile"); @@ -487,7 +487,7 @@ TEST(GraphTransferer, false, // is_text_proto true, // shape_inference_for_unknown_shape false // dry_run_for_unknown_shape - ); + ); ASSERT_TRUE(status.ok()) << status; prof.Stop(); prof.DumpStatistics("LoadGraphFromProtoFile"); @@ -556,7 +556,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) { false, // is_text_proto false, // shape_inference_for_unknown_shape true // dry_run_for_unknown_shape - ); + ); const GraphTransferInfo& gfi0 = gt0.GetGraphTransferInfo(); ASSERT_TRUE(status.ok()); @@ -576,7 +576,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) { false, // is_text_proto true, // shape_inference_for_unknown_shape false // dry_run_for_unknown_shape - ); + ); const GraphTransferInfo& gfi1 = gt1.GetGraphTransferInfo(); ASSERT_TRUE(status.ok()); diff --git a/tensorflow/core/kernels/hinge-loss.h b/tensorflow/core/kernels/hinge-loss.h index 789a7ce7a3d8ec9e5d918dd75fce8d644a3b5682..d303e9c877e7b7be05205003c26cf66ef8273416 100644 --- a/tensorflow/core/kernels/hinge-loss.h +++ b/tensorflow/core/kernels/hinge-loss.h @@ -50,9 +50,8 @@ class HingeLossUpdater : public DualLossUpdater { // valid value for new dual = 0 // c. new optimal value > 1.0. Then new optimal value should be set to 1.0. const double candidate_optimal_dual = - current_dual + - (label - wx) / - (num_loss_partitions * example_weight * weighted_example_norm); + current_dual + (label - wx) / (num_loss_partitions * example_weight * + weighted_example_norm); if (label * candidate_optimal_dual < 0) { return 0.0; } diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc index c2bb958be8b29c4a6df99cf5533748d7db73179c..a88e9b0ddcdda660cf34a88253ef7c8d1e28029c 100644 --- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc +++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc @@ -17,16 +17,16 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/histogram_op.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "external/cub_archive/cub/device/device_histogram.cuh" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/histogram_op.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/cuda_kernel_helper.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -104,8 +104,8 @@ struct HistogramFixedWidthFunctor { /* num_samples */ num_samples, /* stream */ stream); if (err != cudaSuccess) { - return errors::Internal("Could not launch HistogramRange: ", - cudaGetErrorString(err), "."); + return errors::Internal( + "Could not launch HistogramRange: ", cudaGetErrorString(err), "."); } return Status::OK(); diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc index 1db9263e5d396b4cdb0920db18e5189149128758..a18a72c66dc659ffd372c231524dbf038df6ac22 100644 --- a/tensorflow/core/kernels/identity_op.cc +++ b/tensorflow/core/kernels/identity_op.cc @@ -128,6 +128,7 @@ REGISTER_GPU_KERNEL(Variant); REGISTER_GPU_HOST_KERNEL(int32); REGISTER_GPU_HOST_KERNEL(bool); +REGISTER_GPU_HOST_KERNEL(string); #undef REGISTER_GPU_HOST_KERNEL diff --git a/tensorflow/core/kernels/image_resizer_state.h b/tensorflow/core/kernels/image_resizer_state.h index f088315ff538e821666aa95d9a4c4ed49f7c0b59..faf997be05cccc366bcab618c99c8d39ff25e18b 100644 --- a/tensorflow/core/kernels/image_resizer_state.h +++ b/tensorflow/core/kernels/image_resizer_state.h @@ -109,8 +109,9 @@ struct ImageResizerState { ValidateAndCalculateOutputSize(context, input); if (!context->status().ok()) return; OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({input.dim_size(0), out_height, - out_width, input.dim_size(3)}), + 0, + TensorShape({input.dim_size(0), out_height, + out_width, input.dim_size(3)}), &output)); } @@ -168,8 +169,9 @@ struct ImageResizerGradientState { CalculateResizeScale(original_width, resized_width, align_corners_); output = nullptr; OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({batch_size, original_height, - original_width, channels}), + 0, + TensorShape({batch_size, original_height, + original_width, channels}), &output)); } diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc index e2861ae090ccd48c0408b83a7bc7c0230bf2c1a5..c37055239c28e0ab243ea30b05b2c8af0905766c 100644 --- a/tensorflow/core/kernels/in_topk_op.cc +++ b/tensorflow/core/kernels/in_topk_op.cc @@ -17,11 +17,11 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/bounds_check.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -98,36 +98,36 @@ class InTopK : public OpKernel { int k_; }; -REGISTER_KERNEL_BUILDER( - Name("InTopK").Device(DEVICE_CPU) - .HostMemory("predictions") - .HostMemory("targets") - .HostMemory("precision") - .TypeConstraint("T"), - InTopK); -REGISTER_KERNEL_BUILDER( - Name("InTopK").Device(DEVICE_CPU) - .HostMemory("predictions") - .HostMemory("targets") - .HostMemory("precision") - .TypeConstraint("T"), - InTopK); - -REGISTER_KERNEL_BUILDER( - Name("InTopKV2").Device(DEVICE_CPU) - .HostMemory("predictions") - .HostMemory("targets") - .HostMemory("k") - .HostMemory("precision") - .TypeConstraint("T"), - InTopK); -REGISTER_KERNEL_BUILDER( - Name("InTopKV2").Device(DEVICE_CPU) - .HostMemory("predictions") - .HostMemory("targets") - .HostMemory("k") - .HostMemory("precision") - .TypeConstraint("T"), - InTopK); +REGISTER_KERNEL_BUILDER(Name("InTopK") + .Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("precision") + .TypeConstraint("T"), + InTopK); +REGISTER_KERNEL_BUILDER(Name("InTopK") + .Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("precision") + .TypeConstraint("T"), + InTopK); + +REGISTER_KERNEL_BUILDER(Name("InTopKV2") + .Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("k") + .HostMemory("precision") + .TypeConstraint("T"), + InTopK); +REGISTER_KERNEL_BUILDER(Name("InTopKV2") + .Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("k") + .HostMemory("precision") + .TypeConstraint("T"), + InTopK); } // namespace tensorflow diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index 7728ba850c94aa79feb31d137712692df0f89176..a71d047ed1a381bfc0311f86987f585f51b02536 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -27,13 +27,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SyclDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace functor { template -Status DoParallelConcatUpdate(const Device& d, const Tensor& value, - int32 loc, Tensor* output) { +Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32 loc, + Tensor* output) { auto Tvalue = value.shaped({1, value.NumElements()}); auto Toutput = output->flat_outer_dims(); auto nrows = Toutput.dimension(0); @@ -74,7 +74,7 @@ Status DoParallelConcat(const SyclDevice& d, const Tensor& value, int32 loc, return errors::InvalidArgument("Unsupported data type: ", value.dtype()); } } -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // end namespace functor @@ -207,7 +207,7 @@ REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") .HostMemory("output") .TypeConstraint("T"), ParallelConcatUpdate); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/l2loss_op.cc b/tensorflow/core/kernels/l2loss_op.cc index f8ed9351579ff8cbeeb5f45030e8ff278fa75101..f561287f7a142f4cbcf74225c3f2fde3986c169a 100644 --- a/tensorflow/core/kernels/l2loss_op.cc +++ b/tensorflow/core/kernels/l2loss_op.cc @@ -17,8 +17,8 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/l2loss_op.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc index 36907fb5716fcde3b0efc28cc4edca543432c8f4..b58bcf583480cb50ee7a6be13465e6c6d301295b 100644 --- a/tensorflow/core/kernels/linalg_ops_common.cc +++ b/tensorflow/core/kernels/linalg_ops_common.cc @@ -108,7 +108,6 @@ void LinearAlgebraOp::Compute(OpKernelContext* context) { auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); Shard(worker_threads.num_threads, worker_threads.workers, batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard); - } template diff --git a/tensorflow/core/kernels/lmdb_reader_op.cc b/tensorflow/core/kernels/lmdb_reader_op.cc index 31a427f2c90ad8a321d6004bf7ef85772d8e951f..2474fe4d564b37a7de36a85a6af3820e2bc4ac65 100755 --- a/tensorflow/core/kernels/lmdb_reader_op.cc +++ b/tensorflow/core/kernels/lmdb_reader_op.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/reader_op_kernel.h" #include "tensorflow/core/framework/reader_base.h" +#include "tensorflow/core/framework/reader_op_kernel.h" #include "tensorflow/core/lib/core/errors.h" #include @@ -26,9 +26,8 @@ namespace tensorflow { class LMDBReader : public ReaderBase { public: - LMDBReader(const string& node_name, Env* env) + LMDBReader(const string& node_name, Env* /*unused*/) : ReaderBase(strings::StrCat("LMDBReader '", node_name, "'")), - env_(env), mdb_env_(nullptr), mdb_dbi_(0), mdb_txn_(nullptr), @@ -77,15 +76,13 @@ class LMDBReader : public ReaderBase { *at_end = true; return Status::OK(); } - } - else { + } else { if (Seek(MDB_NEXT) == false) { *at_end = true; return Status::OK(); } } - *key = string(static_cast(mdb_key_.mv_data), - mdb_key_.mv_size); + *key = string(static_cast(mdb_key_.mv_data), mdb_key_.mv_size); *value = string(static_cast(mdb_value_.mv_data), mdb_value_.mv_size); *produced = true; @@ -109,7 +106,6 @@ class LMDBReader : public ReaderBase { } } - Env* const env_; MDB_env* mdb_env_; MDB_dbi mdb_dbi_; @@ -123,13 +119,10 @@ class LMDBReaderOp : public ReaderOpKernel { explicit LMDBReaderOp(OpKernelConstruction* context) : ReaderOpKernel(context) { Env* env = context->env(); - SetReaderFactory([this, env]() { - return new LMDBReader(name(), env); - }); + SetReaderFactory([this, env]() { return new LMDBReader(name(), env); }); } }; -REGISTER_KERNEL_BUILDER(Name("LMDBReader").Device(DEVICE_CPU), - LMDBReaderOp); +REGISTER_KERNEL_BUILDER(Name("LMDBReader").Device(DEVICE_CPU), LMDBReaderOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/logistic-loss.h b/tensorflow/core/kernels/logistic-loss.h index 2765f42bbdc2d3bf3b9ec42f9f225166218fa9d0..6479e6f5dc3795451babd5675f1decc05b670251 100644 --- a/tensorflow/core/kernels/logistic-loss.h +++ b/tensorflow/core/kernels/logistic-loss.h @@ -122,10 +122,9 @@ class LogisticLossUpdater : public DualLossUpdater { num_loss_partitions * weighted_example_norm * example_weight * (0.5 * (1 + tanhx) / label - current_dual); - const double denominator = -2 * label - - num_loss_partitions * weighted_example_norm * - example_weight * (1 - tanhx * tanhx) * 0.5 / - label; + const double denominator = + -2 * label - num_loss_partitions * weighted_example_norm * + example_weight * (1 - tanhx * tanhx) * 0.5 / label; return x - numerator / denominator; } }; diff --git a/tensorflow/core/kernels/loss_test.cc b/tensorflow/core/kernels/loss_test.cc index 89f0677e1f5a7a0301c2d85700ee9954869c50bb..460d65c5c270c43aae4cb8b26b5258c7d4dd9a5f 100644 --- a/tensorflow/core/kernels/loss_test.cc +++ b/tensorflow/core/kernels/loss_test.cc @@ -32,14 +32,17 @@ namespace { TEST(LogisticLoss, ComputePrimalLoss) { LogisticLossUpdater loss_updater; - EXPECT_NEAR(0.693147, loss_updater.ComputePrimalLoss( - 0 /* wx */, 1 /* label */, 1 /* example weight */), + EXPECT_NEAR(0.693147, + loss_updater.ComputePrimalLoss(0 /* wx */, 1 /* label */, + 1 /* example weight */), 1e-3); - EXPECT_NEAR(0.0, loss_updater.ComputePrimalLoss(70 /* wx */, 1 /* label */, - 1 /* example weight */), + EXPECT_NEAR(0.0, + loss_updater.ComputePrimalLoss(70 /* wx */, 1 /* label */, + 1 /* example weight */), 1e-3); - EXPECT_NEAR(0.0, loss_updater.ComputePrimalLoss(-70 /* wx */, -1 /* label */, - 1 /* example weight */), + EXPECT_NEAR(0.0, + loss_updater.ComputePrimalLoss(-70 /* wx */, -1 /* label */, + 1 /* example weight */), 1e-3); } @@ -53,31 +56,35 @@ TEST(LogisticLoss, ComputeDualLoss) { loss_updater.ComputeDualLoss(1 /* current dual */, 1 /* label */, 1 /* example weight */), 1e-3); - EXPECT_NEAR(-0.693147, loss_updater.ComputeDualLoss(0.5 /* current dual */, - 1 /* label */, - 1 /* example weight */), - 1e-3); + EXPECT_NEAR( + -0.693147, + loss_updater.ComputeDualLoss(0.5 /* current dual */, 1 /* label */, + 1 /* example weight */), + 1e-3); } TEST(LogisticLoss, ComputeUpdatedDual) { LogisticLossUpdater loss_updater; - EXPECT_NEAR(0.479, loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, 1.0 /* label */, - 1.0 /* example weight */, 0.5 /* current_dual */, - 0.3 /* wx */, 10.0 /* weighted_example_norm */), + EXPECT_NEAR(0.479, + loss_updater.ComputeUpdatedDual( + 1 /* num partitions */, 1.0 /* label */, + 1.0 /* example weight */, 0.5 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */), 1e-3); - EXPECT_NEAR(-0.031, loss_updater.ComputeUpdatedDual( - 2 /* num partitions */, -1.0 /* label */, - 1.0 /* example weight */, 0.1 /* current_dual */, - -0.8 /* wx */, 10.0 /* weighted_example_norm */), + EXPECT_NEAR(-0.031, + loss_updater.ComputeUpdatedDual( + 2 /* num partitions */, -1.0 /* label */, + 1.0 /* example weight */, 0.1 /* current_dual */, + -0.8 /* wx */, 10.0 /* weighted_example_norm */), 1e-3); } TEST(SquaredLoss, ComputePrimalLoss) { SquaredLossUpdater loss_updater; - EXPECT_NEAR(0.5, loss_updater.ComputePrimalLoss(0.0 /* wx */, 1.0 /* label */, - 1.0 /* example weight */), + EXPECT_NEAR(0.5, + loss_updater.ComputePrimalLoss(0.0 /* wx */, 1.0 /* label */, + 1.0 /* example weight */), 1e-3); EXPECT_NEAR(40.5, loss_updater.ComputePrimalLoss(10.0 /* wx */, 1.0 /* label */, @@ -95,43 +102,50 @@ TEST(SquaredLoss, ComputePrimalLoss) { TEST(SquaredLoss, ComputeDualLoss) { SquaredLossUpdater loss_updater; - EXPECT_NEAR(0.0, loss_updater.ComputeDualLoss(0.0 /* current dual */, - -1.0 /* label */, - 1.0 /* example weight */), - 1e-3); - EXPECT_NEAR(0.66, loss_updater.ComputeDualLoss(0.2 /* current dual */, - -1.0 /* label */, - 3.0 /* example weight */), - 1e-3); - EXPECT_NEAR(-0.375, loss_updater.ComputeDualLoss(1.5 /* current dual */, - 1.0 /* label */, - 1.0 /* example weight */), - 1e-3); - EXPECT_NEAR(-1.125, loss_updater.ComputeDualLoss(0.5 /* current dual */, - 1.0 /* label */, - 3.0 /* example weight */), - 1e-3); + EXPECT_NEAR( + 0.0, + loss_updater.ComputeDualLoss(0.0 /* current dual */, -1.0 /* label */, + 1.0 /* example weight */), + 1e-3); + EXPECT_NEAR( + 0.66, + loss_updater.ComputeDualLoss(0.2 /* current dual */, -1.0 /* label */, + 3.0 /* example weight */), + 1e-3); + EXPECT_NEAR( + -0.375, + loss_updater.ComputeDualLoss(1.5 /* current dual */, 1.0 /* label */, + 1.0 /* example weight */), + 1e-3); + EXPECT_NEAR( + -1.125, + loss_updater.ComputeDualLoss(0.5 /* current dual */, 1.0 /* label */, + 3.0 /* example weight */), + 1e-3); } TEST(SquaredLoss, ComputeUpdatedDual) { SquaredLossUpdater loss_updater; - EXPECT_NEAR(0.336, loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, 1.0 /* label */, - 1.0 /* example weight */, 0.3 /* current_dual */, - 0.3 /* wx */, 10.0 /* weighted_example_norm */), + EXPECT_NEAR(0.336, + loss_updater.ComputeUpdatedDual( + 1 /* num partitions */, 1.0 /* label */, + 1.0 /* example weight */, 0.3 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */), 1e-3); - EXPECT_NEAR(-0.427, loss_updater.ComputeUpdatedDual( - 5 /* num partitions */, -1.0 /* label */, - 1.0 /* example weight */, -0.4 /* current_dual */, - 0.8 /* wx */, 10.0 /* weighted_example_norm */), + EXPECT_NEAR(-0.427, + loss_updater.ComputeUpdatedDual( + 5 /* num partitions */, -1.0 /* label */, + 1.0 /* example weight */, -0.4 /* current_dual */, + 0.8 /* wx */, 10.0 /* weighted_example_norm */), 1e-3); } TEST(HingeLoss, ComputePrimalLoss) { HingeLossUpdater loss_updater; - EXPECT_NEAR(1.0, loss_updater.ComputePrimalLoss(0.0 /* wx */, 1.0 /* label */, - 1.0 /* example weight */), + EXPECT_NEAR(1.0, + loss_updater.ComputePrimalLoss(0.0 /* wx */, 1.0 /* label */, + 1.0 /* example weight */), 1e-3); EXPECT_NEAR(0.0, loss_updater.ComputePrimalLoss(10.0 /* wx */, 1.0 /* label */, @@ -149,10 +163,11 @@ TEST(HingeLoss, ComputePrimalLoss) { TEST(HingeLoss, ComputeDualLoss) { HingeLossUpdater loss_updater; - EXPECT_NEAR(0.0, loss_updater.ComputeDualLoss(0.0 /* current dual */, - -1.0 /* label */, - 1.0 /* example weight */), - 1e-3); + EXPECT_NEAR( + 0.0, + loss_updater.ComputeDualLoss(0.0 /* current dual */, -1.0 /* label */, + 1.0 /* example weight */), + 1e-3); EXPECT_NEAR( std::numeric_limits::max(), loss_updater.ComputeDualLoss(0.2 /* current dual */, -1.0 /* label */, @@ -163,10 +178,11 @@ TEST(HingeLoss, ComputeDualLoss) { loss_updater.ComputeDualLoss(1.5 /* current dual */, 1.0 /* label */, 1.0 /* example weight */), 1e-3); - EXPECT_NEAR(-1.5, loss_updater.ComputeDualLoss(0.5 /* current dual */, - 1.0 /* label */, - 3.0 /* example weight */), - 1e-3); + EXPECT_NEAR( + -1.5, + loss_updater.ComputeDualLoss(0.5 /* current dual */, 1.0 /* label */, + 3.0 /* example weight */), + 1e-3); } TEST(HingeLoss, ConvertLabel) { @@ -195,28 +211,31 @@ TEST(HingeLoss, ComputeUpdatedDual) { // weighted_example_norm=100.0, it turns out that the optimal value to update // the dual to is 0.507 which is within the permitted range and thus should be // the value returned. - EXPECT_NEAR(0.507, loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, 1.0 /* label */, - 1.0 /* example weight */, 0.5 /* current_dual */, - 0.3 /* wx */, 100.0 /* weighted_example_norm */), + EXPECT_NEAR(0.507, + loss_updater.ComputeUpdatedDual( + 1 /* num partitions */, 1.0 /* label */, + 1.0 /* example weight */, 0.5 /* current_dual */, + 0.3 /* wx */, 100.0 /* weighted_example_norm */), 1e-3); // When label=-1.0, example_weight=1.0, current_dual=0.4, wx=0.6, // weighted_example_norm=10.0 and num_loss_partitions=10, it turns out that // the optimal value to update the dual to is 0.384 which is within the // permitted range and thus should be the value returned. - EXPECT_NEAR(-0.416, loss_updater.ComputeUpdatedDual( - 10 /* num partitions */, -1.0 /* label */, - 1.0 /* example weight */, -0.4 /* current_dual */, - 0.6 /* wx */, 10.0 /* weighted_example_norm */), + EXPECT_NEAR(-0.416, + loss_updater.ComputeUpdatedDual( + 10 /* num partitions */, -1.0 /* label */, + 1.0 /* example weight */, -0.4 /* current_dual */, + 0.6 /* wx */, 10.0 /* weighted_example_norm */), 1e-3); // When label=1.0, example_weight=1.0, current_dual=-0.5, wx=0.3 and // weighted_example_norm=10.0, it turns out that the optimal value to update // the dual to is -0.43. However, this is outside the allowed [0.0, 1.0] range // and hence the closest permitted value (0.0) should be returned instead. - EXPECT_NEAR(0.0, loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, 1.0 /* label */, - 1.0 /* example weight */, -0.5 /* current_dual */, - 0.3 /* wx */, 10.0 /* weighted_example_norm */), + EXPECT_NEAR(0.0, + loss_updater.ComputeUpdatedDual( + 1 /* num partitions */, 1.0 /* label */, + 1.0 /* example weight */, -0.5 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */), 1e-3); // When label=-1.0, example_weight=2.0, current_dual=-1.0, wx=0.3 and @@ -224,17 +243,19 @@ TEST(HingeLoss, ComputeUpdatedDual) { // the dual to is -1.065. However, this is outside the allowed [-1.0, 0.0] // range and hence the closest permitted value (-1.0) should be returned // instead. - EXPECT_NEAR(-1.0, loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, -1.0 /* label */, - 2.0 /* example weight */, -1.0 /* current_dual */, - 0.3 /* wx */, 10.0 /* weighted_example_norm */), + EXPECT_NEAR(-1.0, + loss_updater.ComputeUpdatedDual( + 1 /* num partitions */, -1.0 /* label */, + 2.0 /* example weight */, -1.0 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */), 1e-3); } TEST(SmoothHingeLoss, ComputePrimalLoss) { SmoothHingeLossUpdater loss_updater; - EXPECT_NEAR(0.5, loss_updater.ComputePrimalLoss(0.0 /* wx */, 1.0 /* label */, - 1.0 /* example weight */), + EXPECT_NEAR(0.5, + loss_updater.ComputePrimalLoss(0.0 /* wx */, 1.0 /* label */, + 1.0 /* example weight */), 1e-3); EXPECT_NEAR(0.0, loss_updater.ComputePrimalLoss(10.0 /* wx */, 1.0 /* label */, @@ -252,10 +273,11 @@ TEST(SmoothHingeLoss, ComputePrimalLoss) { TEST(SmoothHingeLoss, ComputeDualLoss) { SmoothHingeLossUpdater loss_updater; - EXPECT_NEAR(0.0, loss_updater.ComputeDualLoss(0.0 /* current dual */, - -1.0 /* label */, - 1.0 /* example weight */), - 1e-3); + EXPECT_NEAR( + 0.0, + loss_updater.ComputeDualLoss(0.0 /* current dual */, -1.0 /* label */, + 1.0 /* example weight */), + 1e-3); EXPECT_NEAR( std::numeric_limits::max(), loss_updater.ComputeDualLoss(0.2 /* current dual */, -1.0 /* label */, @@ -266,24 +288,27 @@ TEST(SmoothHingeLoss, ComputeDualLoss) { loss_updater.ComputeDualLoss(1.5 /* current dual */, 1.0 /* label */, 1.0 /* example weight */), 1e-3); - EXPECT_NEAR(-1.125, loss_updater.ComputeDualLoss(0.5 /* current dual */, - 1.0 /* label */, - 3.0 /* example weight */), - 1e-3); + EXPECT_NEAR( + -1.125, + loss_updater.ComputeDualLoss(0.5 /* current dual */, 1.0 /* label */, + 3.0 /* example weight */), + 1e-3); } TEST(SmoothHingeLoss, ComputeUpdatedDual) { SmoothHingeLossUpdater loss_updater; - EXPECT_NEAR(0.336, loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, 1.0 /* label */, - 1.0 /* example weight */, 0.3 /* current_dual */, - 0.3 /* wx */, 10.0 /* weighted_example_norm */), + EXPECT_NEAR(0.336, + loss_updater.ComputeUpdatedDual( + 1 /* num partitions */, 1.0 /* label */, + 1.0 /* example weight */, 0.3 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */), 1e-3); - EXPECT_NEAR(-0.427, loss_updater.ComputeUpdatedDual( - 5 /* num partitions */, -1.0 /* label */, - 1.0 /* example weight */, -0.4 /* current_dual */, - 0.8 /* wx */, 10.0 /* weighted_example_norm */), + EXPECT_NEAR(-0.427, + loss_updater.ComputeUpdatedDual( + 5 /* num partitions */, -1.0 /* label */, + 1.0 /* example weight */, -0.4 /* current_dual */, + 0.8 /* wx */, 10.0 /* weighted_example_norm */), 1e-3); } diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc index c905ebc84a6e9251a5e30be19b086d3fae215cad..c3a59c95762ad03f217768a9b14e31d6f501d789 100644 --- a/tensorflow/core/kernels/lrn_op.cc +++ b/tensorflow/core/kernels/lrn_op.cc @@ -229,10 +229,11 @@ class LRNOp : public OpKernel { explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) { int64 depth_radius64; OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); - OP_REQUIRES(context, FastBoundsCheck(depth_radius64, - std::numeric_limits::max()), - errors::InvalidArgument("depth_radius = ", depth_radius64, - " larger than int max")); + OP_REQUIRES( + context, + FastBoundsCheck(depth_radius64, std::numeric_limits::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); depth_radius_ = static_cast(depth_radius64); float tmp; OP_REQUIRES_OK(context, context->GetAttr("bias", &tmp)); @@ -247,9 +248,10 @@ class LRNOp : public OpKernel { const Tensor& in = context->input(0); OP_REQUIRES(context, in.dims() == 4, errors::InvalidArgument("in must be 4-dimensional")); - OP_REQUIRES(context, FastBoundsCheck(in.NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("argument to LRN too large")); + OP_REQUIRES( + context, + FastBoundsCheck(in.NumElements(), std::numeric_limits::max()), + errors::InvalidArgument("argument to LRN too large")); // Cast to platform-specific int to avoid conversion warnings. const int batch = static_cast(in.dim_size(0)); const int rows = static_cast(in.dim_size(1)); @@ -448,10 +450,11 @@ class LRNGradOp : public OpKernel { explicit LRNGradOp(OpKernelConstruction* context) : OpKernel(context) { int64 depth_radius64; OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); - OP_REQUIRES(context, FastBoundsCheck(depth_radius64, - std::numeric_limits::max()), - errors::InvalidArgument("depth_radius = ", depth_radius64, - " larger than int max")); + OP_REQUIRES( + context, + FastBoundsCheck(depth_radius64, std::numeric_limits::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); depth_radius_ = static_cast(depth_radius64); float tmp; OP_REQUIRES_OK(context, context->GetAttr("bias", &tmp)); diff --git a/tensorflow/core/kernels/matching_files_op.cc b/tensorflow/core/kernels/matching_files_op.cc index 5eb060f6641d1565417dd074a95bf72e2a81e472..cdff7bad5fe222b6f0824a742caa0a4e5d939f71 100644 --- a/tensorflow/core/kernels/matching_files_op.cc +++ b/tensorflow/core/kernels/matching_files_op.cc @@ -45,15 +45,14 @@ class MatchingFilesOp : public OpKernel { int num_files = 0; std::vector> all_fnames(num_patterns); for (int i = 0; i < num_patterns; i++) { - OP_REQUIRES_OK( - context, - context->env()->GetMatchingPaths(patterns(i), &all_fnames[i])); + OP_REQUIRES_OK(context, context->env()->GetMatchingPaths(patterns(i), + &all_fnames[i])); num_files += all_fnames[i].size(); } Tensor* output_t = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output( - "filenames", TensorShape({num_files}), &output_t)); + OP_REQUIRES_OK( + context, context->allocate_output("filenames", TensorShape({num_files}), + &output_t)); auto output = output_t->vec(); int index = 0; for (int i = 0; i < num_patterns; ++i) { diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index cb68690f2847709fe6ff38f3eecd974613856dcf..f499ce6519d097c7fea05e8175d08d102880f7fd 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -261,12 +261,12 @@ struct LaunchMatMul { std::vector* algorithms, bool use_autotune, Tensor* out) { using perftools::gputools::blas::AlgorithmConfig; using perftools::gputools::blas::ComputationType; - using perftools::gputools::blas::ProfileResult; - using perftools::gputools::blas::Transpose; using perftools::gputools::blas::kDefaultAlgorithm; using perftools::gputools::blas::kDefaultBlasGemm; using perftools::gputools::blas::kDefaultBlasGemv; using perftools::gputools::blas::kNoAlgorithm; + using perftools::gputools::blas::ProfileResult; + using perftools::gputools::blas::Transpose; Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose}; const uint64 m = a.dim_size(1 - dim_pair[0].first); const uint64 k = a.dim_size(dim_pair[0].first); diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h index 6398da2fb959b0bded9afad8c92be923e44c755c..628895ca86f9c86c5bda987dcade9a4a7af753d8 100644 --- a/tensorflow/core/kernels/matmul_op.h +++ b/tensorflow/core/kernels/matmul_op.h @@ -30,7 +30,8 @@ struct MatMulTypes { typedef Eigen::TensorMap, Eigen::Aligned> out_type; typedef Eigen::TensorMap, - Eigen::Aligned> in_type; + Eigen::Aligned> + in_type; }; template ()(); + + auto as_int64_scalar = [](const Tensor& tensor) -> int64 { + if (tensor.dtype() == DT_INT32) { + return tensor.scalar()(); + } else { + return tensor.scalar()(); + } + }; + const int64 num_lower = as_int64_scalar(num_lower_in); OP_REQUIRES( context, num_lower <= input_reshaped.dimension(1), errors::InvalidArgument( @@ -73,7 +81,7 @@ class MatrixBandPartOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()), errors::InvalidArgument("num_upper must be scalar, got shape ", num_upper_in.shape().DebugString())); - const int64 num_upper = num_upper_in.scalar()(); + const int64 num_upper = as_int64_scalar(num_upper_in); OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2), errors::InvalidArgument("num_upper must be negative or less or " "equal to number of columns (", diff --git a/tensorflow/core/kernels/matrix_exponential_op.cc b/tensorflow/core/kernels/matrix_exponential_op.cc index 4cc3f32f7e4a727fa2d9ec3c21a3750111f46392..99db898301378f7ad55f75b3a403a09a5f59bb3b 100644 --- a/tensorflow/core/kernels/matrix_exponential_op.cc +++ b/tensorflow/core/kernels/matrix_exponential_op.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" - namespace tensorflow { template @@ -40,7 +39,8 @@ class MatrixExponentialOp : public LinearAlgebraOp { MatrixMaps* outputs) final { const ConstMatrixMap& input = inputs[0]; if (input.rows() == 0) return; - using Matrix = Eigen::Matrix; + using Matrix = + Eigen::Matrix; Matrix tmp = input; outputs->at(0) = tmp.exp(); } @@ -51,9 +51,9 @@ class MatrixExponentialOp : public LinearAlgebraOp { REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp), float); REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp), double); -REGISTER_LINALG_OP("MatrixExponential", - (MatrixExponentialOp), complex64); -REGISTER_LINALG_OP("MatrixExponential", - (MatrixExponentialOp), complex128); +REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp), + complex64); +REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp), + complex128); } // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_logarithm_op.cc b/tensorflow/core/kernels/matrix_logarithm_op.cc index cf0007b5b6776d0c8a297067f3a49ca21a132ac0..22ca094e2432723a49afab8a255339fc8ac2512e 100644 --- a/tensorflow/core/kernels/matrix_logarithm_op.cc +++ b/tensorflow/core/kernels/matrix_logarithm_op.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" - namespace tensorflow { template @@ -40,7 +39,8 @@ class MatrixLogarithmOp : public LinearAlgebraOp { MatrixMaps* outputs) final { const ConstMatrixMap& input = inputs[0]; if (input.rows() == 0) return; - using Matrix = Eigen::Matrix; + using Matrix = + Eigen::Matrix; Matrix tmp = input; outputs->at(0) = tmp.log(); } @@ -53,9 +53,9 @@ class MatrixLogarithmOp : public LinearAlgebraOp { // logarithm. If all eigenvalues are positive, then this returns the correct // logarithm, however checking for positive definiteness adds significant // overhead. Therefore at present we only register this Op for complex types. -REGISTER_LINALG_OP("MatrixLogarithm", - (MatrixLogarithmOp), complex64); -REGISTER_LINALG_OP("MatrixLogarithm", - (MatrixLogarithmOp), complex128); +REGISTER_LINALG_OP("MatrixLogarithm", (MatrixLogarithmOp), + complex64); +REGISTER_LINALG_OP("MatrixLogarithm", (MatrixLogarithmOp), + complex128); } // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc index 9dd665392bc33e1559d46d0e7be2277e8c22a20a..502d593474e06cc495854706a1d4d90014ea8f96 100644 --- a/tensorflow/core/kernels/matrix_set_diag_op.cc +++ b/tensorflow/core/kernels/matrix_set_diag_op.cc @@ -69,8 +69,8 @@ class MatrixSetDiagOp : public OpKernel { errors::InvalidArgument( "must have diagonal.shape == input.shape[:-2] + " "min(input.shape[-2:]), but received input shape: ", - input_shape.DebugString(), " and diagonal shape: ", - diag_shape.DebugString())); + input_shape.DebugString(), + " and diagonal shape: ", diag_shape.DebugString())); if (input.NumElements() == 0) { // This is a no-op. diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index 2eefadad4949fd8d78f6a27533ce0385c38d9c69..9be7408012bb81e80c73c29a6ee9bb6763c04490 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/kernels/maxpooling_op.h" #include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #if GOOGLE_CUDA #include "tensorflow/core/kernels/maxpooling_op_gpu.h" @@ -89,7 +89,6 @@ static void SpatialMaxPoolWithArgMaxHelper( // max value. auto shard = [¶ms, &in_mat, &out_mat, &out_arg_max_mat, &input_backprop, &output_arg_max, &out_backprop](int64 start, int64 limit) { - const int32 depth = params.depth; const int32 in_rows = params.tensor_in_rows; const int32 in_cols = params.tensor_in_cols; @@ -180,7 +179,6 @@ static void SpatialMaxPoolWithArgMaxHelper( input_backprop_flat(input_backprop_index) += out_backprop_flat(index); } } - }; const int64 shard_cost = params.tensor_in_rows * params.tensor_in_cols * @@ -567,7 +565,7 @@ class MaxPoolingGradGradOp : public OpKernel { // tensor_out_as_matrix with the corresponding values in // top_diff_as_matrix. auto shard = [¶ms, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat]( - int64 start, int64 limit) { + int64 start, int64 limit) { const int32 depth = params.depth; const int32 in_rows = params.tensor_in_rows; const int32 in_cols = params.tensor_in_cols; diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index f8daaca4c94aada5dbae5e5582f0da075b7222d5..0c7a236b2ff0f0b5c6287d1dffb1e8ef9bac7cc0 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -450,10 +450,10 @@ bool MaxPoolBackwardWithArgmax::operator()( T* bottom_diff, const Eigen::GpuDevice& d) { const int kThreadsPerBlock = 1024; SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff); + kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff); MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, d.stream()>>>( - output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff); + output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff); return d.ok(); } diff --git a/tensorflow/core/kernels/meta_support.cc b/tensorflow/core/kernels/meta_support.cc index 9fed01189fc3bfde4ad1e23ea8fda0c76311b3bc..39e60c9fcef174a4f9e2271600ed847f4e769625 100644 --- a/tensorflow/core/kernels/meta_support.cc +++ b/tensorflow/core/kernels/meta_support.cc @@ -98,9 +98,9 @@ typedef gemmlowp::meta::SimpleContext LocalContext; template void MultiThreadGemm(Context* context, const Params& params) { if (params.m <= 4) { - gemmlowp::meta::MultiThreadGemm< - Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params, - 1, 8, 8>(context, params); + gemmlowp::meta::MultiThreadGemm< + Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params, 1, + 8, 8>(context, params); } else { if (params.m >= params.n) { gemmlowp::meta::MultiThreadGemm< diff --git a/tensorflow/core/kernels/mfcc.cc b/tensorflow/core/kernels/mfcc.cc index 2793005aa2678b4017dc7a562b8362470e43b8ed..8c755e0df87546ab5f85c3ac5ce2d895d020de78 100644 --- a/tensorflow/core/kernels/mfcc.cc +++ b/tensorflow/core/kernels/mfcc.cc @@ -27,21 +27,19 @@ const double kFilterbankFloor = 1e-12; const int kDefaultFilterbankChannelCount = 40; const int kDefaultDCTCoefficientCount = 13; -Mfcc::Mfcc() : initialized_(false), - lower_frequency_limit_(kDefaultLowerFrequencyLimit), - upper_frequency_limit_(kDefaultUpperFrequencyLimit), - filterbank_channel_count_(kDefaultFilterbankChannelCount), - dct_coefficient_count_(kDefaultDCTCoefficientCount) { } +Mfcc::Mfcc() + : initialized_(false), + lower_frequency_limit_(kDefaultLowerFrequencyLimit), + upper_frequency_limit_(kDefaultUpperFrequencyLimit), + filterbank_channel_count_(kDefaultFilterbankChannelCount), + dct_coefficient_count_(kDefaultDCTCoefficientCount) {} -bool Mfcc::Initialize(int input_length, - double input_sample_rate) { - bool initialized = mel_filterbank_.Initialize(input_length, - input_sample_rate, - filterbank_channel_count_, - lower_frequency_limit_, - upper_frequency_limit_); - initialized &= dct_.Initialize(filterbank_channel_count_, - dct_coefficient_count_); +bool Mfcc::Initialize(int input_length, double input_sample_rate) { + bool initialized = mel_filterbank_.Initialize( + input_length, input_sample_rate, filterbank_channel_count_, + lower_frequency_limit_, upper_frequency_limit_); + initialized &= + dct_.Initialize(filterbank_channel_count_, dct_coefficient_count_); initialized_ = initialized; return initialized; } diff --git a/tensorflow/core/kernels/mfcc.h b/tensorflow/core/kernels/mfcc.h index 8268f4720348bbc820bd3f8863698d34999abb7b..8eee76f7f0cadad45cb223ab9fbb990e4c365a44 100644 --- a/tensorflow/core/kernels/mfcc.h +++ b/tensorflow/core/kernels/mfcc.h @@ -20,18 +20,17 @@ limitations under the License. #include +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/mfcc_dct.h" #include "tensorflow/core/kernels/mfcc_mel_filterbank.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { class Mfcc { public: Mfcc(); - bool Initialize(int input_length, - double input_sample_rate); + bool Initialize(int input_length, double input_sample_rate); // Input is a single squared-magnitude spectrogram frame. The input spectrum // is converted to linear magnitude and weighted into bands using a diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.cc b/tensorflow/core/kernels/mfcc_mel_filterbank.cc index 630de8a5a3362b77306ac76b70bbb63416d561d0..3db3b51e8b665f6e28ccb2bf8f3850785c7561fb 100644 --- a/tensorflow/core/kernels/mfcc_mel_filterbank.cc +++ b/tensorflow/core/kernels/mfcc_mel_filterbank.cc @@ -38,13 +38,12 @@ namespace tensorflow { MfccMelFilterbank::MfccMelFilterbank() : initialized_(false) {} -bool MfccMelFilterbank::Initialize(int input_length, - double input_sample_rate, - int output_channel_count, - double lower_frequency_limit, - double upper_frequency_limit) { +bool MfccMelFilterbank::Initialize(int input_length, double input_sample_rate, + int output_channel_count, + double lower_frequency_limit, + double upper_frequency_limit) { num_channels_ = output_channel_count; - sample_rate_ = input_sample_rate; + sample_rate_ = input_sample_rate; input_length_ = input_length; if (num_channels_ < 1) { @@ -85,10 +84,9 @@ bool MfccMelFilterbank::Initialize(int input_length, } // Always exclude DC; emulate HTK. - const double hz_per_sbin = 0.5 * sample_rate_ / - static_cast(input_length_ - 1); - start_index_ = static_cast(1.5 + (lower_frequency_limit / - hz_per_sbin)); + const double hz_per_sbin = + 0.5 * sample_rate_ / static_cast(input_length_ - 1); + start_index_ = static_cast(1.5 + (lower_frequency_limit / hz_per_sbin)); end_index_ = static_cast(upper_frequency_limit / hz_per_sbin); // Maps the input spectrum bin indices to filter bank channels/indices. For @@ -121,12 +119,12 @@ bool MfccMelFilterbank::Initialize(int input_length, weights_[i] = 0.0; } else { if (channel >= 0) { - weights_[i] = (center_frequencies_[channel + 1] - - FreqToMel(i * hz_per_sbin)) / + weights_[i] = + (center_frequencies_[channel + 1] - FreqToMel(i * hz_per_sbin)) / (center_frequencies_[channel + 1] - center_frequencies_[channel]); } else { weights_[i] = (center_frequencies_[0] - FreqToMel(i * hz_per_sbin)) / - (center_frequencies_[0] - mel_low); + (center_frequencies_[0] - mel_low); } } } @@ -152,16 +150,16 @@ bool MfccMelFilterbank::Initialize(int input_length, } } if (!bad_channels.empty()) { - LOG(ERROR) << "Missing " << bad_channels.size() << " bands " << - " starting at " << bad_channels[0] << - " in mel-frequency design. " << - "Perhaps too many channels or " << - "not enough frequency resolution in spectrum. (" << - "input_length: " << input_length << - " input_sample_rate: " << input_sample_rate << - " output_channel_count: " << output_channel_count << - " lower_frequency_limit: " << lower_frequency_limit << - " upper_frequency_limit: " << upper_frequency_limit; + LOG(ERROR) << "Missing " << bad_channels.size() << " bands " + << " starting at " << bad_channels[0] + << " in mel-frequency design. " + << "Perhaps too many channels or " + << "not enough frequency resolution in spectrum. (" + << "input_length: " << input_length + << " input_sample_rate: " << input_sample_rate + << " output_channel_count: " << output_channel_count + << " lower_frequency_limit: " << lower_frequency_limit + << " upper_frequency_limit: " << upper_frequency_limit; } initialized_ = true; return true; @@ -171,7 +169,7 @@ bool MfccMelFilterbank::Initialize(int input_length, // square root, then summing FFT magnitudes under triangular integration windows // whose widths increase with frequency. void MfccMelFilterbank::Compute(const std::vector &input, - std::vector *output) const { + std::vector *output) const { if (!initialized_) { LOG(ERROR) << "Mel Filterbank not initialized."; return; diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.h b/tensorflow/core/kernels/mfcc_mel_filterbank.h index 1bdc2dc93b80a2691d4adec219426b142ef24321..37c3936e80d893a3c12b153ea92749ec4b73f872 100644 --- a/tensorflow/core/kernels/mfcc_mel_filterbank.h +++ b/tensorflow/core/kernels/mfcc_mel_filterbank.h @@ -27,10 +27,8 @@ class MfccMelFilterbank { public: MfccMelFilterbank(); bool Initialize(int input_length, // Number of unique FFT bins fftsize/2+1. - double input_sample_rate, - int output_channel_count, - double lower_frequency_limit, - double upper_frequency_limit); + double input_sample_rate, int output_channel_count, + double lower_frequency_limit, double upper_frequency_limit); // Takes a squared-magnitude spectrogram slice as input, computes a // triangular-mel-weighted linear-magnitude filterbank, and places the result @@ -56,7 +54,7 @@ class MfccMelFilterbank { // FFT bin i contributes to the upper side of mel channel band_mapper_[i] std::vector band_mapper_; int start_index_; // Lowest FFT bin used to calculate mel spectrum. - int end_index_; // Highest FFT bin used to calculate mel spectrum. + int end_index_; // Highest FFT bin used to calculate mel spectrum. TF_DISALLOW_COPY_AND_ASSIGN(MfccMelFilterbank); }; diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc index 602dfeb4e5400143a10232219f02c8e5d8154a04..54f31e1699ef1843d942f952f540b2d657b2d063 100644 --- a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc +++ b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc @@ -34,11 +34,9 @@ TEST(MfccMelFilterbankTest, AgreesWithPythonGoldenValues) { input.push_back(i + 1); } const int kChannelCount = 20; - filterbank.Initialize(input.size(), - 22050 /* sample rate */, - kChannelCount /* channels */, - 20.0 /* lower frequency limit */, - 4000.0 /* upper frequency limit */); + filterbank.Initialize( + input.size(), 22050 /* sample rate */, kChannelCount /* channels */, + 20.0 /* lower frequency limit */, 4000.0 /* upper frequency limit */); std::vector output; filterbank.Compute(input, &output); @@ -65,13 +63,10 @@ TEST(MfccMelFilterbankTest, IgnoresExistingContentOfOutputVector) { std::vector input; std::vector output; - filterbank.Initialize(kSampleCount, - 22050 /* sample rate */, - 20 /* channels */, - 20.0 /* lower frequency limit */, + filterbank.Initialize(kSampleCount, 22050 /* sample rate */, + 20 /* channels */, 20.0 /* lower frequency limit */, 4000.0 /* upper frequency limit */); - // First call with nonzero input value, and an empty output vector, // will resize the output and fill it with the correct, nonzero outputs. input.assign(kSampleCount, 1.0); diff --git a/tensorflow/core/kernels/mfcc_test.cc b/tensorflow/core/kernels/mfcc_test.cc index cb32df8811ed04363fd61490e3253dd31539460d..72c1d331d6e7bd91385aa268d7b59bbd786859b4 100644 --- a/tensorflow/core/kernels/mfcc_test.cc +++ b/tensorflow/core/kernels/mfcc_test.cc @@ -36,11 +36,10 @@ TEST(MfccTest, AgreesWithPythonGoldenValues) { std::vector output; mfcc.Compute(input, &output); - std::vector expected = {29.13970072, -6.41568601, -0.61903012, - -0.96778652, -0.26819878, -0.40907028, - -0.15614748, -0.23203119, -0.10481487, - -0.1543029, -0.0769791, -0.10806114, - -0.06047613}; + std::vector expected = { + 29.13970072, -6.41568601, -0.61903012, -0.96778652, -0.26819878, + -0.40907028, -0.15614748, -0.23203119, -0.10481487, -0.1543029, + -0.0769791, -0.10806114, -0.06047613}; ASSERT_EQ(expected.size(), output.size()); for (int i = 0; i < output.size(); ++i) { diff --git a/tensorflow/core/kernels/mirror_pad_op.cc b/tensorflow/core/kernels/mirror_pad_op.cc index fbdeaf43ebbfdcf6b76f97046130f40cf8c8efd1..26e1082989f317a35d55826a466cb8d9ef306c4c 100644 --- a/tensorflow/core/kernels/mirror_pad_op.cc +++ b/tensorflow/core/kernels/mirror_pad_op.cc @@ -87,8 +87,8 @@ class MirrorPadOp : public OpKernel { const Tpaddings before = paddings(d, 0); // Pad before existing elements. const Tpaddings after = paddings(d, 1); // Pad after existing elements. OP_REQUIRES(context, before >= 0 && after >= 0, - errors::InvalidArgument("paddings must be non-negative: ", - before, " ", after)); + errors::InvalidArgument( + "paddings must be non-negative: ", before, " ", after)); if (offset_ == 0) { // SYMMETRIC mode. OP_REQUIRES(context, before <= in0.dim_size(d) && after <= in0.dim_size(d), @@ -296,8 +296,8 @@ class MirrorPadGradOp : public OpKernel { const Tpaddings before = paddings(d, 0); // Pad before existing elements. const Tpaddings after = paddings(d, 1); // Pad after existing elements. OP_REQUIRES(context, before >= 0 && after >= 0, - errors::InvalidArgument("Paddings must be non-negative: ", - before, ", ", after)); + errors::InvalidArgument( + "Paddings must be non-negative: ", before, ", ", after)); const int64 out_size = in0.dim_size(d) - (before + after); if (offset_ == 0) { // SYMMETRIC mode. diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc index 89d37d2f874c0b8fa7550b1c49c0e3c4106e2ee5..b539b00009eb5cdc383aa557881e32782dce5193 100644 --- a/tensorflow/core/kernels/mkl_aggregate_ops.cc +++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc @@ -28,7 +28,7 @@ limitations under the License. #include "mkl_dnn_types.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" using mkldnn::stream; using mkldnn::sum; @@ -37,7 +37,7 @@ using mkldnn::sum; namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML template class MklAddNOp : public OpKernel { @@ -285,7 +285,7 @@ class MklAddNOp : public OpKernel { } MklAddNOpContext; }; -#else // INTEL_MKL_DNN +#else // INTEL_MKL_ML template class MklAddNOp : public OpKernel { public: @@ -317,8 +317,11 @@ class MklAddNOp : public OpKernel { : src2_tensor.dims(); // if the shapes of two tensors are not same raise op error TensorShape src1_shape, src2_shape; - src1_shape = src1_tensor.shape(); - src2_shape = src2_tensor.shape(); + src1_shape = input1_in_mkl_format ? src1_mkl_shape.GetTfShape() + : src1_tensor.shape(); + src2_shape = input2_in_mkl_format ? src2_mkl_shape.GetTfShape() + : src2_tensor.shape(); + if (!src1_shape.IsSameSize(src2_shape)) { ctx->SetStatus(errors::InvalidArgument( "Inputs to operation ", this->name(), " of type ", diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc index d751a70fc86b40d8ca656322484848cf906359fd..d545d34fdfd8682b2e5b856d321579f675696e2f 100644 --- a/tensorflow/core/kernels/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc @@ -24,24 +24,23 @@ #include "tensorflow/core/kernels/mkl_pooling_ops_common.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" -using mkldnn::memory; +using mkldnn::algorithm; +using mkldnn::engine; using mkldnn::error; -using mkldnn::pooling_forward; -using mkldnn::pooling_backward; +using mkldnn::memory; using mkldnn::padding_kind; -using mkldnn::engine; +using mkldnn::pooling_backward; +using mkldnn::pooling_forward; using mkldnn::prop_kind; -using mkldnn::algorithm; #endif namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -// For now, MKL-ML is default. So making MKL-DNN not a default choice. -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML template class MklAvgPoolingOp : public OpKernel { @@ -358,10 +357,11 @@ class MklAvgPoolingGradOp : public OpKernel { if (!outbackprop_in_mkl_format) { // For avgpooling, tensor_in_shape should have 1 dimension, and 4 // elements. - OP_REQUIRES(context, tensor_in_shape.dims() == 1 && - tensor_in_shape.NumElements() == 4, - errors::InvalidArgument("original input shape must be " - "1-dimensional and 4 elements")); + OP_REQUIRES( + context, + tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4, + errors::InvalidArgument("original input shape must be " + "1-dimensional and 4 elements")); // For avgpooling, out_backprop should have 4 dimensions. OP_REQUIRES(context, out_backprop.dims() == 4, @@ -428,14 +428,13 @@ class MklAvgPoolingGradOp : public OpKernel { TensorFormat data_format_; }; // MklAvgPoolingGradOp - -#else // INTEL_MKL_DNN is defined +#else template class MklAvgPoolingOp : public MklPoolingForwardOpBase { public: explicit MklAvgPoolingOp(OpKernelConstruction* context) - : MklPoolingForwardOpBase(context) { + : MklPoolingForwardOpBase(context) { // Workspace is an MKLDNN construct that is only used in Max Pooling. // So set workspace_enabled_ to false. this->workspace_enabled_ = false; @@ -444,8 +443,8 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { void Compute(OpKernelContext* context) override { try { auto cpu_engine = engine(engine::cpu, 0); - const Tensor& input_tensor = MklGetInput(context, - this->kInputTensorIndexInput); + const Tensor& input_tensor = + MklGetInput(context, this->kInputTensorIndexInput); MklDnnShape dnn_shape_input; GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input); this->SanityCheckInput(context, input_tensor, dnn_shape_input); @@ -457,9 +456,8 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { // initialize variables for the pooling op MklPoolParameters pool_params; // Get the input tensor and initialize the pooling parameters - this->ConfigureInput(context, dnn_shape_input, - input_tensor, &pool_params, - &dnn_data_input); + this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params, + &dnn_data_input); OP_REQUIRES_OK(context, context->status()); // Declare output tensor @@ -467,59 +465,77 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { memory::dims output_dims_mkl_order; this->GetOutputDims(pool_params, &output_dims_mkl_order); + // If input is an empty tensor, allocate an empty output tensor and return + if (input_tensor.NumElements() == 0) { + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(false); + TensorShape output_tf_shape; + if (pool_params.data_format == TensorFormat::FORMAT_NCHW) { + output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); + } else { + memory::dims output_dims_NHWC_order; + output_dims_NHWC_order = {pool_params.tensor_in_batch, + static_cast(pool_params.out_height), + static_cast(pool_params.out_width), + pool_params.out_depth}; + output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); + } + const int kOutputIndex = 0; + AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor, + output_tf_shape, output_mkl_shape); + CHECK_NOTNULL(output_tensor); + return; + } + // If input is in Mkl layout, then just get the memory format from it // directly, instead of using input data_format to AvgPool. if (dnn_shape_input.IsMklTensor()) { - dnn_data_output.SetUsrMem(output_dims_mkl_order, - static_cast(dnn_data_input.GetUsrMemDesc() - .data.format)); + dnn_data_output.SetUsrMem( + output_dims_mkl_order, + static_cast( + dnn_data_input.GetUsrMemDesc().data.format)); } else { - dnn_data_output.SetUsrMem(output_dims_mkl_order, - this->data_format_mkldnn_); + dnn_data_output.SetUsrMem(output_dims_mkl_order, + this->data_format_mkldnn_); } - // describe the memory layout + // describe the memory layout dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any); // 3. create a pooling primitive descriptor - auto pool_desc = pooling_forward::desc(prop_kind::forward, - algorithm::pooling_avg_exclude_padding, - dnn_data_input.GetUsrMemDesc(), - dnn_data_output.GetUsrMemDesc(), - memory::dims({ pool_params.row_stride, - pool_params.col_stride}), - memory::dims({ pool_params.window_rows, - pool_params.window_cols}), - memory::dims({ static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({ static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, - cpu_engine); + auto pool_desc = pooling_forward::desc( + prop_kind::forward, algorithm::pooling_avg_exclude_padding, + dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(), + memory::dims({pool_params.row_stride, pool_params.col_stride}), + memory::dims({pool_params.window_rows, pool_params.window_cols}), + memory::dims({static_cast(pool_params.pad_top), + static_cast(pool_params.pad_left)}), + memory::dims({static_cast(pool_params.pad_bottom), + static_cast(pool_params.pad_right)}), + TFPaddingToMklDnnPadding(this->padding_)); + auto pool_prim_desc = + pooling_forward::primitive_desc(pool_desc, cpu_engine); this->AllocateOutputTensor(context, pool_prim_desc, output_dims_mkl_order, - this->data_format_mkldnn_, &output_tensor); + this->data_format_mkldnn_, &output_tensor); CHECK_NOTNULL(output_tensor); OP_REQUIRES_OK(context, context->status()); dnn_data_output.SetUsrMemDataHandle(output_tensor); - this->PrepareAndExecuteNet(pool_prim_desc, - &dnn_data_input, - &dnn_data_output); - } catch (mkldnn::error &e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Operation received an exception:", - error_msg)); + this->PrepareAndExecuteNet(pool_prim_desc, &dnn_data_input, + &dnn_data_output); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } // Compute -}; // MklAvgPoolingOp +}; // MklAvgPoolingOp //----------------------------------------------------------------------------- @@ -527,27 +543,23 @@ template class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { public: explicit MklAvgPoolingGradOp(OpKernelConstruction* context) - : MklPoolingBackwardOpBase(context) { - } + : MklPoolingBackwardOpBase(context) {} void Compute(OpKernelContext* context) override { try { auto cpu_engine = engine(engine::cpu, 0); MklDnnShape original_input_mkl_shape, input_gradient_mkl_shape; - const Tensor& tensor_in_shape = MklGetInput(context, - kInputTensorIndexInputShape); - const Tensor& input_gradient_tensor = MklGetInput(context, - kInputTensorIndexInputGradient); + const Tensor& tensor_in_shape = + MklGetInput(context, kInputTensorIndexInputShape); + const Tensor& input_gradient_tensor = + MklGetInput(context, kInputTensorIndexInputGradient); GetMklShape(context, kInputTensorIndexInputShape, - &original_input_mkl_shape); + &original_input_mkl_shape); GetMklShape(context, kInputTensorIndexInputGradient, - &input_gradient_mkl_shape); - + &input_gradient_mkl_shape); - SanityCheckInputs(context, tensor_in_shape, - input_gradient_tensor, - original_input_mkl_shape, - input_gradient_mkl_shape); + SanityCheckInputs(context, tensor_in_shape, input_gradient_tensor, + original_input_mkl_shape, input_gradient_mkl_shape); if (!context->status().ok()) return; // Used to allocate output_diff_src/diff_src @@ -562,90 +574,70 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { MklPoolParameters pool_params; memory::dims output_dims_mkl_order, original_input_dims_nchw; // Configure the original input memory descriptor - memory::desc original_input_md = ConfigureOriginalInput(context, - tensor_in_shape, - original_input_mkl_shape, - &original_input_dims_nchw, - &pool_params, - &original_input_shape); + memory::desc original_input_md = ConfigureOriginalInput( + context, tensor_in_shape, original_input_mkl_shape, + &original_input_dims_nchw, &pool_params, &original_input_shape); // configure the original output memory descriptor // by definition, the shape of the original output is the same // as the shape of the gradient diff_dst memory::desc original_output_md = this->ConfigureOriginalOutput( - pool_params, input_gradient_mkl_shape, output_dims_mkl_order); + pool_params, input_gradient_mkl_shape, output_dims_mkl_order); memory::desc target_diff_dst_md = this->ConfigureInputGradient( - input_gradient_mkl_shape, - input_gradient_tensor, - &input_gradient_diff_dst, - original_output_md); + input_gradient_mkl_shape, input_gradient_tensor, + &input_gradient_diff_dst, original_output_md); // The shape of the output diff src needs to be the same shape as the // original input. But we will set its format to be same as the format of // input gradient. We won't use format of original input since it will // always be in Tensorflow layout (given that AvgPoolGrad gets shape of // the input rather than actual input). - output_diff_src.SetUsrMem(original_input_dims_nchw, - static_cast( - target_diff_dst_md.data.format)); + output_diff_src.SetUsrMem( + original_input_dims_nchw, + static_cast(target_diff_dst_md.data.format)); // Create the forward pooling primitive descriptor so we can reference it // in the backward pooling primitive descriptor - auto pool_fwd_desc = pooling_forward::desc(prop_kind::forward, - algorithm::pooling_avg_exclude_padding, - original_input_md, - original_output_md, - memory::dims({ pool_params.row_stride, - pool_params.col_stride}), - memory::dims({ pool_params.window_rows, - pool_params.window_cols}), - memory::dims({ static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({ static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_fwd_prim_desc - = pooling_forward::primitive_desc(pool_fwd_desc, - cpu_engine); + auto pool_fwd_desc = pooling_forward::desc( + prop_kind::forward, algorithm::pooling_avg_exclude_padding, + original_input_md, original_output_md, + memory::dims({pool_params.row_stride, pool_params.col_stride}), + memory::dims({pool_params.window_rows, pool_params.window_cols}), + memory::dims({static_cast(pool_params.pad_top), + static_cast(pool_params.pad_left)}), + memory::dims({static_cast(pool_params.pad_bottom), + static_cast(pool_params.pad_right)}), + TFPaddingToMklDnnPadding(this->padding_)); + auto pool_fwd_prim_desc = + pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine); auto pool_bkwd_desc = pooling_backward::desc( - algorithm::pooling_avg_exclude_padding, - output_diff_src.GetUsrMemDesc(), - target_diff_dst_md, - memory::dims({ pool_params.row_stride, - pool_params.col_stride}), - memory::dims({ pool_params.window_rows, - pool_params.window_cols}), - memory::dims({ static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({ static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_bkwd_prim_desc - = pooling_backward::primitive_desc(pool_bkwd_desc, - cpu_engine, - pool_fwd_prim_desc); - this->AllocateOutputTensor(context, pool_bkwd_prim_desc, - original_input_dims_nchw, - this->data_format_mkldnn_, - &output_tensor_diff_src); + algorithm::pooling_avg_exclude_padding, + output_diff_src.GetUsrMemDesc(), target_diff_dst_md, + memory::dims({pool_params.row_stride, pool_params.col_stride}), + memory::dims({pool_params.window_rows, pool_params.window_cols}), + memory::dims({static_cast(pool_params.pad_top), + static_cast(pool_params.pad_left)}), + memory::dims({static_cast(pool_params.pad_bottom), + static_cast(pool_params.pad_right)}), + TFPaddingToMklDnnPadding(this->padding_)); + auto pool_bkwd_prim_desc = pooling_backward::primitive_desc( + pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc); + this->AllocateOutputTensor( + context, pool_bkwd_prim_desc, original_input_dims_nchw, + this->data_format_mkldnn_, &output_tensor_diff_src); output_diff_src.SetUsrMemDataHandle(output_tensor_diff_src); - this->PrepareAndExecuteNet(pool_bkwd_prim_desc, - &input_gradient_diff_dst, - &output_diff_src, - memory::primitive_desc( - target_diff_dst_md, - cpu_engine)); - } catch (mkldnn::error &e) { + this->PrepareAndExecuteNet( + pool_bkwd_prim_desc, &input_gradient_diff_dst, &output_diff_src, + memory::primitive_desc(target_diff_dst_md, cpu_engine)); + } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Compute received an exception:", - error_msg)); + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", + error_msg)); } } // Compute @@ -655,12 +647,11 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { const int kInputTensorIndexInputShape = 0; const int kInputTensorIndexInputGradient = 1; - memory::desc ConfigureOriginalInput(OpKernelContext* context, - const Tensor& tensor_original_input_shape, - const MklDnnShape& original_input_mkl_shape, - memory::dims* original_input_dims_mkl_order, - MklPoolParameters* pool_params, - TensorShape* input_tensor_shape) { + memory::desc ConfigureOriginalInput( + OpKernelContext* context, const Tensor& tensor_original_input_shape, + const MklDnnShape& original_input_mkl_shape, + memory::dims* original_input_dims_mkl_order, + MklPoolParameters* pool_params, TensorShape* input_tensor_shape) { CHECK_NOTNULL(original_input_dims_mkl_order); CHECK_NOTNULL(pool_params); CHECK_NOTNULL(input_tensor_shape); @@ -672,47 +663,43 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { } return MklPoolingBackwardOpBase::ConfigureOriginalInput( - context, - tensor_original_input_shape, - original_input_mkl_shape, - original_input_dims_mkl_order, - pool_params, - *input_tensor_shape); -} + context, tensor_original_input_shape, original_input_mkl_shape, + original_input_dims_mkl_order, pool_params, *input_tensor_shape); + } void SanityCheckInputs(OpKernelContext* context, - const Tensor& tensor_in_shape, - const Tensor& input_gradient_tensor, - const MklDnnShape& original_input_mkl_shape, - const MklDnnShape& input_gradient_mkl_shape) { + const Tensor& tensor_in_shape, + const Tensor& input_gradient_tensor, + const MklDnnShape& original_input_mkl_shape, + const MklDnnShape& input_gradient_mkl_shape) { if (!original_input_mkl_shape.IsMklTensor()) { - OP_REQUIRES(context, tensor_in_shape.dims() == 1 && - tensor_in_shape.NumElements() == 4, + OP_REQUIRES( + context, + tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4, errors::InvalidArgument("original input shape must be " - "1-dimensional and 4 elements")); + "1-dimensional and 4 elements")); } else { - OP_REQUIRES(context, original_input_mkl_shape.GetDimension() == 1 && - original_input_mkl_shape.DimSize(0) == 4, - errors::InvalidArgument("original input shape must be " - "1-dimensional and 4 elements")); + OP_REQUIRES(context, + original_input_mkl_shape.GetDimension() == 1 && + original_input_mkl_shape.DimSize(0) == 4, + errors::InvalidArgument("original input shape must be " + "1-dimensional and 4 elements")); } if (!input_gradient_mkl_shape.IsMklTensor()) { // For avgpooling, input_gradient_diff_dst should have 4 dimensions. OP_REQUIRES(context, input_gradient_tensor.dims() == 4, - errors::InvalidArgument("Gradient shape must be " - "4-dimensional")); + errors::InvalidArgument("Gradient shape must be " + "4-dimensional")); } else { OP_REQUIRES(context, input_gradient_mkl_shape.GetDimension() == 4, - errors::InvalidArgument("Gradient shape must be " - "4-dimensional")); + errors::InvalidArgument("Gradient shape must be " + "4-dimensional")); } } }; // MklAvgPoolingGradOp - - -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML REGISTER_KERNEL_BUILDER(Name("_MklAvgPool") .Device(DEVICE_CPU) @@ -728,4 +715,3 @@ REGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad") } // namespace tensorflow #endif // INTEL_MKL - diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc index 9fee94f946555480fce8acf904a7909622404524..c48a2038f921986635795575a69606cbab24f12a 100644 --- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc @@ -29,7 +29,6 @@ limitations under the License. #include #include "mkl_cblas.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -40,10 +39,6 @@ limitations under the License. #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - -#define MKL_Complex8 tensorflow::complex64 -#define MKL_Complex16 tensorflow::complex128 namespace tensorflow { @@ -181,16 +176,16 @@ class BatchMatMulMkl : public OpKernel { void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, const MKL_INT *M_Array, const MKL_INT *N_Array, const MKL_INT *K_Array, - const MKL_Complex8 **A_Array, const MKL_INT *lda_Array, - const MKL_Complex8 **B_Array, const MKL_INT *ldb_Array, - MKL_Complex8 **C_Array, const MKL_INT *ldc_Array, + const complex64 **A_Array, const MKL_INT *lda_Array, + const complex64 **B_Array, const MKL_INT *ldb_Array, + complex64 **C_Array, const MKL_INT *ldc_Array, const MKL_INT group_count, const MKL_INT *group_size) { std::vector TransA_array( group_size[0], TransA ? CblasConjTrans : CblasNoTrans); std::vector TransB_array( group_size[0], TransB ? CblasConjTrans : CblasNoTrans); - std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); - std::vector beta_Array(group_size[0], {0.0f, 0.0f}); + std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); + std::vector beta_Array(group_size[0], {0.0f, 0.0f}); cblas_cgemm_batch( Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array, static_cast(&alpha_Array[0]), @@ -203,18 +198,18 @@ class BatchMatMulMkl : public OpKernel { void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, const MKL_INT *M_Array, const MKL_INT *N_Array, const MKL_INT *K_Array, - const MKL_Complex16 **A_Array, + const complex128 **A_Array, const MKL_INT *lda_Array, - const MKL_Complex16 **B_Array, - const MKL_INT *ldb_Array, MKL_Complex16 **C_Array, + const complex128 **B_Array, + const MKL_INT *ldb_Array, complex128 **C_Array, const MKL_INT *ldc_Array, const MKL_INT group_count, const MKL_INT *group_size) { std::vector TransA_array( group_size[0], TransA ? CblasConjTrans : CblasNoTrans); std::vector TransB_array( group_size[0], TransB ? CblasConjTrans : CblasNoTrans); - std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); - std::vector beta_Array(group_size[0], {0.0f, 0.0f}); + std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); + std::vector beta_Array(group_size[0], {0.0f, 0.0f}); cblas_zgemm_batch( Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array, static_cast(&alpha_Array[0]), diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index d109bb6bcfe6360af12086bad452752336357f35..f1f267e849aa39b43c153b857493160e0d103970 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -30,11 +30,11 @@ limitations under the License. #include "mkl_dnn_types.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" -using mkldnn::stream; using mkldnn::concat; +using mkldnn::stream; #endif namespace tensorflow { @@ -45,7 +45,6 @@ typedef std::vector TensorShapeList; enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM }; - // TODO(intelft) Check if we can reuse existing EigenConcatOp using Mutable // reference inputs. // -------------------------------------------------------------------------- @@ -63,7 +62,7 @@ class EigenConcatBaseOp : public OpKernel { // we need to have empty Compute because Compute is pure virtual function. void Compute(OpKernelContext* c) {} -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML void Compute(OpKernelContext* c, const std::vector& values) { const Tensor* concat_dim_tensor; @@ -152,8 +151,8 @@ class EigenConcatBaseOp : public OpKernel { #else // MKL_DNN -void Compute(OpKernelContext* c, const std::vector& values, - const TensorShapeList& input_shapes) { + void Compute(OpKernelContext* c, const std::vector& values, + const TensorShapeList& input_shapes) { const Tensor* concat_dim_tensor; const char* axis_attribute_name = AxisArgName == NAME_IS_AXIS @@ -197,7 +196,8 @@ void Compute(OpKernelContext* c, const std::vector& values, const auto in = values[i]; const bool in_is_scalar = IsLegacyScalar(input_shapes[i]); OP_REQUIRES( - c, (input_shapes[i].dims() == input_dims) || + c, + (input_shapes[i].dims() == input_dims) || (input_is_scalar && in_is_scalar), errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", @@ -208,8 +208,8 @@ void Compute(OpKernelContext* c, const std::vector& values, inputs_flat.emplace_back(new typename TTypes::ConstMatrix( in.shaped({inputs_flat_dim0, inputs_flat_dim1}))); } - output_concat_dim += input_shapes[i].dims() > 0 ? - input_shapes[i].dim_size(axis) : 1; + output_concat_dim += + input_shapes[i].dims() > 0 ? input_shapes[i].dim_size(axis) : 1; } TensorShape output_shape(input_shape); @@ -230,7 +230,7 @@ void Compute(OpKernelContext* c, const std::vector& values, #endif }; -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML // -------------------------------------------------------------------------- // Mkl Concat Op @@ -418,7 +418,6 @@ class MklConcatOp : public OpKernel { OP_REQUIRES_OK(context, context->status()); } - private: typedef struct { TensorFormat data_format; @@ -590,39 +589,45 @@ class MklConcatOp : public OpKernel { GetMklShapeList(context, "values", &input_shapes); const Tensor& concat_dim_tensor = (AxisArgName == NAME_IS_CONCAT_DIM) - ? MklGetInput(context, 0) : MklGetInput(context, N); + ? MklGetInput(context, 0) + : MklGetInput(context, N); // Sanity checks - OP_REQUIRES(context, IsLegacyScalar(concat_dim_tensor.shape()), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_tensor.shape().DebugString())); - int32 concat_dim = internal::SubtleMustCopy( - concat_dim_tensor.scalar()()); + OP_REQUIRES( + context, IsLegacyScalar(concat_dim_tensor.shape()), + errors::InvalidArgument( + "Concat dim tensor should be a scalar integer, but got shape ", + concat_dim_tensor.shape().DebugString())); + int32 concat_dim = + internal::SubtleMustCopy(concat_dim_tensor.scalar()()); // check that ranks of all tensors match // and that their shapes match except for concat_dim. int i = 0; bool invoke_eigen = false; bool are_all_mkl_inputs = true, are_all_tf_inputs = true; - const TensorShape expected_shape = input_shapes[0].IsMklTensor() ? - input_shapes[0].GetTfShape() : - input_tensors[0].shape(); + const TensorShape expected_shape = input_shapes[0].IsMklTensor() + ? input_shapes[0].GetTfShape() + : input_tensors[0].shape(); size_t expected_dims = expected_shape.dims(); if (concat_dim < 0) concat_dim = expected_dims + concat_dim; for (auto& s : input_shapes) { - if (s == expected_shape) {++i; continue;} + if (s == expected_shape) { + ++i; + continue; + } - TensorShape s_shape = s.IsMklTensor() ? s.GetTfShape() : - input_tensors[i].shape(); + TensorShape s_shape = + s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape(); size_t s_dims = s_shape.dims(); - OP_REQUIRES(context, s_dims == expected_dims, - errors::InvalidArgument( - "_MklConcatOp : Ranks of all input tensors should match:" - " input dimensions = ", - s_dims, " vs. expected rank = ", expected_dims)); + OP_REQUIRES( + context, s_dims == expected_dims, + errors::InvalidArgument( + "_MklConcatOp : Ranks of all input tensors should match:" + " input dimensions = ", + s_dims, " vs. expected rank = ", expected_dims)); for (int d = 0; d < expected_dims; ++d) { if (d == concat_dim) continue; @@ -630,10 +635,11 @@ class MklConcatOp : public OpKernel { size_t expected_size = expected_shape.dim_size(d); size_t s_size = s_shape.dim_size(d); OP_REQUIRES( - context, expected_size == s_size, - errors::InvalidArgument("_MklConcatOp : Dimensions of inputs " - "should match: shape[0][", d, "]= ", expected_size, - " vs. shape[", i, "][", d, "] = ", s_size)); + context, expected_size == s_size, + errors::InvalidArgument("_MklConcatOp : Dimensions of inputs " + "should match: shape[0][", + d, "]= ", expected_size, " vs. shape[", i, + "][", d, "] = ", s_size)); } if (s.IsMklTensor()) @@ -657,8 +663,8 @@ class MklConcatOp : public OpKernel { TensorShapeList tf_input_shapes; i = 0; for (auto& s : input_shapes) { - TensorShape s_shape = s.IsMklTensor() ? s.GetTfShape() : - input_tensors[i].shape(); + TensorShape s_shape = + s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape(); tf_input_shapes.push_back(s_shape); ++i; } @@ -678,21 +684,22 @@ class MklConcatOp : public OpKernel { std::vector srcs_pd; std::vector> srcs(N, MklDnnData(&cpu_engine)); int64 dst_concat_dim_size = 0; - for (int k =0; k < N; k++) { + for (int k = 0; k < N; k++) { bool is_mkl_tensor = input_shapes[k].IsMklTensor(); memory::dims src_dims; // Same comment as dst_dims for src_dims. - src_dims = (is_mkl_tensor) ? - TFShapeToMklDnnDims(input_shapes[k].GetTfShape()) : - TFShapeToMklDnnDims(input_tensors[k].shape()); + src_dims = (is_mkl_tensor) + ? TFShapeToMklDnnDims(input_shapes[k].GetTfShape()) + : TFShapeToMklDnnDims(input_tensors[k].shape()); dst_concat_dim_size += src_dims[concat_dim]; - auto src_md = is_mkl_tensor ? input_shapes[k].GetMklLayout() : - // It does not matter what data format we use here (NHWC or NCHW). - // We just need to ensure that output of Concat uses same data format - // as input. - memory::desc(src_dims, MklDnnType(), memory::format::nchw); + auto src_md = + is_mkl_tensor ? input_shapes[k].GetMklLayout() : + // It does not matter what data format we use here + // (NHWC or NCHW). We just need to ensure that output + // of Concat uses same data format as input. + memory::desc(src_dims, MklDnnType(), memory::format::nchw); srcs[k].SetUsrMem(src_md, &input_tensors[k]); auto src_mpd = srcs[k].GetUsrMemPrimDesc(); @@ -707,14 +714,15 @@ class MklConcatOp : public OpKernel { // Since we are passing a specific format for destination, // we need to have dst_dims in MklDnn order (NCHW). auto orig_tf_format = input_shapes[0].GetTfDataFormat(); - dst_dims_in_nchw = MklDnnDimsInNCHW(dst_dims, - MklDnnDataFormatToTFDataFormat(orig_tf_format)); + dst_dims_in_nchw = MklDnnDimsInNCHW( + dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format)); // We will set the output in the same format as input to avoid layout // conversions. // Currently we are setting dst format same as input format. // See if we can make this choice in a better way. - dst_md = memory::desc(dst_dims_in_nchw, MklDnnType(), - (memory::format) input_shapes[0].GetMklLayout().data.format); + dst_md = memory::desc( + dst_dims_in_nchw, MklDnnType(), + (memory::format)input_shapes[0].GetMklLayout().data.format); } else { // Again, format does not matter here. We just need to make it same as // input format. @@ -722,7 +730,7 @@ class MklConcatOp : public OpKernel { } std::vector inputs; - for (int k=0; k < input_tensors.size(); k++) + for (int k = 0; k < input_tensors.size(); k++) inputs.push_back(srcs[k].GetOpMem()); // If all inputs are in MKL format, then meaning of concat_dim needs to @@ -732,8 +740,7 @@ class MklConcatOp : public OpKernel { // But ifinput tensors are in NHWC order, then semantics need to change. // E.g., if we are concatinating over Channel (dimension 3 for NHWC), // then since MklDnn order is NCHW, concat_dim needs to be 1. - if (are_all_mkl_inputs) - concat_dim = input_shapes[0].TfDimIdx(concat_dim); + if (are_all_mkl_inputs) concat_dim = input_shapes[0].TfDimIdx(concat_dim); auto concat_pd = concat::primitive_desc(dst_md, concat_dim, srcs_pd); @@ -752,24 +759,25 @@ class MklConcatOp : public OpKernel { dnn_shape_dst.SetMklTensor(false); tf_shape_dst = MklDnnDimsToTFShape(dst_dims); } - AllocateOutputSetMklShape(context, 0, &dst_tensor, - tf_shape_dst, dnn_shape_dst); + AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst, + dnn_shape_dst); CHECK_NOTNULL(dst_tensor); - dst_md = dnn_shape_dst.IsMklTensor() ? - dnn_shape_dst.GetMklLayout() : dst_md; + dst_md = + dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout() : dst_md; dst.SetUsrMem(dst_md, dst_tensor); auto concat_op = concat(concat_pd, inputs, dst.GetOpMem()); std::vector net; net.push_back(concat_op); stream(stream::kind::eager).submit(net).wait(); - } catch (mkldnn::error &e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); - OP_REQUIRES_OK(context, errors::Aborted( - "Operation received an exception:", error_msg)); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } @@ -790,11 +798,9 @@ class MklConcatOp : public OpKernel { dnn_shape_output.SetDimensions(4); Tensor* output_tensor = nullptr; TensorShape tf_shape_output; - tf_shape_output.AddDim( - dnn_shape_output.GetSerializeBufferSize()); - context->allocate_output( - GetTensorMetaDataIndex(0, context->num_outputs()), - tf_shape_output, &output_tensor); + tf_shape_output.AddDim(dnn_shape_output.GetSerializeBufferSize()); + context->allocate_output(GetTensorMetaDataIndex(0, context->num_outputs()), + tf_shape_output, &output_tensor); dnn_shape_output.SerializeMklDnnShape( output_tensor->flat().data(), output_tensor->flat().size() * sizeof(uint8)); diff --git a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc index 0f1a218fe62dd91160320254342828811e3aa458..25c2573741265d4d33c9c91474792be241dd3b32 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc @@ -38,9 +38,9 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" -#include "tensorflow/core/util/mkl_util.h" #include "mkl_dnn.h" #include "mkl_dnn_types.h" +#include "tensorflow/core/util/mkl_util.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 54d4916d4943be4957bb60b273cdbf2d6ce1ffdc..1401bc65a45bd80ed78230840cf0b9958b1f012e 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -38,24 +38,24 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" -#include "tensorflow/core/util/mkl_util.h" #include "mkl_dnn.h" #include "mkl_dnn_types.h" +#include "tensorflow/core/util/mkl_util.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" -using mkldnn::stream; -using mkldnn::prop_kind; using mkldnn::convolution_backward_weights; using mkldnn::memory; +using mkldnn::prop_kind; +using mkldnn::stream; #endif namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML template class MklConv2DCustomBackpropFilterOp : public OpKernel { @@ -360,8 +360,8 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input; const Tensor& out_backprop = MklGetInput(context, 2); - void* mkl_buf_out_backprop = const_cast(static_cast( - out_backprop.flat().data())); + void* mkl_buf_out_backprop = const_cast( + static_cast(out_backprop.flat().data())); CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop, prim_conv_bwdfilter, @@ -371,10 +371,11 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop, lt_out_backprop); if (mkl_convert_out_backprop) { CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop, - lt_out_backprop, mkl_lt_internal_out_backprop), + lt_out_backprop, + mkl_lt_internal_out_backprop), E_SUCCESS); AllocTmpBuffer(context, mkl_tmp_out_backprop_buf_tensor, - lt_out_backprop, &mkl_buf_convert_out_backprop); + lt_out_backprop, &mkl_buf_convert_out_backprop); CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop, mkl_buf_out_backprop, mkl_buf_convert_out_backprop), @@ -428,18 +429,18 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ - MklConv2DCustomBackpropFilterOp); + MklConv2DCustomBackpropFilterOp); TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); #undef REGISTER_MKL_FILTER_KERNELS #else template -class MklConv2DCustomBackpropFilterOp : - public MklConv2DBackpropCommonOp { +class MklConv2DCustomBackpropFilterOp + : public MklConv2DBackpropCommonOp { public: explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context) - : MklConv2DBackpropCommonOp(context) { } + : MklConv2DBackpropCommonOp(context) {} ~MklConv2DCustomBackpropFilterOp() {} private: @@ -447,7 +448,7 @@ class MklConv2DCustomBackpropFilterOp : const MklDnnShape& filter_mkl_shape, const MklDnnShape& obp_mkl_shape) { CHECK(!filter_mkl_shape.IsMklTensor()) - << "Conv2DBackpropFilter: filter should not be in MKL Layout"; + << "Conv2DBackpropFilter: filter should not be in MKL Layout"; } size_t GetInputTensorIndexWithSizes() { return 1; /* filter index */ } @@ -462,8 +463,10 @@ class MklConv2DCustomBackpropFilterOp : const Tensor& filter_tensor) { TensorShape filter_tf_shape; CHECK_EQ(TensorShapeUtils::IsVector(filter_tensor.shape()), true); - CHECK_EQ(TensorShapeUtils::MakeShape( - filter_tensor.vec(), &filter_tf_shape).ok(), true); + CHECK_EQ(TensorShapeUtils::MakeShape(filter_tensor.vec(), + &filter_tf_shape) + .ok(), + true); return filter_tf_shape; } @@ -485,16 +488,13 @@ class MklConv2DCustomBackpropFilterOp : return memory::format::hwio; } - void CreatePrimitive(OpKernelContext* context, - const engine& cpu_engine, + void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine, const convolution_forward::primitive_desc& conv_fwd_pd, MklDnnData* input, MklDnnData* filter, MklDnnData* outbackprop, MklDnnData* output, - Tensor** output_tensor, - const memory::dims& strides, + Tensor** output_tensor, const memory::dims& strides, const memory::dims& padding_l, - const memory::dims& padding_r, - padding_kind padding, + const memory::dims& padding_r, padding_kind padding, const memory::dims& bwd_output_dims, memory::format bwd_output_format) { CHECK_NOTNULL(context); @@ -508,34 +508,35 @@ class MklConv2DCustomBackpropFilterOp : int depth = 0; if (biasEnabled) { // Data structure for bias_grad - bias_grad = new MklDnnData (&cpu_engine); + bias_grad = new MklDnnData(&cpu_engine); TensorShape obp_tf_shape = GetTfShape(context, 2); - depth = (MklConv2DBackpropCommonOp::GetTFDataFormat() - == FORMAT_NCHW) ? - obp_tf_shape.dim_size(1) : obp_tf_shape.dim_size(3); + depth = (MklConv2DBackpropCommonOp::GetTFDataFormat() == + FORMAT_NCHW) + ? obp_tf_shape.dim_size(1) + : obp_tf_shape.dim_size(3); memory::dims bias_grad_dims = {depth}; bias_grad->SetOpMemDesc(bias_grad_dims, memory::format::x); } // Create convolution backward weights primitive. - auto bwd_desc = (biasEnabled && (bias_grad != nullptr))? - convolution_backward_weights::desc(convolution_direct, - input->GetOpMemDesc(), output->GetOpMemDesc(), - bias_grad->GetOpMemDesc(), - outbackprop->GetOpMemDesc(), strides, padding_l, - padding_r, padding) : - convolution_backward_weights::desc(convolution_direct, - input->GetOpMemDesc(), output->GetOpMemDesc(), - outbackprop->GetOpMemDesc(), strides, padding_l, - padding_r, padding); - - auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc, - cpu_engine, - conv_fwd_pd); + auto bwd_desc = + (biasEnabled && (bias_grad != nullptr)) + ? convolution_backward_weights::desc( + convolution_direct, input->GetOpMemDesc(), + output->GetOpMemDesc(), bias_grad->GetOpMemDesc(), + outbackprop->GetOpMemDesc(), strides, padding_l, padding_r, + padding) + : convolution_backward_weights::desc( + convolution_direct, input->GetOpMemDesc(), + output->GetOpMemDesc(), outbackprop->GetOpMemDesc(), strides, + padding_l, padding_r, padding); + + auto bwd_pd = convolution_backward_weights::primitive_desc( + bwd_desc, cpu_engine, conv_fwd_pd); // Allocate output tensor. - AllocateOutputTensor(context, bwd_pd, bwd_output_dims, - bwd_output_format, output_tensor); + AllocateOutputTensor(context, bwd_pd, bwd_output_dims, bwd_output_format, + output_tensor); CHECK_NOTNULL(*output_tensor); // Set buffer handle using allocated output tensor. @@ -548,8 +549,8 @@ class MklConv2DCustomBackpropFilterOp : AllocateBiasGradTensor(context, bias_grad_shape, &bias_grad_tensor); memory::dims bias_grad_dims = {depth}; // Since Bias is 1D, we use format::x from MKLDNN to represent it. - auto bias_grad_md = memory::desc({bias_grad_dims}, MklDnnType(), - memory::format::x); + auto bias_grad_md = + memory::desc({bias_grad_dims}, MklDnnType(), memory::format::x); bias_grad->SetUsrMem(bias_grad_md, bias_grad_tensor); bias_grad->SetUsrMemDataHandle(bias_grad_tensor); } @@ -562,28 +563,29 @@ class MklConv2DCustomBackpropFilterOp : } // Allocate output tensor. - void AllocateOutputTensor(OpKernelContext* context, - const convolution_backward_weights::primitive_desc& conv_pd, - const memory::dims& output_dims_mkl_order, - memory::format output_tf_format, Tensor** output_tensor) { - CHECK_NOTNULL(output_tensor); - - // For BackpropFilter, we convert the output tensor back in Tensorflow - // layout. Because typically, BackpropFilter is the last operator in the - // graph that emit filter gradient that is provided to ApplyGradient - // method to update the filter. But it may be possible to eliminate this - // by forwarding filter in MKL layout if we support ApplyGradient method - // for MKL layout propagation. - MklDnnShape output_mkl_shape; - output_mkl_shape.SetMklTensor(false); - // output_dims_mkl_order is in OIHW format. - // Allocate shape of TF tensor in HWIO format. - TensorShape output_tf_shape({output_dims_mkl_order[MklDnnDims::Dim_H], - output_dims_mkl_order[MklDnnDims::Dim_W], - output_dims_mkl_order[MklDnnDims::Dim_I], - output_dims_mkl_order[MklDnnDims::Dim_O]}); - AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape, - output_mkl_shape); + void AllocateOutputTensor( + OpKernelContext* context, + const convolution_backward_weights::primitive_desc& conv_pd, + const memory::dims& output_dims_mkl_order, + memory::format output_tf_format, Tensor** output_tensor) { + CHECK_NOTNULL(output_tensor); + + // For BackpropFilter, we convert the output tensor back in Tensorflow + // layout. Because typically, BackpropFilter is the last operator in the + // graph that emit filter gradient that is provided to ApplyGradient + // method to update the filter. But it may be possible to eliminate this + // by forwarding filter in MKL layout if we support ApplyGradient method + // for MKL layout propagation. + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(false); + // output_dims_mkl_order is in OIHW format. + // Allocate shape of TF tensor in HWIO format. + TensorShape output_tf_shape({output_dims_mkl_order[MklDnnDims::Dim_H], + output_dims_mkl_order[MklDnnDims::Dim_W], + output_dims_mkl_order[MklDnnDims::Dim_I], + output_dims_mkl_order[MklDnnDims::Dim_O]}); + AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape, + output_mkl_shape); } // Allocate tensor for bias grad @@ -600,9 +602,9 @@ class MklConv2DCustomBackpropFilterOp : // Prepare and execute net - checks for input and output reorders. void PrepareAndExecutePrimitive( - const convolution_backward_weights::primitive_desc& conv_pd, - MklDnnData* input, MklDnnData* obp, - MklDnnData* output, MklDnnData* bias_grad = nullptr) { + const convolution_backward_weights::primitive_desc& conv_pd, + MklDnnData* input, MklDnnData* obp, MklDnnData* output, + MklDnnData* bias_grad = nullptr) { // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. std::vector net; @@ -612,15 +614,15 @@ class MklConv2DCustomBackpropFilterOp : // For BackpropFilter, we convert the output tensor back in Tensorflow // layout. bool output_reorder_required = output->PrepareReorderToUserMemIfReq( - conv_pd.diff_weights_primitive_desc()); + conv_pd.diff_weights_primitive_desc()); if (biasEnabled && (bias_grad != nullptr)) { - net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(), - obp->GetOpMem(), output->GetOpMem(), - bias_grad->GetOpMem())); + net.push_back(convolution_backward_weights( + conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem(), + bias_grad->GetOpMem())); } else { - net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(), - obp->GetOpMem(), output->GetOpMem())); + net.push_back(convolution_backward_weights( + conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem())); } if (output_reorder_required) { @@ -631,27 +633,29 @@ class MklConv2DCustomBackpropFilterOp : } }; -#define REGISTER_MKL_FILTER_KERNELS(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConv2DCustomBackpropFilterOp);\ - REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilterWithBias") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConv2DCustomBackpropFilterOp); \ - REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklDummyOp); +#define REGISTER_MKL_FILTER_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklConv2DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConv2DCustomBackpropFilterOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklConv2DBackpropFilterWithBias") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConv2DCustomBackpropFilterOp); \ + REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklDummyOp); TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); #undef REGISTER_MKL_FILTER_KERNELS -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index ef6db58d31f125487bd5beefb53710569b0584d8..eeed0095310280997ebb2ec3e848451df378c4fa 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -23,6 +23,8 @@ limitations under the License. #define EIGEN_USE_THREADS #include #include +#include "mkl_dnn.h" +#include "mkl_dnn_types.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -41,22 +43,20 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" -#include "mkl_dnn.h" -#include "mkl_dnn_types.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" -using mkldnn::stream; -using mkldnn::prop_kind; using mkldnn::convolution_backward_data; +using mkldnn::prop_kind; +using mkldnn::stream; #endif namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML template class MklConv2DCustomBackpropInputOp : public OpKernel { @@ -359,16 +359,15 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { #else template -class MklConv2DCustomBackpropInputOp : - public MklConv2DBackpropCommonOp { +class MklConv2DCustomBackpropInputOp + : public MklConv2DBackpropCommonOp { public: explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context) - : MklConv2DBackpropCommonOp(context) { } + : MklConv2DBackpropCommonOp(context) {} ~MklConv2DCustomBackpropInputOp() {} private: - const int kInputIndex_Filter = 1, - kInputIndex_InputSizes = 0, + const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0, kInputIndex_OutBackProp = 2; void ValidateMklShapes(const MklDnnShape& input_mkl_shape, const MklDnnShape& filter_mkl_shape, @@ -377,7 +376,7 @@ class MklConv2DCustomBackpropInputOp : // of the Tensor and never an actual tensor. So it will never be in MKL // layout. CHECK(!input_mkl_shape.IsMklTensor()) - << "Conv2DBackpropInput: input should not be in MKL Layout"; + << "Conv2DBackpropInput: input should not be in MKL Layout"; } size_t GetInputTensorIndexWithSizes() { return kInputIndex_InputSizes; } @@ -386,8 +385,10 @@ class MklConv2DCustomBackpropInputOp : const Tensor& input_tensor) { TensorShape input_tf_shape; CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true); - CHECK_EQ(TensorShapeUtils::MakeShape(input_tensor.vec(), - &input_tf_shape).ok(), true); + CHECK_EQ( + TensorShapeUtils::MakeShape(input_tensor.vec(), &input_tf_shape) + .ok(), + true); return input_tf_shape; } @@ -414,16 +415,13 @@ class MklConv2DCustomBackpropInputOp : return data_format; } - void CreatePrimitive(OpKernelContext* context, - const engine& cpu_engine, + void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine, const convolution_forward::primitive_desc& conv_fwd_pd, MklDnnData* input, MklDnnData* filter, MklDnnData* outbackprop, MklDnnData* output, - Tensor** output_tensor, - const memory::dims& strides, + Tensor** output_tensor, const memory::dims& strides, const memory::dims& padding_l, - const memory::dims& padding_r, - padding_kind padding, + const memory::dims& padding_r, padding_kind padding, const memory::dims& bwd_output_dims, memory::format bwd_output_format) { CHECK_NOTNULL(context); @@ -434,19 +432,16 @@ class MklConv2DCustomBackpropInputOp : CHECK_NOTNULL(output_tensor); // Create convolution backward data primitive. - auto bwd_desc = convolution_backward_data::desc(convolution_direct, - output->GetOpMemDesc(), filter->GetOpMemDesc(), - outbackprop->GetOpMemDesc(), strides, padding_l, - padding_r, padding); - - auto bwd_pd = convolution_backward_data::primitive_desc(bwd_desc, - cpu_engine, - conv_fwd_pd); + auto bwd_desc = convolution_backward_data::desc( + convolution_direct, output->GetOpMemDesc(), filter->GetOpMemDesc(), + outbackprop->GetOpMemDesc(), strides, padding_l, padding_r, padding); + auto bwd_pd = convolution_backward_data::primitive_desc( + bwd_desc, cpu_engine, conv_fwd_pd); // Allocate output tensor in TensorFlow and MKL layout. - AllocateOutputTensor(context, bwd_pd, bwd_output_dims, - bwd_output_format, output_tensor); + AllocateOutputTensor(context, bwd_pd, bwd_output_dims, bwd_output_format, + output_tensor); CHECK_NOTNULL(*output_tensor); // Set buffer handle using allocated output tensor. output->SetUsrMemDataHandle(*output_tensor); @@ -455,50 +450,50 @@ class MklConv2DCustomBackpropInputOp : } // Allocate output tensor. - void AllocateOutputTensor(OpKernelContext* context, - const convolution_backward_data::primitive_desc& conv_pd, - const memory::dims& output_dims_mkl_order, - memory::format output_tf_format, Tensor** output_tensor) { - CHECK_NOTNULL(output_tensor); - - // Output primitive descriptor for backward data is diff_src. - auto dst_pd = conv_pd.diff_src_primitive_desc(); - - // Allocate shape of Mkl tensor. - MklDnnShape output_mkl_shape; - output_mkl_shape.SetMklTensor(true); - output_mkl_shape.SetMklLayout(&dst_pd); - output_mkl_shape.SetElemType(MklDnnType()); - output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), - output_dims_mkl_order, output_tf_format); - - // Allocate shape of TF tensor. - TensorShape output_tf_shape; - output_tf_shape.AddDim(dst_pd.get_size() / sizeof(T)); - - AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape, - output_mkl_shape); + void AllocateOutputTensor( + OpKernelContext* context, + const convolution_backward_data::primitive_desc& conv_pd, + const memory::dims& output_dims_mkl_order, + memory::format output_tf_format, Tensor** output_tensor) { + CHECK_NOTNULL(output_tensor); + + // Output primitive descriptor for backward data is diff_src. + auto dst_pd = conv_pd.diff_src_primitive_desc(); + + // Allocate shape of Mkl tensor. + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SetMklLayout(&dst_pd); + output_mkl_shape.SetElemType(MklDnnType()); + output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), + output_dims_mkl_order, output_tf_format); + + // Allocate shape of TF tensor. + TensorShape output_tf_shape; + output_tf_shape.AddDim(dst_pd.get_size() / sizeof(T)); + + AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape, + output_mkl_shape); } // Prepare and execute net - checks for input and output reorders. void PrepareAndExecutePrimitive( - const convolution_backward_data::primitive_desc& conv_pd, - MklDnnData* filter, MklDnnData* obp, - MklDnnData* output) { + const convolution_backward_data::primitive_desc& conv_pd, + MklDnnData* filter, MklDnnData* obp, MklDnnData* output) { // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. std::vector net; filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net); obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net); - net.push_back(convolution_backward_data(conv_pd, obp->GetOpMem(), - filter->GetOpMem(), output->GetOpMem())); + net.push_back(convolution_backward_data( + conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem())); stream(stream::kind::eager).submit(net).wait(); } }; -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML #define REGISTER_MKL_CPU_KERNELS(T) \ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 0e77b45993c17815889005c4d313c5489ae2f14b..2953426d5824064952858124882126c154fe6725 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include #include +#include #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -41,15 +41,15 @@ limitations under the License. #include "tensorflow/core/util/mkl_util.h" +#ifndef INTEL_MKL_ML -#ifdef INTEL_MKL_DNN #include "mkldnn.hpp" -using mkldnn::stream; using mkldnn::prop_kind; +using mkldnn::stream; -using mkldnn::convolution_forward; using mkldnn::convolution_direct; +using mkldnn::convolution_forward; #else #include "mkl_dnn.h" #include "mkl_dnn_types.h" @@ -59,8 +59,8 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -// For now, MKL-ML is default. So making MKL-DNN not a default choice. -#ifndef INTEL_MKL_DNN +// MKL-DNN is now default. MKL-ML must be specified explicitly. +#ifdef INTEL_MKL_ML template class MklConv2DOp : public OpKernel { @@ -116,18 +116,19 @@ class MklConv2DOp : public OpKernel { filter.shape().DebugString())); for (int i = 0; i < 3; i++) { - OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i), - std::numeric_limits::max()), - errors::InvalidArgument("filter too large")); + OP_REQUIRES( + context, + FastBoundsCheck(filter.dim_size(i), std::numeric_limits::max()), + errors::InvalidArgument("filter too large")); } const int64 input_depth = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C') : GetTensorDim(input, data_format_, 'C'); - OP_REQUIRES( - context, input_depth == filter.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - input_depth, " vs ", filter.dim_size(2))); + OP_REQUIRES(context, input_depth == filter.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", input_depth, + " vs ", filter.dim_size(2))); // The last dimension for filter is out_depth. const int out_depth = static_cast(filter.dim_size(3)); @@ -136,9 +137,10 @@ class MklConv2DOp : public OpKernel { const int64 input_rows_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H') : GetTensorDim(input, data_format_, 'H'); - OP_REQUIRES(context, FastBoundsCheck(input_rows_raw, - std::numeric_limits::max()), - errors::InvalidArgument("Input rows too large")); + OP_REQUIRES( + context, + FastBoundsCheck(input_rows_raw, std::numeric_limits::max()), + errors::InvalidArgument("Input rows too large")); const int input_rows = static_cast(input_rows_raw); const int filter_rows = static_cast(filter.dim_size(0)); @@ -147,9 +149,10 @@ class MklConv2DOp : public OpKernel { const int64 input_cols_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W') : GetTensorDim(input, data_format_, 'W'); - OP_REQUIRES(context, FastBoundsCheck(input_cols_raw, - std::numeric_limits::max()), - errors::InvalidArgument("Input cols too large")); + OP_REQUIRES( + context, + FastBoundsCheck(input_cols_raw, std::numeric_limits::max()), + errors::InvalidArgument("Input cols too large")); const int input_cols = static_cast(input_cols_raw); const int filter_cols = static_cast(filter.dim_size(1)); @@ -157,9 +160,10 @@ class MklConv2DOp : public OpKernel { const int64 input_batch_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N') : GetTensorDim(input, data_format_, 'N'); - OP_REQUIRES(context, FastBoundsCheck(input_batch_raw, - std::numeric_limits::max()), - errors::InvalidArgument("batch is too large")); + OP_REQUIRES( + context, + FastBoundsCheck(input_batch_raw, std::numeric_limits::max()), + errors::InvalidArgument("batch is too large")); const int batch = static_cast(input_batch_raw); // For now we take the stride from the second and third dimensions only (we @@ -313,8 +317,7 @@ class MklConv2DOp : public OpKernel { // Temp tensor used to allocate tmp buffers Tensor mkl_tmp_input_buf_tensor, mkl_tmp_filter_buf_tensor, mkl_tmp_bias_buf_tensor; - mkl_context.MklPrepareConvolutionInputs(context, - &mkl_tmp_input_buf_tensor, + mkl_context.MklPrepareConvolutionInputs(context, &mkl_tmp_input_buf_tensor, &mkl_tmp_filter_buf_tensor, &mkl_tmp_bias_buf_tensor); @@ -398,8 +401,9 @@ class MklConv2DOp : public OpKernel { mkl_convert_input = !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input); if (mkl_convert_input) { - CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, - lt_input, mkl_lt_internal_input), E_SUCCESS); + CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input, + mkl_lt_internal_input), + E_SUCCESS); AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, &mkl_buf_convert_input); CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input, @@ -517,8 +521,8 @@ class MklConv2DOp : public OpKernel { GetMklShape(context, kInputIndex_Src, &src_mkl_shape); GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape); OP_REQUIRES(context, filter_mkl_shape.IsMklTensor() == false, - errors::InvalidArgument("Filter should not be in " - "Mkl Layout")); + errors::InvalidArgument("Filter should not be in " + "Mkl Layout")); MklDnnData src(&cpu_engine); MklDnnData filter(&cpu_engine); @@ -531,11 +535,10 @@ class MklConv2DOp : public OpKernel { MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); auto src_tf_shape = GetTfShape(context, kInputIndex_Src); auto filter_tf_shape = GetTfShape(context, kInputIndex_Filter); - conv_utl.GetConvFwdSizesInMklOrder(src_tf_shape, filter_tf_shape, - &src_dims, &filter_dims, &strides, - &output_dims_tf_order, - &output_dims_mkl_order, &padding_l, - &padding_r); + conv_utl.GetConvFwdSizesInMklOrder( + src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides, + &output_dims_tf_order, &output_dims_mkl_order, &padding_l, + &padding_r); if (!context->status().ok()) return; // Check for corner case - if there is nothing to compute, return. @@ -543,21 +546,20 @@ class MklConv2DOp : public OpKernel { // Corner cases: output with 0 elements and 0 batch size. Tensor* output_tensor = nullptr; - if (output_tf_shape.num_elements() == 0 || - output_dims_tf_order[0] == 0) { + if (output_tf_shape.num_elements() == 0 || output_dims_tf_order[0] == 0) { // TODO(jbobba): Verify correctness here // Need semantics for Null MKL tensor MklDnnShape output_mkl_shape; output_mkl_shape.SetMklTensor(false); AllocateOutputSetMklShape(context, kOutputIndex_Dst, &output_tensor, - src_tf_shape, output_mkl_shape); + src_tf_shape, output_mkl_shape); // MklConv2D also outputs converted filter as 2nd output of Conv2D. filter_mkl_shape.SetMklTensor(false); Tensor* output_filter_tensor = nullptr; AllocateOutputSetMklShape(context, kOutputIndex_Filter, - &output_filter_tensor, - filter_tf_shape, filter_mkl_shape); + &output_filter_tensor, filter_tf_shape, + filter_mkl_shape); return; } @@ -570,14 +572,15 @@ class MklConv2DOp : public OpKernel { // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's // layout (NHWC or NCHW depending on data format). auto src_md = src_mkl_shape.IsMklTensor() - ? src_mkl_shape.GetMklLayout() - : memory::desc(src_dims, MklDnnType(), tf_fmt); + ? src_mkl_shape.GetMklLayout() + : memory::desc(src_dims, MklDnnType(), tf_fmt); src.SetUsrMem(src_md, &src_tensor); // Although filter shape (filter_dims) required is in MKL-DNN order, // the layout is Tensorflow's layout (HWIO). auto filter_md = filter_mkl_shape.IsMklTensor() // Should NEVER be true - ? filter_mkl_shape.GetMklLayout() - : memory::desc(filter_dims, MklDnnType(), memory::format::hwio); + ? filter_mkl_shape.GetMklLayout() + : memory::desc(filter_dims, MklDnnType(), + memory::format::hwio); filter.SetUsrMem(filter_md, &filter_tensor); // Set output shape (output_dims) required in MKL-DNN order. @@ -601,34 +604,34 @@ class MklConv2DOp : public OpKernel { bias.SetOpMemDesc(bias_size, memory::format::any); // Create convolution primitive with Bias. - auto conv_desc = convolution_forward::desc(prop_kind::forward, - convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(), - bias.GetOpMemDesc(), output.GetOpMemDesc(), strides, - padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); - - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, - cpu_engine); - AllocateOutputTensor(context, conv_prim_desc, - output_dims_mkl_order, tf_fmt, &output_tensor); + auto conv_desc = convolution_forward::desc( + prop_kind::forward, convolution_direct, src.GetOpMemDesc(), + filter.GetOpMemDesc(), bias.GetOpMemDesc(), output.GetOpMemDesc(), + strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); + + auto conv_prim_desc = + convolution_forward::primitive_desc(conv_desc, cpu_engine); + AllocateOutputTensor(context, conv_prim_desc, output_dims_mkl_order, + tf_fmt, &output_tensor); // Set data handle for output. output.SetUsrMemDataHandle(output_tensor); Tensor* filter_out_tensor = nullptr; AllocateFilterOutputTensor(context, conv_prim_desc, - TFShapeToMklDnnDims(filter_tf_shape), - &filter_out_tensor); + TFShapeToMklDnnDims(filter_tf_shape), + &filter_out_tensor); - PrepareAndExecuteNet(conv_prim_desc, &src, &filter, - &bias, &output, filter_out_tensor); + PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output, + filter_out_tensor); } else { // Create convolution primitive without Bias. - auto conv_desc = convolution_forward::desc(prop_kind::forward, - convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(), - output.GetOpMemDesc(), strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_)); + auto conv_desc = convolution_forward::desc( + prop_kind::forward, convolution_direct, src.GetOpMemDesc(), + filter.GetOpMemDesc(), output.GetOpMemDesc(), strides, padding_l, + padding_r, TFPaddingToMklDnnPadding(padding_)); - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, - cpu_engine); + auto conv_prim_desc = + convolution_forward::primitive_desc(conv_desc, cpu_engine); AllocateOutputTensor(context, conv_prim_desc, output_dims_mkl_order, tf_fmt, &output_tensor); // Set data handle for output. @@ -636,18 +639,18 @@ class MklConv2DOp : public OpKernel { Tensor* filter_out_tensor = nullptr; AllocateFilterOutputTensor(context, conv_prim_desc, - TFShapeToMklDnnDims(filter_tf_shape), - &filter_out_tensor); - PrepareAndExecuteNet(conv_prim_desc, &src, &filter, - nullptr, &output, filter_out_tensor); + TFShapeToMklDnnDims(filter_tf_shape), + &filter_out_tensor); + PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output, + filter_out_tensor); } - } catch (mkldnn::error &e) { + } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + std::string(e.message) + - ", in file " + std::string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Operation received an exception:", error_msg)); + ", message: " + std::string(e.message) + ", in file " + + std::string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } @@ -655,71 +658,67 @@ class MklConv2DOp : public OpKernel { std::vector strides_; Padding padding_; TensorFormat data_format_; - const int kInputIndex_Src = 0, - kInputIndex_Filter = 1, - kInputIndex_Bias = 2; + const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2; const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1; // Allocate output tensor. void AllocateOutputTensor( - OpKernelContext* context, - const convolution_forward::primitive_desc& conv_prim_desc, - const memory::dims& output_dims_mkl_order, - memory::format output_tf_format, Tensor** output_tensor) { - CHECK_NOTNULL(output_tensor); - auto dst_pd = conv_prim_desc.dst_primitive_desc(); - - // Allocate shape of Mkl tensor. - MklDnnShape output_mkl_shape; - output_mkl_shape.SetMklTensor(true); - output_mkl_shape.SetMklLayout(&dst_pd); - output_mkl_shape.SetElemType(MklDnnType()); - output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), - output_dims_mkl_order, output_tf_format); - - // Allocate shape of TF tensor. - TensorShape output_tf_shape; - output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T))); - - AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, - output_tf_shape, output_mkl_shape); + OpKernelContext* context, + const convolution_forward::primitive_desc& conv_prim_desc, + const memory::dims& output_dims_mkl_order, + memory::format output_tf_format, Tensor** output_tensor) { + CHECK_NOTNULL(output_tensor); + auto dst_pd = conv_prim_desc.dst_primitive_desc(); + + // Allocate shape of Mkl tensor. + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SetMklLayout(&dst_pd); + output_mkl_shape.SetElemType(MklDnnType()); + output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), + output_dims_mkl_order, output_tf_format); + + // Allocate shape of TF tensor. + TensorShape output_tf_shape; + output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T))); + + AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, + output_tf_shape, output_mkl_shape); } // Allocate output tensor. void AllocateFilterOutputTensor( - OpKernelContext* context, - const convolution_forward::primitive_desc& conv_prim_desc, - const memory::dims& filter_dims_tf_order, - Tensor** filter_tensor) { - CHECK_NOTNULL(filter_tensor); - auto filter_pd = conv_prim_desc.weights_primitive_desc(); - - // Allocate shape of Mkl tensor. - MklDnnShape filter_mkl_shape; - filter_mkl_shape.SetMklTensor(true); - filter_mkl_shape.SetMklLayout(&filter_pd); - filter_mkl_shape.SetElemType(MklDnnType()); - - // The format of the filter is actually OIhw8i8o, but TF doesn't support - // this format. Just use format::blocked for now because the layout - // is stored in the MKL data. - filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(), - filter_dims_tf_order, memory::format::blocked); - - // Allocate the data space for the filter to propagate as TF tensor. - TensorShape filter_tf_shape; - filter_tf_shape.AddDim((filter_pd.get_size() / sizeof(T))); - - AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor, - filter_tf_shape, filter_mkl_shape); + OpKernelContext* context, + const convolution_forward::primitive_desc& conv_prim_desc, + const memory::dims& filter_dims_tf_order, Tensor** filter_tensor) { + CHECK_NOTNULL(filter_tensor); + auto filter_pd = conv_prim_desc.weights_primitive_desc(); + + // Allocate shape of Mkl tensor. + MklDnnShape filter_mkl_shape; + filter_mkl_shape.SetMklTensor(true); + filter_mkl_shape.SetMklLayout(&filter_pd); + filter_mkl_shape.SetElemType(MklDnnType()); + + // The format of the filter is actually OIhw8i8o, but TF doesn't support + // this format. Just use format::blocked for now because the layout + // is stored in the MKL data. + filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(), + filter_dims_tf_order, memory::format::blocked); + + // Allocate the data space for the filter to propagate as TF tensor. + TensorShape filter_tf_shape; + filter_tf_shape.AddDim((filter_pd.get_size() / sizeof(T))); + + AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor, + filter_tf_shape, filter_mkl_shape); } // Prepare and execute net - checks for input and output reorders. void PrepareAndExecuteNet( - const convolution_forward::primitive_desc& conv_prim_desc, - MklDnnData* src, MklDnnData* filter, - MklDnnData* bias, MklDnnData* output, - Tensor* filter_out_tensor) { + const convolution_forward::primitive_desc& conv_prim_desc, + MklDnnData* src, MklDnnData* filter, MklDnnData* bias, + MklDnnData* output, Tensor* filter_out_tensor) { CHECK_NOTNULL(filter_out_tensor); // Create reorders between user layout and MKL layout if it is needed and @@ -731,18 +730,20 @@ class MklConv2DOp : public OpKernel { // rather than re-order to a temp buffer, reorder directly to the // filter output tensor filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(), - filter->GetTensorBuffer(filter_out_tensor), &net); + filter->GetTensorBuffer(filter_out_tensor), + &net); // Create convolution primitive and add it to net. if (bias) { CHECK_EQ(biasEnabled, true); net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(), - filter->GetOpMem(), bias->GetOpMem(), - output->GetOpMem())); + filter->GetOpMem(), bias->GetOpMem(), + output->GetOpMem())); } else { CHECK_EQ(biasEnabled, false); net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(), - filter->GetOpMem(), output->GetOpMem())); + filter->GetOpMem(), + output->GetOpMem())); } stream(stream::kind::eager).submit(net).wait(); diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index c6456bd5c330d8a5672a99dc7f649f3bab4d3519..9dd88221a84671e1f69df13cca1b62b2ce65bb4e 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ #define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ -#include #include #include +#include #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -27,8 +27,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -40,19 +40,19 @@ limitations under the License. #include "tensorflow/core/util/mkl_util.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" -using mkldnn::stream; using mkldnn::prop_kind; +using mkldnn::stream; -using mkldnn::convolution_forward; using mkldnn::convolution_direct; +using mkldnn::convolution_forward; #endif namespace tensorflow { -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML class MklDnnConvUtil { protected: @@ -63,13 +63,13 @@ class MklDnnConvUtil { public: MklDnnConvUtil(OpKernelContext* context, const std::vector& strides, - Padding pad, TensorFormat fm) : context_(context), - strides_(strides), padding_(pad), data_format_(fm) {} + Padding pad, TensorFormat fm) + : context_(context), strides_(strides), padding_(pad), data_format_(fm) {} virtual ~MklDnnConvUtil() { context_ = nullptr; } // Calculate Convolution strides - virtual inline void GetStridesInMklOrder(memory::dims *strides) { + virtual inline void GetStridesInMklOrder(memory::dims* strides) { // For now we take the stride from the second and third dimensions only // (we do not support striding on the batch or depth dimension). CHECK_NOTNULL(strides); @@ -82,14 +82,14 @@ class MklDnnConvUtil { // requires input in NCHW format. Function does not return anything. // But errors arising from sanity checks are returned in context's // status. - virtual inline void - GetInputSizeInMklOrder(const TensorShape& input_shape, - memory::dims *input_dims) { - #define CHECK_BOUNDS(val, err_msg) do { \ - OP_REQUIRES(context_, FastBoundsCheck(val, \ - std::numeric_limits::max()), \ - errors::InvalidArgument(err_msg)); \ - }while(0) + virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape, + memory::dims* input_dims) { +#define CHECK_BOUNDS(val, err_msg) \ + do { \ + OP_REQUIRES(context_, \ + FastBoundsCheck(val, std::numeric_limits::max()), \ + errors::InvalidArgument(err_msg)); \ + } while (0) CHECK_NOTNULL(input_dims); @@ -112,7 +112,7 @@ class MklDnnConvUtil { CHECK_BOUNDS(input_batch_raw, "Input batch too large"); int input_batch = static_cast(input_batch_raw); - #undef CHECK_BOUNDS +#undef CHECK_BOUNDS // MKL-DNN always requires input in NCHW format. std::vector mkldnn_sizes(4, -1); @@ -138,10 +138,9 @@ class MklDnnConvUtil { // forward gets actual tensor as input). // // TODO(nhasabni): Add similar function for input and filter in MklShape. - virtual inline void - GetFilterSizeInMklOrder(const TensorShape& input_shape, - const TensorShape& filter_shape, - memory::dims *filter_dims) { + virtual inline void GetFilterSizeInMklOrder(const TensorShape& input_shape, + const TensorShape& filter_shape, + memory::dims* filter_dims) { CHECK_NOTNULL(filter_dims); OP_REQUIRES(context_, filter_shape.dims() == 4, @@ -149,17 +148,18 @@ class MklDnnConvUtil { filter_shape.DebugString())); for (int i = 0; i < 3; i++) { - OP_REQUIRES(context_, FastBoundsCheck(filter_shape.dim_size(i), - std::numeric_limits::max()), - errors::InvalidArgument("filter too large")); + OP_REQUIRES(context_, + FastBoundsCheck(filter_shape.dim_size(i), + std::numeric_limits::max()), + errors::InvalidArgument("filter too large")); } int input_depth = GetTensorDim(input_shape, data_format_, 'C'); - OP_REQUIRES( - context_, input_depth == filter_shape.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - input_depth, " vs ", filter_shape.dim_size(2))); + OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", input_depth, + " vs ", filter_shape.dim_size(2))); // TF filter is always in (rows, cols, in_depth, out_depth) order. int filter_rows = static_cast(filter_shape.dim_size(0)); @@ -182,25 +182,24 @@ class MklDnnConvUtil { // requires filter in OIHW format. Function does not return anything. // But errors arising from sanity checks are returned in context's // status. - virtual inline void - GetFilterSizeInMklOrder(size_t src_index, size_t filter_index, - memory::dims *filter_dims) { + virtual inline void GetFilterSizeInMklOrder(size_t src_index, + size_t filter_index, + memory::dims* filter_dims) { CHECK_NOTNULL(filter_dims); GetFilterSizeInMklOrder(GetTfShape(context_, src_index), - GetTfShape(context_, filter_index), - filter_dims); + GetTfShape(context_, filter_index), filter_dims); } // Calculate Bias size for 2D Convolution. Function does not return // anything, but sets error in context status. - virtual inline void - GetBiasSizeInMklOrder(size_t bias_index, memory::dims *bias_dims) { + virtual inline void GetBiasSizeInMklOrder(size_t bias_index, + memory::dims* bias_dims) { const Tensor& bias = MklGetInput(context_, bias_index); OP_REQUIRES(context_, bias.dims() == 1, errors::InvalidArgument("bias must be 1-dimensional: ", bias.shape().DebugString())); - *bias_dims = { static_cast(bias.dim_size(0)) }; + *bias_dims = {static_cast(bias.dim_size(0))}; } // Function to calculate output and padding size for 2D convolution. @@ -212,13 +211,11 @@ class MklDnnConvUtil { // status is returned via context status. // // TODO(nhasabni): Add similar function for input and filter in MklShape. - virtual inline void - GetOutputAndPadSizeInMklOrder(const TensorShape& input_shape, - const TensorShape& filter_shape, - const memory::dims& strides, - memory::dims *output_dims_tf_order, - memory::dims *output_dims_mkl_order, - memory::dims *pad_l, memory::dims *pad_r) { + virtual inline void GetOutputAndPadSizeInMklOrder( + const TensorShape& input_shape, const TensorShape& filter_shape, + const memory::dims& strides, memory::dims* output_dims_tf_order, + memory::dims* output_dims_mkl_order, memory::dims* pad_l, + memory::dims* pad_r) { CHECK_NOTNULL(output_dims_tf_order); CHECK_NOTNULL(output_dims_mkl_order); CHECK_NOTNULL(pad_l); @@ -244,16 +241,16 @@ class MklDnnConvUtil { int64 out_rows = 0, out_cols = 0; int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right; - OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerbose(input_rows, filter_rows, stride_rows, - padding_, &out_rows, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerbose(input_cols, filter_cols, stride_cols, - padding_, &out_cols, &pad_left, &pad_right)); + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( + input_rows, filter_rows, stride_rows, padding_, + &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( + input_cols, filter_cols, stride_cols, padding_, + &out_cols, &pad_left, &pad_right)); // Tensorflow output is in data_format order. (NHWC or NCHW) - TensorShape out_shape = ShapeFromFormat(data_format_, out_batch, - out_rows, out_cols, out_depth); + TensorShape out_shape = + ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth); *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); // MKL-DNN always needs output in NCHW format. @@ -273,12 +270,10 @@ class MklDnnConvUtil { // See comment on GetConvOutputAndPadSizeInMklOrder for parameters. // // Function does not return anything, but sets error in context status. - inline void - GetOutputAndPadSizeInMklOrder(size_t src_index, size_t filter_index, - const memory::dims& strides, - memory::dims *output_dims_tf_order, - memory::dims *output_dims_mkl_order, - memory::dims *pad_l, memory::dims *pad_r) { + inline void GetOutputAndPadSizeInMklOrder( + size_t src_index, size_t filter_index, const memory::dims& strides, + memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, + memory::dims* pad_l, memory::dims* pad_r) { CHECK_NOTNULL(output_dims_tf_order); CHECK_NOTNULL(output_dims_mkl_order); CHECK_NOTNULL(pad_l); @@ -289,11 +284,11 @@ class MklDnnConvUtil { OP_REQUIRES(context_, input_tf_shape.dims() == 4, errors::InvalidArgument("input must be 4-dimensional", - input_tf_shape.DebugString())); + input_tf_shape.DebugString())); - GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, - strides, output_dims_tf_order, - output_dims_mkl_order, pad_l, pad_r); + GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides, + output_dims_tf_order, output_dims_mkl_order, + pad_l, pad_r); } // Wrapper function to calculate input, filter, and output sizes of @@ -302,15 +297,12 @@ class MklDnnConvUtil { // also calculates strides and paddings for 2D Convolution. // // Function does not return anything, but sets error in context status. - inline void GetConvFwdSizesInMklOrder(const TensorShape& input_shape, - const TensorShape& filter_shape, - memory::dims *input_dims, - memory::dims *filter_dims, - memory::dims *strides, - memory::dims *output_dims_tf_order, - memory::dims *output_dims_mkl_order, - memory::dims *pad_l, - memory::dims *pad_r) { + inline void GetConvFwdSizesInMklOrder( + const TensorShape& input_shape, const TensorShape& filter_shape, + memory::dims* input_dims, memory::dims* filter_dims, + memory::dims* strides, memory::dims* output_dims_tf_order, + memory::dims* output_dims_mkl_order, memory::dims* pad_l, + memory::dims* pad_r) { CHECK_NOTNULL(input_dims); CHECK_NOTNULL(filter_dims); CHECK_NOTNULL(strides); @@ -325,8 +317,7 @@ class MklDnnConvUtil { if (!context_->status().ok()) return; GetStridesInMklOrder(strides); GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides, - output_dims_tf_order, - output_dims_mkl_order, + output_dims_tf_order, output_dims_mkl_order, pad_l, pad_r); if (!context_->status().ok()) return; } @@ -337,7 +328,7 @@ class MklDnnConvUtil { ///////////////////////////////////////////////////////////////////// template -class MklConv2DBackpropCommonOp : public OpKernel { +class MklConv2DBackpropCommonOp : public OpKernel { public: ~MklConv2DBackpropCommonOp() {} explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context) @@ -397,12 +388,11 @@ class MklConv2DBackpropCommonOp : public OpKernel { outbprop_tf_shape.num_elements() == 0) { MklDnnShape output_mkl_shape; output_mkl_shape.SetMklTensor(false); - TensorShape output_tf_shape = GetOutputTfShape(input_tf_shape, - filter_tf_shape, - outbprop_tf_shape); + TensorShape output_tf_shape = GetOutputTfShape( + input_tf_shape, filter_tf_shape, outbprop_tf_shape); const int kOutputIdx = 0; AllocateOutputSetMklShape(context, kOutputIdx, &output_tensor, - output_tf_shape, output_mkl_shape); + output_tf_shape, output_mkl_shape); CHECK_NOTNULL(output_tensor); // if output tensor has more than 0 elements, we need to 0 them out. @@ -421,12 +411,10 @@ class MklConv2DBackpropCommonOp : public OpKernel { // Get forward convolution parameters. MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); - conv_utl.GetConvFwdSizesInMklOrder(input_tf_shape, filter_tf_shape, - &fwd_input_dims, &fwd_filter_dims, - &strides, - &fwd_output_dims_tf_order, - &fwd_output_dims, - &padding_l, &padding_r); + conv_utl.GetConvFwdSizesInMklOrder( + input_tf_shape, filter_tf_shape, &fwd_input_dims, &fwd_filter_dims, + &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l, + &padding_r); if (!context->status().ok()) return; // Create Convolution forward descriptor since Convolution backward @@ -437,20 +425,22 @@ class MklConv2DBackpropCommonOp : public OpKernel { // construct input TF layout. For TF layout, although input shape // required is in MKL-DNN order, the layout is Tensorflow's layout // (NHWC or NCHW depending on data format). - auto fwd_input_md = input_mkl_shape.IsMklTensor() ? - input_mkl_shape.GetMklLayout() : - memory::desc(fwd_input_dims, MklDnnType(), tf_fmt); + auto fwd_input_md = + input_mkl_shape.IsMklTensor() + ? input_mkl_shape.GetMklLayout() + : memory::desc(fwd_input_dims, MklDnnType(), tf_fmt); // If filter is in MKL layout, then simply grab filter layout; otherwise // construct filter in TF layout. For TF layout, filter is in HWIO format. - auto fwd_filter_md = filter_mkl_shape.IsMklTensor() ? - filter_mkl_shape.GetMklLayout() : - memory::desc(fwd_filter_dims, MklDnnType(), - memory::format::hwio); + auto fwd_filter_md = filter_mkl_shape.IsMklTensor() + ? filter_mkl_shape.GetMklLayout() + : memory::desc(fwd_filter_dims, MklDnnType(), + memory::format::hwio); // Tensorflow Output of Conv2D is in data_format order. auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType(), tf_fmt); - auto fwd_desc = convolution_forward::desc(prop_kind::forward, - convolution_direct, fwd_input_md, fwd_filter_md, fwd_out_md, - strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); + auto fwd_desc = convolution_forward::desc( + prop_kind::forward, convolution_direct, fwd_input_md, fwd_filter_md, + fwd_out_md, strides, padding_l, padding_r, + TFPaddingToMklDnnPadding(padding_)); auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine); // Create memory for user data. Describe how the inputs and outputs of @@ -495,17 +485,16 @@ class MklConv2DBackpropCommonOp : public OpKernel { // Operator-specific call to create and execute primitive. CreatePrimitive(context, cpu_engine, fwd_pd, &input, &filter, - &outbackprop, &output, &output_tensor, - strides, padding_l, padding_r, - TFPaddingToMklDnnPadding(padding_), + &outbackprop, &output, &output_tensor, strides, padding_l, + padding_r, TFPaddingToMklDnnPadding(padding_), bwd_output_dims, bwd_output_format); - } catch (mkldnn::error &e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:", - error_msg)); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } @@ -523,11 +512,11 @@ class MklConv2DBackpropCommonOp : public OpKernel { /// Get TensorFlow shape of input tensor. virtual TensorShape MakeInputTfShape(OpKernelContext* context, - const Tensor& input_tensor) = 0; + const Tensor& input_tensor) = 0; /// Get TensorFlow shape of filter tensor. virtual TensorShape MakeFilterTfShape(OpKernelContext* context, - const Tensor& filter_tensor) = 0; + const Tensor& filter_tensor) = 0; /// Get the TensorFlow shape of output tensor. virtual TensorShape GetOutputTfShape(const TensorShape& input_shape, @@ -536,9 +525,9 @@ class MklConv2DBackpropCommonOp : public OpKernel { /// Get shape of output in MKL-DNN order. Computes shape of output from /// input shape (fwd_input_dims) and filter shape (fwd_filter_dims). - virtual - const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims, - const memory::dims& fwd_filter_dims) = 0; + virtual const memory::dims& GetOutputDims( + const memory::dims& fwd_input_dims, + const memory::dims& fwd_filter_dims) = 0; /// Get data_format of output in MKL-DNN order. If output data format is /// same as input data format, then it simply returns value of data_format @@ -546,24 +535,25 @@ class MklConv2DBackpropCommonOp : public OpKernel { virtual memory::format GetOutputFormat(const memory::format data_format) = 0; /// Create and execute the primitive storing output in the output_tensor. - virtual void CreatePrimitive(OpKernelContext* context, - const engine& cpu_engine, - const convolution_forward::primitive_desc& conv_fwd_pd, - MklDnnData* input, MklDnnData* filter, MklDnnData* outbackprop, - MklDnnData* output, Tensor** output_tensor, const memory::dims& strides, - const memory::dims& padding_l, const memory::dims& padding_r, - padding_kind padding, const memory::dims& bwd_output_dims, - memory::format bwd_output_format) = 0; + virtual void CreatePrimitive( + OpKernelContext* context, const engine& cpu_engine, + const convolution_forward::primitive_desc& conv_fwd_pd, + MklDnnData* input, MklDnnData* filter, MklDnnData* outbackprop, + MklDnnData* output, Tensor** output_tensor, + const memory::dims& strides, const memory::dims& padding_l, + const memory::dims& padding_r, padding_kind padding, + const memory::dims& bwd_output_dims, + memory::format bwd_output_format) = 0; // Get the data_format {NCHW, NHWC} - TensorFormat GetTFDataFormat () { return data_format_; } + TensorFormat GetTFDataFormat() { return data_format_; } private: std::vector strides_; Padding padding_; TensorFormat data_format_; }; -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML ///////////////////////////////////////////////////////////////////// /// Dummy Mkl op that is just used for operators that are intermediate @@ -575,12 +565,12 @@ class MklDummyOp : public OpKernel { public: ~MklDummyOp() {} - explicit MklDummyOp(OpKernelConstruction* context) : - OpKernel(context) {} + explicit MklDummyOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - TF_CHECK_OK(errors::Unimplemented("This is a dummy op." - "It should not have been invoked.")); + TF_CHECK_OK( + errors::Unimplemented("This is a dummy op." + "It should not have been invoked.")); } }; diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc index c065724e0dbbe091d253eb2315c9a5f3c041d695..58f0c30f32b0eebd7ceff856b2e3bd881b28121c 100644 --- a/tensorflow/core/kernels/mkl_cwise_ops_common.cc +++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0(the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 8340a91d059de16dfbabf53067f24fbca1bc1385..8313224d7fe3e2d307d3642ced5b277b95c85cdb 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -25,15 +25,15 @@ limitations under the License. #include "mkl_dnn_types.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" -using mkldnn::stream; +using mkldnn::batch_normalization_backward; +using mkldnn::batch_normalization_forward; using mkldnn::prop_kind; -using mkldnn::use_scale_shift; +using mkldnn::stream; using mkldnn::use_global_stats; -using mkldnn::batch_normalization_forward; -using mkldnn::batch_normalization_backward; +using mkldnn::use_scale_shift; #endif // TODO(inteltf) Address comments from PR 8968. @@ -41,7 +41,7 @@ using mkldnn::batch_normalization_backward; namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML template class MklFusedBatchNormOp : public OpKernel { @@ -601,7 +601,7 @@ class MklFusedBatchNormGradOp : public OpKernel { mkl_res_batchnorm_bwd[dnnResourceSrc] = (mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input; - bool mkl_convert_out_backprop; + bool mkl_convert_out_backprop; dnnPrimitive_t mkl_prim_convert_out_backprop = nullptr; dnnLayout_t mkl_lt_internal_out_backprop = nullptr; void* mkl_buf_converted_out_backprop = nullptr; @@ -683,7 +683,7 @@ class MklFusedBatchNormGradOp : public OpKernel { }; #endif -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML template class MklFusedBatchNormOp : public OpKernel { @@ -709,12 +709,11 @@ class MklFusedBatchNormOp : public OpKernel { const size_t kMeanIndex = 3; // index of est_mean tensor const size_t kVarianceIndex = 4; // index of est_variance tensor - const Tensor& src_tensor = MklGetInput(context, kSrcIndex); - const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); - const Tensor& shift_tensor = MklGetInput(context, kShiftIndex); - const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex); - const Tensor& est_variance_tensor = MklGetInput(context, - kVarianceIndex); + const Tensor& src_tensor = MklGetInput(context, kSrcIndex); + const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); + const Tensor& shift_tensor = MklGetInput(context, kShiftIndex); + const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex); + const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex); TensorShape tf_shape_src; MklDnnShape dnn_shape_src; @@ -723,37 +722,34 @@ class MklFusedBatchNormOp : public OpKernel { if (dnn_shape_src.IsMklTensor()) { tf_shape_src = dnn_shape_src.GetTfShape(); OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, - errors::InvalidArgument( - "input must be 4-dimensional", - src_tensor.shape().DebugString())); + errors::InvalidArgument("input must be 4-dimensional", + src_tensor.shape().DebugString())); } else { tf_shape_src = src_tensor.shape(); OP_REQUIRES(context, src_tensor.dims() == 4, - errors::InvalidArgument( - "input must be 4-dimensional", - src_tensor.shape().DebugString())); + errors::InvalidArgument("input must be 4-dimensional", + src_tensor.shape().DebugString())); } OP_REQUIRES(context, scale_tensor.dims() == 1, - errors::InvalidArgument( - "scale must be 1-dimensional", - scale_tensor.shape().DebugString())); + errors::InvalidArgument("scale must be 1-dimensional", + scale_tensor.shape().DebugString())); OP_REQUIRES(context, shift_tensor.dims() == 1, errors::InvalidArgument("offset must be 1-dimensional", - shift_tensor.shape().DebugString())); - OP_REQUIRES(context, est_mean_tensor.dims() == 1, - errors::InvalidArgument( - "estimated_mean must be 1-dimensional", - est_mean_tensor.shape().DebugString())); - OP_REQUIRES(context, est_variance_tensor.dims() == 1, - errors::InvalidArgument( - "estimated_variance must be 1-dimensional", - est_variance_tensor.shape().DebugString())); + shift_tensor.shape().DebugString())); + OP_REQUIRES( + context, est_mean_tensor.dims() == 1, + errors::InvalidArgument("estimated_mean must be 1-dimensional", + est_mean_tensor.shape().DebugString())); + OP_REQUIRES( + context, est_variance_tensor.dims() == 1, + errors::InvalidArgument("estimated_variance must be 1-dimensional", + est_variance_tensor.shape().DebugString())); if (is_training_) { - OP_REQUIRES(context, est_mean_tensor.dim_size(0) == 0, - errors::InvalidArgument( - "estimated_mean must be empty for training", - est_mean_tensor.shape().DebugString())); + OP_REQUIRES( + context, est_mean_tensor.dim_size(0) == 0, + errors::InvalidArgument("estimated_mean must be empty for training", + est_mean_tensor.shape().DebugString())); OP_REQUIRES(context, est_variance_tensor.dim_size(0) == 0, errors::InvalidArgument( "estimated_variance must be empty for training", @@ -763,11 +759,9 @@ class MklFusedBatchNormOp : public OpKernel { // special case: input with 0 element and 0 batch size Tensor* dst_tensor = nullptr; if (tf_shape_src.num_elements() == 0) { - HandleEmptyInput(context, - tf_shape_src, - scale_tensor.shape(), - &dst_tensor); - return; + HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), + &dst_tensor); + return; } if (dnn_shape_src.IsMklTensor()) @@ -783,11 +777,8 @@ class MklFusedBatchNormOp : public OpKernel { Tensor* batch_variance_tensor = nullptr; Tensor* saved_mean_tensor = nullptr; Tensor* saved_variance_tensor = nullptr; - AllocateTFOutputs(context, - scale_tensor.shape(), - &batch_mean_tensor, - &batch_variance_tensor, - &saved_mean_tensor, + AllocateTFOutputs(context, scale_tensor.shape(), &batch_mean_tensor, + &batch_variance_tensor, &saved_mean_tensor, &saved_variance_tensor); if (is_training_) @@ -815,69 +806,63 @@ class MklFusedBatchNormOp : public OpKernel { src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), tensor_format_); } else { - src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), - tensor_format_); + src_dims = + TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); } auto src_md = dnn_shape_src.IsMklTensor() - ? dnn_shape_src.GetMklLayout() - : memory::desc(src_dims, MklDnnType(), format_m); + ? dnn_shape_src.GetMklLayout() + : memory::desc(src_dims, MklDnnType(), format_m); src.SetUsrMem(src_md, &src_tensor); // set weights primitive // MKL-DNN packs scale & shift as "weights": // ...... - auto weights_desc = memory::desc({2, depth_}, - MklDnnType(), - memory::format::nc); + auto weights_desc = + memory::desc({2, depth_}, MklDnnType(), memory::format::nc); auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); auto weights_m = memory(weights_pd); - T* weights_data = reinterpret_cast( - weights_m.get_data_handle()); - T* scale_tf = reinterpret_cast( - const_cast(scale_tensor.flat().data())); - T* shift_tf = reinterpret_cast( - const_cast(shift_tensor.flat().data())); - - for (int k=0; k < depth_; k++) { + T* weights_data = reinterpret_cast(weights_m.get_data_handle()); + T* scale_tf = + reinterpret_cast(const_cast(scale_tensor.flat().data())); + T* shift_tf = + reinterpret_cast(const_cast(shift_tensor.flat().data())); + + for (int k = 0; k < depth_; k++) { weights_data[k] = scale_tf[k]; weights_data[k + depth_] = shift_tf[k]; } // set mean primitive - auto mean_desc = memory::desc({1, depth_}, - MklDnnType(), - memory::format::nc); + auto mean_desc = + memory::desc({1, depth_}, MklDnnType(), memory::format::nc); auto mean_pd = memory::primitive_desc(mean_desc, cpu_engine); - char* saved_mean_data_tf = reinterpret_cast - (saved_mean_tensor->flat().data()); - std::memcpy(saved_mean_data_tf, - reinterpret_cast(mean_values_), - depth_*sizeof(T)); - auto mean_m = memory(mean_pd, - reinterpret_cast(saved_mean_data_tf)); + char* saved_mean_data_tf = + reinterpret_cast(saved_mean_tensor->flat().data()); + std::memcpy(saved_mean_data_tf, reinterpret_cast(mean_values_), + depth_ * sizeof(T)); + auto mean_m = + memory(mean_pd, reinterpret_cast(saved_mean_data_tf)); // set variance primitive - auto variance_desc = memory::desc({1, depth_}, - MklDnnType(), - memory::format::nc); + auto variance_desc = + memory::desc({1, depth_}, MklDnnType(), memory::format::nc); auto variance_pd = memory::primitive_desc(variance_desc, cpu_engine); - char* saved_variance_data_tf = reinterpret_cast - (saved_variance_tensor->flat().data()); + char* saved_variance_data_tf = + reinterpret_cast(saved_variance_tensor->flat().data()); std::memcpy(saved_variance_data_tf, reinterpret_cast(variance_values_), - depth_*sizeof(T)); + depth_ * sizeof(T)); auto variance_m = memory(variance_pd, saved_variance_data_tf); - prop_kind pk = (is_training_) ? - prop_kind::forward_training : - prop_kind::forward_scoring; + prop_kind pk = (is_training_) ? prop_kind::forward_training + : prop_kind::forward_scoring; auto bnrm_fwd_desc = batch_normalization_forward::desc( - pk, src.GetUsrMemDesc(), epsilon_, - is_training_ ? use_scale_shift : - (use_scale_shift | use_global_stats)); + pk, src.GetUsrMemDesc(), epsilon_, + is_training_ ? use_scale_shift + : (use_scale_shift | use_global_stats)); auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( - bnrm_fwd_desc, cpu_engine); + bnrm_fwd_desc, cpu_engine); // allocate dst tensor MklDnnShape dnn_shape_dst; @@ -887,47 +872,39 @@ class MklFusedBatchNormOp : public OpKernel { auto dst_pd = bnrm_fwd_pd.dst_primitive_desc(); dnn_shape_dst.SetMklLayout(&dst_pd); dnn_shape_dst.SetElemType(MklDnnType()); - dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), - src_dims, format_m); - tf_shape_dst.AddDim(dst_pd.get_size()/sizeof(T)); + dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), src_dims, + format_m); + tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); } else { dnn_shape_dst.SetMklTensor(false); tf_shape_dst = src_tensor.shape(); } - AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, - tf_shape_dst, dnn_shape_dst); + AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, + dnn_shape_dst); // Output of batchnorm has same shape as input. dst.SetUsrMem(src_md, dst_tensor); primitive bnrm_fwd_op; if (is_training_) { - bnrm_fwd_op = batch_normalization_forward( - bnrm_fwd_pd, - src.GetOpMem(), - weights_m, - dst.GetOpMem(), - mean_m, - variance_m); + bnrm_fwd_op = + batch_normalization_forward(bnrm_fwd_pd, src.GetOpMem(), weights_m, + dst.GetOpMem(), mean_m, variance_m); } else { bnrm_fwd_op = batch_normalization_forward( - bnrm_fwd_pd, - src.GetOpMem(), - mean_m, - variance_m, - (const primitive::at) weights_m, - dst.GetOpMem()); + bnrm_fwd_pd, src.GetOpMem(), mean_m, variance_m, + (const primitive::at)weights_m, dst.GetOpMem()); } std::vector net; net.push_back(bnrm_fwd_op); stream(stream::kind::eager).submit(net).wait(); // copy batch_mean data - T* batch_mean_data_tf = reinterpret_cast( - batch_mean_tensor->flat().data()); + T* batch_mean_data_tf = + reinterpret_cast(batch_mean_tensor->flat().data()); std::memcpy(reinterpret_cast(batch_mean_data_tf), reinterpret_cast(mean_m.get_data_handle()), - depth_*sizeof(T)); + depth_ * sizeof(T)); // copy batch_variance data with Bessel's correction // if training mode is on @@ -937,18 +914,17 @@ class MklFusedBatchNormOp : public OpKernel { size_t adjust_size = orig_size - 1; adjust_factor = (static_cast(orig_size)) / adjust_size; } - for (int k=0; k < depth_; k++) + for (int k = 0; k < depth_; k++) batch_variance_tensor->flat().data()[k] = - (reinterpret_cast(variance_m.get_data_handle()))[k] - * adjust_factor; - } catch (mkldnn::error &e) { + (reinterpret_cast(variance_m.get_data_handle()))[k] * + adjust_factor; + } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Operation received an exception:", - error_msg)); + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } @@ -958,7 +934,7 @@ class MklFusedBatchNormOp : public OpKernel { bool is_training_; T* mean_values_; T* variance_values_; - size_t depth_; // batch normalization is done for per channel. + size_t depth_; // batch normalization is done for per channel. void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -966,23 +942,20 @@ class MklFusedBatchNormOp : public OpKernel { } void SetMeanVariance(const Tensor& mean, const Tensor& variance) { - mean_values_ = reinterpret_cast( - const_cast(mean.flat().data())); - variance_values_ = reinterpret_cast( - const_cast(variance.flat().data())); + mean_values_ = reinterpret_cast(const_cast(mean.flat().data())); + variance_values_ = + reinterpret_cast(const_cast(variance.flat().data())); } - void HandleEmptyInput(OpKernelContext* context, - TensorShape tf_shape_src, - TensorShape tf_shape_scale, - Tensor** dst_tensor) { + void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, + TensorShape tf_shape_scale, Tensor** dst_tensor) { CHECK_NOTNULL(dst_tensor); const size_t kDstIndex = 0; MklDnnShape dnn_shape_dst; dnn_shape_dst.SetMklTensor(false); - AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, - tf_shape_src, dnn_shape_dst); + AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src, + dnn_shape_dst); CHECK_NOTNULL(*dst_tensor); memset(const_cast((*dst_tensor)->tensor_data().data()), 0, (*dst_tensor)->tensor_data().size()); @@ -991,15 +964,12 @@ class MklFusedBatchNormOp : public OpKernel { Tensor* batch_variance_tensor = nullptr; Tensor* saved_mean_tensor = nullptr; Tensor* saved_variance_tensor = nullptr; - AllocateTFOutputs(context, tf_shape_scale, - &batch_mean_tensor, - &batch_variance_tensor, - &saved_mean_tensor, + AllocateTFOutputs(context, tf_shape_scale, &batch_mean_tensor, + &batch_variance_tensor, &saved_mean_tensor, &saved_variance_tensor); } - void AllocateTFOutputs(OpKernelContext* context, - TensorShape tf_shape_scale, + void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale, Tensor** batch_mean_tensor, Tensor** batch_variance_tensor, Tensor** saved_mean_tensor, @@ -1017,51 +987,43 @@ class MklFusedBatchNormOp : public OpKernel { // allocate batch mean output tensor MklDnnShape mkl_shape_batch_mean; mkl_shape_batch_mean.SetMklTensor(false); - AllocateOutputSetMklShape(context, - kBatchMeanIndex, - batch_mean_tensor, - tf_shape_scale, - mkl_shape_batch_mean); + AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor, + tf_shape_scale, mkl_shape_batch_mean); CHECK_NOTNULL(*batch_mean_tensor); // set NAN mean value in case of empty input tensor - for (int k=0; k < tf_shape_scale.num_elements(); k++) + for (int k = 0; k < tf_shape_scale.num_elements(); k++) (*batch_mean_tensor)->flat().data()[k] = NAN; // allocate batch variance output tensor MklDnnShape mkl_shape_batch_variance; mkl_shape_batch_variance.SetMklTensor(false); - AllocateOutputSetMklShape(context, - kBatchVarianceIndex, - batch_variance_tensor, - tf_shape_scale, + AllocateOutputSetMklShape(context, kBatchVarianceIndex, + batch_variance_tensor, tf_shape_scale, mkl_shape_batch_variance); CHECK_NOTNULL(*batch_variance_tensor); // set NAN variance value in case of empty input tensor - for (int k=0; k < tf_shape_scale.num_elements(); k++) + for (int k = 0; k < tf_shape_scale.num_elements(); k++) (*batch_variance_tensor)->flat().data()[k] = NAN; // Mean and variance (without Bessel's correction) saved for backward // computation to serve as pre-computed mean and variance. MklDnnShape mkl_shape_saved_mean; mkl_shape_saved_mean.SetMklTensor(false); - AllocateOutputSetMklShape(context, kSavedMeanIndex, - saved_mean_tensor, - tf_shape_scale, - mkl_shape_saved_mean); + AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor, + tf_shape_scale, mkl_shape_saved_mean); CHECK_NOTNULL(*saved_mean_tensor); // set NAN mean value in case of empty input tensor - for (int k=0; k < tf_shape_scale.num_elements(); k++) + for (int k = 0; k < tf_shape_scale.num_elements(); k++) (*saved_mean_tensor)->flat().data()[k] = NAN; MklDnnShape mkl_shape_saved_variance; mkl_shape_saved_variance.SetMklTensor(false); AllocateOutputSetMklShape(context, kSavedVarianceIndex, - saved_variance_tensor, - tf_shape_scale, + saved_variance_tensor, tf_shape_scale, mkl_shape_saved_variance); CHECK_NOTNULL(*saved_variance_tensor); // set NAN variance value in case of empty input tensor - for (int k=0; k < tf_shape_scale.num_elements(); k++) + for (int k = 0; k < tf_shape_scale.num_elements(); k++) (*saved_variance_tensor)->flat().data()[k] = NAN; } }; @@ -1093,8 +1055,8 @@ class MklFusedBatchNormGradOp : public OpKernel { const Tensor& src_tensor = MklGetInput(context, kSrcIndex); const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex); - const Tensor& saved_variance_tensor = MklGetInput(context, - kVarianceIndex); + const Tensor& saved_variance_tensor = + MklGetInput(context, kVarianceIndex); MklDnnShape dnn_shape_src, dnn_shape_diff_dst; GetMklShape(context, kSrcIndex, &dnn_shape_src); @@ -1103,53 +1065,49 @@ class MklFusedBatchNormGradOp : public OpKernel { if (dnn_shape_diff_dst.IsMklTensor()) { tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); - OP_REQUIRES(context, dnn_shape_diff_dst.GetDimension() == 4, - errors::InvalidArgument( - "input must be 4-dimensional", - diff_dst_tensor.shape().DebugString())); + OP_REQUIRES( + context, dnn_shape_diff_dst.GetDimension() == 4, + errors::InvalidArgument("input must be 4-dimensional", + diff_dst_tensor.shape().DebugString())); } else { tf_shape_diff_dst = diff_dst_tensor.shape(); - OP_REQUIRES(context, diff_dst_tensor.dims() == 4, - errors::InvalidArgument( - "input must be 4-dimensional", - diff_dst_tensor.shape().DebugString())); + OP_REQUIRES( + context, diff_dst_tensor.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + diff_dst_tensor.shape().DebugString())); } if (dnn_shape_src.IsMklTensor()) { tf_shape_src = dnn_shape_src.GetTfShape(); OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, - errors::InvalidArgument( - "input must be 4-dimensional", - src_tensor.shape().DebugString())); + errors::InvalidArgument("input must be 4-dimensional", + src_tensor.shape().DebugString())); } else { tf_shape_src = src_tensor.shape(); OP_REQUIRES(context, src_tensor.dims() == 4, - errors::InvalidArgument( - "input must be 4-dimensional", - src_tensor.shape().DebugString())); + errors::InvalidArgument("input must be 4-dimensional", + src_tensor.shape().DebugString())); } OP_REQUIRES(context, scale_tensor.dims() == 1, - errors::InvalidArgument( - "scale must be 1-dimensional", - scale_tensor.shape().DebugString())); - OP_REQUIRES(context, saved_mean_tensor.dims() == 1, - errors::InvalidArgument( - "saved mean must be 1-dimensional", - saved_mean_tensor.shape().DebugString())); - - OP_REQUIRES(context, saved_variance_tensor.dims() == 1, - errors::InvalidArgument( - "saved variance must be 1-dimensional", - saved_variance_tensor.shape().DebugString())); + errors::InvalidArgument("scale must be 1-dimensional", + scale_tensor.shape().DebugString())); + OP_REQUIRES( + context, saved_mean_tensor.dims() == 1, + errors::InvalidArgument("saved mean must be 1-dimensional", + saved_mean_tensor.shape().DebugString())); + + OP_REQUIRES( + context, saved_variance_tensor.dims() == 1, + errors::InvalidArgument("saved variance must be 1-dimensional", + saved_variance_tensor.shape().DebugString())); Tensor* diff_src_tensor = nullptr; if (tf_shape_src.num_elements() == 0 || tf_shape_diff_dst.num_elements() == 0) { - HandleEmptyInput(context, tf_shape_src, - scale_tensor.shape(), - &diff_src_tensor); - return; + HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), + &diff_src_tensor); + return; } if (dnn_shape_src.IsMklTensor()) @@ -1175,20 +1133,18 @@ class MklFusedBatchNormGradOp : public OpKernel { memory::dims src_dims, diff_dst_dims; if (dnn_shape_src.IsMklTensor()) - src_dims = TFShapeToMklDnnDimsInNCHW( - dnn_shape_src.GetTfShape(), tensor_format_); + src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), + tensor_format_); else - src_dims = TFShapeToMklDnnDimsInNCHW( - src_tensor.shape(), tensor_format_); + src_dims = + TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); if (dnn_shape_diff_dst.IsMklTensor()) diff_dst_dims = TFShapeToMklDnnDimsInNCHW( - dnn_shape_diff_dst.GetTfShape(), - tensor_format_); + dnn_shape_diff_dst.GetTfShape(), tensor_format_); else - diff_dst_dims = TFShapeToMklDnnDimsInNCHW( - diff_dst_tensor.shape(), - tensor_format_); + diff_dst_dims = + TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_); // set src and diff_dst primitives memory::desc src_md({}, memory::data_undef, memory::format_undef); @@ -1202,7 +1158,7 @@ class MklFusedBatchNormGradOp : public OpKernel { src_md = diff_dst_md; } } else { - src_md = memory::desc(src_dims, MklDnnType(), format_m); + src_md = memory::desc(src_dims, MklDnnType(), format_m); diff_dst_md = src_md; } src.SetUsrMem(src_md, &src_tensor); @@ -1210,55 +1166,47 @@ class MklFusedBatchNormGradOp : public OpKernel { // weights -- DNN packs scales/shifts as weights in order of // scale, ..., scale, shift, ..., shift - auto weights_desc = memory::desc({2, depth_}, - MklDnnType(), - memory::format::nc); + auto weights_desc = + memory::desc({2, depth_}, MklDnnType(), memory::format::nc); auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); auto weights_m = memory(weights_pd); T* weights_data = reinterpret_cast(weights_m.get_data_handle()); - T* scale_tf = reinterpret_cast(const_cast - (scale_tensor.flat().data())); - for (int k=0; k < depth_; k++) { + T* scale_tf = + reinterpret_cast(const_cast(scale_tensor.flat().data())); + for (int k = 0; k < depth_; k++) { weights_data[k] = scale_tf[k]; weights_data[k + depth_] = 0; } // set mean primitive memory::dims mv_dims = GetMeanVarianceDims(); - mean.SetUsrMem(mv_dims, - memory::format::nc, - const_cast(static_cast - (saved_mean_tensor.flat().data()))); + mean.SetUsrMem(mv_dims, memory::format::nc, + const_cast(static_cast( + saved_mean_tensor.flat().data()))); mean.SetOpMemDesc(mv_dims, memory::format::nc); // set variance primitive - variance.SetUsrMem(mv_dims, memory::format::nc, - const_cast(static_cast - (saved_variance_tensor.flat().data()))); + variance.SetUsrMem(mv_dims, memory::format::nc, + const_cast(static_cast( + saved_variance_tensor.flat().data()))); variance.SetOpMemDesc(mv_dims, memory::format::nc); // set diff_weight primitive - auto diff_weights_desc = memory::desc( - {2, depth_}, - MklDnnType(), - memory::format::nc); - auto diff_weights_pd = memory::primitive_desc( - diff_weights_desc, - cpu_engine); + auto diff_weights_desc = + memory::desc({2, depth_}, MklDnnType(), memory::format::nc); + auto diff_weights_pd = + memory::primitive_desc(diff_weights_desc, cpu_engine); auto diff_weights_m = memory(diff_weights_pd); auto bnrm_fwd_desc = batch_normalization_forward::desc( - prop_kind::forward_training, - src.GetUsrMemDesc(), - epsilon_, - is_training_ ? use_scale_shift : - (use_scale_shift | use_global_stats)); + prop_kind::forward_training, src.GetUsrMemDesc(), epsilon_, + is_training_ ? use_scale_shift + : (use_scale_shift | use_global_stats)); auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( - bnrm_fwd_desc, - cpu_engine); + bnrm_fwd_desc, cpu_engine); // Indices of output tensors - const size_t kDiffSrcIndex = 0; // index of diff_src tensor + const size_t kDiffSrcIndex = 0; // index of diff_src tensor // allocate diff_src tensor MklDnnShape dnn_shape_diff_src; @@ -1268,14 +1216,11 @@ class MklFusedBatchNormGradOp : public OpKernel { auto diff_src_pd = bnrm_fwd_pd.dst_primitive_desc(); dnn_shape_diff_src.SetMklLayout(&diff_src_pd); dnn_shape_diff_src.SetElemType(MklDnnType()); - dnn_shape_diff_src.SetTfLayout( - dnn_shape_src.GetDimension(), - src_dims, - format_m); - dnn_shape_diff_src.SetTfDimOrder( - dnn_shape_src.GetDimension(), - tensor_format_); - tf_shape_diff_src.AddDim(diff_src_pd.get_size()/sizeof(T)); + dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), src_dims, + format_m); + dnn_shape_diff_src.SetTfDimOrder(dnn_shape_src.GetDimension(), + tensor_format_); + tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); } else { dnn_shape_diff_src.SetMklTensor(false); tf_shape_diff_src = src_tensor.shape(); @@ -1287,33 +1232,22 @@ class MklFusedBatchNormGradOp : public OpKernel { prop_kind pk = prop_kind::backward; auto bnrm_bwd_desc = batch_normalization_backward::desc( - pk, - diff_src.GetUsrMemDesc(), - src.GetUsrMemDesc(), - epsilon_, - /* for inference, specify use_global_stats - 1. on fwd prop, use mean and variance - provided as inputs - 2. on bwd prop, mean and variance are - considered as constants. Thus, - reduce the amout of MKL computations - */ - is_training_ ? use_scale_shift : - (use_scale_shift | use_global_stats)); + pk, diff_src.GetUsrMemDesc(), src.GetUsrMemDesc(), epsilon_, + /* for inference, specify use_global_stats + 1. on fwd prop, use mean and variance + provided as inputs + 2. on bwd prop, mean and variance are + considered as constants. Thus, + reduce the amout of MKL computations + */ + is_training_ ? use_scale_shift + : (use_scale_shift | use_global_stats)); auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc( - bnrm_bwd_desc, - cpu_engine, - bnrm_fwd_pd); + bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd); auto bnrm_bwd_op = batch_normalization_backward( - bnrm_bwd_pd, - src.GetOpMem(), - mean.GetOpMem(), - variance.GetOpMem(), - diff_dst.GetOpMem(), - weights_m, - diff_src.GetOpMem(), - diff_weights_m); + bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(), + diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m); std::vector net; net.push_back(bnrm_bwd_op); @@ -1322,43 +1256,39 @@ class MklFusedBatchNormGradOp : public OpKernel { // allocate 4 output TF tensors Tensor* diff_scale_tensor = nullptr; Tensor* diff_shift_tensor = nullptr; - AllocateTFOutputs(context, scale_tensor.shape(), - &diff_scale_tensor, + AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor, &diff_shift_tensor); // copy data: diff_scale and diff_shift - T* diff_weights_data_dnn = reinterpret_cast - (diff_weights_m.get_data_handle()); + T* diff_weights_data_dnn = + reinterpret_cast(diff_weights_m.get_data_handle()); for (int i = 0; i < depth_; i++) { - diff_scale_tensor->flat().data()[i] = - diff_weights_data_dnn[i]; + diff_scale_tensor->flat().data()[i] = diff_weights_data_dnn[i]; diff_shift_tensor->flat().data()[i] = - diff_weights_data_dnn[i + depth_]; + diff_weights_data_dnn[i + depth_]; } - } catch (mkldnn::error &e) { + } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Operation received an exception:", - error_msg)); + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } private: T epsilon_; TensorFormat tensor_format_; - int depth_; // batch normalization is done for per channel. + int depth_; // batch normalization is done for per channel. bool is_training_; void ExtractParams(OpKernelContext* context) { - const Tensor& input = MklGetInput(context, 0); - depth_ = static_cast(GetTensorDim(input, tensor_format_, 'C')); + const Tensor& input = MklGetInput(context, 0); + depth_ = static_cast(GetTensorDim(input, tensor_format_, 'C')); } - void HandleEmptyInput(OpKernelContext* context, - TensorShape tf_shape_src, + void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, TensorShape tf_shape_scale_shift, Tensor** diff_src_tensor) { const size_t kDiffSrcIndex = 0; @@ -1366,22 +1296,20 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnShape dnn_shape_diff_src; dnn_shape_diff_src.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, - tf_shape_src, dnn_shape_diff_src); - for (size_t i=0; i < (*diff_src_tensor)->shape().num_elements(); i++) - (*diff_src_tensor)->flat().data()[i] = 0; + tf_shape_src, dnn_shape_diff_src); + for (size_t i = 0; i < (*diff_src_tensor)->shape().num_elements(); i++) + (*diff_src_tensor)->flat().data()[i] = 0; Tensor* diff_scale_tensor = nullptr; Tensor* diff_shift_tensor = nullptr; - AllocateTFOutputs(context, - tf_shape_scale_shift, - &diff_scale_tensor, + AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor, &diff_shift_tensor); } void AllocateTFOutputs(OpKernelContext* context, - TensorShape tf_shape_scale_shift, - Tensor** diff_scale_tensor, - Tensor** diff_shift_tensor) { + TensorShape tf_shape_scale_shift, + Tensor** diff_scale_tensor, + Tensor** diff_shift_tensor) { CHECK_NOTNULL(diff_scale_tensor); CHECK_NOTNULL(diff_shift_tensor); @@ -1396,31 +1324,29 @@ class MklFusedBatchNormGradOp : public OpKernel { AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, tf_shape_scale_shift, mkl_shape_diff_scale); CHECK_NOTNULL(*diff_scale_tensor); - for (size_t i=0; i < (*diff_scale_tensor)->shape().num_elements(); i++) - (*diff_scale_tensor)->flat().data()[i] = 0; + for (size_t i = 0; i < (*diff_scale_tensor)->shape().num_elements(); i++) + (*diff_scale_tensor)->flat().data()[i] = 0; MklDnnShape mkl_shape_diff_shift; mkl_shape_diff_shift.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, tf_shape_scale_shift, mkl_shape_diff_shift); CHECK_NOTNULL(*diff_shift_tensor); - for (size_t i=0; i < (*diff_shift_tensor)->shape().num_elements(); i++) - (*diff_shift_tensor)->flat().data()[i] = 0; + for (size_t i = 0; i < (*diff_shift_tensor)->shape().num_elements(); i++) + (*diff_shift_tensor)->flat().data()[i] = 0; // Placeholders for estimated_mean and estimated_variance, which are // used for inference and thus not needed here for gradient computation. - Tensor* p1_tensor = nullptr, *p2_tensor = nullptr; + Tensor *p1_tensor = nullptr, *p2_tensor = nullptr; MklDnnShape mkl_shape_p; mkl_shape_p.SetMklTensor(false); - AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, - TensorShape({}), mkl_shape_p); - AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, - TensorShape({}), mkl_shape_p); + AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}), + mkl_shape_p); + AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}), + mkl_shape_p); } - memory::dims GetMeanVarianceDims() { - return memory::dims({1, depth_}); - } + memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } }; #endif diff --git a/tensorflow/core/kernels/mkl_identity_op.cc b/tensorflow/core/kernels/mkl_identity_op.cc index 9ee27ee21c8d23c8ce314a7687ac9b79a1d9ea30..6c027f8e728b8660d18a70ae58995fa104f0b375 100644 --- a/tensorflow/core/kernels/mkl_identity_op.cc +++ b/tensorflow/core/kernels/mkl_identity_op.cc @@ -28,14 +28,14 @@ limitations under the License. #include "mkl_dnn_types.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" #endif namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML template class MklIdentityOp : public OpKernel { diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc index 4b5f7b831001458c222536be30bc40fcf5d2899a..e9a2376b545fcec97e1ced5c592351203abadd69 100644 --- a/tensorflow/core/kernels/mkl_input_conversion_op.cc +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/kernels/mkl_tfconv_op.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" using mkldnn::stream; @@ -59,7 +59,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; // convert the TF format input to MKL format /////////////////////////////////////////////////////////// -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML template class MklInputConversionOp : public OpKernel { public: @@ -145,8 +145,8 @@ class MklInputConversionOp : public OpKernel { const MklShape* mkl_shape; const Tensor* tf_tensor; MklShape* tf_mkl_shape; - uint mkl_tensor_index; - uint tf_tensor_index; + uint32 mkl_tensor_index; + uint32 tf_tensor_index; if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { mkl_tensor = &input_tensor_0; mkl_shape = &input_shape_0; @@ -271,8 +271,8 @@ class MklInputConversionOp : public OpKernel { MklDnnShape input_shape_1; GetMklShape(context, 1, &input_shape_1); - bool tf_shapes_are_same = context->input(0).shape() == - context->input(1).shape(); + bool tf_shapes_are_same = + context->input(0).shape() == context->input(1).shape(); VLOG(1) << "MklInputConversionOp: Input shapes are " << (tf_shapes_are_same ? "*same*" : "*different*") << ": " @@ -293,14 +293,58 @@ class MklInputConversionOp : public OpKernel { // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // If both inputs are in MKL format if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { - // If both have the same shape, pass them through if (tf_shapes_are_same) { - VLOG(1) << "MklInputConversionOp: No conversion needed, " - << "copying MKL inputs with identical shapes to output"; - - ForwardMklTensorInToOut(context, 0, 0); - ForwardMklTensorInToOut(context, 1, 1); - return; + auto input0_md = input_shape_0.GetMklLayout(); + auto input1_md = input_shape_1.GetMklLayout(); + + // If both have the same shape and same format, pass them through + if (input0_md.data.format == input1_md.data.format) { + VLOG(1) << "MklInputConversionOp: No conversion needed, " + << "copying MKL inputs with identical shapes to output"; + + ForwardMklTensorInToOut(context, 0, 0); + ForwardMklTensorInToOut(context, 1, 1); + return; + } else { + VLOG(1) << "MklInputConversionOp: Shape is same, but format is " + "different, " + << "need to convert to same format"; + + // Convert input0, and keep input1 unchanged + // Create MklDnnShape for output mkl tensor based on input0 + Tensor* tensor_out; + MklDnnShape mkl_output_mkl_shape; + mkl_output_mkl_shape.SetMklTensor(true); + mkl_output_mkl_shape.SetElemType(MklDnnType()); + mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(), + input_shape_0.GetSizesAsMklDnnDims(), + input_shape_0.GetTfDataFormat()); + + // Get MKL layout from input1 as destination layout + mkl_output_mkl_shape.SetMklLayout(&input1_md); + + // Create output Mkl tensor for index 0 + AllocateOutputSetMklShape(context, 0, &tensor_out, + input_tensor_0.shape(), + mkl_output_mkl_shape); + + // Create MklDnnData object for input0 tesnsor + auto cpu_engine = engine(engine::cpu, 0); + MklDnnData input(&cpu_engine); + input.SetUsrMem(input0_md, &input_tensor_0); + + // Create reorder from input0's layout to input1's layout + std::vector net; + CHECK_EQ(input.CheckReorderToOpMem( + memory::primitive_desc(input1_md, cpu_engine), + tensor_out, &net), + true); + stream(stream::kind::eager).submit(net).wait(); + + // Input1 will be passed through + ForwardMklTensorInToOut(context, 1, 1); + return; + } } // Sanity check @@ -400,9 +444,9 @@ class MklInputConversionOp : public OpKernel { // Create reorder between tensorflow layout and Mkl layout. std::vector net; - CHECK_EQ(tf_input.CheckReorderToOpMem(memory::primitive_desc( - output_mkl_md, cpu_engine), - tensor_out, &net), + CHECK_EQ(tf_input.CheckReorderToOpMem( + memory::primitive_desc(output_mkl_md, cpu_engine), + tensor_out, &net), true); stream(stream::kind::eager).submit(net).wait(); diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc index 95e0404ba8ab7d305e530239be30c7a842edf16d..5f0a12a1fb9bff3086e05918e23b8396196eb389 100644 --- a/tensorflow/core/kernels/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl_lrn_op.cc @@ -22,6 +22,9 @@ limitations under the License. #define EIGEN_USE_THREADS #include +#include "mkl_dnn.h" +#include "mkl_dnn_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -30,20 +33,17 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "mkl_dnn.h" -#include "mkl_dnn_types.h" #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/util/work_sharder.h" #endif -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" -using mkldnn::lrn_forward; +using mkldnn::lrn_across_channels; using mkldnn::lrn_backward; +using mkldnn::lrn_forward; using mkldnn::prop_kind; -using mkldnn::lrn_across_channels; using mkldnn::stream; #endif @@ -67,7 +67,7 @@ void GetBandMatrix(int depth, int depth_radius, } // namespace -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML template class MklLRNOp : public OpKernel { @@ -77,10 +77,11 @@ class MklLRNOp : public OpKernel { explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) { int64 depth_radius64; OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); - OP_REQUIRES(context, FastBoundsCheck(depth_radius64, - std::numeric_limits::max()), - errors::InvalidArgument("depth_radius = ", depth_radius64, - " larger than int max")); + OP_REQUIRES( + context, + FastBoundsCheck(depth_radius64, std::numeric_limits::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); depth_radius_ = static_cast(depth_radius64); OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); @@ -103,9 +104,10 @@ class MklLRNOp : public OpKernel { : input.dims(); OP_REQUIRES(context, mkl_context.in_dims == 4, errors::InvalidArgument("input must be 4-dimensional")); - OP_REQUIRES(context, FastBoundsCheck(input.NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("argument to LRN too large")); + OP_REQUIRES( + context, + FastBoundsCheck(input.NumElements(), std::numeric_limits::max()), + errors::InvalidArgument("argument to LRN too large")); if (!input_in_mkl_format) { mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_, @@ -339,17 +341,17 @@ class MklLRNOp : public OpKernel { float beta_; }; - template class MklLRNGradOp : public OpKernel { public: explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) { int64 depth_radius64; OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); - OP_REQUIRES(context, FastBoundsCheck(depth_radius64, - std::numeric_limits::max()), - errors::InvalidArgument("depth_radius = ", depth_radius64, - " larger than int max")); + OP_REQUIRES( + context, + FastBoundsCheck(depth_radius64, std::numeric_limits::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); depth_radius_ = static_cast(depth_radius64); OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_)); @@ -740,10 +742,11 @@ class MklLRNOp : public OpKernel { explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) { int64 depth_radius64; OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); - OP_REQUIRES(context, FastBoundsCheck(depth_radius64, - std::numeric_limits::max()), - errors::InvalidArgument("depth_radius = ", depth_radius64, - " larger than int max")); + OP_REQUIRES( + context, + FastBoundsCheck(depth_radius64, std::numeric_limits::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); depth_radius_ = static_cast(depth_radius64); OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); @@ -773,10 +776,10 @@ class MklLRNOp : public OpKernel { if (!src_dnn_shape.IsMklTensor()) { MklDefaultToEigen(context, src_tensor); return; - } else if (!src_dnn_shape.IsMklChannelDim( - src_dnn_shape.GetDimension() - 1) ) { + } else if (!src_dnn_shape.IsMklChannelDim(src_dnn_shape.GetDimension() - + 1)) { Tensor converted_tensor = - ConvertMklToTF(context, src_tensor, src_dnn_shape); + ConvertMklToTF(context, src_tensor, src_dnn_shape); MklDefaultToEigen(context, converted_tensor); return; } @@ -807,18 +810,16 @@ class MklLRNOp : public OpKernel { // Create LRN primitive descriptor. // Tensorflow's normalization semantics is across channels. // MKL-DNN also supports normalization within channel. - auto lrn_desc = lrn_forward::desc(prop_kind::forward, - lrn_across_channels, + auto lrn_desc = lrn_forward::desc(prop_kind::forward, lrn_across_channels, src_dnn_data.GetUsrMemDesc(), - kernel_size, - new_alpha, beta_, bias_); + kernel_size, new_alpha, beta_, bias_); auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine); // Allocate output_dnn_data tensor. Tensor* output_tensor = nullptr; memory::format input_format = src_dnn_shape.GetTfDataFormat(); - AllocateOutputTensor(context, lrn_prim_desc, input_dims, - input_format, &output_tensor); + AllocateOutputTensor(context, lrn_prim_desc, input_dims, input_format, + &output_tensor); OP_REQUIRES_OK(context, context->status()); CHECK_NOTNULL(output_tensor); dst_dnn_data.SetUsrMemDataHandle(output_tensor); @@ -827,25 +828,23 @@ class MklLRNOp : public OpKernel { AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data); OP_REQUIRES_OK(context, context->status()); - PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data, - &dst_dnn_data, &workspace_dnn_data); - } catch (mkldnn::error &e) { + PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data, &dst_dnn_data, + &workspace_dnn_data); + } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Operation received an exception:", - error_msg)); + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } private: - void PrepareAndExecuteNet( - const lrn_forward::primitive_desc& lrn_fwd_desc, - MklDnnData* src_dnn_data, - MklDnnData* dst_dnn_data, - MklDnnData* wksp_dnn_data = nullptr) { + void PrepareAndExecuteNet(const lrn_forward::primitive_desc& lrn_fwd_desc, + MklDnnData* src_dnn_data, + MklDnnData* dst_dnn_data, + MklDnnData* wksp_dnn_data = nullptr) { std::vector net; // Check for input reorder @@ -853,23 +852,21 @@ class MklLRNOp : public OpKernel { // Create pooling primitive and add it to net if (wksp_dnn_data != nullptr) { - net.push_back(lrn_forward(lrn_fwd_desc, - src_dnn_data->GetOpMem(), - wksp_dnn_data->GetOpMem(), - dst_dnn_data->GetOpMem())); + net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(), + wksp_dnn_data->GetOpMem(), + dst_dnn_data->GetOpMem())); } else { - net.push_back(lrn_forward(lrn_fwd_desc, - src_dnn_data->GetOpMem(), - dst_dnn_data->GetOpMem())); + net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(), + dst_dnn_data->GetOpMem())); } stream(stream::kind::eager).submit(net).wait(); } - void AllocateOutputTensor(OpKernelContext* context, - const lrn_forward::primitive_desc& lrn_fwd_prim_desc, - const memory::dims output_dims_mkl_order, - const memory::format& output_tf_format, - Tensor** output_tensor) { + void AllocateOutputTensor( + OpKernelContext* context, + const lrn_forward::primitive_desc& lrn_fwd_prim_desc, + const memory::dims output_dims_mkl_order, + const memory::format& output_tf_format, Tensor** output_tensor) { CHECK_NOTNULL(output_tensor); memory::primitive_desc dst_pd = lrn_fwd_prim_desc.dst_primitive_desc(); @@ -880,111 +877,106 @@ class MklLRNOp : public OpKernel { output_mkl_shape.SetMklLayout(&dst_pd); output_mkl_shape.SetElemType(MklDnnType()); output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), - output_dims_mkl_order, - output_tf_format); + output_dims_mkl_order, output_tf_format); TensorShape output_tf_shape; // only allocate enough space for the elements we need. size_t num_bytes = dst_pd.get_size(); CHECK_EQ(num_bytes % sizeof(T), 0); output_tf_shape.AddDim(num_bytes / sizeof(T)); - AllocateOutputSetMklShape(context, kIdxOutput, - output_tensor, - output_tf_shape, output_mkl_shape); - } - - // Fallback implementation - Taken from lrn_op.cc - // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a - // copy. - void MklDefaultToEigen(OpKernelContext* context, - const Tensor& input) { - const int batch = static_cast(input.dim_size(0)); - const int rows = static_cast(input.dim_size(1)); - const int cols = static_cast(input.dim_size(2)); - const int depth = static_cast(input.dim_size(3)); - const int nodes = cols * rows; - - auto in_shaped = input.shaped({nodes * batch, depth}); - // Multiplying the input with the band matrix has the effect of reducing - // the - // correct patch along the depth. - Eigen::Tensor multiplier(depth, depth); - GetBandMatrix(depth, depth_radius_, &multiplier); + AllocateOutputSetMklShape(context, kIdxOutput, output_tensor, + output_tf_shape, output_mkl_shape); + } - Tensor *output_dnn_data = nullptr; - MklDnnShape mkl_output_mkl_shape; - mkl_output_mkl_shape.SetMklTensor(false); - mkl_output_mkl_shape.SetDimensions(4); - AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data, - input.shape(), mkl_output_mkl_shape); - CHECK_NOTNULL(output_dnn_data); - - Tensor* workspace_tensor = nullptr; - MklDnnShape workspace_mkl_shape; - workspace_mkl_shape.SetMklTensor(false); - TensorShape workspace_tf_shape; - workspace_tf_shape.AddDim(0); - AllocateOutputSetMklShape(context, kIdxWorkspace, - &workspace_tensor, + // Fallback implementation - Taken from lrn_op.cc + // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a + // copy. + void MklDefaultToEigen(OpKernelContext* context, const Tensor& input) { + const int batch = static_cast(input.dim_size(0)); + const int rows = static_cast(input.dim_size(1)); + const int cols = static_cast(input.dim_size(2)); + const int depth = static_cast(input.dim_size(3)); + const int nodes = cols * rows; + + auto in_shaped = input.shaped({nodes * batch, depth}); + // Multiplying the input with the band matrix has the effect of reducing + // the + // correct patch along the depth. + Eigen::Tensor multiplier(depth, depth); + GetBandMatrix(depth, depth_radius_, &multiplier); + + Tensor* output_dnn_data = nullptr; + MklDnnShape mkl_output_mkl_shape; + mkl_output_mkl_shape.SetMklTensor(false); + mkl_output_mkl_shape.SetDimensions(4); + AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data, + input.shape(), mkl_output_mkl_shape); + CHECK_NOTNULL(output_dnn_data); + + Tensor* workspace_tensor = nullptr; + MklDnnShape workspace_mkl_shape; + workspace_mkl_shape.SetMklTensor(false); + TensorShape workspace_tf_shape; + workspace_tf_shape.AddDim(0); + AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor, workspace_tf_shape, workspace_mkl_shape); - CHECK_NOTNULL(workspace_tensor); - - auto out_shaped = output_dnn_data->shaped({nodes * batch, depth}); - Eigen::array dims = {{DimPair(1, 0)}}; - auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_; - if (beta_ == T(1)) { - out_shaped.device(context->eigen_cpu_device()) = - in_shaped * tmp.inverse(); - } else if (beta_ == T(0.5)) { - out_shaped.device(context->eigen_cpu_device()) = - in_shaped * tmp.rsqrt(); - } else { - out_shaped.device(context->eigen_cpu_device()) = - in_shaped * (tmp.log() * -beta_).exp(); - } + CHECK_NOTNULL(workspace_tensor); + + auto out_shaped = output_dnn_data->shaped({nodes * batch, depth}); + Eigen::array dims = {{DimPair(1, 0)}}; + auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_; + if (beta_ == T(1)) { + out_shaped.device(context->eigen_cpu_device()) = + in_shaped * tmp.inverse(); + } else if (beta_ == T(0.5)) { + out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt(); + } else { + out_shaped.device(context->eigen_cpu_device()) = + in_shaped * (tmp.log() * -beta_).exp(); } + } - void AllocateWorkspaceTensor(OpKernelContext* context, - const lrn_forward::primitive_desc& lrn_fwd_prim_desc, - MklDnnData* dnn_data_wksp) { - CHECK_NOTNULL(dnn_data_wksp); - Tensor* workspace_tensor = nullptr; - memory::primitive_desc workspace_pd - = lrn_fwd_prim_desc.workspace_primitive_desc(); - size_t workspace_bytes = workspace_pd.get_size(); - MklDnnShape workspace_mkl_shape; - // the workspace tensor is a uint8 tensor that has - // exactly the number of bytes necessary - workspace_mkl_shape.SetMklTensor(false); - TensorShape workspace_tf_shape; - workspace_tf_shape.AddDim(workspace_bytes); - AllocateOutputSetMklShape(context, kIdxWorkspace, - &workspace_tensor, + void AllocateWorkspaceTensor( + OpKernelContext* context, + const lrn_forward::primitive_desc& lrn_fwd_prim_desc, + MklDnnData* dnn_data_wksp) { + CHECK_NOTNULL(dnn_data_wksp); + Tensor* workspace_tensor = nullptr; + memory::primitive_desc workspace_pd = + lrn_fwd_prim_desc.workspace_primitive_desc(); + size_t workspace_bytes = workspace_pd.get_size(); + MklDnnShape workspace_mkl_shape; + // the workspace tensor is a uint8 tensor that has + // exactly the number of bytes necessary + workspace_mkl_shape.SetMklTensor(false); + TensorShape workspace_tf_shape; + workspace_tf_shape.AddDim(workspace_bytes); + AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor, workspace_tf_shape, workspace_mkl_shape); - CHECK_NOTNULL(workspace_tensor); - dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); - } + CHECK_NOTNULL(workspace_tensor); + dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); + } void SanityCheckInputs(OpKernelContext* context) { const Tensor& src_tensor = MklGetInput(context, kIdxInput); MklDnnShape src_dnn_shape; GetMklShape(context, kIdxInput, &src_dnn_shape); if (src_dnn_shape.IsMklTensor()) { - OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4, - errors::InvalidArgument("input must be 4-dimensional")); - OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("argument to LRN too large")); + OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4, + errors::InvalidArgument("input must be 4-dimensional")); + OP_REQUIRES(context, + FastBoundsCheck(src_tensor.NumElements(), + std::numeric_limits::max()), + errors::InvalidArgument("argument to LRN too large")); } else { - OP_REQUIRES(context, src_tensor.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional")); - OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("argument to LRN too large")); + OP_REQUIRES(context, src_tensor.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional")); + OP_REQUIRES(context, + FastBoundsCheck(src_tensor.NumElements(), + std::numeric_limits::max()), + errors::InvalidArgument("argument to LRN too large")); } } - const int kIdxInput = 0, - kIdxOutput = 0, - kIdxWorkspace = 1; + const int kIdxInput = 0, kIdxOutput = 0, kIdxWorkspace = 1; typedef typename Eigen::Tensor::DimensionPair DimPair; bool workspace_enabled_; @@ -994,17 +986,17 @@ class MklLRNOp : public OpKernel { float beta_; }; - template class MklLRNGradOp : public OpKernel { public: explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) { int64 depth_radius64; OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); - OP_REQUIRES(context, FastBoundsCheck(depth_radius64, - std::numeric_limits::max()), - errors::InvalidArgument("depth_radius = ", depth_radius64, - " larger than int max")); + OP_REQUIRES( + context, + FastBoundsCheck(depth_radius64, std::numeric_limits::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); depth_radius_ = static_cast(depth_radius64); OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_)); @@ -1025,7 +1017,7 @@ class MklLRNGradOp : public OpKernel { MklDnnData output_dnn_data(&cpu_engine); MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape, - orig_output_dnn_shape; + orig_output_dnn_shape; GetMklShape(context, kIdxGradient, &input_grad_dnn_shape); GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape); GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape); @@ -1037,16 +1029,16 @@ class MklLRNGradOp : public OpKernel { orig_input_dnn_shape.IsMklTensor() && orig_output_dnn_shape.IsMklTensor() && input_grad_dnn_shape.IsMklChannelDim( - input_grad_dnn_shape.GetDimension() - 1) && + input_grad_dnn_shape.GetDimension() - 1) && orig_input_dnn_shape.IsMklChannelDim( - orig_input_dnn_shape.GetDimension() - 1) && + orig_input_dnn_shape.GetDimension() - 1) && orig_output_dnn_shape.IsMklChannelDim( - orig_output_dnn_shape.GetDimension() - 1); + orig_output_dnn_shape.GetDimension() - 1); if (!can_use_mkldnn) { - // Fallback to eigen - MklDefaultToEigen(context); - return; + // Fallback to eigen + MklDefaultToEigen(context); + return; } // At this point, we have the all clear to use MklDnn constructs // Naming: diff_dst is input_gradient_tensor; src is orig_input_tensor. @@ -1059,13 +1051,11 @@ class MklLRNGradOp : public OpKernel { // NHWC format. memory::desc original_output_md = orig_output_dnn_shape.GetCurLayout(); memory::desc target_diff_dst_md = ConfigureInputGradient( - input_grad_tensor, - input_grad_dnn_shape, - &input_grad_dnn_data); + input_grad_tensor, input_grad_dnn_shape, &input_grad_dnn_data); memory::desc orig_input_md = orig_input_dnn_shape.GetCurLayout(); memory::dims orig_input_dims = - orig_input_dnn_shape.GetSizesAsMklDnnDims(); + orig_input_dnn_shape.GetSizesAsMklDnnDims(); orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor); orig_input_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc); @@ -1079,27 +1069,21 @@ class MklLRNGradOp : public OpKernel { // Create LRN backward primitive descriptor. It requires LRN forward // primitive descriptor also. - auto lrn_fwd_desc = lrn_forward::desc(prop_kind::forward, - lrn_across_channels, - orig_input_md, - kernel_size, - new_alpha, beta_, bias_); - auto lrn_fwd_prim_desc = lrn_forward::primitive_desc(lrn_fwd_desc, - cpu_engine); - auto lrn_bwd_desc = lrn_backward::desc(lrn_across_channels, - original_output_md, - target_diff_dst_md, - kernel_size, - new_alpha, beta_, bias_); - auto lrn_bwd_prim_desc = lrn_backward::primitive_desc(lrn_bwd_desc, - cpu_engine, - lrn_fwd_prim_desc); + auto lrn_fwd_desc = lrn_forward::desc( + prop_kind::forward, lrn_across_channels, orig_input_md, kernel_size, + new_alpha, beta_, bias_); + auto lrn_fwd_prim_desc = + lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine); + auto lrn_bwd_desc = lrn_backward::desc( + lrn_across_channels, original_output_md, target_diff_dst_md, + kernel_size, new_alpha, beta_, bias_); + auto lrn_bwd_prim_desc = lrn_backward::primitive_desc( + lrn_bwd_desc, cpu_engine, lrn_fwd_prim_desc); Tensor* output_tensor = nullptr; - memory::format orig_input_format - = orig_input_dnn_shape.GetTfDataFormat(); - AllocateOutputTensor(context, lrn_bwd_prim_desc, - orig_input_dims, orig_input_format, &output_tensor); + memory::format orig_input_format = orig_input_dnn_shape.GetTfDataFormat(); + AllocateOutputTensor(context, lrn_bwd_prim_desc, orig_input_dims, + orig_input_format, &output_tensor); OP_REQUIRES_OK(context, context->status()); CHECK_NOTNULL(output_tensor); output_dnn_data.SetUsrMemDataHandle(output_tensor); @@ -1110,35 +1094,32 @@ class MklLRNGradOp : public OpKernel { const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace); MklDnnData workspace_dnn_data(&cpu_engine); ConfigureWorkspace(workspace_tensor, - lrn_fwd_prim_desc.workspace_primitive_desc(), - &workspace_dnn_data); - - PrepareAndExecuteNet(lrn_bwd_prim_desc, - lrn_fwd_prim_desc, - &orig_input_dnn_data, - &input_grad_dnn_data, - &output_dnn_data, - memory::primitive_desc(target_diff_dst_md, cpu_engine), - &workspace_dnn_data); - } catch (mkldnn::error &e) { + lrn_fwd_prim_desc.workspace_primitive_desc(), + &workspace_dnn_data); + + PrepareAndExecuteNet( + lrn_bwd_prim_desc, lrn_fwd_prim_desc, &orig_input_dnn_data, + &input_grad_dnn_data, &output_dnn_data, + memory::primitive_desc(target_diff_dst_md, cpu_engine), + &workspace_dnn_data); + } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Operation received an exception:", - error_msg)); + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } - void AllocateOutputTensor(OpKernelContext* context, - const lrn_backward::primitive_desc& lrn_bkwd_prim_desc, - const memory::dims output_dims_mkl_order, - const memory::format& output_tf_format, - Tensor** output_tensor) { + void AllocateOutputTensor( + OpKernelContext* context, + const lrn_backward::primitive_desc& lrn_bkwd_prim_desc, + const memory::dims output_dims_mkl_order, + const memory::format& output_tf_format, Tensor** output_tensor) { CHECK_NOTNULL(output_tensor); - memory::primitive_desc dst_pd - = lrn_bkwd_prim_desc.diff_src_primitive_desc(); + memory::primitive_desc dst_pd = + lrn_bkwd_prim_desc.diff_src_primitive_desc(); MklDnnShape output_mkl_shape; // We assume that all outputs at this point are MKL Tensors @@ -1146,170 +1127,153 @@ class MklLRNGradOp : public OpKernel { output_mkl_shape.SetMklLayout(&dst_pd); output_mkl_shape.SetElemType(MklDnnType()); output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), - output_dims_mkl_order, - output_tf_format); + output_dims_mkl_order, output_tf_format); TensorShape output_tf_shape; size_t num_bytes = dst_pd.get_size(); CHECK_EQ(num_bytes % sizeof(T), 0); output_tf_shape.AddDim(num_bytes / sizeof(T)); - AllocateOutputSetMklShape(context, kIdxOutput, - output_tensor, - output_tf_shape, output_mkl_shape); + AllocateOutputSetMklShape(context, kIdxOutput, output_tensor, + output_tf_shape, output_mkl_shape); } memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor, - const MklDnnShape& input_grad_dnn_shape, - MklDnnData *input_grad_dnn_data) { + const MklDnnShape& input_grad_dnn_shape, + MklDnnData* input_grad_dnn_data) { CHECK_NOTNULL(input_grad_dnn_data); // This shouldn't be necessary at this point, but just in case CHECK_EQ(input_grad_dnn_shape.IsMklTensor(), true); memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout(); - memory::dims orig_input_dims = - input_grad_dnn_shape.GetSizesAsMklDnnDims(); + memory::dims orig_input_dims = input_grad_dnn_shape.GetSizesAsMklDnnDims(); input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor); input_grad_dnn_data->SetOpMemDesc(orig_input_dims, memory::format::nhwc); return input_grad_md; } void PrepareAndExecuteNet( - const lrn_backward::primitive_desc& lrn_bkwd_desc, - const lrn_forward::primitive_desc& lrn_fwd_desc, - MklDnnData* src_dnn_data, - MklDnnData* input_gradient_diff_dst, - MklDnnData* output_diff_src, - const memory::primitive_desc& target_diff_dst_pd, - const MklDnnData* workspace_dnn_data = nullptr) { + const lrn_backward::primitive_desc& lrn_bkwd_desc, + const lrn_forward::primitive_desc& lrn_fwd_desc, + MklDnnData* src_dnn_data, MklDnnData* input_gradient_diff_dst, + MklDnnData* output_diff_src, + const memory::primitive_desc& target_diff_dst_pd, + const MklDnnData* workspace_dnn_data = nullptr) { std::vector net; // Check for input reordering on the diff dst input input_gradient_diff_dst->CheckReorderToOpMem( - lrn_bkwd_desc.diff_dst_primitive_desc(), &net); + lrn_bkwd_desc.diff_dst_primitive_desc(), &net); // Check for input reordering on the original input - src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(), - &net); + src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(), &net); // Create pooling primitive and add it to net if (nullptr == workspace_dnn_data) { - net.push_back(lrn_backward(lrn_bkwd_desc, - src_dnn_data->GetOpMem(), - input_gradient_diff_dst->GetOpMem(), - output_diff_src->GetOpMem())); + net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(), + input_gradient_diff_dst->GetOpMem(), + output_diff_src->GetOpMem())); } else { - net.push_back(lrn_backward(lrn_bkwd_desc, - src_dnn_data->GetOpMem(), - input_gradient_diff_dst->GetOpMem(), - workspace_dnn_data->GetOpMem(), - output_diff_src->GetOpMem())); + net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(), + input_gradient_diff_dst->GetOpMem(), + workspace_dnn_data->GetOpMem(), + output_diff_src->GetOpMem())); } stream(stream::kind::eager).submit(net).wait(); } void ConfigureWorkspace(const Tensor& workspace_tensor, - memory::primitive_desc workspace_pd, - MklDnnData *workspace_dnn_data) { + memory::primitive_desc workspace_pd, + MklDnnData* workspace_dnn_data) { CHECK_NOTNULL(workspace_dnn_data); workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor); } - // Fallback implementation - Taken from lrn_op.cc - // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a - // copy. - void MklDefaultToEigen(OpKernelContext* context) { - Tensor input_gradient_tensor; - Tensor orig_input_tensor; - Tensor orig_output_tensor; - - MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape, - orig_output_dnn_shape; - GetMklShape(context, kIdxGradient, &input_grad_dnn_shape); - GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape); - GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape); - - if (input_grad_dnn_shape.IsMklTensor()) { - input_gradient_tensor = - ConvertMklToTF(context, - MklGetInput(context, kIdxGradient), - input_grad_dnn_shape); - } else { - input_gradient_tensor = MklGetInput(context, kIdxGradient); - } - - if (orig_input_dnn_shape.IsMklTensor()) { - orig_input_tensor = - ConvertMklToTF(context, - MklGetInput(context, kIdxOrigInput), - orig_input_dnn_shape); - } else { - orig_input_tensor = MklGetInput(context, kIdxOrigInput); - } + // Fallback implementation - Taken from lrn_op.cc + // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a + // copy. + void MklDefaultToEigen(OpKernelContext* context) { + Tensor input_gradient_tensor; + Tensor orig_input_tensor; + Tensor orig_output_tensor; + + MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape, + orig_output_dnn_shape; + GetMklShape(context, kIdxGradient, &input_grad_dnn_shape); + GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape); + GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape); + + if (input_grad_dnn_shape.IsMklTensor()) { + input_gradient_tensor = ConvertMklToTF( + context, MklGetInput(context, kIdxGradient), input_grad_dnn_shape); + } else { + input_gradient_tensor = MklGetInput(context, kIdxGradient); + } - if (orig_output_dnn_shape.IsMklTensor()) { - orig_output_tensor = - ConvertMklToTF(context, - MklGetInput(context, kIdxOrigOutput), - orig_output_dnn_shape); - } else { - orig_output_tensor = MklGetInput(context, kIdxOrigOutput); - } + if (orig_input_dnn_shape.IsMklTensor()) { + orig_input_tensor = ConvertMklToTF( + context, MklGetInput(context, kIdxOrigInput), orig_input_dnn_shape); + } else { + orig_input_tensor = MklGetInput(context, kIdxOrigInput); + } - const int64 batch = static_cast(input_gradient_tensor.dim_size(0)); - const int64 rows = static_cast(input_gradient_tensor.dim_size(1)); - const int64 cols = static_cast(input_gradient_tensor.dim_size(2)); - const int64 depth = static_cast(input_gradient_tensor.dim_size(3)); - const auto nodes = cols * rows; + if (orig_output_dnn_shape.IsMklTensor()) { + orig_output_tensor = ConvertMklToTF( + context, MklGetInput(context, kIdxOrigOutput), orig_output_dnn_shape); + } else { + orig_output_tensor = MklGetInput(context, kIdxOrigOutput); + } - auto grads_shaped = - input_gradient_tensor.shaped({nodes * batch, depth}); + const int64 batch = static_cast(input_gradient_tensor.dim_size(0)); + const int64 rows = static_cast(input_gradient_tensor.dim_size(1)); + const int64 cols = static_cast(input_gradient_tensor.dim_size(2)); + const int64 depth = static_cast(input_gradient_tensor.dim_size(3)); + const auto nodes = cols * rows; - auto in_shaped = orig_input_tensor.shaped({nodes * batch, depth}); - auto activations = - orig_output_tensor.shaped({nodes * batch, depth}); + auto grads_shaped = + input_gradient_tensor.shaped({nodes * batch, depth}); - Tensor* output_dnn_data; - MklShape mkl_output_mkl_shape; - mkl_output_mkl_shape.SetMklTensor(false); - mkl_output_mkl_shape.SetDimensions(4); - AllocateOutputSetMklShape(context, kIdxOutput, - &output_dnn_data, - input_gradient_tensor.shape(), - mkl_output_mkl_shape); + auto in_shaped = orig_input_tensor.shaped({nodes * batch, depth}); + auto activations = orig_output_tensor.shaped({nodes * batch, depth}); - auto out_shaped = output_dnn_data->shaped({nodes * batch, depth}); - out_shaped.setZero(); - auto shard = [this, activations, in_shaped, grads_shaped, out_shaped, - depth](int64 begin, int64 end) { - for (int64 i = begin; i < end; ++i) { - for (int64 j = 0; j < depth; ++j) { - int64 depth_begin = std::max(0, j - depth_radius_); - int64 depth_end = std::min(depth, j + depth_radius_ + 1); + Tensor* output_dnn_data; + MklShape mkl_output_mkl_shape; + mkl_output_mkl_shape.SetMklTensor(false); + mkl_output_mkl_shape.SetDimensions(4); + AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data, + input_gradient_tensor.shape(), + mkl_output_mkl_shape); - T norm(0); - for (int64 k = depth_begin; k < depth_end; ++k) { - norm += in_shaped(i, k) * in_shaped(i, k); - } - norm = alpha_ * norm + bias_; - DCHECK_GT(norm, T(1e-6)); - for (int64 k = depth_begin; k < depth_end; ++k) { - T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) * - activations(i, j) / norm; - if (k == j) { - dyi += Eigen::numext::pow(norm, -beta_); - } - dyi *= grads_shaped(i, j); - const_cast::Tensor&>(out_shaped)(i, k) += - dyi; + auto out_shaped = output_dnn_data->shaped({nodes * batch, depth}); + out_shaped.setZero(); + auto shard = [this, activations, in_shaped, grads_shaped, out_shaped, + depth](int64 begin, int64 end) { + for (int64 i = begin; i < end; ++i) { + for (int64 j = 0; j < depth; ++j) { + int64 depth_begin = std::max(0, j - depth_radius_); + int64 depth_end = std::min(depth, j + depth_radius_ + 1); + + T norm(0); + for (int64 k = depth_begin; k < depth_end; ++k) { + norm += in_shaped(i, k) * in_shaped(i, k); + } + norm = alpha_ * norm + bias_; + DCHECK_GT(norm, T(1e-6)); + for (int64 k = depth_begin; k < depth_end; ++k) { + T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) * + activations(i, j) / norm; + if (k == j) { + dyi += Eigen::numext::pow(norm, -beta_); } + dyi *= grads_shaped(i, j); + const_cast::Tensor&>(out_shaped)(i, k) += dyi; } } - }; - auto worker_threads = - *(context->device()->tensorflow_cpu_worker_threads()); - Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch, - depth * depth, shard); - } + } + }; + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch, + depth * depth, shard); + } void SanityCheckInputs(OpKernelContext* context) { const Tensor& input_gradient_tensor = MklGetInput(context, kIdxGradient); @@ -1317,59 +1281,59 @@ class MklLRNGradOp : public OpKernel { const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput); const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace); MklDnnShape in_grads_dnn_shape, in_image_dnn_shape, out_image_dnn_shape, - workspace_dnn_shape; + workspace_dnn_shape; GetMklShape(context, kIdxGradient, &in_grads_dnn_shape); GetMklShape(context, kIdxOrigInput, &in_image_dnn_shape); GetMklShape(context, kIdxOrigOutput, &out_image_dnn_shape); GetMklShape(context, kIdxWorkspace, &workspace_dnn_shape); if (in_grads_dnn_shape.IsMklTensor()) { OP_REQUIRES(context, in_grads_dnn_shape.GetDimension() == 4, - errors::InvalidArgument("Input gradient must be " - "4-dimensional")); + errors::InvalidArgument("Input gradient must be " + "4-dimensional")); } else { - OP_REQUIRES(context, input_gradient_tensor.dims() == 4, - errors::InvalidArgument("input gradient must be 4-dimensional")); + OP_REQUIRES( + context, input_gradient_tensor.dims() == 4, + errors::InvalidArgument("input gradient must be 4-dimensional")); } if (in_image_dnn_shape.IsMklTensor()) { OP_REQUIRES(context, in_image_dnn_shape.GetDimension() == 4, - errors::InvalidArgument("input images must be " - "4-dimensional")); + errors::InvalidArgument("input images must be " + "4-dimensional")); } else { OP_REQUIRES(context, orig_input_tensor.dims() == 4, errors::InvalidArgument("input images must be " - "4-dimensional")); + "4-dimensional")); } if (out_image_dnn_shape.IsMklTensor()) { OP_REQUIRES(context, out_image_dnn_shape.GetDimension() == 4, - errors::InvalidArgument("Output image must be " - "4-dimensional")); + errors::InvalidArgument("Output image must be " + "4-dimensional")); } else { - OP_REQUIRES(context, orig_output_tensor.dims() == 4, - errors::InvalidArgument("Output image must be 4-dimensional")); + OP_REQUIRES( + context, orig_output_tensor.dims() == 4, + errors::InvalidArgument("Output image must be 4-dimensional")); } if (workspace_enabled_) { if (workspace_dnn_shape.IsMklTensor()) { - OP_REQUIRES(context, workspace_dnn_shape.IsMklTensor() == false, - errors::InvalidArgument("Workspace should not be MKL Tensor.")); + OP_REQUIRES( + context, workspace_dnn_shape.IsMklTensor() == false, + errors::InvalidArgument("Workspace should not be MKL Tensor.")); } else { OP_REQUIRES(context, workspace_tensor.dims() == 1, - errors::InvalidArgument("Workspace must be 1-dimensional")); + errors::InvalidArgument("Workspace must be 1-dimensional")); } } } -// Input("input_grads: T") -// Input("input_image: T") -// Input("output_image: T") -// Input("workspace: uint8") - const int kIdxGradient = 0, - kIdxOrigInput = 1, - kIdxOrigOutput = 2, - kIdxWorkspace = 3, - kIdxOutput = 0; + // Input("input_grads: T") + // Input("input_image: T") + // Input("output_image: T") + // Input("workspace: uint8") + const int kIdxGradient = 0, kIdxOrigInput = 1, kIdxOrigOutput = 2, + kIdxWorkspace = 3, kIdxOutput = 0; typedef typename Eigen::Tensor::DimensionPair DimPair; bool workspace_enabled_; @@ -1379,7 +1343,7 @@ class MklLRNGradOp : public OpKernel { float beta_; }; -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML #define REGISTER_MKL_LRN_CPU(T) \ REGISTER_KERNEL_BUILDER(Name("_MklLRN") \ @@ -1393,7 +1357,6 @@ class MklLRNGradOp : public OpKernel { .Label(mkl_op_registry::kMklOpLabel), \ MklLRNGradOp); - TF_CALL_float(REGISTER_MKL_LRN_CPU); } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 47598f443f76f17a6c0b4005327a4e7d00a6beba..25ad8c94a78a82cc7e4a6f98903aecf1d5a0d1b4 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -170,32 +170,32 @@ class MklMatMulOp : public OpKernel { // Matrix-Matrix Multiplication with Complex64 (std::complex) tensors. // For detailed info about parameters, look at FP32 function description. void MklBlasGemm(bool transa, bool transb, const int m, const int n, - const int k, const std::complex* a, const int lda, - const std::complex* b, const int ldb, - std::complex* c, int const ldc) { + const int k, const complex64* a, const int lda, + const complex64* b, const int ldb, + complex64* c, int const ldc) { const MKL_Complex8 alpha = {1.0f, 0.0f}; const MKL_Complex8 beta = {0.0f, 0.0f}; cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, - static_cast(&alpha), static_cast(a), - lda, static_cast(b), ldb, - static_cast(&beta), static_cast(c), ldc); + transb ? CblasTrans : CblasNoTrans, + m, n, k, &alpha, reinterpret_cast(a), lda, + reinterpret_cast(b), ldb, &beta, + reinterpret_cast(c), ldc); } // Matrix-Matrix Multiplication with Complex128 (std::complex) // tensors. For detailed info about parameters, look at FP32 function // description. void MklBlasGemm(bool transa, bool transb, const int m, const int n, - const int k, const std::complex* a, const int lda, - const std::complex* b, const int ldb, - std::complex* c, const int ldc) { + const int k, const complex128* a, const int lda, + const complex128* b, const int ldb, + complex128* c, const int ldc) { const MKL_Complex16 alpha = {1.0, 0.0}; const MKL_Complex16 beta = {0.0, 0.0}; cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, - static_cast(&alpha), static_cast(a), - lda, static_cast(b), ldb, - static_cast(&beta), static_cast(c), ldc); + transb ? CblasTrans : CblasNoTrans, + m, n, k, &alpha, reinterpret_cast(a), lda, + reinterpret_cast(b), ldb, &beta, + reinterpret_cast(c), ldc); } }; diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index 82c5229bab0cfef51799d521d6ced6fab804176c..14607f26e0ccd1028dd62343000d90ac8451d7bb 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -22,25 +22,25 @@ limitations under the License. #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include #include "mkldnn.hpp" -using mkldnn::memory; +using mkldnn::algorithm; +using mkldnn::engine; using mkldnn::error; -using mkldnn::pooling_forward; -using mkldnn::pooling_backward; +using mkldnn::memory; using mkldnn::padding_kind; -using mkldnn::engine; +using mkldnn::pooling_backward; +using mkldnn::pooling_forward; using mkldnn::prop_kind; -using mkldnn::algorithm; #endif namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -// For now, MKL-ML is default. So making MKL-DNN not a default choice. -#ifndef INTEL_MKL_DNN +// MKL-DNN is now default. MKL-ML must be specified explicitly. +#ifdef INTEL_MKL_ML // An implementation of MaxPooling (forward). template @@ -397,18 +397,19 @@ class MklMaxPoolingGradOp : public OpKernel { if (workspace_enabled == false) { if (convert_input != nullptr) { if (input_in_mkl_format == false) { - CHECK_EQ( - dnnConversionExecute_F32( - convert_input, const_cast(static_cast( - tensor_in.flat().data())), - input_buf), - E_SUCCESS); + CHECK_EQ(dnnConversionExecute_F32( + convert_input, + const_cast(static_cast( + tensor_in.flat().data())), + input_buf), + E_SUCCESS); CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS); convert_input = nullptr; } else { input_shape.GetConvertedFlatData( - lt_input_prim, const_cast(static_cast( - tensor_in.flat().data())), + lt_input_prim, + const_cast( + static_cast(tensor_in.flat().data())), input_buf); } pooling_resfwd[dnnResourceSrc] = input_buf; @@ -453,8 +454,9 @@ class MklMaxPoolingGradOp : public OpKernel { CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS); } else { output_backprop_shape.GetConvertedFlatData( - lt_outbackprop_prim, const_cast(static_cast( - out_backprop.flat().data())), + lt_outbackprop_prim, + const_cast( + static_cast(out_backprop.flat().data())), outbackprop_buf); } pooling_res[dnnResourceDiffDst] = outbackprop_buf; @@ -492,14 +494,14 @@ class MklMaxPoolingGradOp : public OpKernel { bool workspace_enabled_; }; // MklMaxPoolingGradOp -#else // INTEL_MKL_DNN is defined +#else // An implementation of MaxPooling (forward). template class MklMaxPoolingOp : public MklPoolingForwardOpBase { public: explicit MklMaxPoolingOp(OpKernelConstruction* context) - : MklPoolingForwardOpBase(context) { + : MklPoolingForwardOpBase(context) { // In Max Pooling, MKLDNN does not allow passing workspace as NULL. // So we set workspace_enabled_ to true. this->workspace_enabled_ = true; @@ -508,8 +510,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { void Compute(OpKernelContext* context) override { try { auto cpu_engine = engine(engine::cpu, 0); - const Tensor& input_tensor = MklGetInput(context, - this->kInputTensorIndexInput); + const Tensor& input_tensor = + MklGetInput(context, this->kInputTensorIndexInput); MklDnnShape dnn_shape_input; GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input); this->SanityCheckInput(context, input_tensor, dnn_shape_input); @@ -522,9 +524,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { // initialize variables for the pooling op MklPoolParameters pool_params; // Get the input tensor and initialize the pooling parameters - this->ConfigureInput(context, dnn_shape_input, - input_tensor, &pool_params, - &dnn_data_input); + this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params, + &dnn_data_input); OP_REQUIRES_OK(context, context->status()); // Declare output tensor @@ -535,9 +536,10 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { // If input is in Mkl layout, then just get the memory format from it // directly, instead of using input data_format to MaxPool. if (dnn_shape_input.IsMklTensor()) { - dnn_data_output.SetUsrMem(output_dims_mkl_order, - static_cast( - dnn_data_input.GetUsrMemDesc().data.format)); + dnn_data_output.SetUsrMem( + output_dims_mkl_order, + static_cast( + dnn_data_input.GetUsrMemDesc().data.format)); } else { dnn_data_output.SetUsrMem(output_dims_mkl_order, this->data_format_mkldnn_); @@ -546,24 +548,21 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { // describe the memory layout; let mkl-dnn choose the best for the op dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any); - auto pool_desc = pooling_forward::desc(prop_kind::forward, - algorithm::pooling_max, - dnn_data_input.GetUsrMemDesc(), - dnn_data_output.GetUsrMemDesc(), - memory::dims({ pool_params.row_stride, - pool_params.col_stride}), - memory::dims({ pool_params.window_rows, - pool_params.window_cols}), - memory::dims({ static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({ static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_fwd_desc = pooling_forward::primitive_desc(pool_desc, - cpu_engine); + auto pool_desc = pooling_forward::desc( + prop_kind::forward, algorithm::pooling_max, + dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(), + memory::dims({pool_params.row_stride, pool_params.col_stride}), + memory::dims({pool_params.window_rows, pool_params.window_cols}), + memory::dims({static_cast(pool_params.pad_top), + static_cast(pool_params.pad_left)}), + memory::dims({static_cast(pool_params.pad_bottom), + static_cast(pool_params.pad_right)}), + TFPaddingToMklDnnPadding(this->padding_)); + auto pool_fwd_desc = + pooling_forward::primitive_desc(pool_desc, cpu_engine); this->AllocateOutputTensor(context, pool_fwd_desc, output_dims_mkl_order, - this->data_format_mkldnn_, &output_tensor); + this->data_format_mkldnn_, &output_tensor); OP_REQUIRES_OK(context, context->status()); dnn_data_output.SetUsrMemDataHandle(output_tensor); @@ -571,39 +570,38 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { OP_REQUIRES_OK(context, context->status()); this->PrepareAndExecuteNet(pool_fwd_desc, &dnn_data_input, - &dnn_data_output, &dnn_data_wksp); - } catch (mkldnn::error &e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Compute received an exception:", - error_msg)); + &dnn_data_output, &dnn_data_wksp); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", + error_msg)); } } // Compute private: - const int kOutputTensorIndexWorkspace = 1; - - void AllocateWorkspaceTensor(OpKernelContext* context, - const pooling_forward::primitive_desc& pool_fwd_prim_desc, - MklDnnData* dnn_data_wksp) { - CHECK_NOTNULL(dnn_data_wksp); - Tensor* workspace_tensor = nullptr; - memory::primitive_desc workspace_pd - = pool_fwd_prim_desc.workspace_primitive_desc(); - size_t workspace_bytes = workspace_pd.get_size(); - MklDnnShape workspace_mkl_shape; - workspace_mkl_shape.SetMklTensor(false); - TensorShape workspace_tf_shape; - workspace_tf_shape.AddDim(workspace_bytes); - AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace, - &workspace_tensor, - workspace_tf_shape, workspace_mkl_shape); - CHECK_NOTNULL(workspace_tensor); - dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); - } + const int kOutputTensorIndexWorkspace = 1; + + void AllocateWorkspaceTensor( + OpKernelContext* context, + const pooling_forward::primitive_desc& pool_fwd_prim_desc, + MklDnnData* dnn_data_wksp) { + CHECK_NOTNULL(dnn_data_wksp); + Tensor* workspace_tensor = nullptr; + memory::primitive_desc workspace_pd = + pool_fwd_prim_desc.workspace_primitive_desc(); + size_t workspace_bytes = workspace_pd.get_size(); + MklDnnShape workspace_mkl_shape; + workspace_mkl_shape.SetMklTensor(false); + TensorShape workspace_tf_shape; + workspace_tf_shape.AddDim(workspace_bytes); + AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace, + &workspace_tensor, workspace_tf_shape, + workspace_mkl_shape); + CHECK_NOTNULL(workspace_tensor); + dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); + } }; // The operation to compute MaxPool gradients. @@ -616,221 +614,186 @@ template class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { public: explicit MklMaxPoolingGradOp(OpKernelConstruction* context) - : MklPoolingBackwardOpBase(context) { - } + : MklPoolingBackwardOpBase(context) {} void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); - const Tensor& orig_input_tensor = MklGetInput(context, - kInputTensorIndexOrigInput); - const Tensor& orig_output_tensor = MklGetInput(context, - kInputTensorIndexOrigOutput); - const Tensor& grad_tensor = MklGetInput(context, - kInputTensorIndexGradient); - const Tensor& workspace_tensor = MklGetInput(context, - kInputTensorIndexWorkspace); - MklDnnShape orig_input_mkl_shape, - orig_output_mkl_shape, - grad_mkl_shape, - workspace_mkl_shape; - GetMklShape(context, kInputTensorIndexOrigInput, - &orig_input_mkl_shape); - GetMklShape(context, kInputTensorIndexOrigOutput, - &orig_output_mkl_shape); - GetMklShape(context, kInputTensorIndexGradient, - &grad_mkl_shape); - GetMklShape(context, kInputTensorIndexWorkspace, - &workspace_mkl_shape); - - SanityCheckInputs(context, - orig_input_tensor, orig_output_tensor, - grad_tensor, workspace_tensor, - orig_input_mkl_shape, orig_output_mkl_shape, - grad_mkl_shape, workspace_mkl_shape); - if (!context->status().ok()) return; - - MklDnnData grad_dnn_data(&cpu_engine); - MklDnnData workspace_dnn_data(&cpu_engine); - MklDnnData output_dnn_data(&cpu_engine); - Tensor* output_tensor = nullptr; - MklPoolParameters pool_params; - TensorShape orig_input_shape; - memory::dims output_dims_mkl_order, orig_input_dims_mkl_order; - memory::desc original_input_md = ConfigureOriginalInput(context, - orig_input_tensor, - orig_input_mkl_shape, - &orig_input_dims_mkl_order, - &pool_params, - &orig_input_shape); - - memory::desc original_output_md = this->ConfigureOriginalOutput( - pool_params, - orig_output_mkl_shape, - output_dims_mkl_order); - - memory::desc target_diff_dst_md = this->ConfigureInputGradient( - grad_mkl_shape, - grad_tensor, - &grad_dnn_data, - original_output_md); - - output_dnn_data.SetUsrMem(original_input_md); - - // Create the forward pooling primitive descriptor so we can - // pass it as a hint to the backward pooling primitive descriptor - auto pool_fwd_desc = pooling_forward::desc(prop_kind::forward, - algorithm::pooling_max, - original_input_md, - original_output_md, - memory::dims({ pool_params.row_stride, - pool_params.col_stride}), - memory::dims({ pool_params.window_rows, - pool_params.window_cols}), - memory::dims({ static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({ static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_fwd_prim_desc - = pooling_forward::primitive_desc(pool_fwd_desc, - cpu_engine); - - auto pool_bkwd_desc = pooling_backward::desc( - algorithm::pooling_max, - output_dnn_data.GetUsrMemDesc(), - target_diff_dst_md, - memory::dims({ pool_params.row_stride, - pool_params.col_stride}), - memory::dims({ pool_params.window_rows, - pool_params.window_cols}), - memory::dims({ static_cast(pool_params.pad_top), - static_cast(pool_params.pad_left)}), - memory::dims({ static_cast(pool_params.pad_bottom), - static_cast(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_bkwd_prim_desc - = pooling_backward::primitive_desc(pool_bkwd_desc, - cpu_engine, - pool_fwd_prim_desc); - - this->AllocateOutputTensor(context, pool_bkwd_prim_desc, - orig_input_dims_mkl_order, - this->data_format_mkldnn_, - &output_tensor); - output_dnn_data.SetUsrMemDataHandle(output_tensor); - - ConfigureWorkspace(workspace_tensor, - pool_fwd_prim_desc.workspace_primitive_desc(), - &workspace_dnn_data); - this->PrepareAndExecuteNet(pool_bkwd_prim_desc, - &grad_dnn_data, - &output_dnn_data, - memory::primitive_desc( - target_diff_dst_md, - cpu_engine), - &workspace_dnn_data); - } catch (mkldnn::error &e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Compute received an exception:", - error_msg)); + auto cpu_engine = engine(engine::cpu, 0); + const Tensor& orig_input_tensor = + MklGetInput(context, kInputTensorIndexOrigInput); + const Tensor& orig_output_tensor = + MklGetInput(context, kInputTensorIndexOrigOutput); + const Tensor& grad_tensor = + MklGetInput(context, kInputTensorIndexGradient); + const Tensor& workspace_tensor = + MklGetInput(context, kInputTensorIndexWorkspace); + MklDnnShape orig_input_mkl_shape, orig_output_mkl_shape, grad_mkl_shape, + workspace_mkl_shape; + GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape); + GetMklShape(context, kInputTensorIndexOrigOutput, &orig_output_mkl_shape); + GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape); + GetMklShape(context, kInputTensorIndexWorkspace, &workspace_mkl_shape); + + SanityCheckInputs(context, orig_input_tensor, orig_output_tensor, + grad_tensor, workspace_tensor, orig_input_mkl_shape, + orig_output_mkl_shape, grad_mkl_shape, + workspace_mkl_shape); + if (!context->status().ok()) return; + + MklDnnData grad_dnn_data(&cpu_engine); + MklDnnData workspace_dnn_data(&cpu_engine); + MklDnnData output_dnn_data(&cpu_engine); + Tensor* output_tensor = nullptr; + MklPoolParameters pool_params; + TensorShape orig_input_shape; + memory::dims output_dims_mkl_order, orig_input_dims_mkl_order; + memory::desc original_input_md = ConfigureOriginalInput( + context, orig_input_tensor, orig_input_mkl_shape, + &orig_input_dims_mkl_order, &pool_params, &orig_input_shape); + + memory::desc original_output_md = this->ConfigureOriginalOutput( + pool_params, orig_output_mkl_shape, output_dims_mkl_order); + + memory::desc target_diff_dst_md = this->ConfigureInputGradient( + grad_mkl_shape, grad_tensor, &grad_dnn_data, original_output_md); + + output_dnn_data.SetUsrMem(original_input_md); + + // Create the forward pooling primitive descriptor so we can + // pass it as a hint to the backward pooling primitive descriptor + auto pool_fwd_desc = pooling_forward::desc( + prop_kind::forward, algorithm::pooling_max, original_input_md, + original_output_md, + memory::dims({pool_params.row_stride, pool_params.col_stride}), + memory::dims({pool_params.window_rows, pool_params.window_cols}), + memory::dims({static_cast(pool_params.pad_top), + static_cast(pool_params.pad_left)}), + memory::dims({static_cast(pool_params.pad_bottom), + static_cast(pool_params.pad_right)}), + TFPaddingToMklDnnPadding(this->padding_)); + auto pool_fwd_prim_desc = + pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine); + + auto pool_bkwd_desc = pooling_backward::desc( + algorithm::pooling_max, output_dnn_data.GetUsrMemDesc(), + target_diff_dst_md, + memory::dims({pool_params.row_stride, pool_params.col_stride}), + memory::dims({pool_params.window_rows, pool_params.window_cols}), + memory::dims({static_cast(pool_params.pad_top), + static_cast(pool_params.pad_left)}), + memory::dims({static_cast(pool_params.pad_bottom), + static_cast(pool_params.pad_right)}), + TFPaddingToMklDnnPadding(this->padding_)); + auto pool_bkwd_prim_desc = pooling_backward::primitive_desc( + pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc); + + this->AllocateOutputTensor(context, pool_bkwd_prim_desc, + orig_input_dims_mkl_order, + this->data_format_mkldnn_, &output_tensor); + output_dnn_data.SetUsrMemDataHandle(output_tensor); + + ConfigureWorkspace(workspace_tensor, + pool_fwd_prim_desc.workspace_primitive_desc(), + &workspace_dnn_data); + this->PrepareAndExecuteNet( + pool_bkwd_prim_desc, &grad_dnn_data, &output_dnn_data, + memory::primitive_desc(target_diff_dst_md, cpu_engine), + &workspace_dnn_data); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", + error_msg)); } } // Compute private: - // .Input("orig_input: T") - // .Input("orig_output: T") - // .Input("grad: T") - // .Input("workspace: T") - const int kInputTensorIndexOrigInput = 0; - const int kInputTensorIndexOrigOutput = 1; - const int kInputTensorIndexGradient = 2; - const int kInputTensorIndexWorkspace = 3; - // Output("output: T") in Base Class - - memory::desc ConfigureOriginalInput(OpKernelContext* context, - const Tensor& tensor_original_input, - const MklDnnShape& original_input_mkl_shape, - memory::dims* original_input_dims_mkl_order, - MklPoolParameters* pool_params, - TensorShape* input_tensor_shape) { - *input_tensor_shape = tensor_original_input.shape(); - return MklPoolingBackwardOpBase::ConfigureOriginalInput( - context, - tensor_original_input, - original_input_mkl_shape, - original_input_dims_mkl_order, - pool_params, - *input_tensor_shape); - } + // .Input("orig_input: T") + // .Input("orig_output: T") + // .Input("grad: T") + // .Input("workspace: T") + const int kInputTensorIndexOrigInput = 0; + const int kInputTensorIndexOrigOutput = 1; + const int kInputTensorIndexGradient = 2; + const int kInputTensorIndexWorkspace = 3; + // Output("output: T") in Base Class + + memory::desc ConfigureOriginalInput( + OpKernelContext* context, const Tensor& tensor_original_input, + const MklDnnShape& original_input_mkl_shape, + memory::dims* original_input_dims_mkl_order, + MklPoolParameters* pool_params, TensorShape* input_tensor_shape) { + *input_tensor_shape = tensor_original_input.shape(); + return MklPoolingBackwardOpBase::ConfigureOriginalInput( + context, tensor_original_input, original_input_mkl_shape, + original_input_dims_mkl_order, pool_params, *input_tensor_shape); + } - void ConfigureWorkspace(const Tensor& workspace_tensor, - memory::primitive_desc workspace_pd, - MklDnnData *workspace_dnn_data) { - CHECK_NOTNULL(workspace_dnn_data); + void ConfigureWorkspace(const Tensor& workspace_tensor, + memory::primitive_desc workspace_pd, + MklDnnData* workspace_dnn_data) { + CHECK_NOTNULL(workspace_dnn_data); - workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor); - } + workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor); + } - void SanityCheckInputs(OpKernelContext* context, - const Tensor& orig_input_tensor, - const Tensor& orig_output_tensor, - const Tensor& grad_tensor, - const Tensor& workspace_tensor, - const MklDnnShape& orig_input_mkl_shape, - const MklDnnShape& orig_output_mkl_shape, - const MklDnnShape& grad_mkl_shape, - const MklDnnShape& workspace_mkl_shape) { - if (!orig_input_mkl_shape.IsMklTensor()) { - OP_REQUIRES(context, orig_input_tensor.dims() == 4, - errors::InvalidArgument("Original input shape must be " - "4-dimensional")); - } else { - OP_REQUIRES(context, orig_input_mkl_shape.GetDimension() == 4, - errors::InvalidArgument("Original input shape must be " - "4-dimensional")); - } - if (!orig_output_mkl_shape.IsMklTensor()) { - OP_REQUIRES(context, orig_output_tensor.dims() == 4, - errors::InvalidArgument("Original output must be " - "4-dimensional")); - } else { - OP_REQUIRES(context, orig_output_mkl_shape.GetDimension() == 4, - errors::InvalidArgument("Original output must be " - "4-dimensional")); - } - if (!grad_mkl_shape.IsMklTensor()) { - OP_REQUIRES(context, grad_tensor.dims() == 4, - errors::InvalidArgument("Gradient must be 4-dimensional")); - } else { - OP_REQUIRES(context, grad_mkl_shape.GetDimension() == 4, - errors::InvalidArgument("Gradient must be " - "4-dimensional")); - } - if (this->workspace_enabled_) { - // The workspace should not be an MKL tensor - OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false, - errors::InvalidArgument("Workspace tensor should not" - " be an MKL Tensor.")); - // It should only have one dimension - OP_REQUIRES(context, workspace_tensor.dims() == 1, - errors::InvalidArgument("Workspace tensor must be " - "1-dimensional")); - } else { - OP_REQUIRES(context, this->workspace_enabled_, - errors::Unimplemented("MKL-DNN Max Pooling does not " + void SanityCheckInputs(OpKernelContext* context, + const Tensor& orig_input_tensor, + const Tensor& orig_output_tensor, + const Tensor& grad_tensor, + const Tensor& workspace_tensor, + const MklDnnShape& orig_input_mkl_shape, + const MklDnnShape& orig_output_mkl_shape, + const MklDnnShape& grad_mkl_shape, + const MklDnnShape& workspace_mkl_shape) { + if (!orig_input_mkl_shape.IsMklTensor()) { + OP_REQUIRES(context, orig_input_tensor.dims() == 4, + errors::InvalidArgument("Original input shape must be " + "4-dimensional")); + } else { + OP_REQUIRES(context, orig_input_mkl_shape.GetDimension() == 4, + errors::InvalidArgument("Original input shape must be " + "4-dimensional")); + } + if (!orig_output_mkl_shape.IsMklTensor()) { + OP_REQUIRES(context, orig_output_tensor.dims() == 4, + errors::InvalidArgument("Original output must be " + "4-dimensional")); + } else { + OP_REQUIRES(context, orig_output_mkl_shape.GetDimension() == 4, + errors::InvalidArgument("Original output must be " + "4-dimensional")); + } + if (!grad_mkl_shape.IsMklTensor()) { + OP_REQUIRES(context, grad_tensor.dims() == 4, + errors::InvalidArgument("Gradient must be 4-dimensional")); + } else { + OP_REQUIRES(context, grad_mkl_shape.GetDimension() == 4, + errors::InvalidArgument("Gradient must be " + "4-dimensional")); + } + if (this->workspace_enabled_) { + // The workspace should not be an MKL tensor + OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false, + errors::InvalidArgument("Workspace tensor should not" + " be an MKL Tensor.")); + // It should only have one dimension + OP_REQUIRES(context, workspace_tensor.dims() == 1, + errors::InvalidArgument("Workspace tensor must be " + "1-dimensional")); + } else { + OP_REQUIRES( + context, this->workspace_enabled_, + errors::Unimplemented("MKL-DNN Max Pooling does not " "yet support the use case " "where MaxPoolGrad is called without first" " calling MaxPool.")); - } } + } }; // MklMaxPoolingGradOp -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML REGISTER_KERNEL_BUILDER(Name("_MklMaxPool") .Device(DEVICE_CPU) diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index f7cadffd39c11bdedaca6a07e48f222e7ac5e0cb..5ef6ce2a5789034b338fe7308a6eca02f135befa 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -15,9 +15,9 @@ limitations under the License. #ifdef INTEL_MKL -#include -#include #include "tensorflow/core/kernels/mkl_pooling_ops_common.h" +#include +#include #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -42,7 +42,7 @@ void MklPoolParameters::Init(OpKernelContext* context, Init(context, ksize, stride, padding, data_format); } -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML // Initialization for MKL format void MklPoolParameters::Init(OpKernelContext* context, const std::vector& ksize, @@ -72,7 +72,7 @@ void MklPoolParameters::Init(OpKernelContext* context, Init(context, ksize, stride, padding, data_format); } -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML // Common Initialization for TensorFlow and MKL formats void MklPoolParameters::Init(OpKernelContext* context, const std::vector& ksize, @@ -107,21 +107,21 @@ void MklPoolParameters::Init(OpKernelContext* context, OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( tensor_in_cols, window_cols, col_stride, padding, &out_width, &pad_left, &pad_right)); -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML // TF can work with int64, but mkldnn only supports int32 // Fail if the height or width are greater than MAX_INT - OP_REQUIRES(context, FastBoundsCheck(out_height, - std::numeric_limits::max()), + OP_REQUIRES(context, + FastBoundsCheck(out_height, std::numeric_limits::max()), errors::InvalidArgument("output height is too large")); - OP_REQUIRES(context, FastBoundsCheck(out_width, - std::numeric_limits::max()), + OP_REQUIRES(context, + FastBoundsCheck(out_width, std::numeric_limits::max()), errors::InvalidArgument("output width is too large")); #endif out_depth = depth; // output will have the same depth as the input - } else { // we are pooling in the depth dimension + } else { // we are pooling in the depth dimension // Our current version of depthwise max pooling does not support // any padding, and expects the depth_window to equal the depth // stride (no overlapping). diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h index b974b2c59afe91b955af45f3851c7371d8a86610..279167aba24863441774b0665e9793e52d84ccfa 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h @@ -17,16 +17,16 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_ #ifdef INTEL_MKL -#include #include +#include #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" using mkldnn::memory; -using mkldnn::pooling_forward; using mkldnn::pooling_backward; +using mkldnn::pooling_forward; using mkldnn::stream; #endif @@ -61,19 +61,31 @@ struct MklPoolParameters { TensorFormat data_format; MklPoolParameters() - : depth(0) - , tensor_in_cols(0), tensor_in_rows(0), tensor_in_batch(0) - , window_rows(0), window_cols(0), depth_window(0) - , row_stride(0), col_stride(0), depth_stride(0) - , out_height(0), out_width(0), out_depth(0) - , pad_left(0), pad_right(0), pad_top(0), pad_bottom(0), pad_depth(0) - , data_format(TensorFormat::FORMAT_NCHW) {} + : depth(0), + tensor_in_cols(0), + tensor_in_rows(0), + tensor_in_batch(0), + window_rows(0), + window_cols(0), + depth_window(0), + row_stride(0), + col_stride(0), + depth_stride(0), + out_height(0), + out_width(0), + out_depth(0), + pad_left(0), + pad_right(0), + pad_top(0), + pad_bottom(0), + pad_depth(0), + data_format(TensorFormat::FORMAT_NCHW) {} // Updates context->status if there is an invalid input. void Init(OpKernelContext* context, const std::vector& ksize, const std::vector& stride, Padding padding, TensorFormat data_format, const TensorShape& tensor_in_shape); -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML void Init(OpKernelContext* context, const std::vector& ksize, const std::vector& stride, Padding padding, TensorFormat data_format, const MklShape* mkl_in_shape); @@ -90,39 +102,37 @@ struct MklPoolParameters { TensorFormat data_format); }; -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML template class MklPoolingOpBase : public OpKernel { public: explicit MklPoolingOpBase(OpKernelConstruction* context) - : OpKernel(context) - , workspace_enabled_(false) { - string data_format; - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, - FormatFromString(data_format, &this->data_format_tf_), - errors::InvalidArgument("Invalid data format")); - this->data_format_mkldnn_ - = TFDataFormatToMklDnnDataFormat(this->data_format_tf_); - OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_)); - OP_REQUIRES(context, this->ksize_.size() == 4, - errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_)); - OP_REQUIRES(context, this->stride_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_)); - OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1, - errors::Unimplemented("Pooling is not yet supported on the " - "batch dimension.")); - - // We may not get this attribute for this node if it does not go through - // graph rewrite pass. So we do not check for error while retrieving this - // attribute value. - context->GetAttr("workspace_enabled", &this->workspace_enabled_); - } + : OpKernel(context), workspace_enabled_(false) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_), + errors::InvalidArgument("Invalid data format")); + this->data_format_mkldnn_ = + TFDataFormatToMklDnnDataFormat(this->data_format_tf_); + OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_)); + OP_REQUIRES(context, this->ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_)); + OP_REQUIRES(context, this->stride_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_)); + OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1, + errors::Unimplemented("Pooling is not yet supported on the " + "batch dimension.")); + + // We may not get this attribute for this node if it does not go through + // graph rewrite pass. So we do not check for error while retrieving this + // attribute value. + context->GetAttr("workspace_enabled", &this->workspace_enabled_); + } void Compute(OpKernelContext* context) override = 0; protected: @@ -132,24 +142,24 @@ class MklPoolingOpBase : public OpKernel { // output height and output width to have already been int32 // bounds-checked void GetOutputDims(const MklPoolParameters& mkl_pool_params, - memory::dims* output_dims_mkl_order) { + memory::dims* output_dims_mkl_order) { // MKL-DNN always needs output in NCHW format. - *output_dims_mkl_order = { mkl_pool_params.tensor_in_batch, + *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch, mkl_pool_params.out_depth, static_cast(mkl_pool_params.out_height), static_cast(mkl_pool_params.out_width)}; } void InitMklPoolParameters(OpKernelContext* context, - MklPoolParameters* pool_params, - const MklDnnShape& original_input_mkl_shape, - const TensorShape& input_tensor_shape) { + MklPoolParameters* pool_params, + const MklDnnShape& original_input_mkl_shape, + const TensorShape& input_tensor_shape) { if (!original_input_mkl_shape.IsMklTensor()) { pool_params->Init(context, this->ksize_, this->stride_, this->padding_, - this->data_format_tf_, input_tensor_shape); + this->data_format_tf_, input_tensor_shape); } else { pool_params->Init(context, this->ksize_, this->stride_, this->padding_, - this->data_format_tf_, &original_input_mkl_shape); + this->data_format_tf_, &original_input_mkl_shape); } } @@ -159,13 +169,12 @@ class MklPoolingOpBase : public OpKernel { size_t GetNumTElements(const memory::primitive_desc& pd) { size_t num_bytes = pd.get_size(); size_t ret_val = num_bytes / sizeof(T); - if ( num_bytes % sizeof(T) != 0 ) { - ret_val++; + if (num_bytes % sizeof(T) != 0) { + ret_val++; } return ret_val; } - std::vector ksize_; std::vector stride_; Padding padding_; @@ -183,30 +192,29 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase { protected: void ConfigureInput(OpKernelContext* context, - const MklDnnShape& input_mkl_shape, - const Tensor& input_tensor, - MklPoolParameters* pool_params, - MklDnnData* dnn_data_input) { + const MklDnnShape& input_mkl_shape, + const Tensor& input_tensor, + MklPoolParameters* pool_params, + MklDnnData* dnn_data_input) { CHECK_NOTNULL(pool_params); CHECK_NOTNULL(dnn_data_input); TensorShape input_tensor_shape = input_tensor.shape(); - memory::desc input_md = input_mkl_shape.IsMklTensor() - ? input_mkl_shape.GetMklLayout() - : memory::desc( - TFShapeToMklDnnDimsInNCHW( - input_tensor_shape, this->data_format_tf_), - MklDnnType(), - this->data_format_mkldnn_); + memory::desc input_md = + input_mkl_shape.IsMklTensor() + ? input_mkl_shape.GetMklLayout() + : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape, + this->data_format_tf_), + MklDnnType(), this->data_format_mkldnn_); dnn_data_input->SetUsrMem(input_md, &input_tensor); - this->InitMklPoolParameters(context, pool_params, - input_mkl_shape, input_tensor_shape); + this->InitMklPoolParameters(context, pool_params, input_mkl_shape, + input_tensor_shape); } - void AllocateOutputTensor(OpKernelContext* context, - const pooling_forward::primitive_desc& pool_fwd_prim_desc, - const memory::dims output_dims_mkl_order, - const memory::format& output_tf_format, - Tensor** output_tensor) { + void AllocateOutputTensor( + OpKernelContext* context, + const pooling_forward::primitive_desc& pool_fwd_prim_desc, + const memory::dims output_dims_mkl_order, + const memory::format& output_tf_format, Tensor** output_tensor) { CHECK_NOTNULL(output_tensor); memory::primitive_desc dst_pd = pool_fwd_prim_desc.dst_primitive_desc(); @@ -215,50 +223,42 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase { output_mkl_shape.SetMklLayout(&dst_pd); output_mkl_shape.SetElemType(MklDnnType()); output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), - output_dims_mkl_order, - output_tf_format); + output_dims_mkl_order, output_tf_format); TensorShape output_tf_shape; // only allocate enough space for the elements we need. output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); - AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, - output_tensor, - output_tf_shape, output_mkl_shape); + AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, + output_tf_shape, output_mkl_shape); CHECK_NOTNULL(*output_tensor); } void PrepareAndExecuteNet( - const pooling_forward::primitive_desc& pool_fwd_desc, - const MklDnnData* src, - MklDnnData* dst, - MklDnnData* wksp = nullptr) { + const pooling_forward::primitive_desc& pool_fwd_desc, + const MklDnnData* src, MklDnnData* dst, + MklDnnData* wksp = nullptr) { std::vector net; // Create pooling primitive and add it to net if (wksp != nullptr) { - net.push_back(pooling_forward(pool_fwd_desc, - src->GetOpMem(), - dst->GetOpMem(), - wksp->GetOpMem())); + net.push_back(pooling_forward(pool_fwd_desc, src->GetOpMem(), + dst->GetOpMem(), wksp->GetOpMem())); } else { - net.push_back(pooling_forward(pool_fwd_desc, - src->GetOpMem(), - dst->GetOpMem())); + net.push_back( + pooling_forward(pool_fwd_desc, src->GetOpMem(), dst->GetOpMem())); } stream(stream::kind::eager).submit(net).wait(); } - - void SanityCheckInput(OpKernelContext* context, - const Tensor& input_tensor, - const MklDnnShape& input_mkl_shape) { + void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor, + const MklDnnShape& input_mkl_shape) { if (!input_mkl_shape.IsMklTensor()) { OP_REQUIRES(context, input_tensor.dims() == 4, - errors::InvalidArgument("Input must be 4-dimensional")); + errors::InvalidArgument("Input must be 4-dimensional")); } else { - OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4, - errors::InvalidArgument("Input shape must be " - "4-dimensional")); + OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4, + errors::InvalidArgument("Input shape must be " + "4-dimensional")); } } // .Input("value: T") @@ -267,66 +267,58 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase { const int kOutputTensorIndexOutput = 0; }; // MklPoolingForwardBaseOp - template class MklPoolingBackwardOpBase : public MklPoolingOpBase { public: explicit MklPoolingBackwardOpBase(OpKernelConstruction* context) - : MklPoolingOpBase(context) { } + : MklPoolingOpBase(context) {} void Compute(OpKernelContext* context) override = 0; protected: const int kOutputTensorIndexOutput = 0; - void AllocateOutputTensor(OpKernelContext* context, - const pooling_backward::primitive_desc& pool_bkwd_prim_desc, - const memory::dims output_dims_mkl_order, - const memory::format& output_tf_format, - Tensor** output_tensor) { + void AllocateOutputTensor( + OpKernelContext* context, + const pooling_backward::primitive_desc& pool_bkwd_prim_desc, + const memory::dims output_dims_mkl_order, + const memory::format& output_tf_format, Tensor** output_tensor) { CHECK_NOTNULL(output_tensor); - memory::primitive_desc dst_pd - = pool_bkwd_prim_desc.diff_src_primitive_desc(); + memory::primitive_desc dst_pd = + pool_bkwd_prim_desc.diff_src_primitive_desc(); MklDnnShape output_mkl_shape; output_mkl_shape.SetMklTensor(true); output_mkl_shape.SetMklLayout(&dst_pd); output_mkl_shape.SetElemType(MklDnnType()); output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), - output_dims_mkl_order, - output_tf_format); + output_dims_mkl_order, output_tf_format); TensorShape output_tf_shape; output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); - AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, - output_tensor, - output_tf_shape, output_mkl_shape); + AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, + output_tf_shape, output_mkl_shape); CHECK_NOTNULL(*output_tensor); } void PrepareAndExecuteNet( - const pooling_backward::primitive_desc& pool_bkwd_desc, - MklDnnData* input_gradient_diff_dst, - MklDnnData* output_diff_src, - const memory::primitive_desc& target_diff_dst_pd, - const MklDnnData* workspace = nullptr) { - + const pooling_backward::primitive_desc& pool_bkwd_desc, + MklDnnData* input_gradient_diff_dst, MklDnnData* output_diff_src, + const memory::primitive_desc& target_diff_dst_pd, + const MklDnnData* workspace = nullptr) { std::vector net; // If the input gradient isn't in the same format as the output // reorder it to the same format as the output - input_gradient_diff_dst->CheckReorderToOpMem( - target_diff_dst_pd, - &net); + input_gradient_diff_dst->CheckReorderToOpMem(target_diff_dst_pd, &net); // Create pooling primitive and add it to net if (nullptr == workspace) { net.push_back(pooling_backward(pool_bkwd_desc, - input_gradient_diff_dst->GetOpMem(), - output_diff_src->GetOpMem())); + input_gradient_diff_dst->GetOpMem(), + output_diff_src->GetOpMem())); } else { - net.push_back(pooling_backward(pool_bkwd_desc, - input_gradient_diff_dst->GetOpMem(), - workspace->GetOpMem(), - output_diff_src->GetOpMem())); + net.push_back( + pooling_backward(pool_bkwd_desc, input_gradient_diff_dst->GetOpMem(), + workspace->GetOpMem(), output_diff_src->GetOpMem())); } stream(stream::kind::eager).submit(net).wait(); } @@ -334,80 +326,76 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase { // Max Pooling and Avg Pooling have slightly different implementations // Takes the Tensor containing original input data and the original // mkl Dnn Shape and populates other data - memory::desc ConfigureOriginalInput(OpKernelContext* context, - const Tensor& tensor_original_input_shape, - const MklDnnShape& original_input_mkl_shape, - memory::dims* original_input_dims_nchw, - MklPoolParameters* pool_params, - const TensorShape& input_tensor_shape) { + memory::desc ConfigureOriginalInput( + OpKernelContext* context, const Tensor& tensor_original_input_shape, + const MklDnnShape& original_input_mkl_shape, + memory::dims* original_input_dims_nchw, MklPoolParameters* pool_params, + const TensorShape& input_tensor_shape) { CHECK_NOTNULL(original_input_dims_nchw); CHECK_NOTNULL(pool_params); - this->InitMklPoolParameters(context, pool_params, - original_input_mkl_shape, - input_tensor_shape); - - *original_input_dims_nchw - = original_input_mkl_shape.IsMklTensor() - ? original_input_mkl_shape.GetSizesAsMklDnnDims() - : TFShapeToMklDnnDimsInNCHW(input_tensor_shape, - this->data_format_tf_); - - return original_input_mkl_shape.IsMklTensor() - ? original_input_mkl_shape.GetMklLayout() - : memory::desc(*original_input_dims_nchw, - MklDnnType(), - this->data_format_mkldnn_); + this->InitMklPoolParameters(context, pool_params, original_input_mkl_shape, + input_tensor_shape); + + *original_input_dims_nchw = + original_input_mkl_shape.IsMklTensor() + ? original_input_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(input_tensor_shape, + this->data_format_tf_); + + return original_input_mkl_shape.IsMklTensor() + ? original_input_mkl_shape.GetMklLayout() + : memory::desc(*original_input_dims_nchw, MklDnnType(), + this->data_format_mkldnn_); } - memory::desc ConfigureOriginalOutput(const MklPoolParameters& pool_params, - const MklDnnShape& original_output_mkl_shape, - memory::dims output_dims_mkl_order) { + memory::desc ConfigureOriginalOutput( + const MklPoolParameters& pool_params, + const MklDnnShape& original_output_mkl_shape, + memory::dims output_dims_mkl_order) { this->GetOutputDims(pool_params, &output_dims_mkl_order); return original_output_mkl_shape.IsMklTensor() - ? original_output_mkl_shape.GetMklLayout() - : memory::desc(output_dims_mkl_order, - MklDnnType(), - this->data_format_mkldnn_); + ? original_output_mkl_shape.GetMklLayout() + : memory::desc(output_dims_mkl_order, MklDnnType(), + this->data_format_mkldnn_); } memory::desc ConfigureInputGradient( - const MklDnnShape& input_gradient_mkl_shape, - const Tensor& input_gradient_tensor, - MklDnnData* input_gradient_dnn_data, - const memory::desc& original_output_md) { + const MklDnnShape& input_gradient_mkl_shape, + const Tensor& input_gradient_tensor, + MklDnnData* input_gradient_dnn_data, + const memory::desc& original_output_md) { // Configure the gradient as is - memory::desc original_input_grad_md - = input_gradient_mkl_shape.IsMklTensor() - ? input_gradient_mkl_shape.GetMklLayout() - : memory::desc(TFShapeToMklDnnDimsInNCHW( - input_gradient_tensor.shape(), - this->data_format_tf_), - MklDnnType(), this->data_format_mkldnn_); + memory::desc original_input_grad_md = + input_gradient_mkl_shape.IsMklTensor() + ? input_gradient_mkl_shape.GetMklLayout() + : memory::desc( + TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(), + this->data_format_tf_), + MklDnnType(), this->data_format_mkldnn_); input_gradient_dnn_data->SetUsrMem(original_input_grad_md, - &input_gradient_tensor); + &input_gradient_tensor); // Check to see if input grad diff dst is in the right format // Create a new memory descriptor with the same shape as the // original, but the format of the other tensors. memory::format original_output_format = - static_cast(original_output_md.data.format); - bool grad_reorder_needed = input_gradient_dnn_data->IsReorderNeeded( - original_output_format); - memory::dims diff_dst_dims = input_gradient_mkl_shape.IsMklTensor() - ? input_gradient_mkl_shape.GetSizesAsMklDnnDims() - : TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(), - this->data_format_tf_); - memory::desc target_diff_dst_md = memory::desc(diff_dst_dims, - MklDnnType(), original_output_format); - - return grad_reorder_needed - ? target_diff_dst_md - : original_input_grad_md; + static_cast(original_output_md.data.format); + bool grad_reorder_needed = + input_gradient_dnn_data->IsReorderNeeded(original_output_format); + memory::dims diff_dst_dims = + input_gradient_mkl_shape.IsMklTensor() + ? input_gradient_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(), + this->data_format_tf_); + memory::desc target_diff_dst_md = + memory::desc(diff_dst_dims, MklDnnType(), original_output_format); + + return grad_reorder_needed ? target_diff_dst_md : original_input_grad_md; } }; -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML //------------------------------------------------------------------- // Utility functions diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index dc899d8c7ee231af403e6ca98ca60d94f78d0a81..51db3991e2a24f087771f571cd91fc9fbb26040b 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -16,29 +16,29 @@ limitations under the License. // See docs in ../ops/nn_ops.cc. #ifdef INTEL_MKL +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/platform/default/logging.h" -#include "tensorflow/core/util/mkl_util.h" #include "mkl_dnn.h" #include "mkl_dnn_types.h" +#include "tensorflow/core/platform/default/logging.h" +#include "tensorflow/core/util/mkl_util.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" -using mkldnn::stream; -using mkldnn::prop_kind; using mkldnn::algorithm; -using mkldnn::relu_forward; -using mkldnn::relu_backward; -using mkldnn::eltwise_relu; using mkldnn::eltwise_elu; +using mkldnn::eltwise_relu; using mkldnn::eltwise_tanh; +using mkldnn::prop_kind; +using mkldnn::relu_backward; +using mkldnn::relu_forward; +using mkldnn::stream; #endif namespace tensorflow { @@ -58,7 +58,7 @@ struct MklReluHelpers { } }; -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML template class MklReluOp : public OpKernel { @@ -180,7 +180,6 @@ class MklReluOp : public OpKernel { } MklReluOpContext; }; - template class MklReluGradOp : public OpKernel { public: @@ -214,10 +213,11 @@ class MklReluGradOp : public OpKernel { if (!dnnLayoutCompare_F32(lt_input, lt_grad)) { AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_grad, &mkl_buffer_convert); - CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, - lt_grad), E_SUCCESS); + CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad), + E_SUCCESS); CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, buf_input, - mkl_buffer_convert), E_SUCCESS); + mkl_buffer_convert), + E_SUCCESS); relu_res[dnnResourceSrc] = mkl_buffer_convert; dnnDelete_F32(cv_input_to_grad); } else { @@ -325,7 +325,8 @@ void MklReluGradOp::Compute(OpKernelContext* context) { float negative_slope = 0.0; CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL, mkl_context.lt_grad, mkl_context.lt_grad, - negative_slope), E_SUCCESS); + negative_slope), + E_SUCCESS); Tensor mkl_tmp_input_buf_tensor; mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_input_buf_tensor); @@ -348,7 +349,8 @@ void MklReluGradOp::Compute(OpKernelContext* context) { } tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast( - mkl_context.output_shape.GetMklLayout())) / sizeof(T)); + mkl_context.output_shape.GetMklLayout())) / + sizeof(T)); AllocateOutputSetMklShape(context, 0, &output, tf_shape, mkl_context.output_shape); } else { @@ -361,22 +363,19 @@ void MklReluGradOp::Compute(OpKernelContext* context) { mkl_context.relu_res[dnnResourceDiffSrc] = static_cast(output->flat().data()); - CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_bwd, - mkl_context.relu_res), - E_SUCCESS); + CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_bwd, mkl_context.relu_res), + E_SUCCESS); mkl_context.MklCleanup(); } - -#else // INTEL_MKL_DNN +#else // INTEL_MKL_ML template class MklReluOpBase : public OpKernel { public: ~MklReluOpBase() {} - explicit MklReluOpBase(OpKernelConstruction* context) : OpKernel(context) { - } + explicit MklReluOpBase(OpKernelConstruction* context) : OpKernel(context) {} virtual void Compute_Scalar(OpKernelContext* context) = 0; @@ -413,12 +412,12 @@ class MklReluOpBase : public OpKernel { T alpha = 0, beta = 0; std::shared_ptr relu_fwd_pd; - auto relu_fwd_desc = relu_forward::desc(prop_kind::forward_training, + auto relu_fwd_desc = relu_forward::desc( + prop_kind::forward_training, // Operator memory descriptor is same as user memory descriptor. - alg_kind, src.GetUsrMemDesc(), - alpha, beta); - relu_fwd_pd.reset(new relu_forward::primitive_desc(relu_fwd_desc, - cpu_engine)); + alg_kind, src.GetUsrMemDesc(), alpha, beta); + relu_fwd_pd.reset( + new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine)); // allocate dst tensor MklDnnShape dnn_shape_dst; @@ -431,7 +430,7 @@ class MklReluOpBase : public OpKernel { dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), dnn_shape_src.GetSizesAsMklDnnDims(), dnn_shape_src.GetTfDataFormat()); - tf_shape_dst.AddDim(dst_pd.get_size()/sizeof(T)); + tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); } else { dnn_shape_dst.SetMklTensor(false); tf_shape_dst = src_tensor.shape(); @@ -445,34 +444,32 @@ class MklReluOpBase : public OpKernel { // execute net std::vector net; - auto relu_fwd = relu_forward(*relu_fwd_pd, src.GetOpMem(), - dst.GetOpMem()); + auto relu_fwd = + relu_forward(*relu_fwd_pd, src.GetOpMem(), dst.GetOpMem()); net.push_back(relu_fwd); stream(stream::kind::eager).submit(net).wait(); - } catch (mkldnn::error &e) { + } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Operation received an exception:", - error_msg)); + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } }; - template class MklReluGradOpBase : public OpKernel { public: ~MklReluGradOpBase() {} - explicit MklReluGradOpBase(OpKernelConstruction* context) : - OpKernel(context) {} + explicit MklReluGradOpBase(OpKernelConstruction* context) + : OpKernel(context) {} virtual void Compute_Scalar(OpKernelContext* context) = 0; - void Compute(OpKernelContext* context) { + void Compute(OpKernelContext* context) { try { auto cpu_engine = engine(engine::cpu, 0); MklDnnData src(&cpu_engine); @@ -483,9 +480,9 @@ class MklReluGradOpBase : public OpKernel { const size_t src_index = 1; // index of src input tensor const size_t diff_src_index = 0; // index of diff_src output tensor - const Tensor& src_tensor = MklGetInput(context, src_index); + const Tensor& src_tensor = MklGetInput(context, src_index); const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); - Tensor* diff_src_tensor = nullptr; + Tensor* diff_src_tensor = nullptr; MklDnnShape dnn_shape_src, dnn_shape_diff_dst; GetMklShape(context, src_index, &dnn_shape_src); @@ -526,25 +523,25 @@ class MklReluGradOpBase : public OpKernel { src_md = dnn_shape_src.GetMklLayout(); memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat(); - auto src_tf_data_format = MklDnnDataFormatToTFDataFormat( - src_mkl_data_format); + auto src_tf_data_format = + MklDnnDataFormatToTFDataFormat(src_mkl_data_format); auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), src_tf_data_format); - diff_dst_md = memory::desc(diff_dst_dims, MklDnnType(), - src_mkl_data_format); + diff_dst_md = + memory::desc(diff_dst_dims, MklDnnType(), src_mkl_data_format); } else if (!dnn_shape_src.IsMklTensor() && - dnn_shape_diff_dst.IsMklTensor()) { + dnn_shape_diff_dst.IsMklTensor()) { // Same comment as above. diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); memory::format diff_dst_mkl_data_format = - dnn_shape_diff_dst.GetTfDataFormat(); - auto diff_dst_tf_data_format = MklDnnDataFormatToTFDataFormat( - diff_dst_mkl_data_format); + dnn_shape_diff_dst.GetTfDataFormat(); + auto diff_dst_tf_data_format = + MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format); auto src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), diff_dst_tf_data_format); - src_md = memory::desc(src_dims, MklDnnType(), - diff_dst_mkl_data_format); + src_md = + memory::desc(src_dims, MklDnnType(), diff_dst_mkl_data_format); } else { // If both the inputs are in MKL format, we use Mkl layout of the input // tensors. @@ -572,12 +569,12 @@ class MklReluGradOpBase : public OpKernel { std::shared_ptr relu_fwd_pd; auto relu_fwd_desc = relu_forward::desc(prop_kind::forward_training, alg_kind, src_md, alpha, beta); - relu_fwd_pd.reset(new relu_forward::primitive_desc(relu_fwd_desc, - cpu_engine)); - auto relu_bwd_desc = relu_backward::desc(alg_kind, common_md, common_md, - alpha, beta); - auto relu_bwd_pd = relu_backward::primitive_desc(relu_bwd_desc, - cpu_engine, *relu_fwd_pd); + relu_fwd_pd.reset( + new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine)); + auto relu_bwd_desc = + relu_backward::desc(alg_kind, common_md, common_md, alpha, beta); + auto relu_bwd_pd = relu_backward::primitive_desc( + relu_bwd_desc, cpu_engine, *relu_fwd_pd); // allocate diff_src tensor MklDnnShape dnn_shape_diff_src; @@ -590,33 +587,32 @@ class MklReluGradOpBase : public OpKernel { dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), dnn_shape_src.GetSizesAsMklDnnDims(), dnn_shape_src.GetTfDataFormat()); - tf_shape_diff_src.AddDim(diff_src_pd.get_size()/sizeof(T)); + tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); } else { dnn_shape_diff_src.SetMklTensor(false); tf_shape_diff_src = src_tensor.shape(); } AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, - tf_shape_diff_src, dnn_shape_diff_src); + tf_shape_diff_src, dnn_shape_diff_src); // diff_src memory descriptor is same as memory descriptor for both // inputs. diff_src.SetUsrMem(common_md, diff_src_tensor); PrepareAndExecuteNet(relu_bwd_pd, &src, &diff_src, &diff_dst); - } catch (mkldnn::error &e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Operation received an exception:", - error_msg)); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } void PrepareAndExecuteNet(const relu_backward::primitive_desc& relu_prim_desc, - MklDnnData* src, MklDnnData* diff_src, MklDnnData* - diff_dst) { + MklDnnData* src, MklDnnData* diff_src, + MklDnnData* diff_dst) { std::vector net; // Check if we need to reorder original input tensors into common_md layout @@ -632,14 +628,13 @@ class MklReluGradOpBase : public OpKernel { } }; - template class MklReluOp : public MklReluOpBase { public: ~MklReluOp() {} - explicit MklReluOp(OpKernelConstruction* context) : - MklReluOpBase(context) {} + explicit MklReluOp(OpKernelConstruction* context) + : MklReluOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t src_index = 0; // index of src input tensor @@ -649,15 +644,15 @@ class MklReluOp : public MklReluOpBase { GetMklShape(context, src_index, &dnn_shape_src); Tensor* dst_tensor = nullptr; - void* user_i = static_cast(const_cast( - src_tensor.flat().data())); + void* user_i = + static_cast(const_cast(src_tensor.flat().data())); MklDnnShape dnn_shape_dst; dnn_shape_dst.SetMklTensor(false); AllocateOutputSetMklShape(context, dst_index, &dst_tensor, src_tensor.shape(), dnn_shape_dst); void* out_o = static_cast(dst_tensor->flat().data()); (static_cast(out_o))[0] = - std::max((static_cast(user_i))[0], static_cast(0)); + std::max((static_cast(user_i))[0], static_cast(0)); return; } }; @@ -667,14 +662,14 @@ class MklReluGradOp : public MklReluGradOpBase { public: ~MklReluGradOp() {} - explicit MklReluGradOp(OpKernelConstruction* context) : - MklReluGradOpBase(context) {} + explicit MklReluGradOp(OpKernelConstruction* context) + : MklReluGradOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t src_index = 1; // index of src input tensor const size_t diff_src_index = 0; // index of diff_src output tensor - const Tensor& src_tensor = MklGetInput(context, src_index); + const Tensor& src_tensor = MklGetInput(context, src_index); const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); Tensor* diff_src_tensor = nullptr; @@ -687,11 +682,11 @@ class MklReluGradOp : public MklReluGradOpBase { diff_dst_tensor.shape(), dnn_shape_diff_src); void* out_o = static_cast(diff_src_tensor->flat().data()); void* user_i = - static_cast(const_cast(src_tensor.flat().data())); + static_cast(const_cast(src_tensor.flat().data())); void* user_g = - static_cast(const_cast(diff_dst_tensor.flat().data())); - (static_cast(out_o))[0] = (static_cast(user_g))[0] * - ((static_cast(user_i))[0] > 0); + static_cast(const_cast(diff_dst_tensor.flat().data())); + (static_cast(out_o))[0] = + (static_cast(user_g))[0] * ((static_cast(user_i))[0] > 0); return; } }; @@ -701,8 +696,8 @@ class MklEluOp : public MklReluOpBase { public: ~MklEluOp() {} - explicit MklEluOp(OpKernelConstruction* context) : - MklReluOpBase(context) {} + explicit MklEluOp(OpKernelConstruction* context) + : MklReluOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t src_index = 0; // index of src input tensor @@ -712,8 +707,8 @@ class MklEluOp : public MklReluOpBase { GetMklShape(context, src_index, &dnn_shape_src); Tensor* dst_tensor = nullptr; - void* user_i = static_cast(const_cast( - src_tensor.flat().data())); + void* user_i = + static_cast(const_cast(src_tensor.flat().data())); MklDnnShape dnn_shape_dst; dnn_shape_dst.SetMklTensor(false); AllocateOutputSetMklShape(context, dst_index, &dst_tensor, @@ -734,14 +729,14 @@ class MklEluGradOp : public MklReluGradOpBase { public: ~MklEluGradOp() {} - explicit MklEluGradOp(OpKernelConstruction* context) : - MklReluGradOpBase(context) {} + explicit MklEluGradOp(OpKernelConstruction* context) + : MklReluGradOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t src_index = 1; // index of src input tensor const size_t diff_src_index = 0; // index of diff_src output tensor - const Tensor& src_tensor = MklGetInput(context, src_index); + const Tensor& src_tensor = MklGetInput(context, src_index); const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); Tensor* diff_src_tensor = nullptr; @@ -754,9 +749,9 @@ class MklEluGradOp : public MklReluGradOpBase { diff_dst_tensor.shape(), dnn_shape_diff_src); void* out_o = static_cast(diff_src_tensor->flat().data()); void* user_i = - static_cast(const_cast(src_tensor.flat().data())); + static_cast(const_cast(src_tensor.flat().data())); void* user_g = - static_cast(const_cast(diff_dst_tensor.flat().data())); + static_cast(const_cast(diff_dst_tensor.flat().data())); // gradient of elu(x) = 1 if x > 0; elu(x) + 1 otherwise T feature = (static_cast(user_i))[0]; if (feature > 0) { @@ -773,8 +768,8 @@ class MklTanhOp : public MklReluOpBase { public: ~MklTanhOp() {} - explicit MklTanhOp(OpKernelConstruction* context) : - MklReluOpBase(context) {} + explicit MklTanhOp(OpKernelConstruction* context) + : MklReluOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t src_index = 0; // index of src input tensor @@ -784,8 +779,8 @@ class MklTanhOp : public MklReluOpBase { GetMklShape(context, src_index, &dnn_shape_src); Tensor* dst_tensor = nullptr; - void* user_i = static_cast(const_cast( - src_tensor.flat().data())); + void* user_i = + static_cast(const_cast(src_tensor.flat().data())); MklDnnShape dnn_shape_dst; dnn_shape_dst.SetMklTensor(false); AllocateOutputSetMklShape(context, dst_index, &dst_tensor, @@ -795,7 +790,7 @@ class MklTanhOp : public MklReluOpBase { T feature = (static_cast(user_i))[0]; T e1 = std::exp(feature); T e2 = std::exp(-feature); - (static_cast(out_o))[0] = (e1 - e2)/(e1 + e2); + (static_cast(out_o))[0] = (e1 - e2) / (e1 + e2); return; } }; @@ -805,14 +800,14 @@ class MklTanhGradOp : public MklReluGradOpBase { public: ~MklTanhGradOp() {} - explicit MklTanhGradOp(OpKernelConstruction* context) : - MklReluGradOpBase(context) {} + explicit MklTanhGradOp(OpKernelConstruction* context) + : MklReluGradOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t src_index = 1; // index of src input tensor const size_t diff_src_index = 0; // index of diff_src output tensor - const Tensor& src_tensor = MklGetInput(context, src_index); + const Tensor& src_tensor = MklGetInput(context, src_index); const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); Tensor* diff_src_tensor = nullptr; @@ -825,16 +820,16 @@ class MklTanhGradOp : public MklReluGradOpBase { diff_dst_tensor.shape(), dnn_shape_diff_src); void* out_o = static_cast(diff_src_tensor->flat().data()); void* user_i = - static_cast(const_cast(src_tensor.flat().data())); + static_cast(const_cast(src_tensor.flat().data())); // gradient of tanh(x) = 1 - tanh(x)^2 T feature = (static_cast(user_i))[0]; T e1 = std::exp(feature); T e2 = std::exp(-feature); - T tanh = (e1 - e2)/(e1 + e2); + T tanh = (e1 - e2) / (e1 + e2); void* user_g = - static_cast(const_cast(diff_dst_tensor.flat().data())); - (static_cast(out_o))[0] = (static_cast(user_g))[0] * - (1 - tanh * tanh); + static_cast(const_cast(diff_dst_tensor.flat().data())); + (static_cast(out_o))[0] = + (static_cast(user_g))[0] * (1 - tanh * tanh); } }; @@ -854,16 +849,16 @@ class MklTanhGradOp : public MklReluGradOpBase { MklReluGradOp); TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML // register dnn kernels for supported operations and supported types -#define REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ - REGISTER_KERNEL_BUILDER(Name("_MklElu") \ +#define REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ + REGISTER_KERNEL_BUILDER(Name("_MklElu") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ - MklEluOp); \ - REGISTER_KERNEL_BUILDER(Name("_MklEluGrad") \ + MklEluOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklEluGrad") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ @@ -888,4 +883,3 @@ TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); } // namespace tensorflow #endif // INTEL_MKL - diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index b41e529357b2e93570377aaf350c99e0c8f2bd3c..5dbc4a2709e2bc379ae3b9aa68ed14f3d6893e7c 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -28,7 +28,7 @@ limitations under the License. #include "mkl_dnn_types.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" using mkldnn::stream; #endif @@ -40,7 +40,7 @@ class MklReshapeOp : public OpKernel { public: explicit MklReshapeOp(OpKernelConstruction* context) : OpKernel(context) {} -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML void Compute(OpKernelContext* context) override { const Tensor& input = MklGetInput(context, 0); const Tensor& sizes = MklGetInput(context, 1); @@ -166,9 +166,9 @@ class MklReshapeOp : public OpKernel { MklDnnShape mkl_shape_input; GetMklShape(context, kInputSlotIdx, &mkl_shape_input); bool input_in_mkl_format = mkl_shape_input.IsMklTensor(); - const int64 nelems = input_in_mkl_format ? - mkl_shape_input.GetTfShape().num_elements() - : input_tensor.NumElements(); + const int64 nelems = input_in_mkl_format + ? mkl_shape_input.GetTfShape().num_elements() + : input_tensor.NumElements(); // Preliminary validation of sizes. OP_REQUIRES(context, IsLegacyVector(sizes.shape()), @@ -210,11 +210,11 @@ class MklReshapeOp : public OpKernel { product)); shape.set_dim(unknown_index, missing); } - OP_REQUIRES(context, shape.num_elements() == nelems, - errors::InvalidArgument("Input to reshape is a tensor with ", - nelems, - " values, but the requested shape has ", - shape.num_elements())); + OP_REQUIRES( + context, shape.num_elements() == nelems, + errors::InvalidArgument("Input to reshape is a tensor with ", nelems, + " values, but the requested shape has ", + shape.num_elements())); if (input_in_mkl_format) { TensorShape& shape_to = shape; @@ -237,38 +237,38 @@ class MklReshapeOp : public OpKernel { // need to update MklDnnShape object associated with the input // tensor to reflect the shape change expected by reshape. if (!SkipReorder(mkl_shape_input, shape_to)) { - // If dimensions that are being expanded or collapsed are not - // maintained contiguously by MKLDNN, then we use reorder. - - // Get Mkl layout of input tensor. - auto input_mkl_md = mkl_shape_input.GetMklLayout(); - // Set input Mkl layout as the user layout. - dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor); - // Get expected Tensorflow layout of input tensor. - auto output_tf_md = mkl_shape_input.GetTfLayout(); - auto output_tf_pd = memory::primitive_desc(output_tf_md, - cpu_engine); - - Tensor* output_tensor = nullptr; - MklShape mkl_shape_output; - mkl_shape_output.SetMklTensor(false); - // We allocate output tensor in the shape expected by Reshape. - AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor, - shape_to, mkl_shape_output); - - // Insert reorder between Mkl layout and TensorFlow layout if - // needed. If reorder is not needed but reshape is needed (since - // shape_from != shape_to), then we just copy input tensor to - // output tensor with target shape (we cannot forward Mkl layout - // in such case because shape has changed.) - std::vector net; - if (dnn_data_input.CheckReorderToOpMem(output_tf_pd, - output_tensor, &net)) { - stream(stream::kind::eager).submit(net).wait(); - } else { - output_tensor->CopyFrom(input_tensor, shape_to); - } - return; + // If dimensions that are being expanded or collapsed are not + // maintained contiguously by MKLDNN, then we use reorder. + + // Get Mkl layout of input tensor. + auto input_mkl_md = mkl_shape_input.GetMklLayout(); + // Set input Mkl layout as the user layout. + dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor); + // Get expected Tensorflow layout of input tensor. + auto output_tf_md = mkl_shape_input.GetTfLayout(); + auto output_tf_pd = + memory::primitive_desc(output_tf_md, cpu_engine); + + Tensor* output_tensor = nullptr; + MklShape mkl_shape_output; + mkl_shape_output.SetMklTensor(false); + // We allocate output tensor in the shape expected by Reshape. + AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor, + shape_to, mkl_shape_output); + + // Insert reorder between Mkl layout and TensorFlow layout if + // needed. If reorder is not needed but reshape is needed (since + // shape_from != shape_to), then we just copy input tensor to + // output tensor with target shape (we cannot forward Mkl layout + // in such case because shape has changed.) + std::vector net; + if (dnn_data_input.CheckReorderToOpMem(output_tf_pd, output_tensor, + &net)) { + stream(stream::kind::eager).submit(net).wait(); + } else { + output_tensor->CopyFrom(input_tensor, shape_to); + } + return; } else { // If dimensions that are being expanded or collapsed are // maintained contiguously by MKLDNN, then we skip reorder, just @@ -276,10 +276,10 @@ class MklReshapeOp : public OpKernel { // Tensorflow tensor as it is to the output. auto output_dims = TFShapeToMklDnnDims(shape_to); auto output_strides = CalculateTFStrides(output_dims); - auto output_tf_md = MklDnnData::CreateBlockedMemDesc(output_dims, - output_strides); - auto output_tf_pd = memory::primitive_desc(output_tf_md, - cpu_engine); + auto output_tf_md = MklDnnData::CreateBlockedMemDesc( + output_dims, output_strides); + auto output_tf_pd = + memory::primitive_desc(output_tf_md, cpu_engine); // Set MklDnnShape MklDnnShape mkl_shape_output; @@ -291,18 +291,17 @@ class MklReshapeOp : public OpKernel { // We now simply forward input Mkl tensor to output and change its // output MklDnnShape object. - ForwardMklTensorInToOutWithMklShape(context, kInputSlotIdx, - kOutputSlotIdx, mkl_shape_output); + ForwardMklTensorInToOutWithMklShape( + context, kInputSlotIdx, kOutputSlotIdx, mkl_shape_output); return; } - } catch (mkldnn::error &e) { + } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + - ", in file " + string(__FILE__) + ":" + - std::to_string(__LINE__); - OP_REQUIRES_OK(context, - errors::Aborted("Operation received an exception:", - error_msg)); + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } } else { @@ -313,7 +312,7 @@ class MklReshapeOp : public OpKernel { } } -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML private: const int kInputSlotIdx = 0; diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc index c46eabdde103913a712c3d058aa23a627d19f5ea..aceef1e234eff3660b33f5a091a2cd10e25ea2f9 100644 --- a/tensorflow/core/kernels/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl_softmax_op.cc @@ -15,7 +15,7 @@ limitations under the License. // See docs in ../ops/nn_ops.cc. #ifdef INTEL_MKL -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" @@ -156,5 +156,5 @@ TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES); } // namespace tensorflow -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h index c4d5a45d3caff0f59b1ecc61f95dd26fe16fd06b..ddea9e281b2fbcf0e061fd2bf2758984833a3727 100644 --- a/tensorflow/core/kernels/mkl_tfconv_op.h +++ b/tensorflow/core/kernels/mkl_tfconv_op.h @@ -35,7 +35,7 @@ limitations under the License. #include "mkl_dnn_types.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML using mkldnn::stream; #endif @@ -61,7 +61,7 @@ class MklToTfOp : public OpKernel { VLOG(1) << "MKLToTFConversion complete successfully."; } -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context, string data_format_str, DataType op_data_type, bool has_avx512f, uint input_number) { @@ -128,7 +128,7 @@ class MklToTfOp : public OpKernel { #else static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context, string data_format_str, DataType op_data_type, - bool has_avx512f, uint input_number) { + bool has_avx512f, uint32 input_number) { // Check that input tensor is in MKL format. const Tensor& input_tensor = MklGetInput(context, input_number); MklShape input_shape; diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc index 764d4c9400e5751de29b9651eebc1328fdd09d59..b44b4d6f542ed0128a83d20eedf6629f67427867 100644 --- a/tensorflow/core/kernels/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl_transpose_op.cc @@ -18,9 +18,6 @@ limitations under the License. #ifdef INTEL_MKL #define EIGEN_USE_THREADS -#include "tensorflow/core/framework/numeric_types.h" -#define MKL_Complex8 tensorflow::complex64 -#define MKL_Complex16 tensorflow::complex128 #include "mkl_trans.h" #include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/kernels/transpose_op.h" @@ -62,10 +59,31 @@ Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out); INSTANTIATE(float, s) INSTANTIATE(double, d) -INSTANTIATE(complex64, c) -INSTANTIATE(complex128, z) + #undef INSTANTIATE +template <> +Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) { + const MKL_Complex8 alpha = { 1.0f, 0.0f }; + mkl_comatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha, + reinterpret_cast(in.flat().data()), + in.dim_size(1), + reinterpret_cast(const_cast(out->flat().data())), + in.dim_size(0)); + return Status::OK(); +} + +template <> +Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) { + const MKL_Complex16 alpha = { 1.0, 0.0 }; + mkl_zomatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha, + reinterpret_cast(in.flat().data()), + in.dim_size(1), + reinterpret_cast(const_cast(out->flat().data())), + in.dim_size(0)); + return Status::OK(); +} + static const char kMKLTranspose = 'T'; static const char kMKLConjugateTranspose = 'C'; diff --git a/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc b/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc index 17f2af550f248a6924bb3d1e7546eca84d4c1e51..0e820bbb6208ae9c13ac2fb33f67590b9e66ba7e 100644 --- a/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc +++ b/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc @@ -71,10 +71,10 @@ class NeonDepthwiseConv2dNativeOp : public BinaryOp { filter.shape().DebugString())); const int32 in_depth = input.dim_size(3); - OP_REQUIRES( - context, in_depth == filter.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - in_depth, " vs ", filter.dim_size(2))); + OP_REQUIRES(context, in_depth == filter.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, + " vs ", filter.dim_size(2))); const int32 batch = input.dim_size(0); const int32 input_rows = input.dim_size(1); const int32 input_cols = input.dim_size(2); diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index 64bdef0008f20f3947e990e30e2af7b93a69d50c..903b898d0ac850e88c216cb1cc266cdb29fb4ca7 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -92,13 +92,11 @@ static inline bool IOUGreaterThanThreshold( return iou > iou_threshold; } -void DoNonMaxSuppressionOp(OpKernelContext* context, - const Tensor& boxes, - const Tensor& scores, - const Tensor& max_output_size, +void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes, + const Tensor& scores, const Tensor& max_output_size, const float iou_threshold) { OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1, - errors::InvalidArgument("iou_threshold must be in [0, 1]")); + errors::InvalidArgument("iou_threshold must be in [0, 1]")); int num_boxes = 0; ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes); @@ -106,10 +104,8 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, return; } - const int output_size = - std::min(max_output_size.scalar()(), num_boxes); - typename TTypes::ConstTensor boxes_data = - boxes.tensor(); + const int output_size = std::min(max_output_size.scalar()(), num_boxes); + TTypes::ConstTensor boxes_data = boxes.tensor(); std::vector scores_data(num_boxes); std::copy_n(scores.flat().data(), num_boxes, scores_data.begin()); @@ -142,8 +138,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, Tensor* output = nullptr; TensorShape output_shape({static_cast(selected.size())}); OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - typename TTypes::Tensor selected_indices_data = - output->tensor(); + TTypes::Tensor selected_indices_data = output->tensor(); std::copy_n(selected.begin(), selected.size(), selected_indices_data.data()); } @@ -181,8 +176,7 @@ template class NonMaxSuppressionV2Op : public OpKernel { public: explicit NonMaxSuppressionV2Op(OpKernelConstruction* context) - : OpKernel(context) { - } + : OpKernel(context) {} void Compute(OpKernelContext* context) override { // boxes: [num_boxes, 4] @@ -197,10 +191,9 @@ class NonMaxSuppressionV2Op : public OpKernel { max_output_size.shape().DebugString())); // iou_threshold: scalar const Tensor& iou_threshold = context->input(3); - OP_REQUIRES( - context, TensorShapeUtils::IsScalar(iou_threshold.shape()), - errors::InvalidArgument("iou_threshold must be 0-D, got shape ", - iou_threshold.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), + errors::InvalidArgument("iou_threshold must be 0-D, got shape ", + iou_threshold.shape().DebugString())); const float iou_threshold_val = iou_threshold.scalar()(); diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc index fdbcf05b89ddf122eee9e0133651355edbb1ba5a..67d9217b9502a30f5727b6a91fbf36da872ab972 100644 --- a/tensorflow/core/kernels/non_max_suppression_op_test.cc +++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc @@ -43,9 +43,10 @@ class NonMaxSuppressionOpTest : public OpsTestBase { TEST_F(NonMaxSuppressionOpTest, TestSelectFromThreeClusters) { MakeOp(.5); - AddInputFromArray(TensorShape({6, 4}), - {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, - 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); AddInputFromArray(TensorShape({}), {3}); TF_ASSERT_OK(RunOpKernel()); @@ -58,7 +59,7 @@ TEST_F(NonMaxSuppressionOpTest, TestSelectFromThreeClusters) { TEST_F(NonMaxSuppressionOpTest, TestSelectFromThreeClustersFlippedCoordinates) { MakeOp(.5); AddInputFromArray(TensorShape({6, 4}), - {1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f, + {1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f, 0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100}); AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); AddInputFromArray(TensorShape({}), {3}); @@ -71,9 +72,10 @@ TEST_F(NonMaxSuppressionOpTest, TestSelectFromThreeClustersFlippedCoordinates) { TEST_F(NonMaxSuppressionOpTest, TestSelectAtMostTwoBoxesFromThreeClusters) { MakeOp(.5); - AddInputFromArray(TensorShape({6, 4}), - {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, - 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); AddInputFromArray(TensorShape({}), {2}); TF_ASSERT_OK(RunOpKernel()); @@ -85,9 +87,10 @@ TEST_F(NonMaxSuppressionOpTest, TestSelectAtMostTwoBoxesFromThreeClusters) { TEST_F(NonMaxSuppressionOpTest, TestSelectAtMostThirtyBoxesFromThreeClusters) { MakeOp(.5); - AddInputFromArray(TensorShape({6, 4}), - {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, - 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); AddInputFromArray(TensorShape({}), {30}); TF_ASSERT_OK(RunOpKernel()); @@ -134,9 +137,10 @@ TEST_F(NonMaxSuppressionOpTest, TestSelectFromTenIdenticalBoxes) { TEST_F(NonMaxSuppressionOpTest, TestInconsistentBoxAndScoreShapes) { MakeOp(.5); - AddInputFromArray(TensorShape({6, 4}), - {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, - 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); AddInputFromArray(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f}); AddInputFromArray(TensorShape({}), {30}); Status s = RunOpKernel(); diff --git a/tensorflow/core/kernels/nth_element_op.cc b/tensorflow/core/kernels/nth_element_op.cc index da825e408c24617862e8613c6b63ed1a51944041..7f12eb953a31ec667a5f3cee379bd3d1970b3a56 100644 --- a/tensorflow/core/kernels/nth_element_op.cc +++ b/tensorflow/core/kernels/nth_element_op.cc @@ -16,15 +16,15 @@ limitations under the License. // See docs in ../ops/nn_ops.cc. #include "tensorflow/core/kernels/nth_element_op.h" +#include +#include +#include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" -#include -#include -#include namespace tensorflow { @@ -54,8 +54,9 @@ class NthElementOp : public OpKernel { errors::InvalidArgument("Input must be >= 1-D, got shape ", input_in.shape().DebugString())); // The last dimension of input tensor must be greater than N. - OP_REQUIRES(context, input_in.dim_size(num_dims-1) > n, - errors::InvalidArgument("Input must have at least n+1 columns")); + OP_REQUIRES( + context, input_in.dim_size(num_dims - 1) > n, + errors::InvalidArgument("Input must have at least n+1 columns")); // std::nth_element only support the nth-smallest selection. if (reverse_) { @@ -64,7 +65,7 @@ class NthElementOp : public OpKernel { // Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1]. TensorShape out_shape; - for (int i = 0; i < num_dims-1; ++i) { + for (int i = 0; i < num_dims - 1; ++i) { out_shape.AddDim(input_in.dim_size(i)); } Tensor* output_tensor = nullptr; @@ -83,32 +84,28 @@ namespace functor { template struct NthElementFunctor { - void operator() (OpKernelContext* context, - const Tensor& input_tensor, - Tensor& output_tensor, - int n, - bool reverse) { + void operator()(OpKernelContext* context, const Tensor& input_tensor, + Tensor& output_tensor, int n, bool reverse) { const T* input = input_tensor.flat().data(); T* output = output_tensor.flat().data(); // Assume input_shape is [d1,d2,...dk], and output_shape is [d1,d2...dk-1], // then num_rows = d1*d2...dk-1, last_dim = dk. const int num_rows = output_tensor.NumElements(); - const int last_dim = input_tensor.dim_size(input_tensor.dims()-1); + const int last_dim = input_tensor.dim_size(input_tensor.dims() - 1); // Allocate each row to different shard. - auto SubNthElement = [&, input, output, last_dim, n](int start, - int limit) { + auto SubNthElement = [&, input, output, last_dim, n](int start, int limit) { // std::nth_element would rearrange the array, so we need a new buffer. std::vector buf(last_dim); for (int b = start; b < limit; ++b) { // Copy from one row of elements to buffer const T* input_start = input + b * last_dim; - const T* input_end = input + (b+1) * last_dim; + const T* input_end = input + (b + 1) * last_dim; std::copy(input_start, input_end, buf.begin()); - std::nth_element(buf.begin(), buf.begin()+n, buf.end()); + std::nth_element(buf.begin(), buf.begin() + n, buf.end()); // The element placed in the nth position is exactly the element that // would occur in this position if the range was fully sorted. output[b] = buf[n]; @@ -116,9 +113,9 @@ struct NthElementFunctor { }; auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); - // The average time complexity of partition-based nth_element (BFPRT) is O(n), - // althought the worst time complexity could be O(n^2). - // Here, 20 is a empirical factor of cost_per_unit. + // The average time complexity of partition-based nth_element (BFPRT) is + // O(n), althought the worst time complexity could be O(n^2). Here, 20 is a + // empirical factor of cost_per_unit. Shard(worker_threads.num_threads, worker_threads.workers, num_rows, 20 * last_dim, SubNthElement); } @@ -126,7 +123,6 @@ struct NthElementFunctor { } // namespace functor - #define REGISTER_NTHOP(T) \ REGISTER_KERNEL_BUILDER( \ Name("NthElement").Device(DEVICE_CPU).TypeConstraint("T"), \ @@ -136,4 +132,3 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_NTHOP); #undef REGISTER_NTHOP } // end namespace tensorflow - diff --git a/tensorflow/core/kernels/nth_element_op.h b/tensorflow/core/kernels/nth_element_op.h index 11a6c996b093fa7255a230122f64eb1054789453..e7d25daecc74a6d7b178034d5d78776a390ffe04 100644 --- a/tensorflow/core/kernels/nth_element_op.h +++ b/tensorflow/core/kernels/nth_element_op.h @@ -26,10 +26,8 @@ namespace functor { template struct NthElementFunctor { - void operator() (OpKernelContext* context, - const Tensor& input_tensor, - Tensor& output_tensor, - int n); + void operator()(OpKernelContext* context, const Tensor& input_tensor, + Tensor& output_tensor, int n); }; } // namespace functor diff --git a/tensorflow/core/kernels/one_hot_op_gpu.cu.cc b/tensorflow/core/kernels/one_hot_op_gpu.cu.cc index 49fd4bdebad420d8e848b0491a764d976f4557cd..647515ae38ab5530b69fa135257584eea531d46c 100644 --- a/tensorflow/core/kernels/one_hot_op_gpu.cu.cc +++ b/tensorflow/core/kernels/one_hot_op_gpu.cu.cc @@ -19,16 +19,16 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/one_hot_op.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/one_hot_op.h" namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -#define DEFINE_GPU_SPEC_INDEX(T, TI) \ - template class generator::OneGenerator; \ +#define DEFINE_GPU_SPEC_INDEX(T, TI) \ + template class generator::OneGenerator; \ template struct functor::OneHot; #define DEFINE_GPU_SPEC(T) \ diff --git a/tensorflow/core/kernels/ops_util_test.cc b/tensorflow/core/kernels/ops_util_test.cc index 9d53882deef89230bd39d8318f11d84269406f20..13427d71ff6841a85c31d3bf42c038f6413c1fe6 100644 --- a/tensorflow/core/kernels/ops_util_test.cc +++ b/tensorflow/core/kernels/ops_util_test.cc @@ -218,7 +218,8 @@ TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_2) { // in_size = 3, ksize = 3, stride = 2, pad_size = 0 TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_0) { bcast_struct bcast[] = { - {{0, 3, 3, 2, 0}, {0, 3}}, {{1, 3, 3, 2, 0}, {2, 1}}, + {{0, 3, 3, 2, 0}, {0, 3}}, + {{1, 3, 3, 2, 0}, {2, 1}}, }; for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { VerifyBcastValues(bcast[i]); @@ -228,7 +229,8 @@ TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_0) { // in_size = 3, ksize = 3, stride = 2, pad_size = 1 TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_1) { bcast_struct bcast[] = { - {{0, 3, 3, 2, 1}, {0, 2}}, {{1, 3, 3, 2, 1}, {1, 2}}, + {{0, 3, 3, 2, 1}, {0, 2}}, + {{1, 3, 3, 2, 1}, {1, 2}}, }; for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { VerifyBcastValues(bcast[i]); @@ -258,7 +260,8 @@ TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_0) { // in_size = 3, ksize = 3, stride = 3, pad_size = 1 TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_1) { bcast_struct bcast[] = { - {{0, 3, 3, 3, 1}, {0, 2}}, {{1, 3, 3, 3, 1}, {2, 1}}, + {{0, 3, 3, 3, 1}, {0, 2}}, + {{1, 3, 3, 3, 1}, {2, 1}}, }; for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { VerifyBcastValues(bcast[i]); @@ -348,8 +351,8 @@ TEST_F(OpsUtilTest, Misaligned1DSlice) { TEST_F(OpsUtilTest, Aligned2DSliceOfDim0) { #if EIGEN_MAX_ALIGN_BYTES == 0 - // When EIGEN_MAX_ALIGN_BYTES is 0 and the size of the first dimension is nonzero, - // a multidimensional tensor is always aligned. + // When EIGEN_MAX_ALIGN_BYTES is 0 and the size of the first dimension is + // nonzero, a multidimensional tensor is always aligned. Tensor t(DT_FLOAT, TensorShape({3, 4})); int64 start = 1; int64 end = 2; diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc index 2033fbf5dc3f238b665c6f4afced06e90c81bb7c..5645275cfa98eb820b7d1e885b18894bfab17e49 100644 --- a/tensorflow/core/kernels/pack_op.cc +++ b/tensorflow/core/kernels/pack_op.cc @@ -36,7 +36,7 @@ typedef Eigen::GpuDevice GPUDevice; #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // -------------------------------------------------------------------------- template @@ -123,7 +123,7 @@ class PackOp : public OpKernel { ConcatSYCL(c->eigen_sycl_device(), inputs_flat, &output_flat); return; } -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL ConcatCPU(c->device(), inputs_flat, &output_flat); } } @@ -139,7 +139,6 @@ class PackOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_PACK); TF_CALL_QUANTIZED_TYPES(REGISTER_PACK); -TF_CALL_variant(REGISTER_PACK); #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) // Primarily used for SavedModel support on mobile. diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc index b232ba16a76877b4d9f0e8c24e7ccd17a9bc0856..0ab9ff9f650e137017b49d5d279f1a28ff45fa29 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc @@ -95,9 +95,10 @@ struct TruncatedNormalFunctor { int64 sample = b * samples_per_batch; // On GPU, this check will just fill samples with NAN if it fails. - OP_REQUIRES(ctx, stddev > T(0) && minval < maxval && - (Eigen::numext::isfinite(minval) || - Eigen::numext::isfinite(maxval)), + OP_REQUIRES(ctx, + stddev > T(0) && minval < maxval && + (Eigen::numext::isfinite(minval) || + Eigen::numext::isfinite(maxval)), errors::InvalidArgument("Invalid parameters")); int numIterations = 0; @@ -118,8 +119,9 @@ struct TruncatedNormalFunctor { // Determine the method to use. const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4)); const T cutoff = - T(2) * Eigen::numext::exp( - T(0.5) + (normMin * (normMin - sqrtFactor)) / T(4)) / + T(2) * + Eigen::numext::exp(T(0.5) + + (normMin * (normMin - sqrtFactor)) / T(4)) / (normMin + sqrtFactor); const T diff = normMax - normMin; if (diff < cutoff) { @@ -309,30 +311,34 @@ class ParameterizedTruncatedNormalOp : public OpKernel { } else { // Parameters must be broadcastable to the shape [num_batches]. OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(means_tensor.shape()) || - means_tensor.dim_size(0) == 1 || - means_tensor.dim_size(0) == num_batches, + ctx, + TensorShapeUtils::IsScalar(means_tensor.shape()) || + means_tensor.dim_size(0) == 1 || + means_tensor.dim_size(0) == num_batches, errors::InvalidArgument( "Input means should have length 1 or shape[0], got shape: ", means_tensor.shape().DebugString())); OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(stddevs_tensor.shape()) || - stddevs_tensor.dim_size(0) == 1 || - stddevs_tensor.dim_size(0) == num_batches, + ctx, + TensorShapeUtils::IsScalar(stddevs_tensor.shape()) || + stddevs_tensor.dim_size(0) == 1 || + stddevs_tensor.dim_size(0) == num_batches, errors::InvalidArgument( "Input stddevs should have length 1 or shape[0], got shape: ", stddevs_tensor.shape().DebugString())); OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(minvals_tensor.shape()) || - minvals_tensor.dim_size(0) == 1 || - minvals_tensor.dim_size(0) == num_batches, + ctx, + TensorShapeUtils::IsScalar(minvals_tensor.shape()) || + minvals_tensor.dim_size(0) == 1 || + minvals_tensor.dim_size(0) == num_batches, errors::InvalidArgument( "Input minvals should have length 1 or shape[0], got shape: ", minvals_tensor.shape().DebugString())); OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(maxvals_tensor.shape()) || - maxvals_tensor.dim_size(0) == 1 || - maxvals_tensor.dim_size(0) == num_batches, + ctx, + TensorShapeUtils::IsScalar(maxvals_tensor.shape()) || + maxvals_tensor.dim_size(0) == 1 || + maxvals_tensor.dim_size(0) == num_batches, errors::InvalidArgument( "Input maxvals should have length 1 or shape[0], got shape: ", maxvals_tensor.shape().DebugString())); diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc index 933de65c15a772154ce439cc54489c4a29c42ea5..661d47d925d1143d88b88d73b4ca51c654b43498 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/util/cuda_kernel_helper.h" -#ifdef COMPILER_MSVC +#if defined(_MSC_VER) && !defined(__clang__) // msvc does not support unroll. One could try the loop pragma but we need to // take a closer look if this generates better code in this case. For now let // the compiler take care of it. @@ -202,12 +202,13 @@ struct TruncatedNormalFunctor { typename TTypes::Flat output) { const auto config = GetCudaLaunchConfig(num_elements, d); - TruncatedNormalKernel< - T><<>>( - gen, output.data(), num_batches, samples_per_batch, num_elements, - means.data(), means.dimension(0) == 1, stddevs.data(), - stddevs.dimension(0) == 1, minvals.data(), minvals.dimension(0) == 1, - maxvals.data(), maxvals.dimension(0) == 1, kMaxIterations); + TruncatedNormalKernel + <<>>( + gen, output.data(), num_batches, samples_per_batch, num_elements, + means.data(), means.dimension(0) == 1, stddevs.data(), + stddevs.dimension(0) == 1, minvals.data(), + minvals.dimension(0) == 1, maxvals.data(), + maxvals.dimension(0) == 1, kMaxIterations); }; }; diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc index 6b599612ad7fde0bac44282521be26581aa752b8..8e175fe8d4b4fa203809e5871bfd301188c985da 100644 --- a/tensorflow/core/kernels/parse_tensor_op.cc +++ b/tensorflow/core/kernels/parse_tensor_op.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/framework/register_types.h" namespace tensorflow { @@ -92,7 +91,6 @@ class SerializeTensorOp : public OpKernel { Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint("T"), \ SerializeTensorOp); TF_CALL_ALL_TYPES(REGISTER) -TF_CALL_variant(REGISTER) #undef REGISTER } // namespace tensorflow diff --git a/tensorflow/core/kernels/pooling_ops_3d.cc b/tensorflow/core/kernels/pooling_ops_3d.cc index a406317213f51d557d7b5a9942260156c0fe6369..01bcfede1e8d1f1a71059c5171f8a4d7290d7a5b 100644 --- a/tensorflow/core/kernels/pooling_ops_3d.cc +++ b/tensorflow/core/kernels/pooling_ops_3d.cc @@ -258,7 +258,7 @@ struct LaunchMaxPooling3dGradOp { Eigen::array bcast = {1, csize, rsize, psize, 1}; #else Eigen::IndexList, int, int, int, - Eigen::type2index<1> > + Eigen::type2index<1>> bcast; bcast.set(1, csize); bcast.set(2, rsize); @@ -431,7 +431,7 @@ struct LaunchAvgPooling3dGradOp { Eigen::array bcast = {1, csize, rsize, psize, 1}; #else Eigen::IndexList, int, int, int, - Eigen::type2index<1> > + Eigen::type2index<1>> bcast; bcast.set(1, csize); bcast.set(2, rsize); @@ -833,7 +833,7 @@ TF_CALL_float(REGISTER_GPU_KERNELS) TF_CALL_half(REGISTER_GPU_KERNELS) #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T) -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS) + TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS) #undef REGISTER_SYCL_KERNELS #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/pooling_ops_3d_sycl.h b/tensorflow/core/kernels/pooling_ops_3d_sycl.h index c1bc5af4986ee7102929af3e9b37a7301830de0e..b4bead2456d58c636301678d8a81864b25e3e85b 100644 --- a/tensorflow/core/kernels/pooling_ops_3d_sycl.h +++ b/tensorflow/core/kernels/pooling_ops_3d_sycl.h @@ -281,12 +281,11 @@ class MaxPool3DGradSYCL { const T* input_data_n = input_data + n * p_.in_planes_ * p_.in_cols_ * p_.in_rows_ * p_.depth_; - const T* output_data_n = - output_data + - n * p_.out_planes_ * p_.out_cols_ * p_.out_rows_ * p_.depth_; - const T* input_backprop_n = - input_backprop + - n * p_.out_planes_ * p_.out_cols_ * p_.out_rows_ * p_.depth_; + const T* output_data_n = output_data + n * p_.out_planes_ * p_.out_cols_ * + p_.out_rows_ * p_.depth_; + const T* input_backprop_n = input_backprop + n * p_.out_planes_ * + p_.out_cols_ * + p_.out_rows_ * p_.depth_; for (int poolp = poolpstart; poolp < poolpend; ++poolp) { int pstart = poolp * p_.stride_planes_ - p_.pad_planes_; const int pend = std::min(pstart + p_.window_planes_, p_.in_planes_); @@ -678,9 +677,9 @@ class AvgPool3DGradSYCL { n /= p_.in_planes_; T gradient = T(0); - const T* input_backprop_n = - input_backprop + - n * p_.out_planes_ * p_.out_cols_ * p_.out_rows_ * p_.depth_; + const T* input_backprop_n = input_backprop + n * p_.out_planes_ * + p_.out_cols_ * + p_.out_rows_ * p_.depth_; for (int poolp = poolpstart; poolp < poolpend; ++poolp) { int pstart = poolp * p_.stride_planes_ - p_.pad_planes_; const int pend = std::min(pstart + p_.window_planes_, p_.in_planes_); diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index e3131b804f2412c890016dcfb3aace1648729172..fc7cb437b8f583a811427deaf52a94d9ef996f37 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -195,7 +195,6 @@ class MaxPoolingOp : public OpKernel { // and updates the corresponding column(s) in output_as_matrix with the // max value. auto shard = [¶ms, &in_mat, &out_mat](int64 start, int64 limit) { - const int32 in_rows = params.tensor_in_rows; const int32 in_cols = params.tensor_in_cols; const int32 pad_rows = params.pad_rows; @@ -443,7 +442,6 @@ class MaxPoolingV2Op : public OpKernel { // and updates the corresponding column(s) in output_as_matrix with the // max value. auto shard = [¶ms, &in_mat, &out_mat](int64 start, int64 limit) { - const int32 in_rows = params.tensor_in_rows; const int32 in_cols = params.tensor_in_cols; const int32 pad_rows = params.pad_rows; diff --git a/tensorflow/core/kernels/quantization_utils_test.cc b/tensorflow/core/kernels/quantization_utils_test.cc index d148c9f78d61d9b1840cc7a14f82c9254a4d434c..176720c22cc54ea8d9b79dacfc77f6cd2532f93a 100644 --- a/tensorflow/core/kernels/quantization_utils_test.cc +++ b/tensorflow/core/kernels/quantization_utils_test.cc @@ -385,8 +385,12 @@ void TestQuantizedToFloatInPlaceUsingEigen( // These are the float values we're going to test the conversions on. typedef std::pair FPair; for (FPair min_and_max : std::vector{ - FPair(-255.0f, 255.0f), FPair(-1.0f, 1.0f), FPair(-1.0f, 255.0f), - FPair(0.0f, 1e6), FPair(0.0f, 1.0f), FPair(-31.0f, 13.0f), + FPair(-255.0f, 255.0f), + FPair(-1.0f, 1.0f), + FPair(-1.0f, 255.0f), + FPair(0.0f, 1e6), + FPair(0.0f, 1.0f), + FPair(-31.0f, 13.0f), FPair(-5.89505e+08, 5.89505e+08), }) { const float f_min = min_and_max.first; diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.h b/tensorflow/core/kernels/quantize_and_dequantize_op.h index 1363c7e325b6a251d97039df3de271e92f59f6c0..3b09ea2527d8b401941c6ef0951c620edd0c5217 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op.h +++ b/tensorflow/core/kernels/quantize_and_dequantize_op.h @@ -71,7 +71,8 @@ struct QuantizeAndDequantizeOneScaleImpl { out.device(d) = ((input.cwiseMin(max_range).cwiseMax(min_range) - min_range) * scale + - T(0.5)).floor() * + T(0.5)) + .floor() * inverse_scale + min_range; } else { diff --git a/tensorflow/core/kernels/quantize_op_test.cc b/tensorflow/core/kernels/quantize_op_test.cc index d2cc55a94ddd7b3e31a5cfc841de25519abe2746..57982bdf76e3969b31f4ee73cbf47c564b2b53e6 100644 --- a/tensorflow/core/kernels/quantize_op_test.cc +++ b/tensorflow/core/kernels/quantize_op_test.cc @@ -250,7 +250,8 @@ TEST_F(QuantizedOpTest, QuantizeV2_32Bit) { Tensor expected(allocator(), DT_QINT32, TensorShape({element_count})); test::FillValues(&expected, { - std::numeric_limits::min(), 0, + std::numeric_limits::min(), + 0, static_cast(1.0f * (1 << 23)), static_cast(1.25f * (1 << 23)), static_cast(1.75f * (1 << 23)), diff --git a/tensorflow/core/kernels/quantized_batch_norm_op.cc b/tensorflow/core/kernels/quantized_batch_norm_op.cc index 18d83b414940504fcb4e031f3304412da3baf51b..b03da7ad17fab45086438691a1013b2acf54ee87 100644 --- a/tensorflow/core/kernels/quantized_batch_norm_op.cc +++ b/tensorflow/core/kernels/quantized_batch_norm_op.cc @@ -16,11 +16,11 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/kernels/quantization_utils.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/quantization_utils.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/quantized_concat_op.cc b/tensorflow/core/kernels/quantized_concat_op.cc index d67f1ab3ec28934bc08c11997a8b2f448c30ad91..b03ac8e87dac8fabe0d45d8685ec4fa5fd642519 100644 --- a/tensorflow/core/kernels/quantized_concat_op.cc +++ b/tensorflow/core/kernels/quantized_concat_op.cc @@ -135,8 +135,8 @@ class QuantizedConcatOp : public OpKernel { context, in.dims() == input_dims || (input_is_scalar && in_is_scalar), errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, "] = ", - in.shape().DebugString())); + input_shape.DebugString(), " vs. shape[", i, + "] = ", in.shape().DebugString())); for (int j = 0; j < input_dims; ++j) { if (j == concat_dim) { continue; @@ -145,8 +145,8 @@ class QuantizedConcatOp : public OpKernel { context, in.dim_size(j) == input_shape.dim_size(j), errors::InvalidArgument( "ConcatOp : Dimensions of inputs should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, "] = ", - in.shape().DebugString())); + input_shape.DebugString(), " vs. shape[", i, + "] = ", in.shape().DebugString())); } if (in.NumElements() > 0) { int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc index 1921b83d12c0688a96bad0c561080a0189e49bbe..5b3570edff5fee4b77d02684ef3da2af1d5f14b1 100644 --- a/tensorflow/core/kernels/quantized_conv_ops.cc +++ b/tensorflow/core/kernels/quantized_conv_ops.cc @@ -278,10 +278,9 @@ class Im2ColConvFunctor { *resource = new Im2ColBufferResource(); return Status::OK(); }; - OP_REQUIRES_OK( - context, - context->resource_manager()->LookupOrCreate( - "Conv2d", "im2col_buffer", &im2col_buffer_resource, creator)); + OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( + "Conv2d", "im2col_buffer", + &im2col_buffer_resource, creator)); // This means that multiple ops can't be run simultaneously on different // threads, because we have a single shared resource. The platforms this is // aimed at have intra-op parallelism as their focus though, so it shouldn't diff --git a/tensorflow/core/kernels/quantized_instance_norm.cc b/tensorflow/core/kernels/quantized_instance_norm.cc index c29f534f31b524f6e1d9ec09750b6de265ec10f8..d62094cc9fad85536edba8bb3854e71870df217c 100644 --- a/tensorflow/core/kernels/quantized_instance_norm.cc +++ b/tensorflow/core/kernels/quantized_instance_norm.cc @@ -278,10 +278,10 @@ class QuantizedInstanceNorm : public OpKernel { float input_max = context->input(2).flat()(0); float input_scale = (input_max - input_min) / 255.0f; - OP_REQUIRES( - context, input_min < input_max, - errors::InvalidArgument("input_min must be less than input_max : ", - input_min, " >= ", input_max)); + OP_REQUIRES(context, input_min < input_max, + errors::InvalidArgument( + "input_min must be less than input_max : ", input_min, + " >= ", input_max)); auto input_tensor = input.tensor(); auto N = input_tensor.dimension(0); diff --git a/tensorflow/core/kernels/quantized_matmul_op.cc b/tensorflow/core/kernels/quantized_matmul_op.cc index afb30d5f627feab1a009ec84c5f0bb9f851766e0..da8c46dc5162f30ea129e71fb5a1c81ee594718d 100644 --- a/tensorflow/core/kernels/quantized_matmul_op.cc +++ b/tensorflow/core/kernels/quantized_matmul_op.cc @@ -104,9 +104,9 @@ class QuantizedMatMulOp : public OpKernel { OP_REQUIRES(context, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), - errors::InvalidArgument("Matrix size-compatible: In[0]: ", - a.shape().DebugString(), ", In[1]: ", - b.shape().DebugString())); + errors::InvalidArgument( + "Matrix size-compatible: In[0]: ", a.shape().DebugString(), + ", In[1]: ", b.shape().DebugString())); OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)), errors::InvalidArgument("shift_c must be between 0 and 31, " diff --git a/tensorflow/core/kernels/quantized_matmul_op_test.cc b/tensorflow/core/kernels/quantized_matmul_op_test.cc index 535b5115c34e61333a0e7e1fdbfbe2b35571bf6c..c9f05dbc10bb8bcd3acae2d2ca0c149ac620bb79 100644 --- a/tensorflow/core/kernels/quantized_matmul_op_test.cc +++ b/tensorflow/core/kernels/quantized_matmul_op_test.cc @@ -206,17 +206,32 @@ TEST_F(QuantizedMatMulTest, Small_WithParams) { // We have set the transpose_a flag to true, so the matrix is transposed, and // for filling the values the in-memory storage order is effectively // column major, rather than the default row-major. - AddInputFromArray(TensorShape({a_rows, a_cols}), - { - 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, - }); + AddInputFromArray(TensorShape({a_rows, a_cols}), { + 11, + 10, + 9, + 8, + 7, + 6, + 5, + 4, + 3, + 2, + 1, + 0, + }); // The B matrix is: // | 1 | 4| // | 2 | 5| // | 3 | 6| AddInputFromArray(TensorShape({b_rows, b_cols}), { - 1, 4, 2, 5, 3, 6, + 1, + 4, + 2, + 5, + 3, + 6, }); AddInputFromArray(TensorShape({1}), {-12.0f}); AddInputFromArray(TensorShape({1}), {243.0f}); @@ -238,10 +253,16 @@ TEST_F(QuantizedMatMulTest, Small_WithParams) { // | -50 | -113 | // | -56 | -128 | Tensor expected(allocator(), DT_QINT32, TensorShape({a_cols, b_cols})); - test::FillValues(&expected, - { - -38, -83, -44, -98, -50, -113, -56, -128, - }); + test::FillValues(&expected, { + -38, + -83, + -44, + -98, + -50, + -113, + -56, + -128, + }); test::ExpectTensorEqual(expected, *GetOutput(0)); } diff --git a/tensorflow/core/kernels/quantized_mul_op.cc b/tensorflow/core/kernels/quantized_mul_op.cc index eaa5e667f7d5681e886a5de9e64a055ec175cf1e..3c7536e037396c338663ce0136832acb87bef401 100644 --- a/tensorflow/core/kernels/quantized_mul_op.cc +++ b/tensorflow/core/kernels/quantized_mul_op.cc @@ -298,9 +298,8 @@ class QuantizedMulOp : public OpKernel { return; } Tensor* z; - OP_REQUIRES_OK( - context, - context->allocate_output(0, BCast::ToShape(bcast.output_shape()), &z)); + OP_REQUIRES_OK(context, context->allocate_output( + 0, BCast::ToShape(bcast.output_shape()), &z)); // Make sure that we have valid quantization ranges for the input buffers. // If the difference between the min and max is negative or zero, it makes diff --git a/tensorflow/core/kernels/quantized_mul_op_test.cc b/tensorflow/core/kernels/quantized_mul_op_test.cc index b0550c8260c0ec7e40eeab4e07a5ecaf4cb8e32b..a4e407c7a94c9c2e11808eeb4533be5c346fb6f4 100644 --- a/tensorflow/core/kernels/quantized_mul_op_test.cc +++ b/tensorflow/core/kernels/quantized_mul_op_test.cc @@ -188,11 +188,12 @@ void TestManualScalar() { 10.0f, {1}, {10.0f}, -100.0f, 100.0f, {10}, {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f, 70.0f, 80.0f, 90.0f, 100.0f}, 3.0f); - TestMul({1}, {10.0f}, -100.0f, 100.0f, {10}, - {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f, - 10.0f, {10}, {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f, 70.0f, 80.0f, - 90.0f, 100.0f}, - 3.0f); + TestMul( + {1}, {10.0f}, -100.0f, 100.0f, {10}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f, + 10.0f, {10}, + {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f, 70.0f, 80.0f, 90.0f, 100.0f}, + 3.0f); } void TestScalar() { diff --git a/tensorflow/core/kernels/quantized_resize_bilinear_op.cc b/tensorflow/core/kernels/quantized_resize_bilinear_op.cc index fb2faede2f9f9e56728ad3ab354440eabd488818..9a1dcd0d496e45977704f49c10fba1048effc943 100644 --- a/tensorflow/core/kernels/quantized_resize_bilinear_op.cc +++ b/tensorflow/core/kernels/quantized_resize_bilinear_op.cc @@ -697,8 +697,8 @@ class QuantizedResizeBilinearOp : public OpKernel { // Return if the output is empty. if (st.output->NumElements() == 0) return; - typename TTypes::ConstTensor image_data = input.tensor(); - typename TTypes::Tensor output_data = st.output->tensor(); + typename TTypes::ConstTensor image_data(input.tensor()); + typename TTypes::Tensor output_data(st.output->tensor()); ResizeBilinear(image_data, st.height_scale, st.width_scale, in_min, in_max, &output_data); diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc index 330d161c32bc1a48b671765cacc21618545fa71a..de495c19cba300fbd034cda01adfd0518548ce68 100644 --- a/tensorflow/core/kernels/queue_base.cc +++ b/tensorflow/core/kernels/queue_base.cc @@ -39,8 +39,8 @@ Status HandleSliceToElement(const Tensor& parent, Tensor* element, return errors::Internal( "HandleSliceToElement Cannot copy slice: number of elements does not " "match. Shapes are: [element]: ", - element->shape().DebugString(), ", [parent slice]: ", - chip_shape.DebugString()); + element->shape().DebugString(), + ", [parent slice]: ", chip_shape.DebugString()); } auto parent_as_matrix = parent.flat_outer_dims(); element->flat() = parent_as_matrix.chip(index, 0); diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc index 17831b74370bcd21cf7772f0ea6809ee840511c3..46a02854d732d6da657414a4e42b535f72ea7b64 100644 --- a/tensorflow/core/kernels/queue_ops.cc +++ b/tensorflow/core/kernels/queue_ops.cc @@ -428,13 +428,14 @@ REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp); class QueueIsClosedOp : public QueueOpKernel { public: explicit QueueIsClosedOp(OpKernelConstruction* context) - : QueueOpKernel(context) {} + : QueueOpKernel(context) {} protected: void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) override { Tensor* Tqueue_is_closed = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed)); + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed)); Tqueue_is_closed->flat().setConstant(queue->is_closed()); callback(); } @@ -443,8 +444,10 @@ class QueueIsClosedOp : public QueueOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp); }; -REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU), QueueIsClosedOp); -REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU), QueueIsClosedOp); +REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU), + QueueIsClosedOp); +REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU), + QueueIsClosedOp); class FakeQueueOp : public OpKernel { public: diff --git a/tensorflow/core/kernels/random_crop_op.cc b/tensorflow/core/kernels/random_crop_op.cc index ba94d6be5caff7245e08ca22b5f057e81f30db74..b89bda4769dd42590006f803ea45dbb7573bc332 100644 --- a/tensorflow/core/kernels/random_crop_op.cc +++ b/tensorflow/core/kernels/random_crop_op.cc @@ -68,10 +68,10 @@ class RandomCropOp : public OpKernel { // Edge case. The target dimensions are larger then the image, so // zero-pad the image. This guarantees that the image will *always* // be [target_height, target_width] in size. - OP_REQUIRES( - context, width >= target_width, - errors::FailedPrecondition("width must be >= target_width: width = ", - width, ", target_width = ", target_width)); + OP_REQUIRES(context, width >= target_width, + errors::FailedPrecondition( + "width must be >= target_width: width = ", width, + ", target_width = ", target_width)); OP_REQUIRES(context, height >= target_height, errors::FailedPrecondition( "height must be >= target_height: height = ", height, @@ -92,8 +92,8 @@ class RandomCropOp : public OpKernel { // TODO(shlens): Do this more efficiently with memcpy once padding is // available for smaller images. - typename TTypes::ConstTensor input_data = input.tensor(); - typename TTypes::Tensor output_data = output->tensor(); + typename TTypes::ConstTensor input_data(input.tensor()); + typename TTypes::Tensor output_data(output->tensor()); for (int y = 0; y < target_height; ++y) { for (int x = 0; x < target_width; ++x) { diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 55a8b9c9b67455483689a135306017bed8974ade..78ff7948fbf1b6406b2faca1d94acd7ea3325437 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -50,7 +50,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace functor { using random::PhiloxRandom; @@ -271,9 +271,10 @@ class RandomGammaOp : public OpKernel { const Tensor& shape_t = ctx->input(0); const Tensor& alpha_t = ctx->input(1); - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(shape_t.shape()) && - (shape_t.dtype() == DataType::DT_INT32 || - shape_t.dtype() == DataType::DT_INT64), + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(shape_t.shape()) && + (shape_t.dtype() == DataType::DT_INT32 || + shape_t.dtype() == DataType::DT_INT64), errors::InvalidArgument( "shape must be a vector of {int32,int64}, got shape: ", shape_t.DebugString())); @@ -325,7 +326,7 @@ class RandomGammaOp : public OpKernel { // avoid a couple flops which can be done on a per-alpha basis. auto DoWork = [num_samples, num_alphas, &rng, samples_flat, alpha_flat]( - int start_output, int limit_output) { + int start_output, int limit_output) { using Eigen::numext::exp; using Eigen::numext::log; using Eigen::numext::pow; @@ -448,40 +449,40 @@ class RandomGammaOp : public OpKernel { } // namespace -#define REGISTER(TYPE) \ - template struct functor::FillPhiloxRandom< \ - CPUDevice, random::UniformDistribution >; \ - template struct functor::FillPhiloxRandom< \ - CPUDevice, random::NormalDistribution >; \ - template struct functor::FillPhiloxRandom< \ - CPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter, TYPE> >; \ - REGISTER_KERNEL_BUILDER( \ - Name("RandomUniform") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - PhiloxRandomOp >); \ - REGISTER_KERNEL_BUILDER( \ - Name("RandomStandardNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - PhiloxRandomOp >); \ - REGISTER_KERNEL_BUILDER( \ - Name("TruncatedNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - PhiloxRandomOp< \ - CPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint("T"), \ +#define REGISTER(TYPE) \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, random::UniformDistribution>; \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, random::NormalDistribution>; \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter, TYPE>>; \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomUniform") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomStandardNormal") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TruncatedNormal") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp< \ + CPUDevice, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter, TYPE>>); \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint("T"), \ RandomGammaOp) #define REGISTER_INT(IntType) \ @@ -504,33 +505,33 @@ TF_CALL_int64(REGISTER_INT); #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("RandomUniform") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .TypeConstraint("T") \ - .TypeConstraint("dtype"), \ - PhiloxRandomOp >); \ - REGISTER_KERNEL_BUILDER( \ - Name("RandomStandardNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .TypeConstraint("T") \ - .TypeConstraint("dtype"), \ - PhiloxRandomOp >); \ - REGISTER_KERNEL_BUILDER( \ - Name("TruncatedNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .TypeConstraint("T") \ - .TypeConstraint("dtype"), \ - PhiloxRandomOp< \ - GPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter, TYPE> >); +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomUniform") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .TypeConstraint("T") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomStandardNormal") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .TypeConstraint("T") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TruncatedNormal") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .TypeConstraint("T") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp< \ + GPUDevice, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter, TYPE>>); #define REGISTER_INT(IntType) \ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ @@ -565,13 +566,12 @@ struct FillPhiloxRandomKernel; template struct FillPhiloxRandomKernel { typedef typename Distribution::ResultElementType T; - using write_accessor = sycl::accessor; + using write_accessor = sycl::accessor; - FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, Distribution& dist) - : data_(data), - gen_(gen), - dist_(dist) { - } + FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, + Distribution& dist) + : data_(data), gen_(gen), dist_(dist) {} void operator()(sycl::nd_item<1> item) { const size_t kGroupSize = Distribution::kResultElementCount; @@ -597,7 +597,7 @@ struct FillPhiloxRandomKernel { const typename Distribution::ResultType samples = dist_(&gen_); for (size_t i = 0; i < kGroupSize; ++i) { if (offset >= size) { - return; + return; } data[offset] = samples[i]; ++offset; @@ -610,17 +610,15 @@ struct FillPhiloxRandomKernel { Distribution dist_; }; - template struct FillPhiloxRandomKernel { typedef typename Distribution::ResultElementType T; - using write_accessor = sycl::accessor; + using write_accessor = sycl::accessor; - FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, Distribution& dist) - : data_(data), - gen_(gen), - dist_(dist) { - } + FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, + Distribution& dist) + : data_(data), gen_(gen), dist_(dist) {} void operator()(sycl::nd_item<1> item) { using random::PhiloxRandom; @@ -628,9 +626,9 @@ struct FillPhiloxRandomKernel { const size_t kReservedSamplesPerOutput = 256; const size_t kGroupSize = Distribution::kResultElementCount; - const size_t kGeneratorSkipPerOutputGroup = kGroupSize * - kReservedSamplesPerOutput / - PhiloxRandom::kResultElementCount; + const size_t kGeneratorSkipPerOutputGroup = + kGroupSize * kReservedSamplesPerOutput / + PhiloxRandom::kResultElementCount; const size_t item_id = item.get_global(0); const size_t total_item_count = item.get_global_range(); @@ -674,10 +672,9 @@ class FillRandomKernel; // It splits the work into several tasks and run them in parallel template void FillPhiloxRandom::operator()( - OpKernelContext* context, const SYCLDevice& device, random::PhiloxRandom gen, - typename Distribution::ResultElementType* data, int64 size, - Distribution dist) { - + OpKernelContext* context, const SYCLDevice& device, + random::PhiloxRandom gen, typename Distribution::ResultElementType* data, + int64 size, Distribution dist) { const size_t group_size = device.maxSyclThreadsPerBlock(); const size_t group_count = (size + group_size - 1) / group_size; @@ -686,50 +683,52 @@ void FillPhiloxRandom::operator()( device.sycl_queue().submit([&](sycl::handler& cgh) { auto access = buffer.template get_access(cgh); - FillPhiloxRandomKernel task(access, gen, dist); + FillPhiloxRandomKernel + task(access, gen, dist); cgh.parallel_for>( - sycl::nd_range<1>(sycl::range<1>(group_count * group_size), sycl::range<1>(group_size)), - task - ); + sycl::nd_range<1>(sycl::range<1>(group_count * group_size), + sycl::range<1>(group_size)), + task); }); } -} +} // namespace functor + +#define REGISTER(TYPE) \ + template struct functor::FillPhiloxRandom< \ + SYCLDevice, random::UniformDistribution>; \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomUniform") \ + .Device(DEVICE_SYCL) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("RandomStandardNormal") \ + .Device(DEVICE_SYCL) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TruncatedNormal") \ + .Device(DEVICE_SYCL) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + PhiloxRandomOp< \ + SYCLDevice, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter, TYPE>>); -#define REGISTER(TYPE) \ - template struct functor::FillPhiloxRandom< \ - SYCLDevice, random::UniformDistribution >; \ - REGISTER_KERNEL_BUILDER( \ - Name("RandomUniform") \ - .Device(DEVICE_SYCL) \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - PhiloxRandomOp >); \ - REGISTER_KERNEL_BUILDER( \ - Name("RandomStandardNormal") \ - .Device(DEVICE_SYCL) \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - PhiloxRandomOp >); \ - REGISTER_KERNEL_BUILDER( \ - Name("TruncatedNormal") \ - .Device(DEVICE_SYCL) \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - PhiloxRandomOp< \ - SYCLDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter, TYPE> >); - -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_SYCL) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint("Tout"), \ +#define REGISTER_INT(IntType) \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_SYCL) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint("Tout"), \ RandomUniformIntOp); TF_CALL_float(REGISTER); @@ -740,6 +739,6 @@ TF_CALL_int64(REGISTER_INT); #undef REGISTER #undef REGISTER_INT -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // end namespace tensorflow diff --git a/tensorflow/core/kernels/random_op_gpu.cu.cc b/tensorflow/core/kernels/random_op_gpu.cu.cc index 7afa6974c6a9389782fbbcd39ddede2a97ecd566..3393b39faf4a25791b48af99a5e474f3e9bfbfce 100644 --- a/tensorflow/core/kernels/random_op_gpu.cu.cc +++ b/tensorflow/core/kernels/random_op_gpu.cu.cc @@ -222,9 +222,8 @@ void FillPhiloxRandom::operator()( (d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) / block_size; - FillPhiloxRandomKernelLaunch< - Distribution><<>>(gen, data, size, - dist); + FillPhiloxRandomKernelLaunch + <<>>(gen, data, size, dist); }; // Explicit instantiation of the GPU distributions functors diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc index bf1d83ec7517d1bcfa9b88b482b983e6a2d3f7c4..64fb4a5c22848009743af6a577c719f206f022bb 100644 --- a/tensorflow/core/kernels/random_poisson_op.cc +++ b/tensorflow/core/kernels/random_poisson_op.cc @@ -103,7 +103,7 @@ struct PoissonFunctor { typedef random::UniformDistribution Uniform; auto DoWork = [num_samples, num_rate, &rng, samples_flat, rate_flat]( - int start_output, int limit_output) { + int start_output, int limit_output) { // Capturing "rng" by value would only make a copy for the _shared_ // lambda. Since we want to let each worker have its own copy, we pass // "rng" by reference and explicitly do a copy assignment. diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index e9695cfde30945c9c99db85f33e44030e5d45054..87fc94333162c4b721fa3608f282bf9d28fc792e 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -334,96 +334,95 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, // TODO(josh11b): This makes two copies of callback, avoid this if possible. dequeue_attempts_.emplace_back( num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token, - [callback, allow_small_batch, this](Attempt* attempt) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - int32 queue_size = queues_[0].size(); - if (closed_ && queue_size < attempt->elements_requested) { - // If we don't have enough for a full dequeue, we have - // to reset the attempt tuple. - if (!attempt->tuple.empty()) { - // Restore already-dequeued elements to the queue. - for (int64 i = attempt->tuple[0].dim_size(0) - - attempt->elements_requested - 1; - i >= 0; --i) { - for (int j = 0; j < num_components(); ++j) { - PersistentTensor element; - Status s = GetElementComponentFromBatch( - attempt->tuple, i, j, attempt->context, &element); - if (!s.ok()) { - attempt->context->SetStatus( - errors::DataLoss("Failed to restore element from " - "partially-dequeued batch " - "to RandomShuffleQueue: ", - s.error_message())); - } - queues_[j].push_back(element); - } - } - } - if (allow_small_batch && !queues_[0].empty()) { - // Request all remaining elements in the queue. - queue_size = queues_[0].size(); - attempt->tuple.clear(); - attempt->elements_requested = queue_size; - } else { - if (allow_small_batch) { - // There may be some other attempts containing - // values. If so, we'll yield and wait for them - // to add elements to the queue. - if (!enqueue_attempts_.empty()) return kProgress; - } - if (attempt->context->status().ok()) { - attempt->context->SetStatus(errors::OutOfRange( - "RandomShuffleQueue '", name_, "' is closed and has ", - "insufficient elements (requested ", - attempt->elements_requested, ", current size ", - queue_size, ")")); + [callback, allow_small_batch, + this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int32 queue_size = queues_[0].size(); + if (closed_ && queue_size < attempt->elements_requested) { + // If we don't have enough for a full dequeue, we have + // to reset the attempt tuple. + if (!attempt->tuple.empty()) { + // Restore already-dequeued elements to the queue. + for (int64 i = attempt->tuple[0].dim_size(0) - + attempt->elements_requested - 1; + i >= 0; --i) { + for (int j = 0; j < num_components(); ++j) { + PersistentTensor element; + Status s = GetElementComponentFromBatch( + attempt->tuple, i, j, attempt->context, &element); + if (!s.ok()) { + attempt->context->SetStatus( + errors::DataLoss("Failed to restore element from " + "partially-dequeued batch " + "to RandomShuffleQueue: ", + s.error_message())); } - return kComplete; + queues_[j].push_back(element); } } + } + if (allow_small_batch && !queues_[0].empty()) { + // Request all remaining elements in the queue. + queue_size = queues_[0].size(); + attempt->tuple.clear(); + attempt->elements_requested = queue_size; + } else { + if (allow_small_batch) { + // There may be some other attempts containing + // values. If so, we'll yield and wait for them + // to add elements to the queue. + if (!enqueue_attempts_.empty()) return kProgress; + } + if (attempt->context->status().ok()) { + attempt->context->SetStatus(errors::OutOfRange( + "RandomShuffleQueue '", name_, "' is closed and has ", + "insufficient elements (requested ", + attempt->elements_requested, ", current size ", + queue_size, ")")); + } + return kComplete; + } + } - RunResult result = kNoProgress; - if (!closed_) queue_size -= min_after_dequeue_; - for (; queue_size > 0; --queue_size) { - if (attempt->tuple.empty()) { - // Only allocate tuple when we have something to dequeue - // so we don't use excessive memory when there are many - // blocked dequeue attempts waiting. - attempt->tuple.reserve(num_components()); - for (int i = 0; i < num_components(); ++i) { - const TensorShape shape = - ManyOutShape(i, attempt->elements_requested); - Tensor element; - attempt->context->SetStatus( - attempt->context->allocate_temp(component_dtypes_[i], - shape, &element)); - if (!attempt->context->status().ok()) return kComplete; - attempt->tuple.emplace_back(element); - } - } - result = kProgress; - Tuple tuple; - DequeueLocked(attempt->context, &tuple); - const int index = attempt->tuple[0].dim_size(0) - - attempt->elements_requested; - for (int i = 0; i < num_components(); ++i) { - attempt->context->SetStatus(batch_util::CopyElementToSlice( - std::move(tuple[i]), &attempt->tuple[i], index)); - if (!attempt->context->status().ok()) return kComplete; - } - tuple.clear(); - --attempt->elements_requested; - if (attempt->elements_requested == 0) { - tuple = attempt->tuple; - attempt->done_callback = [callback, tuple]() { - callback(tuple); - }; - return kComplete; - } + RunResult result = kNoProgress; + if (!closed_) queue_size -= min_after_dequeue_; + for (; queue_size > 0; --queue_size) { + if (attempt->tuple.empty()) { + // Only allocate tuple when we have something to dequeue + // so we don't use excessive memory when there are many + // blocked dequeue attempts waiting. + attempt->tuple.reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + const TensorShape shape = + ManyOutShape(i, attempt->elements_requested); + Tensor element; + attempt->context->SetStatus(attempt->context->allocate_temp( + component_dtypes_[i], shape, &element)); + if (!attempt->context->status().ok()) return kComplete; + attempt->tuple.emplace_back(element); } - return result; - }); + } + result = kProgress; + Tuple tuple; + DequeueLocked(attempt->context, &tuple); + const int index = + attempt->tuple[0].dim_size(0) - attempt->elements_requested; + for (int i = 0; i < num_components(); ++i) { + attempt->context->SetStatus(batch_util::CopyElementToSlice( + std::move(tuple[i]), &attempt->tuple[i], index)); + if (!attempt->context->status().ok()) return kComplete; + } + tuple.clear(); + --attempt->elements_requested; + if (attempt->elements_requested == 0) { + tuple = attempt->tuple; + attempt->done_callback = [callback, tuple]() { + callback(tuple); + }; + return kComplete; + } + } + return result; + }); } } if (!already_cancelled) { diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index 36ca7f834f7b4fe7db1e2591189b1359231c7307..15ae4c1fc53b2b9bfe1d6085d2ecbc3659705b47 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -312,8 +312,7 @@ __global__ void ColumnReduceKernel( int col = blockIdx.x * 32 + threadIdx.x; value_type sum = initVal; - if (row < num_rows && col < num_cols) - sum = in[row * num_cols + col]; + if (row < num_rows && col < num_cols) sum = in[row * num_cols + col]; // 1D array necessary due to bug in CUDA 9 compiler. // TODO(nluehr) revert to 2D array when compiler is ready. @@ -366,8 +365,7 @@ __global__ void CleanupSegments( const int tid = threadIdx.x + blockIdx.x * blockDim.x; value_type val = initVal; - if (tid < segment_size * num_cols) - val = partial_sums[tid]; + if (tid < segment_size * num_cols) val = partial_sums[tid]; typedef cub::WarpReduce WarpReduce; diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index d7bebfb24c82275da07fb5b548f7169b77ea3cb9..03d6e82e018a55214e3ce66d64f49b0a7eb42e11 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -239,9 +239,6 @@ class ReductionOp : public OpKernel { if (!out.CopyFrom(tmp_out, helper.out_shape())) { ctx->SetStatus(errors::Internal("Error during reduction copy.")); } - if (ctx->track_allocations()) { - ctx->record_temp_memory_size(-static_cast(out.AllocatedBytes())); - } ctx->set_output(0, out); } diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index afad288cc00e0c3934318834d8dae8c181541212..d52358737fd121398ff2a4c95e417fd9b80987ab 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -31,7 +31,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #define REGISTER_RELU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ @@ -113,8 +113,7 @@ namespace functor { \ template <> \ void Selu::operator()( \ - const GPUDevice& d, \ - typename TTypes::ConstTensor features, \ + const GPUDevice& d, typename TTypes::ConstTensor features, \ typename TTypes::Tensor activations); \ extern template struct Selu; \ \ @@ -125,8 +124,6 @@ namespace functor { typename TTypes::Tensor backprops); \ extern template struct SeluGrad; - - TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor @@ -157,8 +154,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ SeluGradOp) - - TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS @@ -192,10 +187,8 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ SeluGradOp) - - TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); #undef REGISTER_SYCL_KERNELS -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h index 24b789c5437c78a76c708a6637b60376d5087682..3bc5ba8a50de22156aa631ee6404ddfe04b3a105 100644 --- a/tensorflow/core/kernels/relu_op_functor.h +++ b/tensorflow/core/kernels/relu_op_functor.h @@ -85,10 +85,9 @@ struct Relu6Grad { // make sure not to propagate the associated gradient // value. This allows "features" to be either the input or the output of // the relu6. - backprops.device(d) = - gradients * - ((features > static_cast(0)) * (features < static_cast(6))) - .template cast(); + backprops.device(d) = gradients * ((features > static_cast(0)) * + (features < static_cast(6))) + .template cast(); } }; @@ -161,8 +160,8 @@ struct SeluGrad { const auto scale = static_cast(1.0507009873554804934193349852946); const auto scale_alpha = static_cast(1.7580993408473768599402175208123); backprops.device(d) = - (activations < static_cast(0)).select( - gradients * (activations + scale_alpha), gradients * scale); + (activations < static_cast(0)) + .select(gradients * (activations + scale_alpha), gradients * scale); } }; diff --git a/tensorflow/core/kernels/reshape_op.cc b/tensorflow/core/kernels/reshape_op.cc index 8b86596721aa41c124b35b19cac7aac264b6f574..33c63e70500971cbcfb847d03239e0721d4871ff 100644 --- a/tensorflow/core/kernels/reshape_op.cc +++ b/tensorflow/core/kernels/reshape_op.cc @@ -43,7 +43,6 @@ REGISTER_KERNEL_BUILDER(Name("Reshape") .TypeConstraint("Tshape"), \ ReshapeOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); -TF_CALL_bfloat16(REGISTER_GPU_KERNEL); TF_CALL_bool(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc index ada50dfb70de447d9be9f735c6b973a25933cfa5..98b8a0df282a21f6711cc8926762f7bbb4ef52b0 100644 --- a/tensorflow/core/kernels/resize_area_op.cc +++ b/tensorflow/core/kernels/resize_area_op.cc @@ -149,7 +149,7 @@ class ResizeAreaOp : public OpKernel { if (!context->status().ok()) return; - typename TTypes::ConstTensor input_data = input.tensor(); + typename TTypes::ConstTensor input_data(input.tensor()); // Precompute values used when iterating over x coordinates within a row. // Note that it may be useful to cache x_interps for a given @@ -190,8 +190,7 @@ class ResizeAreaOp : public OpKernel { void ComputeLoop(const ImageResizerState& st, const std::vector& x_interps, typename TTypes::ConstTensor input_data) { - typename TTypes::Tensor output_data = - st.output->tensor(); + TTypes::Tensor output_data = st.output->tensor(); // When using this algorithm for downsizing, the target pixel value is the // weighted average of all the source pixels. The weight is determined by diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc index 1a9cf4c6406d85bf26b43e0b9b855760a4888a4c..65014b6c44eb2e5b0adb528c3ce08f01c21e4f26 100644 --- a/tensorflow/core/kernels/resize_bicubic_op.cc +++ b/tensorflow/core/kernels/resize_bicubic_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/kernels/image_resizer_state.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace { @@ -480,9 +480,8 @@ class ResizeBicubicOp : public OpKernel { if (!context->status().ok()) return; - typename TTypes::ConstTensor input_data = input.tensor(); - typename TTypes::Tensor output_data = - st.output->tensor(); + typename TTypes::ConstTensor input_data(input.tensor()); + TTypes::Tensor output_data = st.output->tensor(); interpolate_with_caching(input_data, st, output_data); } @@ -510,9 +509,8 @@ class ResizeBicubicOpGrad : public OpKernel { if (!context->status().ok()) return; - typename TTypes::ConstTensor input_grad = - input.tensor(); - typename TTypes::Tensor output_grad = st.output->tensor(); + TTypes::ConstTensor input_grad = input.tensor(); + typename TTypes::Tensor output_grad(st.output->tensor()); ResizeBicubicGrad(input_grad, st, output_grad); } diff --git a/tensorflow/core/kernels/resize_bicubic_op_test.cc b/tensorflow/core/kernels/resize_bicubic_op_test.cc index 9e10fec42321023d95f3ae8d32a5a1c8f2c7a94e..25a37d5e1af5835d56dedb50922967704500ad46 100644 --- a/tensorflow/core/kernels/resize_bicubic_op_test.cc +++ b/tensorflow/core/kernels/resize_bicubic_op_test.cc @@ -286,13 +286,14 @@ BM_ResizeBicubicDev(32, 128, 3); BM_ResizeBicubicDev(32, 512, 3); BM_ResizeBicubicDev(32, 1024, 3); -#define BM_ResizeBicubicExpand(BATCH, SIZE, CHANNELS) \ - static void BM_ResizeBicubicExpand##_##BATCH##_##SIZE##_##CHANNELS(int iters) { \ - testing::ItemsProcessed(static_cast(iters) * BATCH * SIZE * SIZE * \ - CHANNELS * 8 * 8); \ - test::Benchmark("cpu", ResizeBicubic(BATCH, SIZE, CHANNELS, 8, 8)) \ - .Run(iters); \ - } \ +#define BM_ResizeBicubicExpand(BATCH, SIZE, CHANNELS) \ + static void BM_ResizeBicubicExpand##_##BATCH##_##SIZE##_##CHANNELS( \ + int iters) { \ + testing::ItemsProcessed(static_cast(iters) * BATCH * SIZE * SIZE * \ + CHANNELS * 8 * 8); \ + test::Benchmark("cpu", ResizeBicubic(BATCH, SIZE, CHANNELS, 8, 8)) \ + .Run(iters); \ + } \ BENCHMARK(BM_ResizeBicubicExpand##_##BATCH##_##SIZE##_##CHANNELS); BM_ResizeBicubicExpand(12, 48, 1); diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc index d9cb993a4b296d053ec5f9f8a44955728dc5c949..dde59e8e741aca2c715aeb9d548979200af8789b 100644 --- a/tensorflow/core/kernels/resize_bilinear_op.cc +++ b/tensorflow/core/kernels/resize_bilinear_op.cc @@ -51,9 +51,8 @@ class ResizeBilinearOp : public OpKernel { // Return if the output is empty. if (st.output->NumElements() == 0) return; - typename TTypes::ConstTensor image_data = input.tensor(); - typename TTypes::Tensor output_data = - st.output->tensor(); + typename TTypes::ConstTensor image_data(input.tensor()); + TTypes::Tensor output_data = st.output->tensor(); functor::ResizeBilinear()(context->eigen_device(), image_data, st.height_scale, @@ -258,9 +257,8 @@ class ResizeBilinearOpGrad : public OpKernel { if (!context->status().ok()) return; - typename TTypes::ConstTensor input_grad = - input.tensor(); - typename TTypes::Tensor output_grad = st.output->tensor(); + TTypes::ConstTensor input_grad = input.tensor(); + typename TTypes::Tensor output_grad(st.output->tensor()); functor::ResizeBilinearGrad()(context->eigen_device(), input_grad, st.height_scale, diff --git a/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc b/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc index a7da7a0777d0cb35ade6a04dfff4edf604c1a169..f82c3fcd9ff45e26d2f44408890fa760c64477e4 100644 --- a/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc +++ b/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc @@ -164,11 +164,11 @@ struct ResizeBilinear { if (total_count == 0) return; CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); - ResizeBilinearKernel< - T><<>>( - config.virtual_thread_count, images.data(), height_scale, width_scale, - batch, in_height, in_width, channels, out_height, out_width, - output.data()); + ResizeBilinearKernel + <<>>( + config.virtual_thread_count, images.data(), height_scale, + width_scale, batch, in_height, in_width, channels, out_height, + out_width, output.data()); } }; @@ -200,11 +200,11 @@ struct ResizeBilinearGrad { // Accumulate. total_count = batch * resized_height * resized_width * channels; config = GetCudaLaunchConfig(total_count, d); - ResizeBilinearGradKernel< - T><<>>( - config.virtual_thread_count, input_grad.data(), height_scale, - width_scale, batch, original_height, original_width, channels, - resized_height, resized_width, output_grad.data()); + ResizeBilinearGradKernel + <<>>( + config.virtual_thread_count, input_grad.data(), height_scale, + width_scale, batch, original_height, original_width, channels, + resized_height, resized_width, output_grad.data()); } }; diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc index bfd29b7ec89e6a2d0e2757db31b707be70d12c1d..8ec526c2b25dc870e150d2afbfb9af6fbd1d778d 100644 --- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc +++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc @@ -56,8 +56,8 @@ class ResizeNearestNeighborOp : public OpKernel { // Return if the output is empty. if (st.output->NumElements() == 0) return; - typename TTypes::ConstTensor input_data = input.tensor(); - typename TTypes::Tensor output_data = st.output->tensor(); + typename TTypes::ConstTensor input_data(input.tensor()); + typename TTypes::Tensor output_data(st.output->tensor()); bool status; if (align_corners_) { @@ -162,8 +162,8 @@ class ResizeNearestNeighborOpGrad : public OpKernel { // Return if the output is empty. if (output->NumElements() == 0) return; - typename TTypes::ConstTensor input_data = input.tensor(); - typename TTypes::Tensor output_data = output->tensor(); + typename TTypes::ConstTensor input_data(input.tensor()); + typename TTypes::Tensor output_data(output->tensor()); const float height_scale = CalculateResizeScale(out_height, in_height, align_corners_); diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 9cc8e03e3ac6b17f16d65f1a9ade04d8fdcba034..702fb89aac9afe577cf7e4cd72616f7136a63b0b 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -130,6 +130,7 @@ REGISTER_KERNEL_BUILDER( ResourceHandleOp) TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); TF_CALL_variant(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA @@ -387,7 +388,6 @@ class AssignVariableOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); -TF_CALL_variant(REGISTER_KERNELS); #undef REGISTER_KERNELS #if GOOGLE_CUDA @@ -399,6 +399,7 @@ TF_CALL_variant(REGISTER_KERNELS); AssignVariableOp); TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); TF_CALL_variant(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA @@ -457,6 +458,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); AssignUpdateVariableOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA @@ -635,6 +637,9 @@ class ResourceScatterUpdateOp : public OpKernel { TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); +REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate", + scatter_op::UpdateOp::ASSIGN); + // Registers GPU kernels. #if GOOGLE_CUDA #define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc index 8f82784d936c05d64317e8f27dd8703502083b9b..bb96c42f10c498d0ec3d6a726728cb1e7bc8f111 100644 --- a/tensorflow/core/kernels/reverse_op.cc +++ b/tensorflow/core/kernels/reverse_op.cc @@ -269,10 +269,10 @@ class ReverseV2Op : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); -// TODO(cwhipkey): we can do dimension folding to reduce, e.g., a reverse of -// a single dimension to the dims=3 or dims=2 case, regardless of the number -// of dimensions in the tensor. This would let some ops use faster -// lower-dimension code (and use optimized versions). + // TODO(cwhipkey): we can do dimension folding to reduce, e.g., a reverse + // of a single dimension to the dims=3 or dims=2 case, regardless of the + // number of dimensions in the tensor. This would let some ops use faster + // lower-dimension code (and use optimized versions). #define HANDLE_REVERSE(NDIMS) \ case NDIMS: \ diff --git a/tensorflow/core/kernels/reverse_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_op_gpu.cu.cc index b05a7c5550438c6937745df5e58e81630361d64a..3ee49db669faaa85f2eff7a7f119725fc7170dea 100644 --- a/tensorflow/core/kernels/reverse_op_gpu.cu.cc +++ b/tensorflow/core/kernels/reverse_op_gpu.cu.cc @@ -28,14 +28,14 @@ typedef Eigen::GpuDevice GPUDevice; #define DEFINE_REVERSE(T, DIM) \ template struct functor::Reverse; #define DEFINE_REVERSE_ALL_DIMS(T) \ - DEFINE_REVERSE(T, 0) \ - DEFINE_REVERSE(T, 1) \ - DEFINE_REVERSE(T, 2) \ - DEFINE_REVERSE(T, 3) \ - DEFINE_REVERSE(T, 4) \ - DEFINE_REVERSE(T, 5) \ - DEFINE_REVERSE(T, 6) \ - DEFINE_REVERSE(T, 7) \ + DEFINE_REVERSE(T, 0) \ + DEFINE_REVERSE(T, 1) \ + DEFINE_REVERSE(T, 2) \ + DEFINE_REVERSE(T, 3) \ + DEFINE_REVERSE(T, 4) \ + DEFINE_REVERSE(T, 5) \ + DEFINE_REVERSE(T, 6) \ + DEFINE_REVERSE(T, 7) \ DEFINE_REVERSE(T, 8) TF_CALL_uint8(DEFINE_REVERSE_ALL_DIMS); diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc index d1980d4b652ecb507d8745bf64be2395d14920bb..15a707a9c6609e2ac5b790ea519f6c8e523067b1 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.cc +++ b/tensorflow/core/kernels/reverse_sequence_op.cc @@ -51,8 +51,7 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { // Copy seq_len info down for validity checks context->eigen_device().memcpyDeviceToHost( - seq_lens_vec.data(), seq_lens_t.data(), - sizeof(Tlen) * seq_lens_t.size()); + seq_lens_vec.data(), seq_lens_t.data(), sizeof(Tlen) * seq_lens_t.size()); OP_REQUIRES(context, batch_dim != seq_dim, errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); @@ -76,8 +75,7 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { } } -void CheckErrorsGPU(OpKernelContext* context, int batch_dim, - int seq_dim) { +void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) { const Tensor& input = context->input(0); const Tensor& seq_lens = context->input(1); @@ -98,13 +96,13 @@ void CheckErrorsGPU(OpKernelContext* context, int batch_dim, template <> void CheckErrors(OpKernelContext* context, int batch_dim, - int seq_dim) { + int seq_dim) { CheckErrorsGPU(context, batch_dim, seq_dim); } template <> void CheckErrors(OpKernelContext* context, int batch_dim, - int seq_dim) { + int seq_dim) { CheckErrorsGPU(context, batch_dim, seq_dim); } @@ -164,14 +162,15 @@ class ReverseSequenceOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp); }; -#define REGISTER_REVERSE_SEQUENCE(type, len_type) \ - REGISTER_KERNEL_BUILDER( \ - Name("ReverseSequence").Device(DEVICE_CPU).TypeConstraint("T"). \ - TypeConstraint("Tlen"), \ - ReverseSequenceOp); +#define REGISTER_REVERSE_SEQUENCE(type, len_type) \ + REGISTER_KERNEL_BUILDER(Name("ReverseSequence") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tlen"), \ + ReverseSequenceOp); -#define REGISTER_REVERSE_SEQUENCE_LEN(type) \ - REGISTER_REVERSE_SEQUENCE(type, int32); \ +#define REGISTER_REVERSE_SEQUENCE_LEN(type) \ + REGISTER_REVERSE_SEQUENCE(type, int32); \ REGISTER_REVERSE_SEQUENCE(type, int64); TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_LEN); @@ -181,23 +180,23 @@ TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_LEN); // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T, Tlen, Dims) \ - template <> \ - void ReverseSequence::Compute( \ - const GPUDevice& d, typename TTypes::ConstTensor input, \ - int32 batch_dim, int32 seq_dim, \ - typename TTypes::ConstVec seq_lens, \ - typename TTypes::Tensor output); \ +#define DECLARE_GPU_SPEC(T, Tlen, Dims) \ + template <> \ + void ReverseSequence::Compute( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + int32 batch_dim, int32 seq_dim, \ + typename TTypes::ConstVec seq_lens, \ + typename TTypes::Tensor output); \ extern template struct ReverseSequence; -#define DECLARE_GPU_SPEC_LEN(T, Dims) \ - DECLARE_GPU_SPEC(T, int32, Dims); \ +#define DECLARE_GPU_SPEC_LEN(T, Dims) \ + DECLARE_GPU_SPEC(T, int32, Dims); \ DECLARE_GPU_SPEC(T, int64, Dims); -#define DECLARE_GPU_SPECS(T) \ - DECLARE_GPU_SPEC_LEN(T, 2); \ - DECLARE_GPU_SPEC_LEN(T, 3); \ - DECLARE_GPU_SPEC_LEN(T, 4); \ +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPEC_LEN(T, 2); \ + DECLARE_GPU_SPEC_LEN(T, 3); \ + DECLARE_GPU_SPEC_LEN(T, 4); \ DECLARE_GPU_SPEC_LEN(T, 5); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); @@ -206,14 +205,15 @@ TF_CALL_bool(DECLARE_GPU_SPECS); } // namespace functor // Registration of the GPU implementations. -#define REGISTER_REVERSE_SEQUENCE_GPU(type, len_type) \ - REGISTER_KERNEL_BUILDER( \ - Name("ReverseSequence").Device(DEVICE_GPU).TypeConstraint("T"). \ - TypeConstraint("Tlen"), \ - ReverseSequenceOp); - -#define REGISTER_REVERSE_SEQUENCE_GPU_LEN(type) \ - REGISTER_REVERSE_SEQUENCE_GPU(type, int32); \ +#define REGISTER_REVERSE_SEQUENCE_GPU(type, len_type) \ + REGISTER_KERNEL_BUILDER(Name("ReverseSequence") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tlen"), \ + ReverseSequenceOp); + +#define REGISTER_REVERSE_SEQUENCE_GPU_LEN(type) \ + REGISTER_REVERSE_SEQUENCE_GPU(type, int32); \ REGISTER_REVERSE_SEQUENCE_GPU(type, int64); TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU_LEN); diff --git a/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc index cb49f14525a3c54ea46df47fb2edeaa9277dc2d3..4a2136a2cd37f4d549c62396d5e30616a306f84f 100644 --- a/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc +++ b/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc @@ -28,14 +28,14 @@ typedef Eigen::GpuDevice GPUDevice; template class generator::ReverseGenerator; \ template struct functor::ReverseSequence; -#define DEFINE_GPU_SPEC_LEN(T, dims) \ - DEFINE_GPU_SPEC(T, int32, dims); \ +#define DEFINE_GPU_SPEC_LEN(T, dims) \ + DEFINE_GPU_SPEC(T, int32, dims); \ DEFINE_GPU_SPEC(T, int64, dims); -#define DEFINE_GPU_SPECS(T) \ - DEFINE_GPU_SPEC_LEN(T, 2); \ - DEFINE_GPU_SPEC_LEN(T, 3); \ - DEFINE_GPU_SPEC_LEN(T, 4); \ +#define DEFINE_GPU_SPECS(T) \ + DEFINE_GPU_SPEC_LEN(T, 2); \ + DEFINE_GPU_SPEC_LEN(T, 3); \ + DEFINE_GPU_SPEC_LEN(T, 4); \ DEFINE_GPU_SPEC_LEN(T, 5); TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bcbdbee058b4fdb587f2099c54545b8a6aec8ca9 --- /dev/null +++ b/tensorflow/core/kernels/roll_op.cc @@ -0,0 +1,334 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/register_types_traits.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +#define EIGEN_USE_THREADS +using CPUDevice = Eigen::ThreadPoolDevice; + +// dim_size - the size of each dimension +// dim_range - the number of indices over in the flattened tensor +// you need to skip in order to make it over from one side of a dimension +// to the other. Used to make the shifts wrap around after a threshold. +// threshold - the index for each dimension that the roll starts to wrap +// back to the front +template +void DoRoll(OpKernelContext* context, const int64 num_elements, + const int num_dims, const gtl::ArraySlice& dim_size, + const T* input, T* output, const gtl::ArraySlice& threshold, + const gtl::ArraySlice& dim_range) { + auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range]( + int64 start, int64 end) { + // array of indices for each dimension + gtl::InlinedVector indices(num_dims); + int offset = 0; // the shift along the flattened tensor for current element + // initialize indices and offset + for (int i = 0; i < num_dims; i++) { + // stride is the number of indices over in the flattened tensor + // you need to skip in order to make it over to an adjacent element + // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1) + const int64 stride = dim_range[i] / dim_size[i]; + const int shift = dim_size[i] - threshold[i]; + const int indx = (start / stride) % dim_size[i]; + indices[i] = indx; + // calculate dimension index after the shift + const int shifted_indx = (indx + shift) % dim_size[i]; + offset += (shifted_indx - indx) * stride; + } + + for (int64 i = start; i < end; i++) { + output[i + offset] = input[i]; + // create next combination of indices + // while at it adjust offset if needed + for (int j = num_dims - 1; j >= 0; j--) { + const int indx = (indices[j] + 1) % dim_size[j]; + indices[j] = indx; + if (indx != 0) { + if (indx == threshold[j]) { // we've reached the threshold + // dim_range[j] = threshold[j] + shift[j] + // offset = shift[j] + ... other offsets + // offset - dim_range[j] = -threshold[j] + ... other offsets + // thus we undo our previous offset as well as add a new offset of + // -threshold[j] in one operation + offset -= dim_range[j]; // now wraps around + } + break; // indx != 0 don't need to carry + } else if (threshold[j] != 0) { // if threshold is 0 shift is 0 + offset += dim_range[j]; // indx became 0 so reverse wrap around + } + } + } + }; + // Shard + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + // 15 - expiramentally determined with float and bool types + const int cost_per_element = 15 * sizeof(T); // rough esitmate + Shard(worker_threads->num_threads, worker_threads->workers, num_elements, + cost_per_element, std::move(work)); +} + +// dim_size - the size of each dimension +// dim_range - the number of indices over in the flattened tensor +// you need to skip in order to make it over from one side of a dimension +// to the other. Used to make the shifts wrap around after a threshold. +// threshold - the index for each dimension that the roll starts to wrap +// back to the front +// isd - inner shift dimension +template +// Use memcpy to copy memory in groups when the data type supports memcpy +void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements, + const int num_dims, const gtl::ArraySlice& dim_size, + const T* input, T* output, + const gtl::ArraySlice& threshold, + const gtl::ArraySlice& dim_range, + const int64 isd) { + auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd]( + int64 start, int64 end) { + // the number of indices over in the flattened tensor you need to skip in + // order to make it over from one side of the isd to the other + const int64 isd_range = std::max(dim_range[isd], 1); + // the distance along the flattend tensor to the next element in the isd + const int64 isd_stride = isd_range / std::max(dim_size[isd], 1); + + // start and end represent the i-th group currently so we will convert + // them into numbers representing the i-th elements. + // there are 2 groups per isd one for all elements before threshold[isd] + // and another for all elements after threshold[isd]. + const int64 start_remainder = (start % 2) * threshold[isd] * isd_stride; + const int64 end_remainder = (end % 2) * threshold[isd] * isd_stride; + start = (start / 2) * isd_range + start_remainder; + end = (end / 2) * isd_range + end_remainder; + + const T* in_ptr = &input[0]; + T* out_ptr = &output[0]; + in_ptr += start; + out_ptr += start; + + // array of indices for each dimension + // indicies = [i, j, k, l, m, n] + gtl::InlinedVector indicies(num_dims); + // the offset needed to make all inner non-shifting dimensions become 0 + int64 remainder_offset = 0; + // initialize indicies + for (int i = 0; i < num_dims; i++) { + // stride is the number of indices over in the flattened tensor + // you need to skip in order to make it over to an adjacent element + // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1) + const int64 stride = dim_range[i] / dim_size[i]; + const int shift = dim_size[i] - threshold[i]; + const int indx = (start / stride) % dim_size[i]; + indicies[i] = indx; + // calculate dimension index after the shift + int out_indx = (indx + shift) % dim_size[i]; + if (i > isd) { + // trailing zeroes for indices after the inner shifted dimension + out_indx = 0; + remainder_offset += (out_indx - indx) * stride; + } + out_ptr += (out_indx - indx) * stride; + } + // set trailing zeroes for indices after the inner shifted dimension + for (int i = num_dims - 1; i > isd; i--) indicies[i] = 0; + + // the number of indices in the isd dimension the next group will skip + // to make it to the next threshold or end point + int isd_indx_skip = 0; + // the size of the next group + int64 group_size = 0; + // initialize isd_indx_skip and group_size + if (indicies[isd] < threshold[isd]) { + isd_indx_skip = threshold[isd] - indicies[isd]; + group_size = isd_indx_skip * isd_stride + remainder_offset; + } else { + isd_indx_skip = dim_size[isd] - indicies[isd]; + group_size = isd_indx_skip * isd_stride + remainder_offset; + } + + int64 i = start; + while (i < end) { + // copy group of elements + memcpy(out_ptr, in_ptr, group_size * sizeof(T)); + + // shift i and the pointers over to the next group position + i += group_size; + out_ptr += group_size; + in_ptr += group_size; + + // produce next combination of indices and adjust the out_ptr position + // to fix the offset if necessary + // the isd (inner shift dim) should skip to next threshold or endpoint + // all dimensions to the left increment by 1 when a digit is carried + // all dimensions to the right remain set to 0 + // +1 +1 +1 +isd_indx_skip + // indicies = [i, j, k, l, 0, 0] + // ^isd + for (int j = isd; j >= 0; j--) { + int inc = 1; + if (j == isd) inc = isd_indx_skip; + const int indx = (indicies[j] + inc) % dim_size[j]; + indicies[j] = indx; + if (indx != 0) { + if (indx == threshold[j]) { + out_ptr -= dim_range[j]; // now wraps around + } + break; // indx != 0 don't need to carry + } else if (threshold[j] != 0) { // if threshold is 0 shift is 0 + out_ptr += dim_range[j]; // indx became 0 so reverse wrap around + } + } + + // set isd_indx_skip and group_size for next iteration + if (indicies[isd] < threshold[isd]) { + isd_indx_skip = threshold[isd] - indicies[isd]; + group_size = isd_indx_skip * isd_stride; + } else { + isd_indx_skip = dim_size[isd] - indicies[isd]; + group_size = isd_indx_skip * isd_stride; + } + } + }; + // Shard + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + const int64 ave_group_size = dim_range[isd] / 2; + const int total_work = 2 * num_elements / std::max(dim_range[isd], 1); + // 25000 - expiramentally determined with float and bool types + const int cost_per_group = 25000 * sizeof(T) * ave_group_size; + Shard(worker_threads->num_threads, worker_threads->workers, total_work, + cost_per_group, std::move(work)); +} + +template +class RollOp : public OpKernel { + public: + explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input = context->input(0); + const Tensor& shift = context->input(1); + const Tensor& axis = context->input(2); + + auto shift_flat = shift.flat(); + auto axis_flat = axis.flat(); + + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()), + errors::InvalidArgument("input must be 1-D or higher")); + OP_REQUIRES(context, shift.shape().dims() <= 1, + errors::InvalidArgument( + "shift must be a scalar or a 1-D vector. Found: ", + shift.shape().DebugString())); + OP_REQUIRES(context, axis.shape().dims() <= 1, + errors::InvalidArgument( + "axis must be a scalar or a 1-D vector. Found: ", + axis.shape().DebugString())); + OP_REQUIRES( + context, shift.shape() == axis.shape(), + errors::InvalidArgument("shift and axis must have the same size")); + const int64 num_elements = input.NumElements(); + const int num_shifts = static_cast(shift_flat.size()); + const int num_dims = input.dims(); + + // if there are any duplicate axes, shift_mod_sum will have the + // total modulo sum of shifts for each dimension + gtl::InlinedVector shift_mod_sum(num_dims, 0); + for (int i = 0; i < num_shifts; i++) { + const int axis = axis_flat(i); + OP_REQUIRES(context, axis < num_dims, + errors::InvalidArgument("axis ", axis, " is out of range")); + const int ds = std::max(static_cast(input.dim_size(axis)), 1); + const int sum = shift_mod_sum[axis] + static_cast(shift_flat(i)); + // modulo that works with negatives: ((x % y) + y) % y + shift_mod_sum[axis] = (sum % ds + ds) % ds; + } + // the size of each dimension + gtl::InlinedVector dim_size(num_dims); + // threshold[i] is the index that the roll starts to wrap back to the front + gtl::InlinedVector threshold(num_dims); + // dim_range is the number of indices over in the flattened tensor + // you need to skip in order to make it over from one side of a dimension + // to the other. Used to make the shifts wrap around after a threshold. + gtl::InlinedVector dim_range(num_dims); + int64 dim_size_prod = 1; // dimension size product + // inner shift dimension (inner most shifted dimension) + int64 isd = 0; + for (int i = num_dims - 1; i >= 0; i--) { + if (isd == 0 && shift_mod_sum[i] != 0) isd = i; + const int ds = std::max(static_cast(input.dim_size(i)), 1); + dim_size[i] = ds; + threshold[i] = (ds - shift_mod_sum[i]) % ds; + dim_size_prod *= static_cast(input.dim_size(i)); + dim_range[i] = dim_size_prod; + } + + Tensor* output = NULL; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + auto input_flat = input.flat().data(); + auto output_flat = output->flat().data(); + + if (std::is_same::value) { + if (DataTypeCanUseMemcpy(DataTypeToEnum::v())) { + // V2 copies memory in groups instead of element by element + DoRollWithMemcpy(context, num_elements, num_dims, dim_size, + input_flat, output_flat, threshold, dim_range, isd); + } else { + // incase memcpy does not work for current data type + DoRoll(context, num_elements, num_dims, dim_size, input_flat, + output_flat, threshold, dim_range); + } + } + } +}; + +// Register the CPU kernels. +#define REGISTER_CPU(type) \ + REGISTER_KERNEL_BUILDER(Name("Roll") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tshift") \ + .TypeConstraint("Taxis"), \ + RollOp) \ + REGISTER_KERNEL_BUILDER(Name("Roll") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tshift") \ + .TypeConstraint("Taxis"), \ + RollOp) \ + REGISTER_KERNEL_BUILDER(Name("Roll") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tshift") \ + .TypeConstraint("Taxis"), \ + RollOp) \ + REGISTER_KERNEL_BUILDER(Name("Roll") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tshift") \ + .TypeConstraint("Taxis"), \ + RollOp) + +TF_CALL_ALL_TYPES(REGISTER_CPU); +#undef REGISTER_CPU +} // namespace tensorflow diff --git a/tensorflow/core/kernels/roll_op_test.cc b/tensorflow/core/kernels/roll_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..90b6f8d0f3094224ca694b59c851c14bb424d120 --- /dev/null +++ b/tensorflow/core/kernels/roll_op_test.cc @@ -0,0 +1,484 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +class RollOpTest : public OpsTestBase { + protected: + void MakeOp(DataType data_type, DataType index_type) { + TF_ASSERT_OK(NodeDefBuilder("myop", "Roll") + .Input(FakeInput(data_type)) + .Input(FakeInput(index_type)) + .Input(FakeInput(index_type)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } +}; + +TEST_F(RollOpTest, ScalarIndices) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5}), {0, 1, 2, 3, 4}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {0}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({5})); + test::FillValues(&expected, {2, 3, 4, 0, 1}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, ScalarIndices_NoMemcpy) { + MakeOp(DT_STRING, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5}), {"a", "b", "c", "d", "e"}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {0}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_STRING, TensorShape({5})); + test::FillValues(&expected, {"c", "d", "e", "a", "b"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, ScalarIndices_Complex) { + MakeOp(DT_COMPLEX64, DT_INT32); + + // Feed and run + AddInputFromArray>( + TensorShape({5}), {std::complex(0, 10), std::complex(1, 11), + std::complex(2, 12), std::complex(3, 13), + std::complex(4, 14)}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {0}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_COMPLEX64, TensorShape({5})); + test::FillValues>( + &expected, {std::complex(2, 12), std::complex(3, 13), + std::complex(4, 14), std::complex(0, 10), + std::complex(1, 11)}); + test::ExpectTensorEqual>(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, Simple_TwoD32) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({3, 5}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray(TensorShape({2}), {2, -1}); + AddInputFromArray(TensorShape({2}), {0, 1}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({3, 5})); + test::FillValues(&expected, + {6, 7, 8, 9, 5, 11, 12, 13, 14, 10, 1, 2, 3, 4, 0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, Simple_TwoD32_NoMemcpy) { + MakeOp(DT_STRING, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({3, 5}), + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + "k", "l", "m", "n", "o"}); + AddInputFromArray(TensorShape({2}), {2, -1}); + AddInputFromArray(TensorShape({2}), {0, 1}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_STRING, TensorShape({3, 5})); + test::FillValues(&expected, {"g", "h", "i", "j", "f", "l", "m", "n", + "o", "k", "b", "c", "d", "e", "a"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, Simple_ThreeD32) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({2, 2, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + AddInputFromArray(TensorShape({3}), {1, -1, -1}); + AddInputFromArray(TensorShape({3}), {0, 1, 2}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 3})); + test::FillValues(&expected, {10, 11, 9, 7, 8, 6, 4, 5, 3, 1, 2, 0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, Simple_ThreeD32_NoMemcpy) { + MakeOp(DT_STRING, DT_INT32); + + // Feed and run + AddInputFromArray( + TensorShape({2, 2, 3}), + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"}); + AddInputFromArray(TensorShape({3}), {1, -1, -1}); + AddInputFromArray(TensorShape({3}), {0, 1, 2}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_STRING, TensorShape({2, 2, 3})); + test::FillValues( + &expected, {"k", "l", "j", "h", "i", "g", "e", "f", "d", "b", "c", "a"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, Simple_TwoD64) { + MakeOp(DT_FLOAT, DT_INT64); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray(TensorShape({2}), {-1, 4}); + AddInputFromArray(TensorShape({2}), {0, 1}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3})); + test::FillValues(&expected, + {5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 2, 0, 1}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, Simple_TwoD64_NoMemcpy) { + MakeOp(DT_STRING, DT_INT64); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + "k", "l", "m", "n", "o"}); + AddInputFromArray(TensorShape({2}), {-1, 4}); + AddInputFromArray(TensorShape({2}), {0, 1}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_STRING, TensorShape({5, 3})); + test::FillValues(&expected, {"f", "d", "e", "i", "g", "h", "l", "j", + "k", "o", "m", "n", "c", "a", "b"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, Simple_ThreeD64) { + MakeOp(DT_FLOAT, DT_INT64); + + // Feed and run + AddInputFromArray(TensorShape({4, 1, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + AddInputFromArray(TensorShape({3}), {4, 3, 2}); + AddInputFromArray(TensorShape({3}), {0, 1, 2}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 1, 3})); + test::FillValues(&expected, {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, Simple_ThreeD64_NoMemcpy) { + MakeOp(DT_STRING, DT_INT64); + + // Feed and run + AddInputFromArray( + TensorShape({4, 1, 3}), + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"}); + AddInputFromArray(TensorShape({3}), {4, 3, 2}); + AddInputFromArray(TensorShape({3}), {0, 1, 2}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_STRING, TensorShape({4, 1, 3})); + test::FillValues( + &expected, {"b", "c", "a", "e", "f", "d", "h", "i", "g", "k", "l", "j"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, ZeroShift_ThreeD32) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({2, 2, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + AddInputFromArray(TensorShape({3}), {0, 0, 0}); + AddInputFromArray(TensorShape({3}), {0, 1, 2}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 3})); + test::FillValues(&expected, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, ZeroShift_ThreeD32_NoMemcpy) { + MakeOp(DT_STRING, DT_INT32); + + // Feed and run + AddInputFromArray( + TensorShape({2, 2, 3}), + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"}); + AddInputFromArray(TensorShape({3}), {0, 0, 0}); + AddInputFromArray(TensorShape({3}), {0, 1, 2}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_STRING, TensorShape({2, 2, 3})); + test::FillValues( + &expected, {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, ZeroSize_ThreeD32) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5, 0, 0}), {}); + AddInputFromArray(TensorShape({}), {1}); + AddInputFromArray(TensorShape({}), {0}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 0, 0})); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, ZeroSize_ThreeD32_NoMemcpy) { + MakeOp(DT_STRING, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5, 0, 0}), {}); + AddInputFromArray(TensorShape({}), {1}); + AddInputFromArray(TensorShape({}), {0}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_STRING, TensorShape({5, 0, 0})); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, OneSize_ThreeD32) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({1, 1, 1}), {5}); + AddInputFromArray(TensorShape({}), {1}); + AddInputFromArray(TensorShape({}), {0}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1})); + test::FillValues(&expected, {5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, OneSize_ThreeD32_NoMemcpy) { + MakeOp(DT_STRING, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({1, 1, 1}), {"a"}); + AddInputFromArray(TensorShape({}), {1}); + AddInputFromArray(TensorShape({}), {0}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_STRING, TensorShape({1, 1, 1})); + test::FillValues(&expected, {"a"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, MultiShifts_TwoD32) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({3, 5}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray(TensorShape({4}), {-2, 2, -1, 1}); + AddInputFromArray(TensorShape({4}), {1, 0, 0, 1}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({3, 5})); + test::FillValues(&expected, + {11, 12, 13, 14, 10, 1, 2, 3, 4, 0, 6, 7, 8, 9, 5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, MultiShifts_TwoD32_NoMemcpy) { + MakeOp(DT_STRING, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({3, 5}), + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + "k", "l", "m", "n", "o"}); + AddInputFromArray(TensorShape({4}), {-2, 2, -1, 1}); + AddInputFromArray(TensorShape({4}), {1, 0, 0, 1}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_STRING, TensorShape({3, 5})); + test::FillValues(&expected, {"l", "m", "n", "o", "k", "b", "c", "d", + "e", "a", "g", "h", "i", "j", "f"}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(RollOpTest, Error_InputMustBeVectorOrHigher) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({}), {7}); + AddInputFromArray(TensorShape({}), {1}); + AddInputFromArray(TensorShape({}), {0}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()).contains("input must be 1-D or higher")) + << s; +} + +TEST_F(RollOpTest, Error_AxisMustBeScalarOrVector) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({2, 2}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({}), {1}); + AddInputFromArray(TensorShape({1, 2}), {0, 1}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("axis must be a scalar or a 1-D vector")) + << s; +} + +TEST_F(RollOpTest, Error_ShiftMustBeScalarOrVector) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({2, 2}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({1, 2}), {0, 1}); + AddInputFromArray(TensorShape({}), {1}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("shift must be a scalar or a 1-D vector")) + << s; +} + +TEST_F(RollOpTest, Error_ShiftAndAxisMustBeSameSize) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({2, 2}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({1}), {1}); + AddInputFromArray(TensorShape({2}), {0, 1}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()) + .contains("shift and axis must have the same size")) + << s; +} + +TEST_F(RollOpTest, Error_AxisOutOfRange) { + MakeOp(DT_FLOAT, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({4}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({}), {1}); + AddInputFromArray(TensorShape({}), {1}); + Status s = RunOpKernel(); + EXPECT_TRUE(StringPiece(s.ToString()).contains("is out of range")) << s; +} + +// isd - (inner shift dimension) The inner most dimension to be shifted. +// All outer dimensions will also be shifted for testing. +static Graph* RollGraph(const TensorShape& shape, int isd) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor input(DT_FLOAT, shape); + input.flat().setRandom(); + const int dims = static_cast(input.dims()); + Tensor shift(DT_INT32, TensorShape({dims})); + for (int i = 0; i < dims; i++) { + // shift the inner shift dimension and all outer dimensions + shift.flat()(i) = (i <= isd) ? 2 : 0; + } + Tensor axis(DT_INT32, TensorShape({dims})); + for (int i = 0; i < dims; i++) { + axis.flat()(i) = i; + } + test::graph::Roll(g, test::graph::Constant(g, input), + test::graph::Constant(g, shift), + test::graph::Constant(g, axis)); + return g; +} + +#define BM_ROLL_OUTER(DEVICE) \ + static void BM_##DEVICE##_roll_outer(int iters, int rows, int columns) { \ + TensorShape shape{rows, columns}; \ + const int64 num_items = static_cast(iters) * shape.num_elements(); \ + testing::ItemsProcessed(num_items); \ + testing::BytesProcessed(num_items * sizeof(float)); \ + testing::UseRealTime(); \ + test::Benchmark(#DEVICE, RollGraph(shape, 0)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_roll_outer) \ + ->ArgPair(256, 256) \ + ->ArgPair(512, 512) \ + ->ArgPair(1024, 1024) \ + ->ArgPair(2048, 2048) + +#define BM_ROLL_ALL(DEVICE) \ + static void BM_##DEVICE##_roll_all(int iters, int rows, int columns) { \ + TensorShape shape{rows, columns}; \ + const int64 num_items = static_cast(iters) * shape.num_elements(); \ + testing::ItemsProcessed(num_items); \ + testing::BytesProcessed(num_items * sizeof(float)); \ + testing::UseRealTime(); \ + test::Benchmark(#DEVICE, RollGraph(shape, 1)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_roll_all) \ + ->ArgPair(256, 256) \ + ->ArgPair(512, 512) \ + ->ArgPair(1024, 1024) \ + ->ArgPair(2048, 2048) + +BM_ROLL_OUTER(cpu); +BM_ROLL_ALL(cpu); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc b/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc index 44a817a5c76d31aa8bde25a5f608b75b81116355..c0fde8042e816c325475a36129fb71630f0ca7c6 100644 --- a/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc +++ b/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc @@ -387,9 +387,9 @@ class SampleDistortedBoundingBoxV2Op : public OpKernel { OP_REQUIRES_OK( context, context->allocate_output(2, TensorShape({1, 1, 4}), &bboxes)); - typename TTypes::Tensor begin_data = begin->tensor(); - typename TTypes::Tensor size_data = size->tensor(); - typename TTypes::Tensor bboxes_data = bboxes->tensor(); + typename TTypes::Tensor begin_data(begin->tensor()); + typename TTypes::Tensor size_data(size->tensor()); + TTypes::Tensor bboxes_data = bboxes->tensor(); begin_data(0) = T(offset_height); size_data(0) = T(target_height); diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index df60eda75978ff9f6a9d7059b9594f86831aa6f5..990bd2bff94ac9cf18dd6f6316503890bb31884d 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -106,11 +106,11 @@ void SaveTensors( OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( shape_spec, &shape, &slice, &slice_shape)); OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()), - errors::InvalidArgument("Slice in shape_and_slice " - "specification does not match the " - "shape of the tensor to save: ", - shape_spec, ", tensor: ", - input.shape().DebugString())); + errors::InvalidArgument( + "Slice in shape_and_slice " + "specification does not match the " + "shape of the tensor to save: ", + shape_spec, ", tensor: ", input.shape().DebugString())); } #define WRITER_ADD(T) \ diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h index c6e35fe329e1c1b7acb62daedeeb2f1a92444b78..079f15e101308867389745ee42146086af91c47c 100644 --- a/tensorflow/core/kernels/scatter_functor.h +++ b/tensorflow/core/kernels/scatter_functor.h @@ -29,7 +29,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace scatter_op { @@ -117,7 +117,7 @@ struct AssignSYCL { p.device(d) = p / u; } }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace internal } // namespace scatter_op @@ -156,7 +156,7 @@ struct ScatterFunctorBase { #ifdef TENSORFLOW_USE_SYCL template -struct ScatterFunctorBase { +struct ScatterFunctorBase { Index operator()(OpKernelContext* c, const SYCLDevice& d, typename TTypes::Matrix params, typename TTypes::ConstMatrix updates, @@ -171,13 +171,13 @@ struct ScatterFunctorBase { const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); if (!FastBoundsCheck(index, limit)) return i; // Copy last Ndim-1 dimensions of updates[i] to params[index] - scatter_op::internal::AssignSYCL::Run(d, params.template chip<0>(index), - updates.template chip<0>(i)); + scatter_op::internal::AssignSYCL::Run( + d, params.template chip<0>(index), updates.template chip<0>(i)); } return -1; } }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template struct ScatterFunctorBase { @@ -217,7 +217,7 @@ struct ScatterFunctorBase { template struct ScatterFunctor - : ScatterFunctorBase{}; + : ScatterFunctorBase {}; #ifdef TENSORFLOW_USE_SYCL template @@ -239,7 +239,7 @@ struct ScatterFunctorSYCL { return -1; } }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h index e116077d3cfc37871009ee3fede633590d269681..be18658543ea330e3196d0f372154df32e4e1dfc 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h @@ -30,9 +30,10 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; template -__global__ void ScatterOpCustomKernel( - T* params, const T* updates, const Index* indices, - Index first_dim_size, Index updates_size, Index indices_size) { +__global__ void ScatterOpCustomKernel(T* params, const T* updates, + const Index* indices, + Index first_dim_size, Index updates_size, + Index indices_size) { Index update_block = updates_size / indices_size; CUDA_1D_KERNEL_LOOP(i, updates_size) { int indices_i = i / update_block; @@ -85,8 +86,8 @@ struct ScatterFunctor { CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d); ScatterOpCustomKernel <<>>( - params.data(), updates.data(), indices.data(), - first_dim_size, updates_size, indices_size); + params.data(), updates.data(), indices.data(), first_dim_size, + updates_size, indices_size); return -1; } }; diff --git a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h index c6c9d4e6588f1f4d847810de1e736220d5572f25..e82660dcc1dcf9dbb7d531c0223e211ce46a8635 100644 --- a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h @@ -40,7 +40,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL class OpKernelContext; @@ -251,7 +251,7 @@ REGISTER_SCATTER_ND_MATH_SYCL(int32); #undef REGISTER_SCATTER_ND_INDEX_SYCL #undef REGISTER_SCATTER_ND_FULL_SYCL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace functor diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc index 31f74671cabdabce2884fcae61a6e56dbfdefe8b..a3c21edc15f684e51c7f1806aeeeeead679ea22e 100644 --- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc @@ -55,6 +55,27 @@ struct LeftUpdate { } }; +// Specializations for std::complex, updating real and imaginary part +// individually. Even though this is not an atomic op anymore, it is safe +// because there is only one type of op per kernel. +template +struct LeftUpdate, scatter_nd_op::UpdateOp::ADD> { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()( + std::complex* out, const std::complex& val) { + T* ptr = reinterpret_cast(out); + CudaAtomicAdd(ptr, val.real()); + CudaAtomicAdd(ptr, val.imag()); + } +}; + +template +struct LeftUpdate, scatter_nd_op::UpdateOp::SUB> { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()( + std::complex* out, const std::complex& val) { + LeftUpdate, scatter_nd_op::UpdateOp::ADD>()(out, -val); + } +}; + } // namespace template diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 8607c7f95af79c8f581768cfc698bad9fe085188..282165349f316144d261859d5a3a992f047e0df3 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -25,7 +25,7 @@ limitations under the License. #ifdef TENSORFLOW_USE_SYCL #include "tensorflow/core/common_runtime/sycl/sycl_util.h" -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace tensorflow { @@ -33,7 +33,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Check whether updates.shape = indices.shape + params.shape[1:] static bool ValidShapes(const Tensor& params, const Tensor& updates, @@ -102,11 +102,12 @@ class ScatterUpdateOp : public OpKernel { // Check that we have enough index space const int64 N_big = indices.NumElements(); - OP_REQUIRES(c, N_big <= std::numeric_limits::max(), - errors::InvalidArgument( - "indices has too many elements for ", - DataTypeString(DataTypeToEnum::v()), " indexing: ", - N_big, " > ", std::numeric_limits::max())); + OP_REQUIRES( + c, N_big <= std::numeric_limits::max(), + errors::InvalidArgument("indices has too many elements for ", + DataTypeString(DataTypeToEnum::v()), + " indexing: ", N_big, " > ", + std::numeric_limits::max())); const Index N = static_cast(indices.NumElements()); OP_REQUIRES( c, params.dim_size(0) <= std::numeric_limits::max(), @@ -137,7 +138,7 @@ class ScatterUpdateOp : public OpKernel { #ifdef TENSORFLOW_USE_SYCL template -class ScatterUpdateOp : public OpKernel { +class ScatterUpdateOp : public OpKernel { public: explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_)); @@ -165,11 +166,12 @@ class ScatterUpdateOp : public OpKernel { // Check that we have enough index space const int64 N_big = indices.NumElements(); - OP_REQUIRES(c, N_big <= std::numeric_limits::max(), - errors::InvalidArgument( - "indices has too many elements for ", - DataTypeString(DataTypeToEnum::v()), " indexing: ", - N_big, " > ", std::numeric_limits::max())); + OP_REQUIRES( + c, N_big <= std::numeric_limits::max(), + errors::InvalidArgument("indices has too many elements for ", + DataTypeString(DataTypeToEnum::v()), + " indexing: ", N_big, " > ", + std::numeric_limits::max())); const Index N = static_cast(indices.NumElements()); OP_REQUIRES( c, params.dim_size(0) <= std::numeric_limits::max(), @@ -206,7 +208,7 @@ class ScatterUpdateOp : public OpKernel { } } }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \ REGISTER_KERNEL_BUILDER(Name(name) \ diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc index 863c123b43f781239dab62e6b57719376fc49dad..066a4b80a2bc6976a6c95ced2c5efecbef13eeba 100644 --- a/tensorflow/core/kernels/sdca_internal.cc +++ b/tensorflow/core/kernels/sdca_internal.cc @@ -37,9 +37,8 @@ void FeatureWeightsDenseStorage::UpdateDenseDeltaWeights( const size_t num_weight_vectors = normalized_bounded_dual_delta.size(); if (num_weight_vectors == 1) { deltas_.device(device) = - deltas_ + - dense_vector.RowAsMatrix() * - deltas_.constant(normalized_bounded_dual_delta[0]); + deltas_ + dense_vector.RowAsMatrix() * + deltas_.constant(normalized_bounded_dual_delta[0]); } else { // Transform the dual vector into a column matrix. const Eigen::TensorMap> @@ -61,9 +60,8 @@ void FeatureWeightsSparseStorage::UpdateSparseDeltaWeights( const Example::SparseFeatures& sparse_features, const std::vector& normalized_bounded_dual_delta) { for (int64 k = 0; k < sparse_features.indices->size(); ++k) { - const double feature_value = sparse_features.values == nullptr - ? 1.0 - : (*sparse_features.values)(k); + const double feature_value = + sparse_features.values == nullptr ? 1.0 : (*sparse_features.values)(k); auto it = indices_to_id_.find((*sparse_features.indices)(k)); for (size_t l = 0; l < normalized_bounded_dual_delta.size(); ++l) { deltas_(l, it->second) += @@ -122,23 +120,24 @@ Status ModelWeights::Initialize(OpKernelContext* const context) { } // Reads in the weights, and allocates and initializes the delta weights. - const auto initialize_weights = [&]( - const OpInputList& weight_inputs, OpOutputList* const weight_outputs, - std::vector* const feature_weights) { - for (int i = 0; i < weight_inputs.size(); ++i) { - Tensor* delta_t; - TF_RETURN_IF_ERROR( - weight_outputs->allocate(i, weight_inputs[i].shape(), &delta_t)); - // Convert the input vector to a row matrix in internal representation. - auto deltas = delta_t->shaped({1, delta_t->NumElements()}); - deltas.setZero(); - feature_weights->emplace_back( - FeatureWeightsDenseStorage{weight_inputs[i].shaped( - {1, weight_inputs[i].NumElements()}), - deltas}); - } - return Status::OK(); - }; + const auto initialize_weights = + [&](const OpInputList& weight_inputs, OpOutputList* const weight_outputs, + std::vector* const feature_weights) { + for (int i = 0; i < weight_inputs.size(); ++i) { + Tensor* delta_t; + TF_RETURN_IF_ERROR( + weight_outputs->allocate(i, weight_inputs[i].shape(), &delta_t)); + // Convert the input vector to a row matrix in internal + // representation. + auto deltas = delta_t->shaped({1, delta_t->NumElements()}); + deltas.setZero(); + feature_weights->emplace_back(FeatureWeightsDenseStorage{ + weight_inputs[i].shaped( + {1, weight_inputs[i].NumElements()}), + deltas}); + } + return Status::OK(); + }; return initialize_weights(dense_weights_inputs, &dense_weights_outputs, &dense_weights_); diff --git a/tensorflow/core/kernels/sdca_internal.h b/tensorflow/core/kernels/sdca_internal.h index 9f072700754320700024be57ebe3c4ca780a1ae9..45915693ac6f0b4ad2d5f2aacebcd4aa34c03439 100644 --- a/tensorflow/core/kernels/sdca_internal.h +++ b/tensorflow/core/kernels/sdca_internal.h @@ -149,7 +149,8 @@ class Example { // 1.0f. struct SparseFeatures { std::unique_ptr::UnalignedConstVec> indices; - std::unique_ptr::UnalignedConstVec> values; // nullptr encodes optional. + std::unique_ptr::UnalignedConstVec> + values; // nullptr encodes optional. }; // A dense vector which is a row-slice of the underlying matrix. diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc index 0f5c2424b38aeed5912287bba7a218575a107073..dbe0177dda337a271433cd3bb4257026dc702364 100644 --- a/tensorflow/core/kernels/sdca_ops.cc +++ b/tensorflow/core/kernels/sdca_ops.cc @@ -57,11 +57,11 @@ namespace tensorflow { namespace { -using sdca::Regularizations; using sdca::Example; using sdca::Examples; using sdca::ExampleStatistics; using sdca::ModelWeights; +using sdca::Regularizations; struct ComputeOptions { explicit ComputeOptions(OpKernelConstruction* const context) { @@ -76,8 +76,9 @@ struct ComputeOptions { } else if (loss_type == "smooth_hinge_loss") { loss_updater.reset(new SmoothHingeLossUpdater); } else { - OP_REQUIRES(context, false, errors::InvalidArgument( - "Unsupported loss type: ", loss_type)); + OP_REQUIRES( + context, false, + errors::InvalidArgument("Unsupported loss type: ", loss_type)); } OP_REQUIRES_OK(context, context->GetAttr("adaptative", &adaptative)); OP_REQUIRES_OK( @@ -90,9 +91,10 @@ struct ComputeOptions { context, num_sparse_features + num_dense_features > 0, errors::InvalidArgument("Requires at least one feature to train.")); - OP_REQUIRES(context, static_cast(num_sparse_features) + - static_cast(num_dense_features) <= - std::numeric_limits::max(), + OP_REQUIRES(context, + static_cast(num_sparse_features) + + static_cast(num_dense_features) <= + std::numeric_limits::max(), errors::InvalidArgument( strings::Printf("Too many feature groups: %lld > %d", static_cast(num_sparse_features) + diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index 3ef1cd1e062b5f5abecca2f4f788e3fed20e33e9..6c4685a50a4139b9f33d22b409059f7c03fa2812 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -20,10 +20,10 @@ limitations under the License. #define EIGEN_USE_GPU #endif // GOOGLE_CUDA -#include "tensorflow/core/kernels/segment_reduction_ops.h" -#include #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/segment_reduction_ops.h" +#include #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -115,7 +115,7 @@ class SegmentReductionOp : public OpKernel { Eigen::DSizes dims_to_reduce; dims_to_reduce[0] = 0; #else - Eigen::IndexList> dims_to_reduce; + Eigen::IndexList > dims_to_reduce; #endif Index start = 0, end = 1; @@ -356,158 +356,180 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL); #undef REGISTER_GPU_SORTED_KERNELS_ALL #endif // GOOGLE_CUDA +// ____________________________________________________________________________ +// Unsorted segment reduction ops. + namespace functor { -// UnsortedSegmentSumFunctor implementation for CPUDevice. -// todo: Remove duplicate code in UnsortedSegmentSumFunctor and UnsortedSegmentMaxFunctor. -template -struct UnsortedSegmentSumFunctor - : UnsortedSegmentBaseFunctor { - void operator()(OpKernelContext* ctx, const CPUDevice& d, - const Index output_rows, const TensorShape& segment_ids_shape, +// The ReductionFunctor implementation for CPU. +template +struct UnsortedSegmentFunctor { + void operator()(OpKernelContext* ctx, const Index num_segments, + const TensorShape& segment_ids_shape, typename TTypes::ConstFlat segment_ids, const Index data_size, const T* data, - typename TTypes::Tensor output) override { - output.setZero(); + typename TTypes::Tensor output) { + output.setConstant(InitialValueF()()); if (data_size == 0) { return; } const int64 N = segment_ids.dimension(0); + ReductionF reduction; auto data_flat = typename TTypes::ConstTensor(data, N, data_size / N); for (int64 i = 0; i < N; ++i) { Index j = internal::SubtleMustCopy(segment_ids(i)); if (j < 0) { continue; } - OP_REQUIRES(ctx, FastBoundsCheck(j, output_rows), + OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments), errors::InvalidArgument( "segment_ids", SliceDebugString(segment_ids_shape, i), - " = ", j, " is out of range [0, ", output_rows, ")")); - output.template chip<0>(j) += data_flat.template chip<0>(i); + " = ", j, " is out of range [0, ", num_segments, ")")); + reduction(data_flat.template chip<0>(i), output.template chip<0>(j)); } } }; -// UnsortedSegmentMaxFunctor implementation for CPUDevice. -template -struct UnsortedSegmentMaxFunctor - : UnsortedSegmentBaseFunctor { - void operator()(OpKernelContext* ctx, const CPUDevice& d, - const Index output_rows, const TensorShape& segment_ids_shape, - typename TTypes::ConstFlat segment_ids, - const Index data_size, const T* data, - typename TTypes::Tensor output) override { - output.setConstant(std::numeric_limits::lowest()); - if (data_size == 0) { - return; - } - const int64 N = segment_ids.dimension(0); - auto data_flat = typename TTypes::ConstTensor(data, N, data_size / N); - for (int64 i = 0; i < N; ++i) { - Index j = internal::SubtleMustCopy(segment_ids(i)); - OP_REQUIRES(ctx, FastBoundsCheck(j, output_rows), - errors::InvalidArgument( - "segment_ids", SliceDebugString(segment_ids_shape, i), - " = ", j, " is out of range [0, ", output_rows, ")")); - output.template chip<0>(j) = - data_flat.template chip<0>(i).cwiseMax(output.template chip<0>(j)); - } + +template +using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes::Matrix>; + +template +using constMatrixChip = + Eigen::TensorChippingOp<0l, const typename TTypes::ConstMatrix>; + +// reduction functors +template +struct SumOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output += data; + } +}; + +template +struct MaxOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output = data.cwiseMax(output); + } +}; + +template +struct MinOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output = data.cwiseMin(output); + } +}; + +template +struct ProdOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output *= data; } }; } // namespace functor -// Base class for SegmentReductionOps that can handle unsorted segment -// definitions -// and specifying the size of the output in addition to a reduction function -template -class UnsortedSegmentBaseOp : public OpKernel { +// Static check routines not in the templated class to reduce code size +static void UnsortedSegmentReductionValidation(OpKernel* op_kernel, + OpKernelContext* context, + const Tensor& data, + const Tensor& segment_ids, + const Tensor& num_segments) { + OP_REQUIRES( + context, op_kernel->IsLegacyScalar(num_segments.shape()), + errors::InvalidArgument("num_segments should be a scalar, not shape ", + num_segments.shape().DebugString())); + OP_REQUIRES( + context, TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()), + errors::InvalidArgument("data.shape = ", data.shape().DebugString(), + " does not start with segment_ids.shape = ", + segment_ids.shape().DebugString())); +} + +static bool UnsortedSegmentReductionDoValidation(OpKernel* op_kernel, + OpKernelContext* context, + const Tensor& data, + const Tensor& segment_ids, + const Tensor& num_segments) { + UnsortedSegmentReductionValidation(op_kernel, context, data, segment_ids, + num_segments); + return context->status().ok(); +} + +// The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor +// is the device specific implementation of the reduction. These device +// specific implementations are templated themselves with the corresponding +// initial value functors and reduction functors. +template +class UnsortedSegmentReductionOp : public OpKernel { public: - explicit UnsortedSegmentBaseOp( - OpKernelConstruction* context, - functor::UnsortedSegmentBaseFunctor& functor) - : OpKernel(context), reduction_functor_(functor) {} + explicit UnsortedSegmentReductionOp(OpKernelConstruction* context) + : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {} void Compute(OpKernelContext* context) override { const Tensor& data = context->input(0); const Tensor& segment_ids = context->input(1); const Tensor& num_segments = context->input(2); - - OP_REQUIRES( - context, IsLegacyScalar(num_segments.shape()), - errors::InvalidArgument("num_segments should be a scalar, not shape ", - num_segments.shape().DebugString())); - OP_REQUIRES( - context, - TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()), - errors::InvalidArgument("data.shape = ", data.shape().DebugString(), - " does not start with segment_ids.shape = ", - segment_ids.shape().DebugString())); - + if (!UnsortedSegmentReductionDoValidation(this, context, data, segment_ids, + num_segments)) { + return; + } const auto segment_flat = segment_ids.flat(); const Index output_rows = internal::SubtleMustCopy(num_segments.scalar()()); OP_REQUIRES(context, output_rows >= 0, errors::InvalidArgument("Input num_segments == ", output_rows, " must not be negative.")); - TensorShape output_shape; output_shape.AddDim(output_rows); for (int i = segment_ids.dims(); i < data.dims(); i++) { output_shape.AddDim(data.dim_size(i)); } - Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); auto output_flat = output->flat_outer_dims(); - auto data_ptr = data.template flat().data(); - reduction_functor_(context, context->template eigen_device(), - output_rows, segment_ids.shape(), segment_flat, - data.NumElements(), data_ptr, output_flat); + reduction_functor_(context, output_rows, segment_ids.shape(), segment_flat, + data.NumElements(), data_ptr, output_flat); } - private: - functor::UnsortedSegmentBaseFunctor& reduction_functor_; -}; -template -class UnsortedSegmentSumOp : public UnsortedSegmentBaseOp { - public: - explicit UnsortedSegmentSumOp(OpKernelConstruction* context) - : UnsortedSegmentBaseOp( - context, - sum_functor_) {} - private: - functor::UnsortedSegmentSumFunctor sum_functor_; + protected: + DeviceReductionFunctor reduction_functor_; }; -template -class UnsortedSegmentMaxOp : public UnsortedSegmentBaseOp { - public: - explicit UnsortedSegmentMaxOp(OpKernelConstruction* context) - : UnsortedSegmentBaseOp( - context, - max_functor_) {} - private: - functor::UnsortedSegmentMaxFunctor max_functor_; -}; - -#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - UnsortedSegmentSumOp); \ - REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentMax") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - UnsortedSegmentMaxOp); - -#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - UnsortedSegmentSumOp); +#define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT( \ + name, type, index_type, initial_value_functor, reduction_functor) \ + REGISTER_KERNEL_BUILDER( \ + Name(name) \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + UnsortedSegmentReductionOp< \ + type, index_type, \ + functor::UnsortedSegmentFunctor >) + +#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ + functor::Zero, \ + functor::SumOp); \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ + functor::Lowest, \ + functor::MaxOp); \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ + functor::Highest, \ + functor::MinOp); \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ + functor::One, \ + functor::ProdOp); + +#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ + functor::Zero, \ + functor::SumOp); \ + REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ + functor::One, \ + functor::ProdOp) #define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \ REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \ @@ -520,31 +542,72 @@ class UnsortedSegmentMaxOp : public UnsortedSegmentBaseOp { TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL); REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64); REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128); + #undef REGISTER_REAL_CPU_UNSORTED_KERNELS +#undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL #undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL #if GOOGLE_CUDA -#define REGISTER_GPU_UNSORTED_KERNELS(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \ - .Device(DEVICE_GPU) \ - .HostMemory("num_segments") \ - .TypeConstraint("T") \ - .TypeConstraint("Tindices"), \ - UnsortedSegmentSumOp); - -#define REGISTER_GPU_UNSORTED_KERNELS_ALL(type) \ - REGISTER_GPU_UNSORTED_KERNELS(type, int32); \ - REGISTER_GPU_UNSORTED_KERNELS(type, int64); +#define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT( \ + name, type, index_type, initial_value_functor, reduction_kernel_functor) \ + REGISTER_KERNEL_BUILDER( \ + Name(name) \ + .Device(DEVICE_GPU) \ + .HostMemory("num_segments") \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + UnsortedSegmentReductionOp< \ + type, index_type, \ + functor::UnsortedSegmentFunctor >) + +// sum is the only op that supports all input types currently +#define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ + functor::Lowest, \ + functor::MaxOpGpu); \ + REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ + functor::Highest, \ + functor::MinOpGpu); \ + REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ + functor::One, \ + functor::ProdOpGpu); + +#define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \ + REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ + functor::Zero, \ + functor::SumOpGpu); + +#define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \ + REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32); \ + REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64); + +#define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \ + REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int32); \ + REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int64); + + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL); +TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); +TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); +TF_CALL_complex64(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); +TF_CALL_complex128(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); + +#undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT +#undef REGISTER_REAL_GPU_UNSORTED_KERNELS +#undef REGISTER_SUM_GPU_UNSORTED_KERNELS +#undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL +#undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_UNSORTED_KERNELS_ALL); -TF_CALL_complex64(REGISTER_GPU_UNSORTED_KERNELS_ALL); -TF_CALL_complex128(REGISTER_GPU_UNSORTED_KERNELS_ALL); -#undef REGISTER_GPU_UNSORTED_KERNELS -#undef REGISTER_GPU_UNSORTED_KERNELS_ALL #endif // GOOGLE_CUDA +// ____________________________________________________________________________ +// Sparse segment reduction ops. + // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented // by two dense tensors, one containing the data, and the other containing // indices into the data. @@ -663,9 +726,9 @@ class SparseSegmentReductionOpBase : public OpKernel { Reduce(input_flat, indices_vec, start, end - start, out); OP_REQUIRES(context, bad_offset < 0, errors::InvalidArgument( - "Bad: indices[", start + bad_offset, "] == ", - indices_vec(start + bad_offset), " out of range [0, ", - input_flat.dimension(0), ")")); + "Bad: indices[", start + bad_offset, + "] == ", indices_vec(start + bad_offset), + " out of range [0, ", input_flat.dimension(0), ")")); start = end; ++end; diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index bcdd42c80c18af381988808db74319e5072f38a7..51814273b305bfa35bca0ddce0376658064ea56a 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ -#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" @@ -46,56 +46,81 @@ struct SegmentSumFunctor { const Index data_size, const T* data, typename TTypes::Tensor output); }; + #endif -// BaseFunctor for definition of UnsorteSegmentReductionOp -// for usage without templates. -template -struct UnsortedSegmentBaseFunctor{ - virtual ~UnsortedSegmentBaseFunctor(){} - virtual void operator()(OpKernelContext* ctx, const Device& d, - const Index output_rows, const TensorShape& segment_ids_shape, +template +struct UnsortedSegmentFunctor { + void operator()(OpKernelContext* ctx, const Index num_segments, + const TensorShape& segment_ids_shape, typename TTypes::ConstFlat segment_ids, const Index data_size, const T* data, - typename TTypes::Tensor output){}; + typename TTypes::Tensor output); }; -// Functor for UnsortedSegmentSumOp. -// output_rows: the number of output segments (unique segment ids in -// 'segment_ids'). -// segment_ids_shape: shape of 'segment_ids' tensor. -// segment_ids: unsorted map from input to output segment ids at which to -// perform segment sum operation. -// data_size: size of input data tensor. -// data: input data tensor. -// output: output reshaped to {output_rows, output.size/output_rows} -template -struct UnsortedSegmentSumFunctor: public UnsortedSegmentBaseFunctor { - void operator()(OpKernelContext* ctx, const Device& d, - const Index output_rows, const TensorShape& segment_ids_shape, - typename TTypes::ConstFlat segment_ids, - const Index data_size, const T* data, - typename TTypes::Tensor output); +#ifdef GOOGLE_CUDA +// reduction functors for the gpu +template +struct SumOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + CudaAtomicAdd(dest, value); + } }; -// Functor for UnsortedSegmentMaxOp. -// output_rows: the number of output segments (unique segment ids in -// 'segment_ids'). -// segment_ids_shape: shape of 'segment_ids' tensor. -// segment_ids: unsorted map from input to output segment ids at which to -// perform segment sum operation. -// data_size: size of input data tensor. -// data: input data tensor. -// output: output reshaped to {output_rows, output.size/output_rows} -template -struct UnsortedSegmentMaxFunctor: public UnsortedSegmentBaseFunctor { - void operator()(OpKernelContext* ctx, const Device& d, - const Index output_rows, const TensorShape& segment_ids_shape, - typename TTypes::ConstFlat segment_ids, - const Index data_size, const T* data, - typename TTypes::Tensor output); +template +struct ProdOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + CudaAtomicMul(dest, value); + } +}; + +template +struct MaxOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + CudaAtomicMax(dest, value); + } +}; + +template +struct MinOpGpu { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, + const T& value) { + CudaAtomicMin(dest, value); + } }; + +#endif // GOOGLE_CUDA + +// initial value functors +template +struct Zero { + EIGEN_STRONG_INLINE T operator()() const { return T(0); } +}; + +template +struct One { + EIGEN_STRONG_INLINE T operator()() const { return T(1); } +}; + +template +struct Lowest { + EIGEN_STRONG_INLINE T operator()() const { + return Eigen::NumTraits::lowest(); + } +}; + +template +struct Highest { + EIGEN_STRONG_INLINE T operator()() const { + return Eigen::NumTraits::highest(); + } +}; + } // namespace functor } // namespace tensorflow -#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc index 159fada621bd88de259e9b044491f3ecebf10b19..ba979e6bb216b649ff4fc3cefa7099ac9cbc1b91 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc @@ -18,42 +18,15 @@ limitations under the License. #define EIGEN_USE_GPU #include "tensorflow/core/kernels/segment_reduction_ops.h" - #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/cuda_device_functions.h" #include "tensorflow/core/util/cuda_kernel_helper.h" + namespace tensorflow { using GPUDevice = Eigen::GpuDevice; -// Helper for UnusortedSegmentSumCustomKernel that adds value into dest -// atomically. -template -static __device__ __forceinline__ void AccumulateInto(T* dest, const T& value) { - CudaAtomicAdd(dest, value); -} - -// Specializations of AccumulateInto for complex types, which CudaAtomicAdd does -// not support. We treat a std::complex* as a T* (the C++ standard section -// 26.4.4 allows this explicitly) and atomic add the real and imaginary -// components individually. The operation as a whole is not atomic, but we can -// safely treat the components independently for the purpose of accumulating. -template <> -__device__ __forceinline__ void AccumulateInto( - std::complex* dest, const std::complex& value) { - auto dest_scalar = reinterpret_cast(dest); - CudaAtomicAdd(dest_scalar, value.real()); - CudaAtomicAdd(dest_scalar + 1, value.imag()); -} - -template <> -__device__ __forceinline__ void AccumulateInto( - std::complex* dest, const std::complex& value) { - auto dest_scalar = reinterpret_cast(dest); - CudaAtomicAdd(dest_scalar, value.real()); - CudaAtomicAdd(dest_scalar + 1, value.imag()); -} - // SortedSegmentSumFunctor kernel reduces input data just as // UnsortedSegmentSumCustomKernel does except that input data // is partitioned along the outer reduction dimension. This is @@ -81,7 +54,7 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, const Index* segment_ids, const T* input, T* output, const Index total_stripe_count) { - CUDA_1D_KERNEL_LOOP(stripe_index, total_stripe_count) { + for (int stripe_index : CudaGridRangeX(total_stripe_count)) { const Index segment_offset = stripe_index % inner_dim_size; const Index input_outer_dim_index_base = stripe_index / inner_dim_size * Index(OuterDimTileSize); @@ -106,7 +79,7 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, // decide whether to write result to global memory using atomic // operations if (last_output_segment_id == first_segment_id) { - AccumulateInto(output + output_index, sum); + CudaAtomicAdd(output + output_index, sum); } else { *(output + output_index) = sum; } @@ -121,31 +94,31 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, // the following strip. const Index output_index = last_output_segment_id * inner_dim_size + segment_offset; - AccumulateInto(output + output_index, sum); + CudaAtomicAdd(output + output_index, sum); } } -// UnsortedSegmentSumFunctor kernel processes 'input_total_size' elements. +// UnsortedSegmentSumKernel processes 'input_total_size' elements. // Each element is mapped from input to output by a combination of its // 'segment_ids' mapping and 'inner_dim_size'. -template -__global__ void UnsortedSegmentSumCustomKernel( - const Index input_outer_dim_size, const Index inner_dim_size, - const Index output_outer_dim_size, const Index* segment_ids, const T* input, - T* output) { +template +__global__ void UnsortedSegmentCustomKernel(const Index input_outer_dim_size, + const Index inner_dim_size, + const Index output_outer_dim_size, + const Index* segment_ids, + const T* input, T* output) { const Index input_total_size = input_outer_dim_size * inner_dim_size; const Index output_total_size = output_outer_dim_size * inner_dim_size; - CUDA_1D_KERNEL_LOOP(input_index, input_total_size) { + for (int input_index : CudaGridRangeX(input_total_size)) { const Index input_segment_index = input_index / inner_dim_size; const Index segment_offset = input_index % inner_dim_size; const Index output_segment_index = segment_ids[input_segment_index]; - if (output_segment_index < 0 || output_segment_index >= output_total_size) { continue; } const Index output_index = output_segment_index * inner_dim_size + segment_offset; - AccumulateInto(output + output_index, ldg(input + input_index)); + KernelReductionFunctor()(output + output_index, ldg(input + input_index)); } } @@ -190,42 +163,40 @@ void SegmentSumFunctor::operator()( <<>>( input_outer_dim_size, input_inner_dim_size, output_rows, segment_ids.data(), data, output.data(), total_stripe_count); -}; +} -// UnsortedSegmentSumFunctor implementation for GPUDevice. -template -struct UnsortedSegmentSumFunctor: UnsortedSegmentBaseFunctor { - void operator()(OpKernelContext* ctx, const GPUDevice& d, - const Index output_rows, const TensorShape& segment_ids_shape, +template +struct UnsortedSegmentFunctor { + void operator()(OpKernelContext* ctx, const Index num_segments, + const TensorShape& segment_ids_shape, typename TTypes::ConstFlat segment_ids, const Index data_size, const T* data, - typename TTypes::Tensor output) override { + typename TTypes::Tensor output) { if (output.size() == 0) { return; } - // Set 'output' to zeros. + // Set 'output' to initial value. + GPUDevice d = ctx->template eigen_device(); CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d); - SetZero<<>>( - output.size(), output.data()); + SetToValue<<>>( + output.size(), output.data(), InitialValueF()()); if (data_size == 0 || segment_ids_shape.num_elements() == 0) { return; } - - // Launch kernel to compute unsorted segment sum. + // Launch kernel to compute unsorted segment reduction. // Notes: - // *) 'input_total_size' is the total number of elements to process. + // *) 'data_size' is the total number of elements to process. // *) 'segment_ids.shape' is a prefix of data's shape. // *) 'input_outer_dim_size' is the total number of segments to process. - const Index input_total_size = data_size; const Index input_outer_dim_size = segment_ids.dimension(0); - const Index input_inner_dim_size = input_total_size / input_outer_dim_size; - - config = GetCudaLaunchConfig(input_total_size, d); - UnsortedSegmentSumCustomKernel< - T, - Index><<>>( - input_outer_dim_size, input_inner_dim_size, output_rows, - segment_ids.data(), data, output.data()); + const Index input_inner_dim_size = data_size / input_outer_dim_size; + config = GetCudaLaunchConfig(data_size, d); + + UnsortedSegmentCustomKernel + <<>>( + input_outer_dim_size, input_inner_dim_size, num_segments, + segment_ids.data(), data, output.data()); } }; @@ -238,19 +209,40 @@ struct UnsortedSegmentSumFunctor: UnsortedSegmentBaseFuncto TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS); -#define DEFINE_GPU_SPECS_INDEX(T, Index) \ - template struct UnsortedSegmentSumFunctor - -#define DEFINE_GPU_SPECS(T) \ - DEFINE_GPU_SPECS_INDEX(T, int32); \ - DEFINE_GPU_SPECS_INDEX(T, int64); - -TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); -TF_CALL_complex64(DEFINE_GPU_SPECS); -TF_CALL_complex128(DEFINE_GPU_SPECS); - -#undef DEFINE_GPU_SPECS -#undef DEFINE_GPU_SPECS_INDEX +#define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \ + template struct UnsortedSegmentFunctor< \ + GPUDevice, T, Index, functor::Lowest, functor::MaxOpGpu>; \ + template struct UnsortedSegmentFunctor< \ + GPUDevice, T, Index, functor::Highest, functor::MinOpGpu>; \ + template struct UnsortedSegmentFunctor, \ + functor::ProdOpGpu>; + +// sum is the only op that supports all input types currently +#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \ + template struct UnsortedSegmentFunctor< \ + GPUDevice, T, Index, functor::Zero, functor::SumOpGpu>; + +#define DEFINE_REAL_GPU_SPECS(T) \ + DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \ + DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int64); + +#define DEFINE_SUM_GPU_SPECS(T) \ + DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int32); \ + DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int64); + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_REAL_GPU_SPECS); +TF_CALL_int32(DEFINE_REAL_GPU_SPECS); +TF_CALL_GPU_NUMBER_TYPES(DEFINE_SUM_GPU_SPECS); +TF_CALL_int32(DEFINE_SUM_GPU_SPECS); +TF_CALL_complex64(DEFINE_SUM_GPU_SPECS); +TF_CALL_complex128(DEFINE_SUM_GPU_SPECS); + +#undef DEFINE_SORTED_GPU_SPECS_INDEX +#undef DEFINE_SORTED_GPU_SPECS +#undef DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX +#undef DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX +#undef DEFINE_REAL_GPU_SPECS +#undef DEFINE_SUM_GPU_SPECS } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/self_adjoint_eig_op.cc b/tensorflow/core/kernels/self_adjoint_eig_op.cc index 97657807268d30d66a01573bc3df09e318ce1d51..bcd88773902824c6e88db4226af43993d5649007 100644 --- a/tensorflow/core/kernels/self_adjoint_eig_op.cc +++ b/tensorflow/core/kernels/self_adjoint_eig_op.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" - namespace tensorflow { template diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index 206fd40fa68c3158fa60b7651d40121ab1344bbd..688e61fcadc3ad01b579f8dfc712af2d8032ee35 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -114,7 +114,7 @@ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp); REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_SYCL), SendOp); REGISTER_KERNEL_BUILDER( Name("_HostSend").Device(DEVICE_SYCL).HostMemory("tensor"), SendOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp); REGISTER_KERNEL_BUILDER( @@ -198,7 +198,7 @@ REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_SYCL), RecvOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp); REGISTER_KERNEL_BUILDER( @@ -207,6 +207,6 @@ REGISTER_KERNEL_BUILDER( #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER( Name("_HostRecv").Device(DEVICE_SYCL).HostMemory("tensor"), RecvOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // end namespace tensorflow diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index e2e3758d87e49702ebc48f78c022affe49a3b7e4..9db0bd4d98bdb9964cb561d96d91782ba3615a7f 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -53,13 +53,13 @@ class RangeOp : public OpKernel { if (delta > 0) { OP_REQUIRES( context, start <= limit, - errors::InvalidArgument("Requires start <= limit when delta > 0: ", - start, "/", limit)); + errors::InvalidArgument( + "Requires start <= limit when delta > 0: ", start, "/", limit)); } else { OP_REQUIRES( context, start >= limit, - errors::InvalidArgument("Requires start >= limit when delta < 0: ", - start, "/", limit)); + errors::InvalidArgument( + "Requires start >= limit when delta < 0: ", start, "/", limit)); } int64 size = (std::is_integral::value ? ((std::abs(limit - start) + std::abs(delta) - 1) / diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc index 61e40caef99c019914fc331bee5d8beab0883f41..799c574d1542c345c606c276b0cc24fe61a47bba 100644 --- a/tensorflow/core/kernels/serialize_sparse_op.cc +++ b/tensorflow/core/kernels/serialize_sparse_op.cc @@ -426,7 +426,6 @@ class DeserializeSparseOp : public OpKernel { switch (dtype_) { TF_CALL_ALL_TYPES(HANDLE_TYPE); TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); - TF_CALL_variant(HANDLE_TYPE); #undef HANDLE_TYPE default: OP_REQUIRES(context, false, diff --git a/tensorflow/core/kernels/session_ops.cc b/tensorflow/core/kernels/session_ops.cc index 185c5b248fca8f5a4e8edf6d46e9447f8a0b4750..f2dd2812b53e2c974efac3d3e1aef1052d907da6 100644 --- a/tensorflow/core/kernels/session_ops.cc +++ b/tensorflow/core/kernels/session_ops.cc @@ -144,7 +144,7 @@ REGISTER_GPU_KERNEL(bool); TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL); REGISTER_SYCL_KERNEL(bool); #undef REGISTER_SYCL_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL class DeleteSessionTensorOp : public OpKernel { public: diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h index 8d9d0ea84612b51bdcd597698b89e3b8ffb8a915..55be308901b2b1233090c097944f441a17938125 100644 --- a/tensorflow/core/kernels/shape_ops.h +++ b/tensorflow/core/kernels/shape_ops.h @@ -235,10 +235,10 @@ class SqueezeOp : public OpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument("Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", - existing_dim)); + errors::InvalidArgument( + "Tried to explicitly squeeze " + "dimension ", + i, " but dimension was not 1: ", existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index 82595de77947fab01a2107e009982f6db96601e5..77594479cb1252d311fbfea8572590b0b32faecd 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -58,7 +58,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Shared code that is not dependent on the type of T. We do this to reduce // code size by not duplicating all this for all T (float, double, int32, etc.) @@ -72,10 +72,11 @@ static void SharedValidation(OpKernelContext* context, const Tensor& size_tensor = context->input(2); OP_REQUIRES( - context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) && - context->op_kernel().IsLegacyVector(size_tensor.shape()) && - begin_tensor.NumElements() == input.dims() && - size_tensor.NumElements() == input.dims(), + context, + context->op_kernel().IsLegacyVector(begin_tensor.shape()) && + context->op_kernel().IsLegacyVector(size_tensor.shape()) && + begin_tensor.NumElements() == input.dims() && + size_tensor.NumElements() == input.dims(), errors::InvalidArgument( "Expected begin and size arguments to be 1-D tensors of size ", input.dims(), ", but got shapes ", begin_tensor.shape().DebugString(), @@ -125,8 +126,7 @@ static void SharedSliceCommonCases(OpKernelContext* context, TensorShape* output_shape, gtl::InlinedVector* begin, gtl::InlinedVector* size, - Tensor** result, - bool* done) { + Tensor** result, bool* done) { bool is_identity = true; bool slice_dim0 = true; *done = false; @@ -142,8 +142,8 @@ static void SharedSliceCommonCases(OpKernelContext* context, return; } - if (slice_dim0 && IsDim0SliceAligned(input.shape(), (*begin)[0], - (*size)[0])) { + if (slice_dim0 && + IsDim0SliceAligned(input.shape(), (*begin)[0], (*size)[0])) { VLOG(1) << "Slice dim 0: " << input.shape().DebugString(); CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true. context->set_output(0, input.Slice((*begin)[0], (*begin)[0] + (*size)[0])); @@ -154,7 +154,6 @@ static void SharedSliceCommonCases(OpKernelContext* context, OP_REQUIRES_OK(context, context->allocate_output(0, *output_shape, result)); } - template class SliceOp : public OpKernel { public: @@ -206,8 +205,9 @@ class SliceOp : public OpKernel { #undef HANDLE_DIM - OP_REQUIRES(context, false, errors::Unimplemented( - "SliceOp : Unhandled input dimensions")); + OP_REQUIRES( + context, false, + errors::Unimplemented("SliceOp : Unhandled input dimensions")); } } @@ -280,8 +280,9 @@ class MklSliceOp : public OpKernel { #undef HANDLE_DIM - OP_REQUIRES(context, false, errors::Unimplemented( - "SliceOp : Unhandled input dimensions")); + OP_REQUIRES( + context, false, + errors::Unimplemented("SliceOp : Unhandled input dimensions")); } } @@ -292,9 +293,9 @@ class MklSliceOp : public OpKernel { // as the sizes of all the dimensions of the input except slice_dim, then // returns True. Otherwise, returns False. bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape, - const gtl::ArraySlice& begin, - const gtl::ArraySlice& size, - int slice_dim) { + const gtl::ArraySlice& begin, + const gtl::ArraySlice& size, + int slice_dim) { for (int dim = 0; dim < 4; dim++) { if (dim != slice_dim && (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) { @@ -316,9 +317,9 @@ class MklSliceOp : public OpKernel { // Returns True if Slicing over a single dimension, and sets slice_dim // to the number of the dimension that satisfies criteria. bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape, - const gtl::ArraySlice& begin, - const gtl::ArraySlice& size, - int* slice_dim) { + const gtl::ArraySlice& begin, + const gtl::ArraySlice& size, + int* slice_dim) { for (int dim = 0; dim < 4; dim++) { if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) { *slice_dim = dim; @@ -329,8 +330,7 @@ class MklSliceOp : public OpKernel { } template - void HandleCase(OpKernelContext* context, - const gtl::ArraySlice& begin, + void HandleCase(OpKernelContext* context, const gtl::ArraySlice& begin, const gtl::ArraySlice& size, Tensor* result) { int slice_dim = -1; TensorShape in_shape = context->input(0).shape(); @@ -340,67 +340,63 @@ class MklSliceOp : public OpKernel { // format over channel dimension. if (NDIM == 4 && DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) { - size_t in_strides[4] = { (size_t) in_shape.dim_size(1) * - in_shape.dim_size(2) * - in_shape.dim_size(3), - (size_t) in_shape.dim_size(2) * - in_shape.dim_size(3), - (size_t) in_shape.dim_size(3), - (size_t) 1 - }; - - size_t out_strides[4] = { (size_t) size[1] * size[2] * size[3], - (size_t) size[2] * size[3], - (size_t) size[3], - (size_t) 1 }; - - T *in_buf = const_cast(const_cast( - context->input(0).flat().data())); - T *op_buf = result->flat().data(); - - if (slice_dim == 1) { - /* data format = NCHW */ - - #pragma omp parallel for - for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { - T *ip = in_buf + (d0 * in_strides[0]); - T *op = op_buf + ((d0 - begin[0]) * out_strides[0]); - #pragma omp parallel for - for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { - T *ip1 = ip + (d1 * in_strides[1]); - T *op1 = op + ((d1 - begin[1]) * out_strides[1]); - // For NCHW, H and W will be contiguous. So we can copy - // both with one memcpy. - memcpy(static_cast(op1), static_cast(ip1), - sizeof(T) * in_strides[1]); - } + size_t in_strides[4] = { + (size_t)in_shape.dim_size(1) * in_shape.dim_size(2) * + in_shape.dim_size(3), + (size_t)in_shape.dim_size(2) * in_shape.dim_size(3), + (size_t)in_shape.dim_size(3), (size_t)1}; + + size_t out_strides[4] = {(size_t)size[1] * size[2] * size[3], + (size_t)size[2] * size[3], (size_t)size[3], + (size_t)1}; + + T* in_buf = const_cast( + const_cast(context->input(0).flat().data())); + T* op_buf = result->flat().data(); + + if (slice_dim == 1) { + /* data format = NCHW */ + +#pragma omp parallel for + for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { + T* ip = in_buf + (d0 * in_strides[0]); + T* op = op_buf + ((d0 - begin[0]) * out_strides[0]); +#pragma omp parallel for + for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { + T* ip1 = ip + (d1 * in_strides[1]); + T* op1 = op + ((d1 - begin[1]) * out_strides[1]); + // For NCHW, H and W will be contiguous. So we can copy + // both with one memcpy. + memcpy(static_cast(op1), static_cast(ip1), + sizeof(T) * in_strides[1]); } - return; - } else if (slice_dim == 3) { - /* data_format = NHWC */ - - #pragma omp parallel for - for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { - T *ip = in_buf + (d0 * in_strides[0]); - T *op = op_buf + ((d0 - begin[0]) * out_strides[0]); - #pragma omp parallel for - for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { - T *ip1 = ip + (d1 * in_strides[1]); - T *op1 = op + ((d1 - begin[1]) * out_strides[1]); - #pragma omp parallel for - for (size_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) { - T *ip2 = ip1 + (d2 * in_strides[2]); - T *ip3 = ip2 + begin[3]; - T *op2 = op1 + ((d2 - begin[2]) * out_strides[2]); - T *op3 = op2; - memcpy(static_cast(op3), static_cast(ip3), - sizeof(T) * size[3]); - } + } + return; + } else if (slice_dim == 3) { + /* data_format = NHWC */ + +#pragma omp parallel for + for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { + T* ip = in_buf + (d0 * in_strides[0]); + T* op = op_buf + ((d0 - begin[0]) * out_strides[0]); +#pragma omp parallel for + for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { + T* ip1 = ip + (d1 * in_strides[1]); + T* op1 = op + ((d1 - begin[1]) * out_strides[1]); +#pragma omp parallel for + for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) { + T* ip2 = ip1 + (d2 * in_strides[2]); + T* ip3 = ip2 + begin[3]; + T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]); + T* op3 = op2; + memcpy(static_cast(op3), static_cast(ip3), + sizeof(T) * size[3]); } } - return; } - // slice_dim is not 1 or 3, then we fallback to Eigen implementation. + return; + } + // slice_dim is not 1 or 3, then we fallback to Eigen implementation. } Eigen::DSizes indices; @@ -535,13 +531,13 @@ REGISTER_KERNEL_BUILDER(Name("Slice") #ifdef TENSORFLOW_USE_SYCL // Forward declarations of the functor specializations for SYCL. namespace functor { -#define DECLARE_SYCL_SPEC(T, NDIM) \ - template <> \ - void Slice::operator()( \ - const SYCLDevice& d, typename TTypes::Tensor output,\ - typename TTypes::ConstTensor input, \ - const Eigen::DSizes& indices, \ - const Eigen::DSizes& sizes); \ +#define DECLARE_SYCL_SPEC(T, NDIM) \ + template <> \ + void Slice::operator()( \ + const SYCLDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, \ + const Eigen::DSizes& indices, \ + const Eigen::DSizes& sizes); \ extern template struct Slice; #define DECLARE_FOR_N(T) \ diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h index 0362a021336f633b88a666c68f42fa5082f4f66d..db7eded745eb0d3c880dc46d164aad31b2531829 100644 --- a/tensorflow/core/kernels/slice_op.h +++ b/tensorflow/core/kernels/slice_op.h @@ -24,7 +24,6 @@ limitations under the License. namespace tensorflow { namespace functor { - template struct Slice { void operator()(const Device& d, typename TTypes::Tensor output, diff --git a/tensorflow/core/kernels/slice_op_cpu_impl.h b/tensorflow/core/kernels/slice_op_cpu_impl.h index 47f1d5342a9e56301dabad2eb9700ce97d45695d..64b6948190a23b554582975d38dae8be638840fa 100644 --- a/tensorflow/core/kernels/slice_op_cpu_impl.h +++ b/tensorflow/core/kernels/slice_op_cpu_impl.h @@ -43,7 +43,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_SYCL_KERNELS); DEFINE_SYCL_KERNELS(int32); #undef DEFINE_SYCL_KERNELS -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/snapshot_op.h b/tensorflow/core/kernels/snapshot_op.h index 2c79893b49661519515a7b4a537ff3caeceba2be..b94834f15988a21ad41eefc8030b3da1a58875f8 100644 --- a/tensorflow/core/kernels/snapshot_op.h +++ b/tensorflow/core/kernels/snapshot_op.h @@ -35,12 +35,17 @@ class SnapshotOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); Tensor* output = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, input.shape(), &output)); - const Device& device = context->eigen_device(); - device.memcpy(output->template flat().data(), - input.template flat().data(), - input.NumElements() * sizeof(Scalar)); + // Try to use buffer forwarding to avoid an explicit copy. + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {0}, 0, input.shape(), &output)); + if (!output->SharesBufferWith(input)) { + // We had to allocate a new buffer since the refcount on the input was + // greater than 1. Copy the input to the new buffer. + const Device& device = context->eigen_device(); + device.memcpy(output->template flat().data(), + input.template flat().data(), + input.NumElements() * sizeof(Scalar)); + } } }; diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc index 590f01c4691f479cbf90971b368656ff3c78c91a..e1712ac239d6be2d51b0c0598a799959a8b53a94 100644 --- a/tensorflow/core/kernels/softmax_op.cc +++ b/tensorflow/core/kernels/softmax_op.cc @@ -30,7 +30,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL // Partial specialization for a CPUDevice, that uses the Eigen implementation // from SoftmaxEigenImpl. @@ -48,7 +48,7 @@ struct SoftmaxFunctor : SoftmaxFunctorBase {}; #ifdef TENSORFLOW_USE_SYCL template struct SoftmaxFunctor : SoftmaxFunctorBase {}; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace functor template @@ -100,5 +100,5 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER( Name("Softmax").Device(DEVICE_SYCL).TypeConstraint("T"), SoftmaxOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/spacetobatch_benchmark_test.cc b/tensorflow/core/kernels/spacetobatch_benchmark_test.cc index c25ce2d8bb5ee5fe50034e74f0362fd6b0e79589..92ddf8edbfbe5e3c8fbc2c3b5ddeddd847838814 100644 --- a/tensorflow/core/kernels/spacetobatch_benchmark_test.cc +++ b/tensorflow/core/kernels/spacetobatch_benchmark_test.cc @@ -70,7 +70,7 @@ static Graph* ConstructSpaceToBatchGraph( } \ BENCHMARK( \ BM_##OP##_##DEVICE##_##DTYPE##_##B##_##H##_##W##_##D##_bs##BS##_pad##P00##_##P01##_##P10##_##P11); -#define BM_SpaceToBatch(OP, ...) \ +#define BM_SpaceToBatch(OP, ...) \ BM_Expand(BM_SpaceToBatchDev(OP, cpu, DT_FLOAT, __VA_ARGS__)); \ BM_Expand(BM_SpaceToBatchDev(OP, gpu, DT_FLOAT, __VA_ARGS__)); \ BM_Expand(BM_SpaceToBatchDev(OP, cpu, DT_HALF, __VA_ARGS__)); \ diff --git a/tensorflow/core/kernels/spacetobatch_functor.cc b/tensorflow/core/kernels/spacetobatch_functor.cc index 23d8a5f9ed4483c0e7d5c15108db6cbbdbe0890a..4c374b8d99444023c14fcb4ed770a5c263535be0 100644 --- a/tensorflow/core/kernels/spacetobatch_functor.cc +++ b/tensorflow/core/kernels/spacetobatch_functor.cc @@ -154,7 +154,7 @@ struct SpaceToBatchFunctor { #define INSTANTIATE(NUM_BLOCK_DIMS, T) \ template struct SpaceToBatchFunctor; \ template struct SpaceToBatchFunctor; \ -/**/ + /**/ #define INSTANTIATE_FOR_T(T) \ TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(INSTANTIATE, T) diff --git a/tensorflow/core/kernels/spacetobatch_functor.h b/tensorflow/core/kernels/spacetobatch_functor.h index 06813650c08ec26a38edfe2ba01440a2fb8066fc..f46a84da1e951113382e4d44b44463c2a621ca10 100644 --- a/tensorflow/core/kernels/spacetobatch_functor.h +++ b/tensorflow/core/kernels/spacetobatch_functor.h @@ -44,7 +44,7 @@ constexpr int kMaxSpaceToBatchBlockDims = 4; MACRO(2 /**/, ##__VA_ARGS__) \ MACRO(3 /**/, ##__VA_ARGS__) \ MACRO(4 /**/, ##__VA_ARGS__) \ -/**/ + /**/ namespace internal { namespace spacetobatch { diff --git a/tensorflow/core/kernels/spacetobatch_functor_gpu.cu.cc b/tensorflow/core/kernels/spacetobatch_functor_gpu.cu.cc index db8d419c38ff5f8a06a1aafde14076b55b7c75e6..5687141c9eaeec11498c1d2cc954155bd9e05856 100644 --- a/tensorflow/core/kernels/spacetobatch_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/spacetobatch_functor_gpu.cu.cc @@ -141,10 +141,10 @@ struct SpaceToBatchFunctor { } CudaLaunchConfig config = GetCudaLaunchConfig(static_cast(total_count), d); - S2B<<>>( - config.virtual_thread_count, const_cast(space_tensor.data()), args, - const_cast(batch_tensor.data())); + S2B + <<>>( + config.virtual_thread_count, const_cast(space_tensor.data()), + args, const_cast(batch_tensor.data())); return Status::OK(); } }; @@ -153,7 +153,7 @@ struct SpaceToBatchFunctor { #define INSTANTIATE(NUM_BLOCK_DIMS, T) \ template struct SpaceToBatchFunctor; \ template struct SpaceToBatchFunctor; \ -/**/ + /**/ #define INSTANTIATE_FOR_T(T) \ TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(INSTANTIATE, T) diff --git a/tensorflow/core/kernels/spacetobatch_op.cc b/tensorflow/core/kernels/spacetobatch_op.cc index 95c1f5e7e8ca978fda334396538de0cf4ed5b774..fdc08ec8e3bfd128a3e341efab8e5ba319c90e4f 100644 --- a/tensorflow/core/kernels/spacetobatch_op.cc +++ b/tensorflow/core/kernels/spacetobatch_op.cc @@ -58,9 +58,10 @@ void SpaceToBatchOpCompute(OpKernelContext* context, errors::InvalidArgument("input rank should be >= ", 1 + block_dims, " instead of ", orig_input_tensor.dims())); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(orig_paddings.shape()) && - block_dims == orig_paddings.dim_size(0) && - 2 == orig_paddings.dim_size(1), + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(orig_paddings.shape()) && + block_dims == orig_paddings.dim_size(0) && + 2 == orig_paddings.dim_size(1), errors::InvalidArgument("paddings should have shape [", block_dims, ", 2] instead of ", orig_paddings.shape().DebugString())); diff --git a/tensorflow/core/kernels/sparse_add_grad_op.cc b/tensorflow/core/kernels/sparse_add_grad_op.cc index d8ed0c6f0c20d13d5e7870159ed1569514333c5e..8597f3a8f7307584d27a265bc8df8949f20898b6 100644 --- a/tensorflow/core/kernels/sparse_add_grad_op.cc +++ b/tensorflow/core/kernels/sparse_add_grad_op.cc @@ -35,9 +35,10 @@ class SparseAddGradOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("b_indices", &b_indices)); OP_REQUIRES_OK(ctx, ctx->input("sum_indices", &sum_indices)); - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()) && - TensorShapeUtils::IsMatrix(b_indices->shape()) && - TensorShapeUtils::IsMatrix(sum_indices->shape()), + OP_REQUIRES(ctx, + TensorShapeUtils::IsMatrix(a_indices->shape()) && + TensorShapeUtils::IsMatrix(b_indices->shape()) && + TensorShapeUtils::IsMatrix(sum_indices->shape()), errors::InvalidArgument( "Input indices should be matrices but received shapes: ", a_indices->shape().DebugString(), " and ", @@ -49,8 +50,9 @@ class SparseAddGradOp : public OpKernel { "Input backprop_val_grad should be a vector but received shape: ", backprop_val_grad->shape().DebugString())); OP_REQUIRES( - ctx, a_indices->dim_size(1) == b_indices->dim_size(1) && - b_indices->dim_size(1) == sum_indices->dim_size(1), + ctx, + a_indices->dim_size(1) == b_indices->dim_size(1) && + b_indices->dim_size(1) == sum_indices->dim_size(1), errors::InvalidArgument("The densified operands should have the same " "ndims; for A, B, sum got: ", a_indices->dim_size(1), b_indices->dim_size(1), diff --git a/tensorflow/core/kernels/sparse_add_op.cc b/tensorflow/core/kernels/sparse_add_op.cc index bd91dfdce64cbfc697345e0f0c7278de938ecc5b..d16317af671dd6592d3e30ac52941508c4ffd088 100644 --- a/tensorflow/core/kernels/sparse_add_op.cc +++ b/tensorflow/core/kernels/sparse_add_op.cc @@ -34,8 +34,9 @@ class SparseAddOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices)); OP_REQUIRES_OK(ctx, ctx->input("b_indices", &b_indices)); - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()) && - TensorShapeUtils::IsMatrix(b_indices->shape()), + OP_REQUIRES(ctx, + TensorShapeUtils::IsMatrix(a_indices->shape()) && + TensorShapeUtils::IsMatrix(b_indices->shape()), errors::InvalidArgument( "Input indices should be matrices but received shapes: ", a_indices->shape().DebugString(), " and ", @@ -46,8 +47,9 @@ class SparseAddOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t)); OP_REQUIRES_OK(ctx, ctx->input("b_values", &b_values_t)); - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_values_t->shape()) && - TensorShapeUtils::IsVector(b_values_t->shape()), + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(a_values_t->shape()) && + TensorShapeUtils::IsVector(b_values_t->shape()), errors::InvalidArgument( "Input values should be vectors but received shapes: ", a_values_t->shape().DebugString(), " and ", @@ -62,8 +64,9 @@ class SparseAddOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape)); OP_REQUIRES_OK(ctx, ctx->input("b_shape", &b_shape)); - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape->shape()) && - TensorShapeUtils::IsVector(b_shape->shape()), + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(a_shape->shape()) && + TensorShapeUtils::IsVector(b_shape->shape()), errors::InvalidArgument( "Input shapes should be a vector but received shapes ", a_shape->shape().DebugString(), " and ", diff --git a/tensorflow/core/kernels/sparse_add_op_test.cc b/tensorflow/core/kernels/sparse_add_op_test.cc index 4cad02bbee8dd20328bac3ec24074c22493009b8..1f08e6c5ce2e8a40cf464760434f9161015b643c 100644 --- a/tensorflow/core/kernels/sparse_add_op_test.cc +++ b/tensorflow/core/kernels/sparse_add_op_test.cc @@ -61,9 +61,9 @@ TEST_F(SparseAddOpTest, TwoD_AddSparseTensorWithSelf) { // [3 4] const auto indices_shape = TensorShape({4, 2}); - std::initializer_list in{ 0, 1, 1, 0, 2, 0, 2, 1 }; + std::initializer_list in{0, 1, 1, 0, 2, 0, 2, 1}; const gtl::ArraySlice indices(in); - std::initializer_list sh{ 3, 2 }; + std::initializer_list sh{3, 2}; const gtl::ArraySlice shape(sh); #define ADD_TENSOR_INPUT() \ diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc index c122616cf15b8567494a604337951c8d278f5ead..80bc1f19344dffadaae864f64c98d1f15addd1fb 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc @@ -103,8 +103,9 @@ class SparseAccumulatorTakeGradientOp DoneCallback callback) override { // Check signature OP_REQUIRES_OK_ASYNC( - ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32}, - {DT_INT64, accumulator->dtype(), DT_INT64}), + ctx, + ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + {DT_INT64, accumulator->dtype(), DT_INT64}), callback); } diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index 07d935d55fe06150309736ba0fec88091ed007c6..7cd4532ad63812d905ceb6b96291aa50293070ef 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -288,8 +288,7 @@ struct CrossTraits { template class SparseCrossOp : public OpKernel { public: - explicit SparseCrossOp(OpKernelConstruction* context) - : OpKernel(context) { + explicit SparseCrossOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_)); // Read signed_hash_key_ as int64 since uint64 attributes are not // supported by REGISTER_OP. @@ -316,8 +315,8 @@ class SparseCrossOp : public OpKernel { GenerateColumnsFromInput(indices_list_in, values_list_in, shapes_list_in, dense_list_in); - typename CrossTraits::Crosser - crosser(columns, num_buckets_, hash_key_); + typename CrossTraits::Crosser crosser( + columns, num_buckets_, hash_key_); Tensor* indices_out; Tensor* values_out; Tensor* shape_out; @@ -326,8 +325,8 @@ class SparseCrossOp : public OpKernel { CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out, &shape_out, &output_start_indices); - typename CrossTraits::Updater - updater(output_start_indices, indices_out, values_out); + typename CrossTraits::Updater updater( + output_start_indices, indices_out, values_out); auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) { for (int b = begin; b < end; b++) { ProductIterator product_iterator(columns, b); @@ -381,8 +380,9 @@ class SparseCrossOp : public OpKernel { "Input values should be a std::vector but received shape ", values_list_in[i].shape().DebugString(), " at position ", i)); OP_REQUIRES( - context, indices_list_in[i].shape().dim_size(0) == - values_list_in[i].shape().dim_size(0), + context, + indices_list_in[i].shape().dim_size(0) == + values_list_in[i].shape().dim_size(0), errors::InvalidArgument( "Expected size of values to be ", indices_list_in[i].shape().dim_size(0), " got ", diff --git a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc index cc0f86ce05e613767b22d51875f90e8391504bdb..ac48202ada2204ea36478257630f20f7892be50b 100644 --- a/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc +++ b/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc @@ -70,8 +70,9 @@ class SparseDenseBinaryOpShared : public OpKernel { errors::InvalidArgument( "Input sp_indices should be a matrix but received shape: ", indices_t->shape().DebugString())); - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values_t->shape()) && - TensorShapeUtils::IsVector(shape_t->shape()), + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(values_t->shape()) && + TensorShapeUtils::IsVector(shape_t->shape()), errors::InvalidArgument( "Inputs sp_values and sp_shape should be vectors " "but received shapes: ", @@ -150,8 +151,9 @@ class SparseDenseBinaryOpShared : public OpKernel { CASE(4); CASE(5); default: - OP_REQUIRES(ctx, false, errors::InvalidArgument( - "Only tensors with ranks between 1 and 5 " + OP_REQUIRES( + ctx, false, + errors::InvalidArgument("Only tensors with ranks between 1 and 5 " "are currently supported. Tensor rank: ", ndims)); #undef CASE diff --git a/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc b/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc index eaf1884243ec19689af783e29adaee886e7498d6..fe198af7e6c131ab19daf877063a2a6838d1f2c7 100644 --- a/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc +++ b/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc @@ -96,9 +96,9 @@ TEST_F(SparseDenseCDivTest, SameShape) { // [2 ] cdiv [dense: same shape, all 1's] // [3 4] const auto indices_shape = TensorShape({4, 2}); - std::initializer_list in{ 0, 1, 1, 0, 2, 0, 2, 1 }; + std::initializer_list in{0, 1, 1, 0, 2, 0, 2, 1}; const gtl::ArraySlice indices(in); - std::initializer_list sh{ 3, 2 }; + std::initializer_list sh{3, 2}; const gtl::ArraySlice shape(sh); // Tensor dense(DT_FLOAT, TensorShape({3, 1})); @@ -125,9 +125,9 @@ TEST_F(SparseDenseCDivTest, BroadcastDenseSameDims) { // [2 ] cdiv [dense: shape [3,1], all 1's] // [3 4] const auto indices_shape = TensorShape({4, 2}); - std::initializer_list in{ 0, 1, 1, 0, 2, 0, 2, 1 }; + std::initializer_list in{0, 1, 1, 0, 2, 0, 2, 1}; const gtl::ArraySlice indices(in); - std::initializer_list sh{ 3, 2 }; + std::initializer_list sh{3, 2}; const gtl::ArraySlice shape(sh); Tensor dense(DT_FLOAT, TensorShape({3, 1})); @@ -152,9 +152,9 @@ TEST_F(SparseDenseCDivTest, BroadcastDenseFewerDims) { // [2 ] cdiv [dense: shape [2]] // [3 4] const auto indices_shape = TensorShape({4, 2}); - std::initializer_list in{ 0, 1, 1, 0, 2, 0, 2, 1 }; + std::initializer_list in{0, 1, 1, 0, 2, 0, 2, 1}; const gtl::ArraySlice indices(in); - std::initializer_list sh{ 3, 2 }; + std::initializer_list sh{3, 2}; const gtl::ArraySlice shape(sh); Tensor dense(DT_FLOAT, TensorShape({2})); @@ -184,9 +184,9 @@ TEST_F(SparseDenseCMulTest, BroadcastDense) { // [1 ?] where ? remains implicitly zero. // [1.5 0] const auto indices_shape = TensorShape({4, 2}); - std::initializer_list in{ 0, 1, 1, 0, 2, 0, 2, 1 }; + std::initializer_list in{0, 1, 1, 0, 2, 0, 2, 1}; const gtl::ArraySlice indices(in); - std::initializer_list sh{ 3, 2 }; + std::initializer_list sh{3, 2}; const gtl::ArraySlice shape(sh); Tensor dense(DT_FLOAT, TensorShape({2})); diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index 8ab23b64d3d94c604ae027bbfd75357a4e2e284b..a1f9667b783ca5f455523874bc4e342f1368d4f3 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -159,8 +159,8 @@ struct SparseSlice { template template -void SparseSlice::Initialize(const typename SparseSlice::ConstMatrixMap& mat, - int col_offset) { +void SparseSlice::Initialize( + const typename SparseSlice::ConstMatrixMap& mat, int col_offset) { const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0); const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1); DCHECK_LE(num_rows, mat_rows); @@ -278,9 +278,9 @@ ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) { float out = 0; auto tmp = reinterpret_cast(&out); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - tmp[0] = *src; + tmp[0] = *src; #else - tmp[1] = *src; + tmp[1] = *src; #endif return out; } @@ -970,9 +970,9 @@ class SparseMatMulOp : public OpKernel { const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0); OP_REQUIRES(ctx, k == k2, - errors::InvalidArgument("Matrix size incompatible: a: ", - a.shape().DebugString(), ", b: ", - b.shape().DebugString())); + errors::InvalidArgument( + "Matrix size incompatible: a: ", a.shape().DebugString(), + ", b: ", b.shape().DebugString())); Tensor* output = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output)); @@ -1224,8 +1224,9 @@ ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src, template inline BlockingCounter* SparseMatMul::ShuffleMatrix( - const typename SparseMatMul::ConstMatrixMapR& mat, int slice_row_start, - int slice_num_rows, int slice_col_start, int slice_num_cols, const int N, + const typename SparseMatMul::ConstMatrixMapR& mat, + int slice_row_start, int slice_num_rows, int slice_col_start, + int slice_num_cols, const int N, const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) { DCHECK_EQ(N % 2, 0); DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N); @@ -1306,8 +1307,9 @@ inline std::unique_ptr SparseMatMul::CreateDenseSlices( template inline void SparseMatMul::ComputeBlockSizes( const typename SparseMatMul::ConstMatrixMapL& left, - const typename SparseMatMul::ConstMatrixMapR& right, bool transpose_left, - int num_threads, int* KR, int* NR, int* KL, int* JB, int* IB) { + const typename SparseMatMul::ConstMatrixMapR& right, + bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB, + int* IB) { // Heuristics for calculating block sizes // Assume two hyperthreads per core. const int est_num_cores = std::max(1, (num_threads + 1) / 2); diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h index cca52558ae25a7a0840d8551440f68ccc5ec2277..14ef2ed7044a796dff67e287230d955e32ca62cd 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.h +++ b/tensorflow/core/kernels/sparse_matmul_op.h @@ -159,25 +159,25 @@ EIGEN_STRONG_INLINE Packet4f pload2bf16(const float* from) { // Return a packet with the first value of the input Packet replicated template <> EIGEN_STRONG_INLINE Packet4f pbroadcast_first(const Packet4f& a) { - return vec_splat (a, 0); + return vec_splat(a, 0); } // Return a packet with the second value of the input Packet replicated template <> EIGEN_STRONG_INLINE Packet4f pbroadcast_second(const Packet4f& a) { - return vec_splat (a, 1); + return vec_splat(a, 1); } // Return a packet with the third value of the input Packet replicated template <> EIGEN_STRONG_INLINE Packet4f pbroadcast_third(const Packet4f& a) { - return vec_splat (a, 2); + return vec_splat(a, 2); } // Return a packet with the fourth value of the input Packet replicated template <> EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth(const Packet4f& a) { - return vec_splat (a, 3); + return vec_splat(a, 3); } #endif diff --git a/tensorflow/core/kernels/sparse_matmul_op_test.cc b/tensorflow/core/kernels/sparse_matmul_op_test.cc index f815ca9e344664c4c95befccb88e750eb99d0eaf..ebc6d8fa4ec5422925e57c25856e0007702299b1 100644 --- a/tensorflow/core/kernels/sparse_matmul_op_test.cc +++ b/tensorflow/core/kernels/sparse_matmul_op_test.cc @@ -284,11 +284,11 @@ class SparseMatmulOpTest : public ::testing::Test { uint16_t* data3_bfloat16_p = reinterpret_cast(data3_bfloat16) + i; #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - data3_p[1] = 0; - data3_bfloat16_p[0] = data3_p[0]; + data3_p[1] = 0; + data3_bfloat16_p[0] = data3_p[0]; #else - data3_p[0] = 0; - data3_bfloat16_p[0] = data3_p[1]; + data3_p[0] = 0; + data3_bfloat16_p[0] = data3_p[1]; #endif } } diff --git a/tensorflow/core/kernels/sparse_reduce_sum_op_test.cc b/tensorflow/core/kernels/sparse_reduce_sum_op_test.cc index 110376be42573fe31cc1a13306c80e5050477f03..96246c7a71272bf638523fafb548b7e802f09039 100644 --- a/tensorflow/core/kernels/sparse_reduce_sum_op_test.cc +++ b/tensorflow/core/kernels/sparse_reduce_sum_op_test.cc @@ -51,9 +51,9 @@ TEST_F(SparseReduceSumOpTest, SimpleReduce) { // [3 4] const auto indices_shape = TensorShape({4, 2}); - std::initializer_list in{ 0, 1, 1, 0, 2, 0, 2, 1 }; + std::initializer_list in{0, 1, 1, 0, 2, 0, 2, 1}; const gtl::ArraySlice indices(in); - std::initializer_list sh{ 3, 2 }; + std::initializer_list sh{3, 2}; const gtl::ArraySlice shape(sh); AddInputFromArray(indices_shape, indices); @@ -93,9 +93,9 @@ TEST_F(SparseReduceSumSparseOpTest, SimpleReduce) { // [3 4] const auto indices_shape = TensorShape({4, 2}); - std::initializer_list in{ 0, 1, 1, 0, 2, 0, 2, 1 }; + std::initializer_list in{0, 1, 1, 0, 2, 0, 2, 1}; const gtl::ArraySlice indices(in); - std::initializer_list sh{ 3, 2 }; + std::initializer_list sh{3, 2}; const gtl::ArraySlice shape(sh); AddInputFromArray(indices_shape, indices); diff --git a/tensorflow/core/kernels/sparse_softmax_op.cc b/tensorflow/core/kernels/sparse_softmax_op.cc index 327a94b8a12e1d8568c5ca79263cc6eb78501d15..444a5f657a969290d9cc67d88c500a49a0971282 100644 --- a/tensorflow/core/kernels/sparse_softmax_op.cc +++ b/tensorflow/core/kernels/sparse_softmax_op.cc @@ -50,8 +50,9 @@ class SparseSoftmaxOp : public OpKernel { errors::InvalidArgument( "Input sp_indices should be a matrix but received shape: ", indices_t->shape().DebugString())); - OP_REQUIRES(context, TensorShapeUtils::IsVector(values_t->shape()) && - TensorShapeUtils::IsVector(shape_t->shape()), + OP_REQUIRES(context, + TensorShapeUtils::IsVector(values_t->shape()) && + TensorShapeUtils::IsVector(shape_t->shape()), errors::InvalidArgument( "Inputs sp_values and sp_shape should be vectors " "but received shapes: ", diff --git a/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc b/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc index b027adba6b384c63d119387b5b13122fb1c25b12..09cb2a6a71c7c0f0ebc9cbc2e7b1951705890a41 100644 --- a/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc +++ b/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc @@ -132,14 +132,16 @@ class SparseSparseBinaryOpShared : public OpKernel { // Validations. OP_REQUIRES( - ctx, TensorShapeUtils::IsMatrix(a_indices_t->shape()) && - TensorShapeUtils::IsMatrix(b_indices_t->shape()), + ctx, + TensorShapeUtils::IsMatrix(a_indices_t->shape()) && + TensorShapeUtils::IsMatrix(b_indices_t->shape()), errors::InvalidArgument("Inputs a_indices and b_indices should be " "matrices but received shapes: ", a_indices_t->shape().DebugString(), ", ", b_indices_t->shape().DebugString())); - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_values_t->shape()) && - TensorShapeUtils::IsVector(b_values_t->shape()), + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(a_values_t->shape()) && + TensorShapeUtils::IsVector(b_values_t->shape()), errors::InvalidArgument( "Inputs a_values and b_values should be vectors " "but received shapes: ", @@ -157,8 +159,9 @@ class SparseSparseBinaryOpShared : public OpKernel { " non-empty input values, got ", a_values.size(), " and ", b_values.size())); - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape_t->shape()) && - TensorShapeUtils::IsVector(b_shape_t->shape()), + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(a_shape_t->shape()) && + TensorShapeUtils::IsVector(b_shape_t->shape()), errors::InvalidArgument( "Input shapes should be a vector but received shapes ", a_shape_t->shape().DebugString(), " and ", diff --git a/tensorflow/core/kernels/sparse_split_op.cc b/tensorflow/core/kernels/sparse_split_op.cc index 6171b532aa243e6a3d8b42e5c8856aaa1c7ad207..67dcf05a6ced17fa2dbd44fb03dca21a032bcc5b 100644 --- a/tensorflow/core/kernels/sparse_split_op.cc +++ b/tensorflow/core/kernels/sparse_split_op.cc @@ -48,18 +48,20 @@ class SparseSplitOp : public OpKernel { "Input shape should be a vector but received shape ", input_shape.shape().DebugString())); - OP_REQUIRES(context, input_shape.dim_size(0) && - split_dim < input_shape.vec().size(), - errors::InvalidArgument( - "Input split_dim should be between 0 and rank (", - input_shape.vec().size(), "), got ", split_dim)); - - OP_REQUIRES(context, num_split_ >= 1 && - num_split_ <= input_shape.vec()(split_dim), - errors::InvalidArgument("Input num_split should be between 1 " - "and the splitting dimension size (", - input_shape.vec()(split_dim), - "), got ", num_split_)); + OP_REQUIRES( + context, + input_shape.dim_size(0) && split_dim < input_shape.vec().size(), + errors::InvalidArgument( + "Input split_dim should be between 0 and rank (", + input_shape.vec().size(), "), got ", split_dim)); + + OP_REQUIRES( + context, + num_split_ >= 1 && num_split_ <= input_shape.vec()(split_dim), + errors::InvalidArgument("Input num_split should be between 1 " + "and the splitting dimension size (", + input_shape.vec()(split_dim), "), got ", + num_split_)); sparse::SparseTensor sparse_tensor(input_indices, input_values, TensorShape(input_shape.vec())); diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc index 6a6cc3d81382a783aa9e34c841cb7be650dd7c87..ba3da21a4331562354e7dfce3348954fda3d46ad 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op.cc @@ -73,8 +73,9 @@ class SparseToDense : public OpKernel { // sparse_values const Tensor& sparse_values = c->input(2); const int64 num_values = sparse_values.NumElements(); - OP_REQUIRES(c, sparse_values.dims() == 0 || - (sparse_values.dims() == 1 && num_values == num_elems), + OP_REQUIRES(c, + sparse_values.dims() == 0 || + (sparse_values.dims() == 1 && num_values == num_elems), errors::InvalidArgument("sparse_values has incorrect shape ", sparse_values.shape().DebugString(), ", should be [] or [", num_elems, "]")); diff --git a/tensorflow/core/kernels/sparse_to_dense_op_test.cc b/tensorflow/core/kernels/sparse_to_dense_op_test.cc index f0d19da8046e7cb3c243f1e4e6c3266a0f96d921..d8b0f93082453bab574fe5fd5edbb78041efad54 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op_test.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op_test.cc @@ -38,7 +38,6 @@ namespace { class SparseToDenseTest : public OpsTestBase { protected: - void MakeOp(int dim, DataType index_type, DataType value_type) { TF_ASSERT_OK(NodeDefBuilder("sparsetodense", "SparseToDense") .Input(FakeInput(index_type)) diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc index c35ba42db2915216fe74a1f82d403e9b6803f63a..f84ffd53238f7753c1b4562268be9058c6c03e6d 100644 --- a/tensorflow/core/kernels/sparse_xent_op.cc +++ b/tensorflow/core/kernels/sparse_xent_op.cc @@ -39,10 +39,10 @@ Status CheckInvalidLabelIndex(const Tensor& labels, int64 max_index) { if (*min_max_dim_value.first < 0 || *min_max_dim_value.second >= max_index) { bad_index = (*min_max_dim_value.first < 0) ? *min_max_dim_value.first : *min_max_dim_value.second; - return errors::InvalidArgument("Received a label value of ", bad_index, - " which is outside the valid range of [0, ", - max_index, "). Label values: ", - labels.SummarizeValue(labels.NumElements())); + return errors::InvalidArgument( + "Received a label value of ", bad_index, + " which is outside the valid range of [0, ", max_index, + "). Label values: ", labels.SummarizeValue(labels.NumElements())); } return Status::OK(); } diff --git a/tensorflow/core/kernels/sparse_xent_op_test.cc b/tensorflow/core/kernels/sparse_xent_op_test.cc index b8ea0d2d7e279bc089aeb5574fc58c1af1686ca9..afb0bf76267f24ba1e2142954abfdcb41356cb96 100644 --- a/tensorflow/core/kernels/sparse_xent_op_test.cc +++ b/tensorflow/core/kernels/sparse_xent_op_test.cc @@ -41,10 +41,10 @@ static Graph* SparseXent(int batch_size, int num_classes) { return g; } -#define BM_SparseXentDev(BATCH, CLASS, DEVICE) \ - static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE(int iters) { \ +#define BM_SparseXentDev(BATCH, CLASS, DEVICE) \ + static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE(int iters) { \ testing::ItemsProcessed(static_cast(iters) * BATCH * CLASS); \ - test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS)).Run(iters); \ + test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS)).Run(iters); \ } \ BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE); diff --git a/tensorflow/core/kernels/split_lib.h b/tensorflow/core/kernels/split_lib.h index ff92ffeeb38a964dcd068b54f9558ca8da7c969e..a08949e626cc8e5d4c3707b75a902d82b46c3376 100644 --- a/tensorflow/core/kernels/split_lib.h +++ b/tensorflow/core/kernels/split_lib.h @@ -57,7 +57,7 @@ struct Split { const Eigen::DSizes& slice_indices, const Eigen::DSizes& slice_sizes); }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/split_lib_cpu.cc b/tensorflow/core/kernels/split_lib_cpu.cc index 25026208d1ee78cb614e4ad41dccb7a0fa0f7817..771c633b156edf7c7d9944fe95703a0e0cd9e981 100644 --- a/tensorflow/core/kernels/split_lib_cpu.cc +++ b/tensorflow/core/kernels/split_lib_cpu.cc @@ -49,13 +49,13 @@ void Split::operator()( typename TTypes::ConstTensor input, const Eigen::DSizes& slice_indices, const Eigen::DSizes& slice_sizes) { - output.device(d) = input.slice(slice_indices, slice_sizes); + output.device(d) = input.slice(slice_indices, slice_sizes); } #define DEFINE_SYCL_KERNELS(T) template struct Split; TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_SYCL_KERNELS); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index 78badde27e5c4ca33faa00073e7b412e85d82970..85f529326dbf5d9d5ae72156da05f08f805d1271 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -39,7 +39,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class SplitOpBase : public OpKernel { @@ -142,8 +142,9 @@ class SplitOpCPU : public SplitOpBase { // Android also uses int32 indexing, so check here also. OP_REQUIRES( - context, FastBoundsCheck(input.NumElements(), - std::numeric_limits::max()), + context, + FastBoundsCheck(input.NumElements(), + std::numeric_limits::max()), errors::InvalidArgument("Split requires input size < ", std::numeric_limits::max())); @@ -245,10 +246,11 @@ class SplitOpGPU : public SplitOpBase { const int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; const int32 num_split = Base::num_outputs(); - OP_REQUIRES(context, FastBoundsCheck(input.NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("Split on GPU requires input size " - "< max int32")); + OP_REQUIRES( + context, + FastBoundsCheck(input.NumElements(), std::numeric_limits::max()), + errors::InvalidArgument("Split on GPU requires input size " + "< max int32")); int32 prefix_dim_size; int32 split_dim_size; int32 suffix_dim_size; @@ -304,8 +306,9 @@ class SplitOpSYCL : public SplitOpBase { // Android also uses int32 indexing, so check here also. OP_REQUIRES( - context, FastBoundsCheck(input.NumElements(), - std::numeric_limits::max()), + context, + FastBoundsCheck(input.NumElements(), + std::numeric_limits::max()), errors::InvalidArgument("Split requires input size < ", std::numeric_limits::max())); @@ -342,14 +345,14 @@ class SplitOpSYCL : public SplitOpBase { {prefix_dim_size, split_dim_output_size, suffix_dim_size}); functor::Split()(context->eigen_device(), - result_shaped, input_reshaped, - slice_indices, slice_sizes); + result_shaped, input_reshaped, + slice_indices, slice_sizes); } indices[1] += split_dim_output_size; } } }; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #define REGISTER_SPLIT(type) \ REGISTER_KERNEL_BUILDER(Name("Split") \ @@ -381,11 +384,11 @@ REGISTER_GPU(bfloat16); #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL(type) \ - REGISTER_KERNEL_BUILDER(Name("Split") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T") \ - .HostMemory("split_dim"), \ +#define REGISTER_SYCL(type) \ + REGISTER_KERNEL_BUILDER(Name("Split") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("split_dim"), \ SplitOpSYCL) TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL); diff --git a/tensorflow/core/kernels/split_v_op.cc b/tensorflow/core/kernels/split_v_op.cc index f1078ac349c979bb14f3949c05a7c493c9355567..7ff5df47d70fa8e47aabfb24e82874c146708ef1 100644 --- a/tensorflow/core/kernels/split_v_op.cc +++ b/tensorflow/core/kernels/split_v_op.cc @@ -197,8 +197,9 @@ class SplitVOpCPU : public SplitVOpBase { // Android also uses int32 indexing, so check here also. OP_REQUIRES( - context, FastBoundsCheck(input.NumElements(), - std::numeric_limits::max()), + context, + FastBoundsCheck(input.NumElements(), + std::numeric_limits::max()), errors::InvalidArgument("Split requires input size < ", std::numeric_limits::max())); @@ -305,10 +306,11 @@ class SplitVOpGPU : public SplitVOpBase { const int32 split_dim_orig = context->input(2).flat()(0); const int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; - OP_REQUIRES(context, FastBoundsCheck(input.NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("Split on GPU requires input size " - "< max int32")); + OP_REQUIRES( + context, + FastBoundsCheck(input.NumElements(), std::numeric_limits::max()), + errors::InvalidArgument("Split on GPU requires input size " + "< max int32")); int32 prefix_dim_size; int32 split_dim_size; diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc index affe81a55567d1ef304c7161c65c201021da1363..65296f61fd180e2f57855d4cee1566bf827dd46a 100644 --- a/tensorflow/core/kernels/stack_ops.cc +++ b/tensorflow/core/kernels/stack_ops.cc @@ -42,7 +42,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL class Stack : public ResourceBase { public: @@ -242,7 +242,7 @@ REGISTER_KERNEL_BUILDER(Name("StackV2") .HostMemory("max_size") .HostMemory("handle"), StackOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class StackPushOp : public AsyncOpKernel { @@ -274,11 +274,11 @@ class StackPushOp : public AsyncOpKernel { static constexpr int kCopyThreshold = 2048; static constexpr double kOccupancy = 0.7; if (swap_memory_ && !alloc_attrs.on_host() && - ( std::is_same::value + (std::is_same::value #ifdef TENSORFLOW_USE_SYCL - || std::is_same::value -#endif // TENSORFLOW_USE_SYCL - ) && + || std::is_same::value +#endif // TENSORFLOW_USE_SYCL + ) && tensor.TotalBytes() > kCopyThreshold && stack->IsUsefulToSwap(tensor)) { DeviceContext* device_ctxt = ctx->op_device_context(); auto device = static_cast(ctx->device()); @@ -391,7 +391,7 @@ REGISTER_SYCL_HOST_KERNEL(int32); REGISTER_SYCL_HOST_KERNEL(bool); #undef REGISTER_SYCL_KERNEL #undef REGISTER_SYCL_HOST_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL class StackPopOp : public AsyncOpKernel { public: @@ -498,7 +498,7 @@ REGISTER_SYCL_HOST_KERNEL(bool); #undef REGISTER_SYCL_KERNEL #undef REGISTER_SYCL_HOST_KERNEL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL class StackCloseOp : public OpKernel { public: @@ -526,6 +526,6 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER( Name("StackCloseV2").Device(DEVICE_SYCL).HostMemory("handle"), StackCloseOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc index 0fae46dea61d361bd4ead0afc0fa33711407fc9b..03fc4467a1dcf9d70c90c19809690934b0a7c2f4 100644 --- a/tensorflow/core/kernels/stage_op.cc +++ b/tensorflow/core/kernels/stage_op.cc @@ -70,12 +70,11 @@ class Buffer : public ResourceBase { return bytes + current_bytes_ > memory_limit_; } - std::size_t GetTupleBytes(const Tuple & tuple) - { + std::size_t GetTupleBytes(const Tuple& tuple) { return std::accumulate(tuple.begin(), tuple.end(), 0, - [](const std::size_t & lhs, const Tensor & rhs) { - return lhs + rhs.TotalBytes(); - }); + [](const std::size_t& lhs, const Tensor& rhs) { + return lhs + rhs.TotalBytes(); + }); } public: @@ -90,19 +89,22 @@ class Buffer : public ResourceBase { std::size_t tuple_bytes = GetTupleBytes(*tuple); // Sanity check so that we don't block for ever below - if(memory_limit_ > 0 && tuple_bytes > memory_limit_) { - return Status(errors::ResourceExhausted("Attempted to insert " - "tensors with combined size of '", tuple_bytes, "' bytes into " - "Staging Area with a memory limit of '", memory_limit_, "'.")); + if (memory_limit_ > 0 && tuple_bytes > memory_limit_) { + return Status( + errors::ResourceExhausted("Attempted to insert " + "tensors with combined size of '", + tuple_bytes, + "' bytes into " + "Staging Area with a memory limit of '", + memory_limit_, "'.")); } - // If buffer capacity is bounded wait until elements have been removed - if(IsBounded()) { + if (IsBounded()) { full_cond_var_.wait(lock, [tuple_bytes, this]() { // If there's a memory limit, check if there's space for insertion - bool memory_limit_valid = memory_limit_ > 0 ? - !WouldExceedMemoryLimit(tuple_bytes) : true; + bool memory_limit_valid = + memory_limit_ > 0 ? !WouldExceedMemoryLimit(tuple_bytes) : true; // If we're configured for capacity check if there's space for insertion bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true; @@ -186,8 +188,7 @@ Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) { ContainerInfo cinfo; // Lambda for creating the Staging Area - auto create_fn = [&ndef](Buffer** ret) -> Status - { + auto create_fn = [&ndef](Buffer** ret) -> Status { int64 capacity; int64 memory_limit; TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity)); @@ -196,7 +197,6 @@ Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) { return Status::OK(); }; - TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */)); TF_RETURN_IF_ERROR(rm->LookupOrCreate(cinfo.container(), cinfo.name(), buf, create_fn)); @@ -228,7 +228,7 @@ REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_GPU), StageOp); #endif #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_SYCL), StageOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL class UnstageOp : public OpKernel { public: @@ -244,7 +244,8 @@ class UnstageOp : public OpKernel { buf->Get(&tuple); - OP_REQUIRES(ctx, tuple.size() == (size_t)ctx->num_outputs(), + OP_REQUIRES( + ctx, tuple.size() == (size_t)ctx->num_outputs(), errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(), " vs. ", ctx->num_outputs())); @@ -260,7 +261,7 @@ REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_GPU), UnstageOp); #endif #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_SYCL), UnstageOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL class StagePeekOp : public OpKernel { public: @@ -278,7 +279,8 @@ class StagePeekOp : public OpKernel { OP_REQUIRES_OK(ctx, buf->Peek(index, &tuple)); - OP_REQUIRES(ctx, tuple.size() == (size_t)ctx->num_outputs(), + OP_REQUIRES( + ctx, tuple.size() == (size_t)ctx->num_outputs(), errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(), " vs. ", ctx->num_outputs())); @@ -288,17 +290,15 @@ class StagePeekOp : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("StagePeek").Device(DEVICE_CPU), - StagePeekOp); +REGISTER_KERNEL_BUILDER(Name("StagePeek").Device(DEVICE_CPU), StagePeekOp); #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("StagePeek").HostMemory("index"). - Device(DEVICE_GPU), StagePeekOp); +REGISTER_KERNEL_BUILDER( + Name("StagePeek").HostMemory("index").Device(DEVICE_GPU), StagePeekOp); #endif #ifdef TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("StagePeek").HostMemory("index") - .Device(DEVICE_SYCL), StagePeekOp); -#endif // TENSORFLOW_USE_SYCL - +REGISTER_KERNEL_BUILDER( + Name("StagePeek").HostMemory("index").Device(DEVICE_SYCL), StagePeekOp); +#endif // TENSORFLOW_USE_SYCL class StageSizeOp : public OpKernel { public: @@ -312,9 +312,8 @@ class StageSizeOp : public OpKernel { core::ScopedUnref scope(buf); // Allocate size output tensor - Tensor * size = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), - &size)); + Tensor* size = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size)); // Set it to the actual size size->scalar().setConstant(buf->Size()); @@ -323,13 +322,13 @@ class StageSizeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("StageSize").Device(DEVICE_CPU), StageSizeOp); #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("StageSize").HostMemory("size") - .Device(DEVICE_GPU), StageSizeOp); +REGISTER_KERNEL_BUILDER(Name("StageSize").HostMemory("size").Device(DEVICE_GPU), + StageSizeOp); #endif #ifdef TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("StageSize").HostMemory("size") - .Device(DEVICE_SYCL), StageSizeOp); -#endif // TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("StageSize").HostMemory("size").Device(DEVICE_SYCL), StageSizeOp); +#endif // TENSORFLOW_USE_SYCL class StageClearOp : public OpKernel { public: @@ -352,7 +351,6 @@ REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_GPU), StageClearOp); #endif #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_SYCL), StageClearOp); -#endif // TENSORFLOW_USE_SYCL - +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 7c213e14d21efd6fcf033d3cd341c35838fe9f7b..7745effe2abe94ba73a2f0d761210e07c62e499c 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -294,6 +294,11 @@ class StridedSliceAssignOp : public OpKernel { OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &v)); old_lhs = *v->tensor(); + OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum::value, + errors::InvalidArgument( + "l-value dtype ", DataTypeString(old_lhs.dtype()), + " does not match r-value dtype ", + DataTypeString(DataTypeToEnum::value))); } else { context->forward_ref_input_to_ref_output(0, 0); old_lhs = context->mutable_input(0, true); @@ -541,5 +546,5 @@ REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") .HostMemory("strides"), StridedSliceAssignOp) #undef REGISTER_SYCL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/strided_slice_op.h b/tensorflow/core/kernels/strided_slice_op.h index 0f72c4b771025458a1403ce13842787249a2718f..2b5863229860c256e1c74f1fe11bf57ed502008e 100644 --- a/tensorflow/core/kernels/strided_slice_op.h +++ b/tensorflow/core/kernels/strided_slice_op.h @@ -21,6 +21,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h index a84ba38ef41486f86f5e37bd95287b8ae6c9bb2e..1c4472bb1ab4e6b9d09a1f1464577172056c6fbe 100644 --- a/tensorflow/core/kernels/strided_slice_op_impl.h +++ b/tensorflow/core/kernels/strided_slice_op_impl.h @@ -26,6 +26,8 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types_traits.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/dense_update_functor.h" #include "tensorflow/core/kernels/ops_util.h" @@ -302,7 +304,7 @@ DECLARE_FOR_N_SYCL(int32); DECLARE_FOR_N_SYCL(int64); #undef DECLARE_FOR_N_SYCL -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #undef INSTANTIATE #undef DECLARE_FOR_N_CPU diff --git a/tensorflow/core/kernels/string_join_op.cc b/tensorflow/core/kernels/string_join_op.cc index 721702bec68efa24d4dafef1e9aaa0c5f1b4c849..28cca9f44849b39647ba08c54d9e1f3c108f91fd 100644 --- a/tensorflow/core/kernels/string_join_op.cc +++ b/tensorflow/core/kernels/string_join_op.cc @@ -50,9 +50,9 @@ class StringJoinOp : public OpKernel { } else { OP_REQUIRES( context, input_shape == input.shape(), - errors::InvalidArgument("Input shapes do not match: ", - input_shape.DebugString(), " vs. ", - input.shape().DebugString())); + errors::InvalidArgument( + "Input shapes do not match: ", input_shape.DebugString(), + " vs. ", input.shape().DebugString())); } } } diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 743f11315042af94cfe41cecf52d145ae69f8209..22e45918a03833c784f23911061c5b049658ffbe 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -95,9 +95,9 @@ class SubstrOp : public OpKernel { // Create BCast helper with shape of input and pos/len BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape)); OP_REQUIRES(context, bcast.IsValid(), - errors::InvalidArgument("Incompatible shapes: ", - input_shape.DebugString(), " vs. ", - pos_shape.DebugString())); + errors::InvalidArgument( + "Incompatible shapes: ", input_shape.DebugString(), + " vs. ", pos_shape.DebugString())); TensorShape output_shape = BCast::ToShape(bcast.result_shape()); int ndims = output_shape.dims(); Tensor* output_tensor = nullptr; @@ -115,7 +115,7 @@ class SubstrOp : public OpKernel { Tensor input_buffer; OP_REQUIRES_OK(context, context->allocate_temp( DT_STRING, output_shape, &input_buffer)); - typename TTypes::Tensor input_bcast = + TTypes::Tensor input_bcast = input_buffer.shaped(bcast.result_shape()); input_bcast = input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast())); @@ -125,8 +125,8 @@ class SubstrOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), output_shape, &pos_buffer)); - typename TTypes::Tensor pos_bcast = - pos_buffer.shaped(bcast.result_shape()); + typename TTypes::Tensor pos_bcast( + pos_buffer.shaped(bcast.result_shape())); pos_bcast = pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); @@ -135,8 +135,8 @@ class SubstrOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), output_shape, &len_buffer)); - typename TTypes::Tensor len_bcast = - len_buffer.shaped(bcast.result_shape()); + typename TTypes::Tensor len_bcast( + len_buffer.shaped(bcast.result_shape())); len_bcast = len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); @@ -164,7 +164,7 @@ class SubstrOp : public OpKernel { Tensor input_buffer; OP_REQUIRES_OK(context, context->allocate_temp( DT_STRING, output_shape, &input_buffer)); - typename TTypes::Tensor input_bcast = + TTypes::Tensor input_bcast = input_buffer.shaped(bcast.result_shape()); input_bcast = input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast())); @@ -174,8 +174,8 @@ class SubstrOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), output_shape, &pos_buffer)); - typename TTypes::Tensor pos_bcast = - pos_buffer.shaped(bcast.result_shape()); + typename TTypes::Tensor pos_bcast( + pos_buffer.shaped(bcast.result_shape())); pos_bcast = pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); @@ -184,8 +184,8 @@ class SubstrOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), output_shape, &len_buffer)); - typename TTypes::Tensor len_bcast = - len_buffer.shaped(bcast.result_shape()); + typename TTypes::Tensor len_bcast( + len_buffer.shaped(bcast.result_shape())); len_bcast = len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc index 233b824bcc3bab74d70c990c44389e6df7b10f02..29b21ee7353fe03ce87bc03dad72b05ca8fd4311 100644 --- a/tensorflow/core/kernels/summary_image_op.cc +++ b/tensorflow/core/kernels/summary_image_op.cc @@ -54,18 +54,20 @@ class SummaryImageOp : public OpKernel { const Tensor& tensor = c->input(1); OP_REQUIRES(c, IsLegacyScalar(tags.shape()), errors::InvalidArgument("Tags must be a scalar")); - OP_REQUIRES(c, tensor.dims() == 4 && - (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 || - tensor.dim_size(3) == 4), + OP_REQUIRES(c, + tensor.dims() == 4 && + (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 || + tensor.dim_size(3) == 4), errors::InvalidArgument( "Tensor must be 4-D with last dim 1, 3, or 4, not ", tensor.shape().DebugString())); const string& base_tag = tags.scalar()(); - OP_REQUIRES(c, tensor.dim_size(0) < (1LL << 31) && - tensor.dim_size(1) < (1LL << 31) && - tensor.dim_size(2) < (1LL << 31) && - (tensor.dim_size(1) * tensor.dim_size(2)) < (1LL << 29), + OP_REQUIRES(c, + tensor.dim_size(0) < (1LL << 31) && + tensor.dim_size(1) < (1LL << 31) && + tensor.dim_size(2) < (1LL << 31) && + (tensor.dim_size(1) * tensor.dim_size(2)) < (1LL << 29), errors::InvalidArgument("Tensor too large for summary ", tensor.shape().DebugString())); diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index 41cbece1d648f3e2dba112375e494d2ed8192db9..d317a8d33db5b69a84a0d193cb6322afaa53dff6 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -42,11 +42,16 @@ class CreateSummaryFileWriterOp : public OpKernel { const int32 flush_millis = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp)); const string filename_suffix = tmp->scalar()(); - SummaryWriterInterface* s; - OP_REQUIRES_OK(ctx, - CreateSummaryFileWriter(max_queue, flush_millis, logdir, - filename_suffix, ctx->env(), &s)); - OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s)); + + SummaryWriterInterface* s = nullptr; + OP_REQUIRES_OK(ctx, LookupOrCreateResource( + ctx, HandleFromInput(ctx, 0), &s, + [max_queue, flush_millis, logdir, filename_suffix, + ctx](SummaryWriterInterface** s) { + return CreateSummaryFileWriter( + max_queue, flush_millis, logdir, + filename_suffix, ctx->env(), s); + })); } }; REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU), @@ -66,17 +71,23 @@ class CreateSummaryDbWriterOp : public OpKernel { const string run_name = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp)); const string user_name = tmp->scalar()(); - SummaryWriterInterface* s; - Sqlite* db; - OP_REQUIRES_OK(ctx, Sqlite::Open(db_uri, - SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, - &db)); - core::ScopedUnref unref(db); - OP_REQUIRES_OK(ctx, SetupTensorboardSqliteDb(db)); + + SummaryWriterInterface* s = nullptr; OP_REQUIRES_OK( - ctx, CreateSummaryDbWriter(db, experiment_name, - run_name, user_name, ctx->env(), &s)); - OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s)); + ctx, + LookupOrCreateResource( + ctx, HandleFromInput(ctx, 0), &s, + [db_uri, experiment_name, run_name, user_name, + ctx](SummaryWriterInterface** s) { + Sqlite* db; + TF_RETURN_IF_ERROR(Sqlite::Open( + db_uri, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db)); + core::ScopedUnref unref(db); + TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db)); + TF_RETURN_IF_ERROR(CreateSummaryDbWriter( + db, experiment_name, run_name, user_name, ctx->env(), s)); + return Status::OK(); + })); } }; REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU), @@ -267,8 +278,6 @@ class WriteAudioSummaryOp : public OpKernel { private: int max_outputs_; - bool has_sample_rate_attr_; - float sample_rate_attr_; }; REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU), WriteAudioSummaryOp); diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc index b818724ec2e895d3995fe19b811327ed0ba112c2..1f4e3418f4826dee789002d4aa688f8ce14e17d2 100644 --- a/tensorflow/core/kernels/summary_op.cc +++ b/tensorflow/core/kernels/summary_op.cc @@ -41,11 +41,12 @@ class SummaryScalarOp : public OpKernel { const Tensor& values = c->input(1); OP_REQUIRES( - c, tags.IsSameSize(values) || - (IsLegacyScalar(tags.shape()) && IsLegacyScalar(values.shape())), - errors::InvalidArgument("tags and values not the same shape: ", - tags.shape().DebugString(), " != ", - values.shape().DebugString(), SingleTag(tags))); + c, + tags.IsSameSize(values) || + (IsLegacyScalar(tags.shape()) && IsLegacyScalar(values.shape())), + errors::InvalidArgument( + "tags and values not the same shape: ", tags.shape().DebugString(), + " != ", values.shape().DebugString(), SingleTag(tags))); auto Ttags = tags.flat(); auto Tvalues = values.flat(); Summary s; diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc index dedc2da60bab0d0c0613630c384c2f23ddae31e3..8c3a58b108abe66f2b61b5153923bee192246cd1 100644 --- a/tensorflow/core/kernels/svd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc @@ -63,8 +63,8 @@ __global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m, int64 ldu, const Scalar* M, const Scalar* U, const Scalar* S, Scalar* V) { - CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) { - CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count, y) { + CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) { + CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) { Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch]; CudaAtomicAdd(V + batch, v); } diff --git a/tensorflow/core/kernels/tile_functor_cpu.cc b/tensorflow/core/kernels/tile_functor_cpu.cc index b2fd669541d32406512c4618fac77604baefedbe..f8144867014eccf04c892d0ce90a2aa280dfd764 100644 --- a/tensorflow/core/kernels/tile_functor_cpu.cc +++ b/tensorflow/core/kernels/tile_functor_cpu.cc @@ -15,10 +15,10 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow/core/kernels/tile_functor.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/tile_functor.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl.h b/tensorflow/core/kernels/tile_ops_cpu_impl.h index 054b31ef9e0b4904d8803d1c4542ff805e0a7673..df6a666cd441d9c1306d950bbe0e79bf3dae28d9 100644 --- a/tensorflow/core/kernels/tile_ops_cpu_impl.h +++ b/tensorflow/core/kernels/tile_ops_cpu_impl.h @@ -63,7 +63,7 @@ TF_CALL_int64(DEFINE_TYPE); #undef DEFINE_DIM #undef DEFINE_TYPE -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // end namespace functor } // end namespace tensorflow diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 38e77ab60fb7126bcdedc09bfe9e2ec7de88c0ad..233aa03c32333e62281cb8ab71828649b4fabe7e 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -1228,11 +1228,8 @@ inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1, quadratic = Eigen::numext::pow(accum, -lr_power) / lr + static_cast(2) * l2; } - if (Eigen::numext::abs(linear) > l1) { - return (l1 * sgn(linear) - linear) / quadratic; - } else { - return static_cast(0.0); - } + auto l1_reg_adjust = std::max(std::min(linear, l1), -l1); + return (l1_reg_adjust - linear) / quadratic; } } // namespace @@ -3279,7 +3276,6 @@ REGISTER_KERNELS(double, int64); #undef REGISTER_KERNELS - template class ApplyAddSignOp : public OpKernel { public: @@ -3362,17 +3358,15 @@ TF_CALL_double(REGISTER_CPU_KERNELS); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void ApplyAddSign::operator()( \ - const GPUDevice& d, \ - typename TTypes::Flat var, \ - typename TTypes::Flat m, \ - typename TTypes::ConstScalar lr, \ - typename TTypes::ConstScalar alpha, \ - typename TTypes::ConstScalar sign_decay, \ - typename TTypes::ConstScalar beta, \ - typename TTypes::ConstFlat grad); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ApplyAddSign::operator()( \ + const GPUDevice& d, typename TTypes::Flat var, \ + typename TTypes::Flat m, typename TTypes::ConstScalar lr, \ + typename TTypes::ConstScalar alpha, \ + typename TTypes::ConstScalar sign_decay, \ + typename TTypes::ConstScalar beta, \ + typename TTypes::ConstFlat grad); \ extern template struct ApplyAddSign; DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); @@ -3387,7 +3381,6 @@ REGISTER_KERNELS(GPU, double); #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS - template class ApplyPowerSignOp : public OpKernel { public: @@ -3470,17 +3463,15 @@ TF_CALL_double(REGISTER_CPU_KERNELS); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void ApplyPowerSign::operator()( \ - const GPUDevice& d, \ - typename TTypes::Flat var, \ - typename TTypes::Flat m, \ - typename TTypes::ConstScalar lr, \ - typename TTypes::ConstScalar logbase, \ - typename TTypes::ConstScalar sign_decay, \ - typename TTypes::ConstScalar beta, \ - typename TTypes::ConstFlat grad); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ApplyPowerSign::operator()( \ + const GPUDevice& d, typename TTypes::Flat var, \ + typename TTypes::Flat m, typename TTypes::ConstScalar lr, \ + typename TTypes::ConstScalar logbase, \ + typename TTypes::ConstScalar sign_decay, \ + typename TTypes::ConstScalar beta, \ + typename TTypes::ConstFlat grad); \ extern template struct ApplyPowerSign; DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index d443a6b3c1d0b548e915216adbc05549a66eaeda..0376a3b2c602c13b3082b7762cf61a2b30552199 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -17,8 +17,8 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/training_ops.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/training_ops.h" namespace tensorflow { @@ -115,13 +115,11 @@ struct ApplyAdam { Eigen::Sizes<1> single; const auto one = static_cast(1.0); m.device(d) = - m + - (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) * - (grad - m); + m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) * + (grad - m); v.device(d) = - v + - (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) * - (grad.square() - v); + v + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) * + (grad.square() - v); if (use_nesterov) { var.device(d) -= @@ -157,9 +155,9 @@ struct ApplyRMSProp { bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; const auto one = static_cast(1.0); - ms.device(d) = ms + - (rho.constant(one) - rho).reshape(single).broadcast(bcast) * - (grad.square() - ms); + ms.device(d) = + ms + (rho.constant(one) - rho).reshape(single).broadcast(bcast) * + (grad.square() - ms); mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) + lr.reshape(single).broadcast(bcast) * grad / @@ -212,7 +210,7 @@ struct ApplyAddSign { auto beta_bcast = beta.reshape(single).broadcast(bcast); auto one_minus_beta = (beta.constant(one) - beta).reshape(single).broadcast(bcast); - m.device(d) = m * beta_bcast + grad * one_minus_beta; + m.device(d) = m * beta_bcast + grad * one_minus_beta; // The following is the GPU equivalent of the CPU version: // var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad; @@ -244,7 +242,7 @@ struct ApplyPowerSign { auto beta_bcast = beta.reshape(single).broadcast(bcast); auto one_minus_beta = (beta.constant(one) - beta).reshape(single).broadcast(bcast); - m.device(d) = m * beta_bcast + grad * one_minus_beta; + m.device(d) = m * beta_bcast + grad * one_minus_beta; // The following is the GPU equivalent of the CPU version: // auto grad_scale = (logbase() * sign_decay() * sign_gm).exp(); @@ -253,7 +251,7 @@ struct ApplyPowerSign { auto lr_bcast = lr.reshape(single).broadcast(bcast); auto logbase_bcast = logbase.reshape(single).broadcast(bcast); auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast); - auto grad_scale = (logbase_bcast * sign_decay_bcast * sign_gm).exp(); + auto grad_scale = (logbase_bcast * sign_decay_bcast * sign_gm).exp(); var.device(d) -= lr_bcast * grad_scale * grad; } }; diff --git a/tensorflow/core/kernels/training_ops_test.cc b/tensorflow/core/kernels/training_ops_test.cc index ffa7f87c9efda0e3288b9fb06d0c9d1a3dcba277..2dcc4a500e6c64753c6fde4f88582f914a50089e 100644 --- a/tensorflow/core/kernels/training_ops_test.cc +++ b/tensorflow/core/kernels/training_ops_test.cc @@ -176,8 +176,9 @@ static void Adam(int32 n, Graph** init_g, Graph** train_g) { auto beta2 = Scalar(g, 0.99); auto epsilon = Scalar(g, 1e-8); auto grad = Random(g, n); - test::graph::Multi(g, "ApplyAdam", {var, m, v, beta1_power, beta2_power, lr, - beta1, beta2, epsilon, grad}); + test::graph::Multi( + g, "ApplyAdam", + {var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad}); *train_g = g; } } diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 2e0d18b634a8aebeaf2b7a0118ea8a9367804086..7177ad78884cae85a847a283017511dcad2e4878 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -176,9 +176,10 @@ void TransposeOp::Compute(OpKernelContext* ctx) { } } for (int i = 0; i < dims; ++i) { - OP_REQUIRES(ctx, bits[i], errors::InvalidArgument( - i, " is missing from {", - str_util::Join(permutation, ","), "}.")); + OP_REQUIRES( + ctx, bits[i], + errors::InvalidArgument(i, " is missing from {", + str_util::Join(permutation, ","), "}.")); } // 0-D, 1-D, and identity transposes do nothing. diff --git a/tensorflow/core/kernels/typed_queue.h b/tensorflow/core/kernels/typed_queue.h index 0d608d9b8799d561141ac2d3378a0f0e3507acfd..43dcb4cef74c568a6bc31abc8c460cff241fc6fa 100644 --- a/tensorflow/core/kernels/typed_queue.h +++ b/tensorflow/core/kernels/typed_queue.h @@ -58,9 +58,9 @@ Status TypedQueue::Initialize() { if (!component_shapes_.empty() && component_dtypes_.size() != component_shapes_.size()) { return errors::InvalidArgument( - "Different number of component types. ", "Types: ", - DataTypeSliceString(component_dtypes_), ", Shapes: ", - ShapeListString(component_shapes_)); + "Different number of component types. ", + "Types: ", DataTypeSliceString(component_dtypes_), + ", Shapes: ", ShapeListString(component_shapes_)); } mutex_lock lock(mu_); diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc index 397bdd56708d766d06e5a68f3b049a5b928195e1..764b6a252adf09c13511a01f95332857f46eee96 100644 --- a/tensorflow/core/kernels/unpack_op.cc +++ b/tensorflow/core/kernels/unpack_op.cc @@ -34,7 +34,7 @@ typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class UnpackOp : public OpKernel { @@ -65,8 +65,9 @@ class UnpackOp : public OpKernel { output_shape.RemoveDim(axis); const int64 output_size = output_shape.num_elements(); OP_REQUIRES( - context, FastBoundsCheck(output_size, - std::numeric_limits::max()), + context, + FastBoundsCheck(output_size, + std::numeric_limits::max()), errors::InvalidArgument("output size must fit in Eigen DenseIndex")); // This optimization is currently not applicable for SYCL devices diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..62e814ff773ccb2ee3d7e9445966f5d805817802 --- /dev/null +++ b/tensorflow/core/kernels/unravel_index_op.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +namespace { +template +struct mod_op { + const T operator()(const T& a, const T& b) const { return a % b; } +}; +} // namespace + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class UnravelIndexOp : public OpKernel { + public: + explicit UnravelIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& indices_tensor = ctx->input(0); + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(indices_tensor.shape()) || + TensorShapeUtils::IsScalar(indices_tensor.shape()), + errors::InvalidArgument( + "The indices can only be scalar or vector, got \"", + indices_tensor.shape().DebugString(), "\"")); + + const Tensor& dims_tensor = ctx->input(1); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(dims_tensor.shape()), + errors::InvalidArgument("The indices can only be 1-D, got \"", + dims_tensor.shape().DebugString(), "\"")); + + auto dims = dims_tensor.vec(); + + Eigen::array reverse({true}); + + Tensor strides_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({dims_tensor.NumElements()}), + &strides_tensor)); + + auto strides = strides_tensor.vec(); + strides = dims.reverse(reverse) + .scan(0, Eigen::internal::ProdReducer(), false) + .reverse(reverse); + + Tensor strides_shifted_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({dims_tensor.NumElements()}), + &strides_shifted_tensor)); + + auto strides_shifted = strides_shifted_tensor.vec(); + strides_shifted = dims.reverse(reverse) + .scan(0, Eigen::internal::ProdReducer(), true) + .reverse(reverse); + + Tensor* output_tensor = nullptr; + if (TensorShapeUtils::IsScalar(indices_tensor.shape())) { + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, TensorShape({dims_tensor.NumElements()}), + &output_tensor)); + + auto output = output_tensor->vec(); + + output = output.constant(indices_tensor.scalar()()); + output = output.binaryExpr(strides, mod_op()) / strides_shifted; + } else { + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, + TensorShape({dims_tensor.NumElements(), + indices_tensor.NumElements()}), + &output_tensor)); + + auto output = output_tensor->matrix(); + + Eigen::array reshape{{dims_tensor.NumElements(), 1}}; + Eigen::array bcast({1, indices_tensor.NumElements()}); + Eigen::array indices_reshape{{1, indices_tensor.NumElements()}}; + Eigen::array indices_bcast({dims_tensor.NumElements(), 1}); + + output = indices_tensor.vec() + .reshape(indices_reshape) + .broadcast(indices_bcast); + output = output.binaryExpr(strides.reshape(reshape).broadcast(bcast), + mod_op()) / + strides_shifted.reshape(reshape).broadcast(bcast); + } + } +}; + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("UnravelIndex").Device(DEVICE_CPU).TypeConstraint("Tidx"), \ + UnravelIndexOp); +TF_CALL_int32(REGISTER_KERNEL) TF_CALL_int64(REGISTER_KERNEL) +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc index 10ccc85b7cd63db7f8d329a4253784abed7174cf..7fd5809ca49eba6af24d7dafe3b34b7f2c238279 100644 --- a/tensorflow/core/kernels/variable_ops.cc +++ b/tensorflow/core/kernels/variable_ops.cc @@ -237,6 +237,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL); IsVariableInitializedOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/word2vec_kernels.cc b/tensorflow/core/kernels/word2vec_kernels.cc index 2d05d72bff162f98e8d13e8a3208e4dd00a48fa4..3477445197a961b275e3efb8ce09d5b075342f9e 100644 --- a/tensorflow/core/kernels/word2vec_kernels.cc +++ b/tensorflow/core/kernels/word2vec_kernels.cc @@ -188,9 +188,9 @@ class SkipgramOp : public OpKernel { ++corpus_size_; } if (corpus_size_ < window_size_ * 10) { - return errors::InvalidArgument("The text file ", filename, - " contains too little data: ", - corpus_size_, " words"); + return errors::InvalidArgument( + "The text file ", filename, + " contains too little data: ", corpus_size_, " words"); } typedef std::pair WordFreq; std::vector ordered; diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc index 0f8d027caadab2dee04d3041ed515a40f22476f3..a6a71fdfaf126410b26766954c0c2fc5b86d003a 100644 --- a/tensorflow/core/kernels/xent_op.cc +++ b/tensorflow/core/kernels/xent_op.cc @@ -30,7 +30,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class SoftmaxXentWithLogitsOp : public OpKernel { @@ -44,8 +44,8 @@ class SoftmaxXentWithLogitsOp : public OpKernel { OP_REQUIRES(context, logits_in.IsSameSize(labels_in), errors::InvalidArgument( "logits and labels must be same size: logits_size=", - logits_in.shape().DebugString(), " labels_size=", - labels_in.shape().DebugString())); + logits_in.shape().DebugString(), + " labels_size=", labels_in.shape().DebugString())); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()), errors::InvalidArgument("logits must be 2-dimensional")); // As we already tested that both inputs have the same shape no need to @@ -72,7 +72,7 @@ class SoftmaxXentWithLogitsOp : public OpKernel { functor(context->eigen_device(), logits_in.matrix(), labels_in.matrix(), scratch.matrix(), loss_out->vec(), back_out->matrix()); - } + } } }; @@ -87,7 +87,7 @@ struct XentFunctorBase { typename TTypes::Vec loss, typename TTypes::Matrix backprop) { XentEigenImpl::Compute(d, logits, labels, scratch, loss, - backprop); + backprop); } }; @@ -97,7 +97,7 @@ struct XentFunctor : XentFunctorBase {}; #ifdef TENSORFLOW_USE_SYCL template struct XentFunctor : XentFunctorBase {}; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace functor #define REGISTER_CPU(T) \ @@ -129,6 +129,6 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") .Device(DEVICE_SYCL) .TypeConstraint("T"), SoftmaxXentWithLogitsOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index 601704c8a70f0b18c611cf8cd10d140314f61dc4..ba03357cc6ac22d42f9f1cceab6875ef7e49b4c2 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -27,9 +27,6 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(); #include #include -#if 0 -#include -#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/blocking_counter.h" @@ -360,7 +357,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, l_tick6 = libxsmm_timer_tick(); #endif -#if 1 BlockingCounter counter(num_threads); for (int i = 0; i < num_threads; ++i) { @@ -371,14 +367,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, }); } counter.Wait(); -#else -#pragma omp parallel - { - chk_libxsmm_err( - libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, omp_get_thread_num()), - "Worker"); - } -#endif #if defined(LIBXSMM_DETAILED_TIMING) l_tick7 = libxsmm_timer_tick(); diff --git a/tensorflow/core/kernels/xsmm_conv2d_test.cc b/tensorflow/core/kernels/xsmm_conv2d_test.cc index e29470124674636a0e125a5cd1b856a467f4c6f0..481f3b7ba46bac42a276d46e60c11f34bc163e3b 100644 --- a/tensorflow/core/kernels/xsmm_conv2d_test.cc +++ b/tensorflow/core/kernels/xsmm_conv2d_test.cc @@ -13,18 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/conv_ops.h" -#include "tensorflow/core/platform/test.h" +#include "include/libxsmm.h" +#include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/conv_ops.h" #include "tensorflow/core/kernels/ops_testutil.h" -#include "include/libxsmm.h" -#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { - typedef struct { int nImg; int nIfm; @@ -49,45 +48,41 @@ typedef struct { int stride_w; } naive_conv_t; - -LIBXSMM_INLINE void naive_copy_NCHW_to_NHWC(const float* nchw, Tensor &nhwc, int N, int H, int W, int C) -{ - LIBXSMM_VLA_DECL(4, const float, input, nchw, C, H, W); +LIBXSMM_INLINE void naive_copy_NCHW_to_NHWC(const float* nchw, Tensor& nhwc, + int N, int H, int W, int C) { + LIBXSMM_VLA_DECL(4, const float, input, nchw, C, H, W); int n, h, w, c; - auto output = nhwc.flat(); - for ( n = 0; n < N; n++ ) { - for ( h = 0; h < H; h++ ) { - for ( w = 0; w < W; w++ ) { - for ( c = 0; c < C; c++ ) { - output(n*H*W*C + h*W*C +w*C + c) = - LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W); + auto output = nhwc.flat(); + for (n = 0; n < N; n++) { + for (h = 0; h < H; h++) { + for (w = 0; w < W; w++) { + for (c = 0; c < C; c++) { + output(n * H * W * C + h * W * C + w * C + c) = + LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W); } } } } } - -LIBXSMM_INLINE void naive_copy_KCRS_to_RSCK(const float* kcrs, Tensor &rsck, int R, int S, int C, int K) -{ - LIBXSMM_VLA_DECL(4, const float, input, kcrs, C, R, S); +LIBXSMM_INLINE void naive_copy_KCRS_to_RSCK(const float* kcrs, Tensor& rsck, + int R, int S, int C, int K) { + LIBXSMM_VLA_DECL(4, const float, input, kcrs, C, R, S); int r, s, c, k; - auto output = rsck.flat(); - - for ( r = 0; r < R; r++ ) { - for ( s = 0; s < S; s++ ) { - for ( c = 0; c < C; c++ ) { - for ( k = 0; k < K; k++ ) { - output(r*S*C*K + s*C*K + c*K + k) = - LIBXSMM_VLA_ACCESS(4, input, k, c, r, s, C, R, S); + auto output = rsck.flat(); + + for (r = 0; r < R; r++) { + for (s = 0; s < S; s++) { + for (c = 0; c < C; c++) { + for (k = 0; k < K; k++) { + output(r * S * C * K + s * C * K + c * K + k) = + LIBXSMM_VLA_ACCESS(4, input, k, c, r, s, C, R, S); } } } } } - - LIBXSMM_INLINE void zero_buf(float* buf, long size) { int i; for (i = 0; i < size; ++i) { @@ -95,52 +90,53 @@ LIBXSMM_INLINE void zero_buf(float* buf, long size) { } } -LIBXSMM_INLINE void copy_buf(Tensor &dst,float *src,long size) { - long i; - auto output = dst.flat(); - for (i = 0; i < size; ++i) - output(i) = src[i]; +LIBXSMM_INLINE void copy_buf(Tensor& dst, float* src, long size) { + long i; + auto output = dst.flat(); + for (i = 0; i < size; ++i) output(i) = src[i]; } -LIBXSMM_INLINE void init_buf(float* buf, long size, int initPos, int initOne) -{ +LIBXSMM_INLINE void init_buf(float* buf, long size, int initPos, int initOne) { int i; zero_buf(buf, size); for (i = 0; i < size; ++i) { - buf[i] = (float)((initOne != 0) ? 1.0 : ((initPos != 0) ? drand48() : (0.05 - drand48()/10.0))); + buf[i] = + (float)((initOne != 0) + ? 1.0 + : ((initPos != 0) ? drand48() : (0.05 - drand48() / 10.0))); } } - - -LIBXSMM_INLINE void naive_conv_fp(naive_conv_t* param, const float* input, float* output, const float* filter) -{ - int nImg = param->nImg; - int nIfm = param->nIfm; - int nOfm = param->nOfm; - int ifhp = param->ifhp; - int ifwp = param->ifwp; - int ofhp = param->ofhp; - int ofwp = param->ofwp; - int ifh = param->ifh; - int ifw = param->ifw; - int ofh = param->ofh; - int ofw = param->ofw; - int pad_h = param->pad_h; - int pad_w = param->pad_w; - int pad_h_in = param->pad_h_in; - int pad_w_in = param->pad_w_in; +LIBXSMM_INLINE void naive_conv_fp(naive_conv_t* param, const float* input, + float* output, const float* filter) { + int nImg = param->nImg; + int nIfm = param->nIfm; + int nOfm = param->nOfm; + int ifhp = param->ifhp; + int ifwp = param->ifwp; + int ofhp = param->ofhp; + int ofwp = param->ofwp; + int ifh = param->ifh; + int ifw = param->ifw; + int ofh = param->ofh; + int ofw = param->ofw; + int pad_h = param->pad_h; + int pad_w = param->pad_w; + int pad_h_in = param->pad_h_in; + int pad_w_in = param->pad_w_in; int pad_h_out = param->pad_h_out; int pad_w_out = param->pad_w_out; - int kh = param->kh; - int kw = param->kw; - int stride_h = param->stride_h; - int stride_w = param->stride_w; + int kh = param->kh; + int kw = param->kw; + int stride_h = param->stride_h; + int stride_w = param->stride_w; /* loop counters */ int img, ofm, ifm, oj, oi, ij, ii, kj, ki; - LIBXSMM_VLA_DECL(4, float, output_t, output + (pad_w_out * ofwp + pad_h_out), nOfm, ofhp, ofwp); - LIBXSMM_VLA_DECL(4, const float, input_t, input + (pad_w_in * ifwp + pad_h_in), nIfm, ifhp, ifwp); + LIBXSMM_VLA_DECL(4, float, output_t, output + (pad_w_out * ofwp + pad_h_out), + nOfm, ofhp, ofwp); + LIBXSMM_VLA_DECL(4, const float, input_t, + input + (pad_w_in * ifwp + pad_h_in), nIfm, ifhp, ifwp); LIBXSMM_VLA_DECL(4, const float, filter_t, filter, nIfm, kh, kw); for (img = 0; img < nImg; ++img) { @@ -151,12 +147,15 @@ LIBXSMM_INLINE void naive_conv_fp(naive_conv_t* param, const float* input, float for (oi = 0; oi < ofw; ++oi) { ii = oi * stride_w - pad_w; for (kj = 0; kj < kh; ++kj) { - if(ij+kj < 0 || ij+kj >= ifh) continue; + if (ij + kj < 0 || ij + kj >= ifh) continue; for (ki = 0; ki < kw; ++ki) { - if(ii+ki < 0 || ii+ki >= ifw) continue; - LIBXSMM_VLA_ACCESS( 4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) += - LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp) - * LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw); + if (ii + ki < 0 || ii + ki >= ifw) continue; + LIBXSMM_VLA_ACCESS(4, output_t, img, ofm, oj, oi, nOfm, ofhp, + ofwp) += + LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij + kj, ii + ki, + nIfm, ifhp, ifwp) * + LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, + kw); } } } @@ -168,134 +167,118 @@ LIBXSMM_INLINE void naive_conv_fp(naive_conv_t* param, const float* input, float void RunXsmmVsGeneric() {} - class XsmmConv2DTest : public OpsTestBase { protected: void MakeOp(int stride) { - TF_CHECK_OK(NodeDefBuilder("xsmm", "Conv2D") - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Attr("strides", {1, stride,stride, 1}) - .Attr("padding", "VALID" ) - .Finalize(node_def())); - + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "VALID") + .Finalize(node_def())); TF_ASSERT_OK(InitOp()); } }; TEST_F(XsmmConv2DTest, Basic) { - MakeOp(1); + MakeOp(1); - // setup scoped allocator, which uses cpu_allocator() for this scope - const libxsmm_tf_allocator tf_allocator; + // setup scoped allocator, which uses cpu_allocator() for this scope + const libxsmm_tf_allocator tf_allocator; - int ifw = 14; /* input width, "W" */ - int ifh = 14; /* input height, "H" */ - int nImg = 32; /* mini-batch size, "N" */ - int nIfm = 64; /* number of input feature maps, "C" */ - int nOfm = 64; /* number of output feature maps, "K" */ - int kh = 3; /* filter height, "R" */ - int kw = 3; /* filter width, "S" */ - int pad = 0; /* padding in output */ - int stride = 1; /* stride when accessing inputs */ + int ifw = 14; /* input width, "W" */ + int ifh = 14; /* input height, "H" */ + int nImg = 32; /* mini-batch size, "N" */ + int nIfm = 64; /* number of input feature maps, "C" */ + int nOfm = 64; /* number of output feature maps, "K" */ + int kh = 3; /* filter height, "R" */ + int kw = 3; /* filter width, "S" */ + int pad = 0; /* padding in output */ + int stride = 1; /* stride when accessing inputs */ + int stride_w = stride; + int stride_h = stride; + int pad_h = pad; + int pad_w = pad; - int stride_w = stride; - int stride_h = stride; - int pad_h = pad; - int pad_w = pad; + int pad_h_in = pad_h; + int pad_w_in = pad_w; - int pad_h_in = pad_h; - int pad_w_in = pad_w; - - int pad_h_out = 0; - int pad_w_out = 0; + int pad_h_out = 0; + int pad_w_out = 0; /* deriving some values for naive code */ - int ofh = (ifh + 2 * pad_h - kh) / stride_h + 1; - int ofw = (ifw + 2 * pad_w - kw) / stride_w + 1; - int ifhp = ifh + 2 * pad_h_in; - int ifwp = ifw + 2 * pad_w_in; - int ofhp = ofh + 2 * pad_h_out; - int ofwp = ofw + 2 * pad_w_out; - - - //Initialization of Filter and Image - - /* allocate data */ - float *naive_input = (float*)libxsmm_aligned_scratch( nImg*nIfm*ifhp*ifwp*sizeof(float), 2097152); - float *naive_output = (float*)libxsmm_aligned_scratch( nImg*nOfm*ofhp*ofwp*sizeof(float), 2097152); - float *naive_filter = (float*)libxsmm_aligned_scratch( nOfm*nIfm*kh*kw* sizeof(float), 2097152); - /* initialize data */ - init_buf(naive_input, nImg*nIfm*ifhp*ifwp, 0, 0); - zero_buf(naive_output, nImg*nOfm*ofhp*ofwp); - init_buf(naive_filter, nOfm*nIfm*kh*kw, 0, 0); - - - Tensor image(DT_FLOAT, - {nImg, ifhp, ifwp, nIfm}); - - - Tensor filter(DT_FLOAT, {kh,kw,nIfm,nOfm}); - - - naive_copy_NCHW_to_NHWC(naive_input, image, nImg, ifhp, ifwp, nIfm); - naive_copy_KCRS_to_RSCK(naive_filter, filter, kh, kw, nIfm, nOfm); - - - //Run naive convolution - - naive_conv_t naive_param; - - naive_param.nImg = nImg; - naive_param.nIfm = nIfm; - naive_param.nOfm = nOfm; - naive_param.ifhp = ifhp; - naive_param.ifwp = ifwp; - naive_param.ofhp = ofhp; - naive_param.ofwp = ofwp; - naive_param.ifh = ifh; - naive_param.ifw = ifw; - naive_param.ofh = ofh; - naive_param.ofw = ofw; - naive_param.pad_h = pad_h; - naive_param.pad_w = pad_w; - naive_param.pad_h_in = pad_h_in; - naive_param.pad_w_in = pad_w_in; - naive_param.pad_h_out = pad_h_out; - naive_param.pad_w_out = pad_w_out; - naive_param.kh = kh; - naive_param.kw = kw; - naive_param.stride_h = stride_h; - naive_param.stride_w = stride_w; - - - naive_conv_fp(&naive_param, naive_input, naive_output, naive_filter); - - - - AddInputFromArray(image.shape(), image.flat()); - AddInputFromArray(filter.shape(), filter.flat()); - - - - //Run Op (TF) - TF_ASSERT_OK(RunOpKernel()); - - // Check the output. - Tensor expected(DT_FLOAT, {nImg,ofhp,ofwp, nOfm}); - naive_copy_NCHW_to_NHWC(naive_output, expected, nImg, ofhp, ofwp, nOfm); - - - test::ExpectTensorNear(expected, *GetOutput(0), 1e-5); - libxsmm_free(naive_input); - libxsmm_free(naive_output); - libxsmm_free(naive_filter); - - - + int ofh = (ifh + 2 * pad_h - kh) / stride_h + 1; + int ofw = (ifw + 2 * pad_w - kw) / stride_w + 1; + int ifhp = ifh + 2 * pad_h_in; + int ifwp = ifw + 2 * pad_w_in; + int ofhp = ofh + 2 * pad_h_out; + int ofwp = ofw + 2 * pad_w_out; + + // Initialization of Filter and Image + + /* allocate data */ + float* naive_input = (float*)libxsmm_aligned_scratch( + nImg * nIfm * ifhp * ifwp * sizeof(float), 2097152); + float* naive_output = (float*)libxsmm_aligned_scratch( + nImg * nOfm * ofhp * ofwp * sizeof(float), 2097152); + float* naive_filter = (float*)libxsmm_aligned_scratch( + nOfm * nIfm * kh * kw * sizeof(float), 2097152); + /* initialize data */ + init_buf(naive_input, nImg * nIfm * ifhp * ifwp, 0, 0); + zero_buf(naive_output, nImg * nOfm * ofhp * ofwp); + init_buf(naive_filter, nOfm * nIfm * kh * kw, 0, 0); + + Tensor image(DT_FLOAT, {nImg, ifhp, ifwp, nIfm}); + + Tensor filter(DT_FLOAT, {kh, kw, nIfm, nOfm}); + + naive_copy_NCHW_to_NHWC(naive_input, image, nImg, ifhp, ifwp, nIfm); + naive_copy_KCRS_to_RSCK(naive_filter, filter, kh, kw, nIfm, nOfm); + + // Run naive convolution + + naive_conv_t naive_param; + + naive_param.nImg = nImg; + naive_param.nIfm = nIfm; + naive_param.nOfm = nOfm; + naive_param.ifhp = ifhp; + naive_param.ifwp = ifwp; + naive_param.ofhp = ofhp; + naive_param.ofwp = ofwp; + naive_param.ifh = ifh; + naive_param.ifw = ifw; + naive_param.ofh = ofh; + naive_param.ofw = ofw; + naive_param.pad_h = pad_h; + naive_param.pad_w = pad_w; + naive_param.pad_h_in = pad_h_in; + naive_param.pad_w_in = pad_w_in; + naive_param.pad_h_out = pad_h_out; + naive_param.pad_w_out = pad_w_out; + naive_param.kh = kh; + naive_param.kw = kw; + naive_param.stride_h = stride_h; + naive_param.stride_w = stride_w; + + naive_conv_fp(&naive_param, naive_input, naive_output, naive_filter); + + AddInputFromArray(image.shape(), image.flat()); + AddInputFromArray(filter.shape(), filter.flat()); + + // Run Op (TF) + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(DT_FLOAT, {nImg, ofhp, ofwp, nOfm}); + naive_copy_NCHW_to_NHWC(naive_output, expected, nImg, ofhp, ofwp, nOfm); + + test::ExpectTensorNear(expected, *GetOutput(0), 1e-5); + libxsmm_free(naive_input); + libxsmm_free(naive_output); + libxsmm_free(naive_filter); } /* @@ -325,7 +308,8 @@ TEST(XsmmConv2DTest, Basic) { desc.threads = num_threads; desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT; desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; - desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK; + desc.filter_format = +LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK; desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; desc.datatype = LIBXSMM_DNN_DATATYPE_F32; diff --git a/tensorflow/core/lib/core/status.h b/tensorflow/core/lib/core/status.h index 58a50a70c26a63a9edd55349e2253a9ace16f1f2..49f74ff47fbc839c84465ba86e85b38cb3bd38ec 100644 --- a/tensorflow/core/lib/core/status.h +++ b/tensorflow/core/lib/core/status.h @@ -131,7 +131,7 @@ inline tensorflow::string* TfCheckOpHelper(::tensorflow::Status v, while (auto _result = ::tensorflow::TfCheckOpHelper(val, #val)) \ LOG(level) << *(_result) -#define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL) +#define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL) #define TF_QCHECK_OK(val) TF_DO_CHECK_OK(val, QFATAL) // DEBUG only version of TF_CHECK_OK. Compiler still parses 'val' even in opt diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc index 2b10ebeaf7cbed4a8466a69898d6d4d6660ed5cb..e55ed79d36cd2db7a6f6b19f3579f47e73b4b2d9 100644 --- a/tensorflow/core/lib/core/threadpool.cc +++ b/tensorflow/core/lib/core/threadpool.cc @@ -66,7 +66,9 @@ struct EigenEnvironment { } return Task{ std::unique_ptr(new TaskImpl{ - std::move(f), Context(ContextKind::kThread), id, + std::move(f), + Context(ContextKind::kThread), + id, }), }; } diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc index 49ddb16645c32a82d90eafa5f550b8887ac84b79..627ef5a892a35ec43d0c31220dcf046b4b8eda55 100644 --- a/tensorflow/core/lib/core/threadpool_test.cc +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -97,8 +97,8 @@ TEST(ThreadPool, ParallelForWithWorkerId) { } pool.ParallelForWithWorkerId( kWorkItems, kHugeCost, - [&threads_running, &work, num_threads]( - int64 begin, int64 end, int64 id) { + [&threads_running, &work, num_threads](int64 begin, int64 end, + int64 id) { // Store true for the current thread, and assert that another thread // is not running with the same id. ASSERT_LE(0, id); diff --git a/tensorflow/core/lib/db/sqlite.h b/tensorflow/core/lib/db/sqlite.h index 0faa458f1d692a103099d5b05d0400944ffdaad7..efe97f78d259199a74bf5e830f70de657d1cd679 100644 --- a/tensorflow/core/lib/db/sqlite.h +++ b/tensorflow/core/lib/db/sqlite.h @@ -18,12 +18,12 @@ limitations under the License. #include #include "sqlite3.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/lib/core/refcount.h" /// TensorFlow SQLite Veneer /// @@ -121,10 +121,7 @@ class LOCKABLE Sqlite : public core::RefCounted { Sqlite(sqlite3* db, sqlite3_stmt* begin, sqlite3_stmt* commit, sqlite3_stmt* rollback) noexcept - : db_(db), - begin_(begin), - commit_(commit), - rollback_(rollback) {} + : db_(db), begin_(begin), commit_(commit), rollback_(rollback) {} sqlite3* const db_; sqlite3_stmt* const begin_; @@ -233,7 +230,8 @@ class SqliteStatement { /// freed until this statement is Reset() or finalized. void BindText(int parameter, const StringPiece& text) { Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), - SQLITE_TRANSIENT, SQLITE_UTF8), parameter); + SQLITE_TRANSIENT, SQLITE_UTF8), + parameter); size_ += text.size(); } void BindText(const char* parameter, const StringPiece& text) { @@ -241,7 +239,8 @@ class SqliteStatement { } void BindTextUnsafe(int parameter, const StringPiece& text) { Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), - SQLITE_STATIC, SQLITE_UTF8), parameter); + SQLITE_STATIC, SQLITE_UTF8), + parameter); size_ += text.size(); } void BindTextUnsafe(const char* parameter, const StringPiece& text) { @@ -254,7 +253,8 @@ class SqliteStatement { /// freed until this statement is Reset() or finalized. void BindBlob(int parameter, const StringPiece& blob) { Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), - SQLITE_TRANSIENT), parameter); + SQLITE_TRANSIENT), + parameter); size_ += blob.size(); } void BindBlob(const char* parameter, const StringPiece& blob) { @@ -262,7 +262,8 @@ class SqliteStatement { } void BindBlobUnsafe(int parameter, const StringPiece& blob) { Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), - SQLITE_STATIC), parameter); + SQLITE_STATIC), + parameter); size_ += blob.size(); } void BindBlobUnsafe(const char* parameter, const StringPiece& text) { @@ -320,9 +321,7 @@ class SqliteStatement { /// \brief Move constructor, after which is reset to empty. SqliteStatement(SqliteStatement&& other) noexcept - : db_(other.db_), - stmt_(other.stmt_), - bind_error_(other.bind_error_) { + : db_(other.db_), stmt_(other.stmt_), bind_error_(other.bind_error_) { other.db_ = nullptr; other.stmt_ = nullptr; other.bind_error_ = SQLITE_OK; diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc index c9c76ea5f2cd30b8abe7e3c9766ce4946ca25200..1e88323d017bec4b2705c6dbb19005efb8adbaa9 100644 --- a/tensorflow/core/lib/db/sqlite_test.cc +++ b/tensorflow/core/lib/db/sqlite_test.cc @@ -33,9 +33,7 @@ class SqliteTest : public ::testing::Test { db_->PrepareOrDie("CREATE TABLE T (a BLOB, b BLOB)").StepAndResetOrDie(); } - void TearDown() override { - db_->Unref(); - } + void TearDown() override { db_->Unref(); } Sqlite* db_; bool is_done_; @@ -213,7 +211,7 @@ TEST_F(SqliteTest, BindFailed) { Status s = stmt.StepOnce(); EXPECT_NE(string::npos, s.error_message().find("INSERT INTO T (a) VALUES (123)")) - << s.error_message(); + << s.error_message(); } TEST_F(SqliteTest, SnappyExtension) { @@ -226,7 +224,7 @@ TEST_F(SqliteTest, SnappyBinaryCompatibility) { EXPECT_EQ( "today is the end of the republic", db_->PrepareOrDie("SELECT UNSNAP(X'03207C746F6461792069732074686520656E64" - "206F66207468652072657075626C6963')") + "206F66207468652072657075626C6963')") .StepOnceOrDie() .ColumnString(0)); } diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc index 0f6999c88fca3fd7ab91d2f3e28348e22d106f45..9a5215320f58d10c22872c2837e882bed82f5b52 100644 --- a/tensorflow/core/lib/gif/gif_io.cc +++ b/tensorflow/core/lib/gif/gif_io.cc @@ -16,6 +16,7 @@ limitations under the License. // Functions to read images in GIF format. #include "tensorflow/core/lib/gif/gif_io.h" +#include #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/gif.h" @@ -44,6 +45,14 @@ int input_callback(GifFileType* gif_file, GifByteType* buf, int size) { return 0; } +static const char* GifErrorStringNonNull(int error_code) { + const char* error_string = GifErrorString(error_code); + if (error_string == nullptr) { + return "Unknown error"; + } + return error_string; +} + uint8* Decode(const void* srcdata, int datasize, const std::function& allocate_output, string* error_string) { @@ -55,17 +64,17 @@ uint8* Decode(const void* srcdata, int datasize, int error_code = D_GIF_SUCCEEDED; if (gif_file && DGifCloseFile(gif_file, &error_code) != GIF_OK) { LOG(WARNING) << "Fail to close gif file, reason: " - << GifErrorString(error_code); + << GifErrorStringNonNull(error_code); } }); if (error_code != D_GIF_SUCCEEDED) { *error_string = strings::StrCat("failed to open gif file: ", - GifErrorString(error_code)); + GifErrorStringNonNull(error_code)); return nullptr; } if (DGifSlurp(gif_file) != GIF_OK) { *error_string = strings::StrCat("failed to slurp gif file: ", - GifErrorString(gif_file->Error)); + GifErrorStringNonNull(gif_file->Error)); return nullptr; } if (gif_file->ImageCount <= 0) { @@ -81,23 +90,52 @@ uint8* Decode(const void* srcdata, int datasize, uint8* const dstdata = allocate_output(num_frames, width, height, channel); if (!dstdata) return nullptr; for (int k = 0; k < num_frames; k++) { + uint8* this_dst = dstdata + k * width * channel * height; + SavedImage* this_image = &gif_file->SavedImages[k]; GifImageDesc* img_desc = &this_image->ImageDesc; + + int imgLeft = img_desc->Left; + int imgTop = img_desc->Top; + int imgRight = img_desc->Left + img_desc->Width; + int imgBottom = img_desc->Top + img_desc->Height; + if (img_desc->Left != 0 || img_desc->Top != 0 || img_desc->Width != width || img_desc->Height != height) { - *error_string = strings::StrCat("can't process optimized gif"); - return nullptr; + // If the first frame does not fill the entire canvas then return error. + if (k == 0) { + *error_string = + strings::StrCat("the first frame does not fill the canvas"); + return nullptr; + } + // Otherwise previous frame will be reused to fill the unoccupied canvas. + imgLeft = std::max(imgLeft, 0); + imgTop = std::max(imgTop, 0); + imgRight = std::min(imgRight, width); + imgBottom = std::min(imgBottom, height); + + uint8* last_dst = dstdata + (k - 1) * width * channel * height; + for (int i = 0; i < height; ++i) { + uint8* p_dst = this_dst + i * width * channel; + uint8* l_dst = last_dst + i * width * channel; + for (int j = 0; j < width; ++j) { + p_dst[j * channel + 0] = l_dst[j * channel + 0]; + p_dst[j * channel + 1] = l_dst[j * channel + 1]; + p_dst[j * channel + 2] = l_dst[j * channel + 2]; + } + } } ColorMapObject* color_map = this_image->ImageDesc.ColorMap ? this_image->ImageDesc.ColorMap : gif_file->SColorMap; - uint8* this_dst = dstdata + k * width * channel * height; - for (int i = 0; i < height; ++i) { + for (int i = imgTop; i < imgBottom; ++i) { uint8* p_dst = this_dst + i * width * channel; - for (int j = 0; j < width; ++j) { - GifByteType color_index = this_image->RasterBits[i * width + j]; + for (int j = imgLeft; j < imgRight; ++j) { + GifByteType color_index = + this_image->RasterBits[(i - img_desc->Top) * (img_desc->Width) + + (j - img_desc->Left)]; const GifColorType& gif_color = color_map->Colors[color_index]; p_dst[j * channel + 0] = gif_color.Red; p_dst[j * channel + 1] = gif_color.Green; diff --git a/tensorflow/core/lib/gtl/cleanup.h b/tensorflow/core/lib/gtl/cleanup.h index 6053e986402598568299d1756d23068693c193c8..6bd60ca482430cf13f4f076badf460cf2e1d593b 100644 --- a/tensorflow/core/lib/gtl/cleanup.h +++ b/tensorflow/core/lib/gtl/cleanup.h @@ -55,22 +55,21 @@ namespace gtl { template class Cleanup { public: - Cleanup() - : released_(true), f_() {} + Cleanup() : released_(true), f_() {} template - explicit Cleanup(G&& f) // NOLINT + explicit Cleanup(G&& f) // NOLINT : f_(std::forward(f)) {} // NOLINT(build/c++11) Cleanup(Cleanup&& src) // NOLINT - : released_(src.is_released()), f_(src.release()) { } + : released_(src.is_released()), f_(src.release()) {} // Implicitly move-constructible from any compatible Cleanup. // The source will be released as if src.release() were called. // A moved-from Cleanup can be safely destroyed or reassigned. template Cleanup(Cleanup&& src) // NOLINT - : released_(src.is_released()), f_(src.release()) { } + : released_(src.is_released()), f_(src.release()) {} // Assignment to a Cleanup object behaves like destroying it // and making a new one in its place, analogous to unique_ptr @@ -102,8 +101,8 @@ class Cleanup { F f_; }; -template ::type> +template ::type> TF_MUST_USE_RESULT Cleanup MakeCleanup(F&& f) { return Cleanup(std::forward(f)); } diff --git a/tensorflow/core/lib/gtl/cleanup_test.cc b/tensorflow/core/lib/gtl/cleanup_test.cc index bd151cb2ab1c8a830eb1bd9546ab452d05c6c20c..a86ffd5fe284485f15fa824026e8d79f5191a384 100644 --- a/tensorflow/core/lib/gtl/cleanup_test.cc +++ b/tensorflow/core/lib/gtl/cleanup_test.cc @@ -65,15 +65,14 @@ TEST(CleanupTest, Release) { TEST(FinallyTest, TypeErasedWithoutFactory) { string s = "active"; { - AnyCleanup s_cleaner([&s]{ s.append(" clean"); }); + AnyCleanup s_cleaner([&s] { s.append(" clean"); }); EXPECT_EQ("active", s); } EXPECT_EQ("active clean", s); } struct Appender { - Appender(string* s, const string& msg) - : s_(s), msg_(msg) {} + Appender(string* s, const string& msg) : s_(s), msg_(msg) {} void operator()() const { s_->append(msg_); } string* s_; string msg_; @@ -163,7 +162,12 @@ class CleanupReferenceTest : public ::testing::Test { int* i; F(int* cp, int* i) : cp(cp), i(i) {} F(const F& o) : cp(o.cp), i(o.i) { ++*cp; } - F& operator=(const F& o) { cp = o.cp; i = o.i; ++*cp; return *this; } + F& operator=(const F& o) { + cp = o.cp; + i = o.i; + ++*cp; + return *this; + } F(F&&) = default; F& operator=(F&&) = default; void operator()() const { ++*i; } @@ -279,7 +283,7 @@ BENCHMARK(BM_AnyCleanup); void BM_AnyCleanupNoFactory(int iters) { while (iters--) { - AnyCleanup fin([]{Incr();}); + AnyCleanup fin([] { Incr(); }); } } BENCHMARK(BM_AnyCleanupNoFactory); diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h index d6e5d9effa794c46b7aa98691bb993dbd7e764c8..6e3cb2206d9658a3b0bc24b506049f503ae304ed 100644 --- a/tensorflow/core/lib/gtl/inlined_vector.h +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -31,12 +31,12 @@ limitations under the License. #ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ #define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ -#include #include #include #include #include #include +#include #include #include #include @@ -407,7 +407,7 @@ class InlinedVector { }; // 2) Construct a T with args at not-yet-initialized memory pointed by dst. struct Construct { - template + template void operator()(T* dst, Args&&... args) const { new (dst) T(std::forward(args)...); } diff --git a/tensorflow/core/lib/gtl/int_type.h b/tensorflow/core/lib/gtl/int_type.h index 647fc81aa7e4925d1d2b74b82146b18b0c17a4a9..af3e50ad78ff9d07bc0e8e79a5ff7cb3d1aacbfe 100644 --- a/tensorflow/core/lib/gtl/int_type.h +++ b/tensorflow/core/lib/gtl/int_type.h @@ -255,13 +255,13 @@ class IntType { value_ op arg_value; \ return *this; \ } - INT_TYPE_ASSIGNMENT_OP(+= ); - INT_TYPE_ASSIGNMENT_OP(-= ); - INT_TYPE_ASSIGNMENT_OP(*= ); - INT_TYPE_ASSIGNMENT_OP(/= ); - INT_TYPE_ASSIGNMENT_OP(<<= ); // NOLINT - INT_TYPE_ASSIGNMENT_OP(>>= ); // NOLINT - INT_TYPE_ASSIGNMENT_OP(%= ); + INT_TYPE_ASSIGNMENT_OP(+=); + INT_TYPE_ASSIGNMENT_OP(-=); + INT_TYPE_ASSIGNMENT_OP(*=); + INT_TYPE_ASSIGNMENT_OP(/=); + INT_TYPE_ASSIGNMENT_OP(<<=); // NOLINT + INT_TYPE_ASSIGNMENT_OP(>>=); // NOLINT + INT_TYPE_ASSIGNMENT_OP(%=); #undef INT_TYPE_ASSIGNMENT_OP ThisType& operator=(ValueType arg_value) { @@ -314,10 +314,10 @@ std::ostream& operator<<(std::ostream& os, // NOLINT INT_TYPE_ARITHMETIC_OP(+); INT_TYPE_ARITHMETIC_OP(-); INT_TYPE_ARITHMETIC_OP(*); -INT_TYPE_ARITHMETIC_OP(/ ); -INT_TYPE_ARITHMETIC_OP(<< ); // NOLINT -INT_TYPE_ARITHMETIC_OP(>> ); // NOLINT -INT_TYPE_ARITHMETIC_OP(% ); +INT_TYPE_ARITHMETIC_OP(/); +INT_TYPE_ARITHMETIC_OP(<<); // NOLINT +INT_TYPE_ARITHMETIC_OP(>>); // NOLINT +INT_TYPE_ARITHMETIC_OP(%); #undef INT_TYPE_ARITHMETIC_OP // -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------ @@ -345,12 +345,12 @@ INT_TYPE_ARITHMETIC_OP(% ); IntType id) { \ return val op id.value(); \ } -INT_TYPE_COMPARISON_OP(== ); // NOLINT -INT_TYPE_COMPARISON_OP(!= ); // NOLINT -INT_TYPE_COMPARISON_OP(< ); // NOLINT -INT_TYPE_COMPARISON_OP(<= ); // NOLINT -INT_TYPE_COMPARISON_OP(> ); // NOLINT -INT_TYPE_COMPARISON_OP(>= ); // NOLINT +INT_TYPE_COMPARISON_OP(==); // NOLINT +INT_TYPE_COMPARISON_OP(!=); // NOLINT +INT_TYPE_COMPARISON_OP(<); // NOLINT +INT_TYPE_COMPARISON_OP(<=); // NOLINT +INT_TYPE_COMPARISON_OP(>); // NOLINT +INT_TYPE_COMPARISON_OP(>=); // NOLINT #undef INT_TYPE_COMPARISON_OP } // namespace gtl diff --git a/tensorflow/core/lib/gtl/int_type_test.cc b/tensorflow/core/lib/gtl/int_type_test.cc index d3c405d9acdb221f465e98d957ba55ba6bc63f57..61d364017cb90933e8e9af7e800db4a6988d8442 100644 --- a/tensorflow/core/lib/gtl/int_type_test.cc +++ b/tensorflow/core/lib/gtl/int_type_test.cc @@ -42,7 +42,8 @@ class IntTypeTest : public ::testing::Test { // All tests below will be executed on all supported IntTypes. typedef ::testing::Types SupportedIntTypes; + Int64_IT, UInt64_IT, Long_IT> + SupportedIntTypes; TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes); @@ -232,7 +233,8 @@ TYPED_TEST(IntTypeTest, TestOperators) { TYPED_TEST(IntTypeTest, TestHashFunctor) { std::unordered_map map; + typename TestFixture::T::Hasher> + map; typename TestFixture::T a(0); map[a] = 'c'; EXPECT_EQ('c', map[a]); diff --git a/tensorflow/core/lib/gtl/optional.h b/tensorflow/core/lib/gtl/optional.h index 2ff8b9c7d1adbbc206e0429142389e9730efa33c..4ee3f88d186562e5d3261bc634952fb53b4f5774 100644 --- a/tensorflow/core/lib/gtl/optional.h +++ b/tensorflow/core/lib/gtl/optional.h @@ -478,7 +478,7 @@ class optional : private internal_optional::optional_data, return *this; } - // Copy assigment, standard semantics. + // Copy assignment, standard semantics. optional& operator=(const optional& src) = default; // Move assignment, standard semantics. @@ -593,12 +593,12 @@ class optional : private internal_optional::optional_data, assert(this->engaged_); return this->pointer(); } - constexpr const T& operator*() const & { return reference(); } + constexpr const T& operator*() const& { return reference(); } T& operator*() & { assert(this->engaged_); return reference(); } - constexpr const T&& operator*() const && { return std::move(reference()); } + constexpr const T&& operator*() const&& { return std::move(reference()); } T&& operator*() && { assert(this->engaged_); return std::move(reference()); @@ -621,7 +621,7 @@ class optional : private internal_optional::optional_data, // Use `opt.value()` to get a reference to underlying value. The constness // and lvalue/rvalue-ness of `opt` is preserved to the view of the T // subobject. - const T& value() const & { + const T& value() const& { CHECK(*this) << "Bad optional access"; return reference(); } @@ -633,7 +633,7 @@ class optional : private internal_optional::optional_data, CHECK(*this) << "Bad optional access"; return std::move(reference()); } - const T&& value() const && { // NOLINT(build/c++11) + const T&& value() const&& { // NOLINT(build/c++11) CHECK(*this) << "Bad optional access"; return std::move(reference()); } @@ -641,7 +641,7 @@ class optional : private internal_optional::optional_data, // Use `opt.value_or(val)` to get either the value of T or the given default // `val` in the empty case. template - constexpr T value_or(U&& v) const & { + constexpr T value_or(U&& v) const& { return static_cast(*this) ? **this : static_cast(std::forward(v)); } @@ -656,8 +656,8 @@ class optional : private internal_optional::optional_data, constexpr const T& reference() const { return *this->pointer(); } T& reference() { return *(this->pointer()); } - // T constraint checks. You can't have an optional of nullopt_t, in_place_t or - // a reference. + // T constraint checks. You can't have an optional of nullopt_t, in_place_t + // or a reference. static_assert( !std::is_same::type>::value, "optional is not allowed."); diff --git a/tensorflow/core/lib/gtl/optional_test.cc b/tensorflow/core/lib/gtl/optional_test.cc index 547bee7b75f3d05e290ec7d53d889ff7e82794a9..12b5bbc60be9961a5f852210c42479b2cd48ea92 100644 --- a/tensorflow/core/lib/gtl/optional_test.cc +++ b/tensorflow/core/lib/gtl/optional_test.cc @@ -24,17 +24,29 @@ limitations under the License. namespace tensorflow { namespace { -using tensorflow::gtl::optional; -using tensorflow::gtl::nullopt; -using tensorflow::gtl::nullopt_t; using tensorflow::gtl::in_place; using tensorflow::gtl::in_place_t; using tensorflow::gtl::make_optional; +using tensorflow::gtl::nullopt; +using tensorflow::gtl::nullopt_t; +using tensorflow::gtl::optional; -template string TypeQuals(T&) { return "&"; } -template string TypeQuals(T&&) { return "&&"; } -template string TypeQuals(const T&) { return "c&"; } -template string TypeQuals(const T&&) { return "c&&"; } +template +string TypeQuals(T&) { + return "&"; +} +template +string TypeQuals(T&&) { + return "&&"; +} +template +string TypeQuals(const T&) { + return "c&"; +} +template +string TypeQuals(const T&&) { + return "c&&"; +} struct StructorListener { int construct0 = 0; diff --git a/tensorflow/core/lib/gtl/top_n_test.cc b/tensorflow/core/lib/gtl/top_n_test.cc index fae85570dc071568a53abcb72fea6ffc22a465ea..ba30c072a9033073a7439f60dbfa3402dbfc5923 100644 --- a/tensorflow/core/lib/gtl/top_n_test.cc +++ b/tensorflow/core/lib/gtl/top_n_test.cc @@ -28,10 +28,10 @@ limitations under the License. namespace { +using tensorflow::string; using tensorflow::gtl::TopN; using tensorflow::random::PhiloxRandom; using tensorflow::random::SimplePhilox; -using tensorflow::string; // Move the contents from an owned raw pointer, returning by value. // Objects are easier to manage by value. diff --git a/tensorflow/core/lib/io/compression.cc b/tensorflow/core/lib/io/compression.cc index c12de98e40105907460f74f967e20aa41bdb0ceb..0d25bca9eccf2b28800a288858ffbc0caeb2dbd3 100644 --- a/tensorflow/core/lib/io/compression.cc +++ b/tensorflow/core/lib/io/compression.cc @@ -22,6 +22,6 @@ namespace compression { const char kNone[] = ""; const char kGzip[] = "GZIP"; -} -} -} +} // namespace compression +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/compression.h b/tensorflow/core/lib/io/compression.h index ef90c60a3a411cdc94a9f92522116db340e04f1b..4d8e7788cad823e0e79a4e9567c6f17a3d9259cf 100644 --- a/tensorflow/core/lib/io/compression.h +++ b/tensorflow/core/lib/io/compression.h @@ -23,8 +23,8 @@ namespace compression { extern const char kNone[]; extern const char kGzip[]; -} -} -} +} // namespace compression +} // namespace io +} // namespace tensorflow #endif // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index 403c82818ef3293a1dc027d362eb766906d0e94a..254fdf115da132343b8e6f176e67672a11281cd0 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -49,7 +49,7 @@ RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions( #endif // IS_SLIM_BUILD } else if (compression_type != compression::kNone) { LOG(ERROR) << "Unsupported compression_type:" << compression_type - << ". No comprression will be used."; + << ". No compression will be used."; } return options; } @@ -207,7 +207,7 @@ Status RecordReader::SkipNBytes(uint64 offset) { } } return Status::OK(); -} +} // namespace io SequentialRecordReader::SequentialRecordReader( RandomAccessFile* file, const RecordReaderOptions& options) diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 3657243c5d38a2076c1ca2c2e5f31b488b5a281b..ebc56482699948974ad434b6ea76fe26e1a4a5c5 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -49,7 +49,7 @@ RecordWriterOptions RecordWriterOptions::CreateRecordWriterOptions( #endif // IS_SLIM_BUILD } else if (compression_type != compression::kNone) { LOG(ERROR) << "Unsupported compression_type:" << compression_type - << ". No comprression will be used."; + << ". No compression will be used."; } return options; } diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc index 507c26a63ff587809e80739f8d015d1adcc3b21d..b7e51256a22b0d84e734e2a036a184b3adc3e547 100644 --- a/tensorflow/core/lib/io/recordio_test.cc +++ b/tensorflow/core/lib/io/recordio_test.cc @@ -218,8 +218,8 @@ TEST_F(RecordioTest, RandomRead) { // Tests of all the error paths in log_reader.cc follow: static void AssertHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain " - << expected; + EXPECT_TRUE(StringPiece(s).contains(expected)) + << s << " does not contain " << expected; } TEST_F(RecordioTest, ReadError) { diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc index 354c819b090ce5e04047f13d2ff19441a499d770..cba473927dd1fce30bbe690b4bfda1e382ca12c0 100644 --- a/tensorflow/core/lib/png/png_io.cc +++ b/tensorflow/core/lib/png/png_io.cc @@ -90,11 +90,8 @@ void WarningHandler(png_structp png_ptr, png_const_charp msg) { void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) { DecodeContext* const ctx = bit_cast(png_get_io_ptr(png_ptr)); if (static_cast(ctx->data_left) < length) { - if (!ctx->error_condition) { - VLOG(1) << "PNG read decoding error"; - ctx->error_condition = true; - } memset(data, 0, length); + png_error(png_ptr, "More bytes requested to read than available"); } else { memcpy(data, ctx->data, length); ctx->data += length; @@ -197,8 +194,8 @@ bool CommonInitDecode(StringPiece png_string, int desired_channels, int desired_channel_bits, DecodeContext* context) { CHECK(desired_channel_bits == 8 || desired_channel_bits == 16) << "desired_channel_bits = " << desired_channel_bits; - CHECK(0 <= desired_channels && desired_channels <= 4) << "desired_channels = " - << desired_channels; + CHECK(0 <= desired_channels && desired_channels <= 4) + << "desired_channels = " << desired_channels; context->error_condition = false; context->channels = desired_channels; context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context, diff --git a/tensorflow/core/lib/random/philox_random_test_utils.h b/tensorflow/core/lib/random/philox_random_test_utils.h index f4bb087e107e10f90196a807c03ed2407d9d1ad6..6c29ae6b6a224d9c0369172bbf21af465ad53a19 100644 --- a/tensorflow/core/lib/random/philox_random_test_utils.h +++ b/tensorflow/core/lib/random/philox_random_test_utils.h @@ -35,8 +35,8 @@ void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p, int64 size) { const int granularity = Distribution::kResultElementCount; - CHECK(size % granularity == 0) << " size: " << size - << " granularity: " << granularity; + CHECK(size % granularity == 0) + << " size: " << size << " granularity: " << granularity; Distribution dist; for (int i = 0; i < size; i += granularity) { diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h index 0e281403f8748ffbb7dbfac888cd2303c0a7253f..3fe1f9bc6cf06158df4811eaa177988b60890006 100644 --- a/tensorflow/core/lib/random/random_distributions.h +++ b/tensorflow/core/lib/random/random_distributions.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #define _USE_MATH_DEFINES -#include #include +#include #undef _USE_MATH_DEFINES #include @@ -27,7 +27,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/lib/random/philox_random.h" - namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc index 90d0dba4a7793f51472b2e5434489448eb40a498..85d68f456e1e27b7a62315f2b0a962843da87d52 100644 --- a/tensorflow/core/lib/random/random_distributions_test.cc +++ b/tensorflow/core/lib/random/random_distributions_test.cc @@ -45,8 +45,8 @@ void FillRandomsWithSingles(PhiloxRandom gen, int64 size) { int granularity = Distribution::kResultElementCount; - CHECK(size % granularity == 0) << " size: " << size - << " granularity: " << granularity; + CHECK(size % granularity == 0) + << " size: " << size << " granularity: " << granularity; SingleSampleAdapter single_samples(&gen); diff --git a/tensorflow/core/lib/strings/ordered_code.cc b/tensorflow/core/lib/strings/ordered_code.cc index af9a15125948d8ed390e5873f3677527ebddea8e..ef90050b4f628ab65c1dd939ba358fec714c95b5 100644 --- a/tensorflow/core/lib/strings/ordered_code.cc +++ b/tensorflow/core/lib/strings/ordered_code.cc @@ -472,7 +472,8 @@ void OrderedCode::WriteSignedNumIncreasing(string* dest, int64 val) { // buf = val in network byte order, sign extended to 10 bytes const char sign_byte = val < 0 ? '\xff' : '\0'; char buf[10] = { - sign_byte, sign_byte, + sign_byte, + sign_byte, }; StoreBigEndian64(buf + 2, val); static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch"); diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h index 5835b0101d9ede219a71acf554c5928e4b624ce7..2bc14945cd0413751003c03c7f5255c300790321 100644 --- a/tensorflow/core/lib/strings/strcat.h +++ b/tensorflow/core/lib/strings/strcat.h @@ -126,7 +126,7 @@ class AlphaNum { : piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {} AlphaNum(const Eigen::half &f); // NOLINT(runtime/explicit) - AlphaNum(Hex hex); // NOLINT(runtime/explicit) + AlphaNum(Hex hex); // NOLINT(runtime/explicit) AlphaNum(const char *c_str) : piece_(c_str) {} // NOLINT(runtime/explicit) AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit) diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 279a5876f962bb32b09a4b832794dfdcfffc6d46..267ce88440080399aae783903503f0bbd025d8b4 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -335,6 +335,13 @@ REGISTER_OP("Unpack") return Status::OK(); }); +REGISTER_OP("UnravelIndex") + .Input("indices: Tidx") + .Input("dims: Tidx") + .Output("output: Tidx") + .Attr("Tidx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { return Status::OK(); }); + // -------------------------------------------------------------------------- // TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph // in the N == 1 case to remove the node. @@ -701,10 +708,11 @@ REGISTER_OP("MatrixDiagPart") // -------------------------------------------------------------------------- REGISTER_OP("MatrixBandPart") .Input("input: T") - .Input("num_lower: int64") - .Input("num_upper: int64") + .Input("num_lower: Tindex") + .Input("num_upper: Tindex") .Output("band: T") .Attr("T: type") + .Attr("Tindex: {int32, int64} = DT_INT64") .SetShapeFn(shape_inference::UnchangedShape); // -------------------------------------------------------------------------- @@ -977,8 +985,8 @@ REGISTER_OP("GatherNd") if (c->Value(r_dim) > c->Rank(params)) { return errors::InvalidArgument( "indices.shape[-1] must be <= params.rank, but saw indices shape: ", - c->DebugString(indices), " and params shape: ", - c->DebugString(params)); + c->DebugString(indices), + " and params shape: ", c->DebugString(params)); } // Remove r_dim from indices to get output. @@ -1252,12 +1260,12 @@ REGISTER_OP("ReverseSequence") // Validate batch_dim and seq_dim against input. const int32 input_rank = c->Rank(input); if (batch_dim >= input_rank) { - return errors::InvalidArgument("batch_dim must be < input rank: ", - batch_dim, " vs. ", input_rank); + return errors::InvalidArgument( + "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank); } if (seq_dim >= input_rank) { - return errors::InvalidArgument("seq_dim must be < input rank: ", - seq_dim, " vs. ", input_rank); + return errors::InvalidArgument( + "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank); } DimensionHandle batch_dim_dim = c->Dim(input, batch_dim); @@ -2638,8 +2646,9 @@ Status ScatterNdShape(InferenceContext* c) { Status s = c->Merge(prefix_indices, prefix_updates, &unused); if (!s.ok()) { return errors::InvalidArgument( - "The outer ", outer_dims, " dimensions of indices.shape=", - c->DebugString(indices_shape), " must match the outer ", outer_dims, + "The outer ", outer_dims, + " dimensions of indices.shape=", c->DebugString(indices_shape), + " must match the outer ", outer_dims, " dimensions of updates.shape=", c->DebugString(updates_shape), ": ", s.error_message()); } diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index a182fd1c475ad44dcd0f05d42a9cbd6eeab16469..86d64635f4c1bc1c34407a517267758ce5cf60fc 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -142,8 +142,13 @@ TEST(ArrayOpsTest, Const_ShapeFn) { TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) { for (const char* op_name : { - "CheckNumerics", "Identity", "RefIdentity", "QuantizeAndDequantize", - "StopGradient", "ZerosLike", "OnesLike", + "CheckNumerics", + "Identity", + "RefIdentity", + "QuantizeAndDequantize", + "StopGradient", + "ZerosLike", + "OnesLike", }) { ShapeInferenceTestOp op(op_name); INFER_OK(op, "?", "in0"); diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc index a64582acee7e84a6eb5c73a61d57148d994558c9..0a62965eedd3c053dff558108f21e99a77407587 100644 --- a/tensorflow/core/ops/batch_ops.cc +++ b/tensorflow/core/ops/batch_ops.cc @@ -26,6 +26,7 @@ REGISTER_OP("Batch") .Output("id: int64") .Attr("num_batch_threads: int") .Attr("max_batch_size: int") + .Attr("max_enqueued_batches: int = 10") .Attr("batch_timeout_micros: int") .Attr("allowed_batch_sizes: list(int) = []") .Attr("grad_timeout_micros: int") diff --git a/tensorflow/core/ops/candidate_sampling_ops_test.cc b/tensorflow/core/ops/candidate_sampling_ops_test.cc index c79b4439148e5795e313c71bbce35c82242cd335..f367371604097b7a500d746a3b8a8a5906082cbb 100644 --- a/tensorflow/core/ops/candidate_sampling_ops_test.cc +++ b/tensorflow/core/ops/candidate_sampling_ops_test.cc @@ -23,9 +23,12 @@ namespace tensorflow { TEST(CandidateSamplerOpsTest, CandidateSampler_ShapeFn) { for (const char* op_name : { - "AllCandidateSampler", "FixedUnigramCandidateSampler", - "LearnedUnigramCandidateSampler", "LogUniformCandidateSampler", - "ThreadUnsafeUnigramCandidateSampler", "UniformCandidateSampler", + "AllCandidateSampler", + "FixedUnigramCandidateSampler", + "LearnedUnigramCandidateSampler", + "LogUniformCandidateSampler", + "ThreadUnsafeUnigramCandidateSampler", + "UniformCandidateSampler", }) { ShapeInferenceTestOp op(op_name); TF_ASSERT_OK(NodeDefBuilder("test", op.name) diff --git a/tensorflow/core/ops/compat/backwards_compatibility_test.cc b/tensorflow/core/ops/compat/backwards_compatibility_test.cc index add05d6610ae62158b653d27699f61bc511ee3b6..6e05ae4be4fb967ac8dcc5a03fa548c7cb6c0f9b 100644 --- a/tensorflow/core/ops/compat/backwards_compatibility_test.cc +++ b/tensorflow/core/ops/compat/backwards_compatibility_test.cc @@ -25,8 +25,9 @@ namespace tensorflow { namespace { TEST(BackwardsCompatibilityTest, IsCompatible) { - OpCompatibilityLib compatibility( - "tensorflow/core/ops", strings::StrCat("v", TF_MAJOR_VERSION), nullptr); + OpCompatibilityLib compatibility("tensorflow/core/ops", + strings::StrCat("v", TF_MAJOR_VERSION), + nullptr); Env* env = Env::Default(); int changed_ops = 0; diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 65ab81931ad4261f432034f73269d1e8c8005384..fc9e5b02a2253621203a47c5f7d1b7d311c82a97 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -7973,6 +7973,83 @@ op { minimum: 1 } } +op { + name: "Batch" + input_arg { + name: "in_tensors" + type_list_attr: "T" + } + output_arg { + name: "batched_tensors" + type_list_attr: "T" + } + output_arg { + name: "batch_index" + type: DT_INT64 + } + output_arg { + name: "id" + type: DT_INT64 + } + attr { + name: "num_batch_threads" + type: "int" + } + attr { + name: "max_batch_size" + type: "int" + } + attr { + name: "max_enqueued_batches" + type: "int" + default_value { + i: 10 + } + } + attr { + name: "batch_timeout_micros" + type: "int" + } + attr { + name: "allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "grad_timeout_micros" + type: "int" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "batching_queue" + type: "string" + default_value { + s: "" + } + } + attr { + name: "T" + type: "list(type)" + has_minimum: true + minimum: 1 + } +} op { name: "BatchCholesky" input_arg { @@ -17136,6 +17213,24 @@ op { type: DT_STRING } } +op { + name: "EnqueueInQueueDataset" + input_arg { + name: "queue" + type: DT_VARIANT + } + input_arg { + name: "components" + type_list_attr: "Tcomponents" + } + attr { + name: "Tcomponents" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Enter" input_arg { @@ -21533,53 +21628,6 @@ op { } } } -op { - name: "IgnoreErrorsDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } - is_stateful: true -} -op { - name: "IgnoreErrorsDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} op { name: "Imag" input_arg { @@ -24840,6 +24888,42 @@ op { type: "type" } } +op { + name: "MatrixBandPart" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "num_lower" + type_attr: "Tindex" + } + input_arg { + name: "num_upper" + type_attr: "Tindex" + } + output_arg { + name: "band" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tindex" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} op { name: "MatrixDeterminant" input_arg { @@ -32096,6 +32180,48 @@ op { minimum: 1 } } +op { + name: "PrependFromQueueAndPaddedBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "padded_shapes" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "padding_values" + type_list_attr: "Toutput_types" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "Toutput_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } +} op { name: "PreventGradient" input_arg { @@ -42820,6 +42946,36 @@ op { } is_stateful: true } +op { + name: "ResourceScatterUpdate" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} op { name: "ResourceSparseApplyAdadelta" input_arg { @@ -46563,6 +46719,49 @@ op { } } } +op { + name: "Roll" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "shift" + type_attr: "Tshift" + } + input_arg { + name: "axis" + type_attr: "Taxis" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshift" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Taxis" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} op { name: "Round" input_arg { @@ -65191,6 +65390,34 @@ op { } } } +op { + name: "UnravelIndex" + input_arg { + name: "indices" + type_attr: "Tidx" + } + input_arg { + name: "dims" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "Tidx" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} op { name: "UnsortedSegmentMax" input_arg { diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index 12c27c79840de6981629984732147671b8a1e28e..4f946fb3ca7608816180351b7753d01f13d469f2 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -171,29 +171,10 @@ Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { return Status::OK(); } -Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { - ShapeHandle handle; - DimensionHandle unused_handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - for (int i = 1; i < c->num_inputs(); ++i) { - TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); - } - for (int i = 0; i < c->num_outputs(); ++i) { - c->set_output(i, c->Scalar()); - } - return Status::OK(); -} - Status TwoElementOutput(InferenceContext* c) { c->set_output(0, c->Vector(2)); return Status::OK(); } - -Status ScalarOutput(InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); -} } // namespace REGISTER_OP("RandomShuffleQueue") diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 2cae814eab1602e72ffcfd100f9813f8f41c6ac9..9e98f56c745a2b0b16531e2785e43ba8464d42b8 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -107,13 +107,6 @@ REGISTER_OP("SkipDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); -REGISTER_OP("IgnoreErrorsDataset") - .Input("input_dataset: variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); - REGISTER_OP("BytesProducedStatsDataset") .Input("input_dataset: variant") .Input("tag: string") @@ -491,4 +484,29 @@ REGISTER_OP("StatsAggregatorSummary") .Output("summary: string") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("PrependFromQueueAndPaddedBatchDataset") + .Input("input_dataset: variant") + .Input("batch_size: int64") + .Input("padded_shapes: N * int64") + .Input("padding_values: Toutput_types") + .Output("handle: variant") + .Attr("Toutput_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("N: int >= 1") + // TODO(ebrevdo): Validate that `padded_shapes` are all vectors, the lengths + // of `Toutput_types` and `output_shapes` are `N`, that the + // length of `output_types` is `N`, the `output_shapes` are + // (as far as possible to tell statically) compatible with `padded_shapes`, + // and that `padding_values` are all scalars. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("EnqueueInQueueDataset") + .Input("queue: variant") + .Input("components: Tcomponents") + .Attr("Tcomponents: list(type) >= 1") + .SetIsStateful() // To avoid CSE on multiple calls to Enqueue. + // TODO(ebrevdo): SetShapeFn to test input dtypes and shapes by + // reading from queue handle (is that even possible?). + .SetShapeFn(shape_inference::NoOutputs); + } // namespace tensorflow diff --git a/tensorflow/core/ops/functional_grad.cc b/tensorflow/core/ops/functional_grad.cc index 6df3536795ce7772faef72d63e0cb276719d7b44..eeccb72da65d7cef1073f54bf7f639436f69e930 100644 --- a/tensorflow/core/ops/functional_grad.cc +++ b/tensorflow/core/ops/functional_grad.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/function.h" #include +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index 515b31623bfbffe12f7722becd839d99279d4fdc..9e18d20db65075e471862def3924811b260f8a08 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -48,4 +48,63 @@ REGISTER_OP("RemoteCall") .Attr("Tout: list(type)") .Attr("f: func") .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("_If") + .Input("cond: Tcond") + .Input("input: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("Tin: list(type)") + .Attr("Tout: list(type)") + .Attr("then_branch: func") + .Attr("else_branch: func") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = cond ? then_branch(input) : else_branch(input) + +cond: A Tensor. If the tensor is a scalar of non-boolean type, the + scalar is converted to a boolean according to the + following rule: if the scalar is a numerical value, non-zero means + True and zero means False; if the scalar is a string, non-empty + means True and empty means False. If the tensor is not a scalar, + being empty means False and being non-empty means True. +input: A list of input tensors. +then_branch: A function that takes 'inputs' and returns a list of + tensors, whose types are the same as what else_branch returns. +else_branch: A function that takes 'inputs' and returns a list of + tensors. whose types are the same as what then_branch returns. +)doc"); + +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("_While") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(i)); + } + return Status::OK(); + }) + .Doc(R"doc( +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified + by T. +)doc"); + } // end namespace tensorflow diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 7484ebb07808a7670d80a4bfdb590e85b94de04f..c3b08e067a2c35432e45f98ef9d57af629b90e02 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -25,42 +25,6 @@ using shape_inference::ShapeHandle; namespace { -const char kDecodeJpegCommonDocStr[] = R"doc( -The attr `channels` indicates the desired number of color channels for the -decoded image. - -Accepted values are: - -* 0: Use the number of channels in the JPEG-encoded image. -* 1: output a grayscale image. -* 3: output an RGB image. - -If needed, the JPEG-encoded image is transformed to match the requested number -of color channels. - -The attr `ratio` allows downscaling the image by an integer factor during -decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -downscaling the image later. - -)doc"; - -const char kDecodeJpegCommonParamsDocStr[] = R"doc( -channels: Number of color channels for the decoded image. -ratio: Downscaling ratio. -fancy_upscaling: If true use a slower but nicer upscaling of the - chroma planes (yuv420/422 only). -try_recover_truncated: If true try to recover an image from truncated input. -acceptable_fraction: The minimum required fraction of lines before a truncated - input is accepted. -dct_method: string specifying a hint about the algorithm used for - decompression. Defaults to "" which maps to a system-specific - default. Currently valid values are ["INTEGER_FAST", - "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal - jpeg library changes to a version that does not have that specific - option.) -image: 3-D with shape `[height, width, channels]`.. -)doc"; - // Sets output[0] to shape [batch_dim,height,width,channel_dim], where // height and width come from the size_tensor. Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, @@ -491,6 +455,17 @@ REGISTER_OP("SampleDistortedBoundingBox") .Attr("use_image_if_no_bounding_boxes: bool = false") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { + // Get inputs and validate ranks. + ShapeHandle image_size; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size)); + ShapeHandle bounding_boxes; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes)); + // image_size: 1-D with [height, width, channels] + // bounding_boxes: 3-D with shape [batch, N, 4] + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused)); + c->set_output(0, c->Vector(3)); c->set_output(1, c->Vector(3)); c->set_output(2, c->MakeShape({1, 1, 4})); @@ -513,6 +488,19 @@ REGISTER_OP("SampleDistortedBoundingBoxV2") .Attr("use_image_if_no_bounding_boxes: bool = false") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { + // Get inputs and validate ranks. + ShapeHandle image_size; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size)); + ShapeHandle bounding_boxes; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes)); + ShapeHandle min_object_covered; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered)); + // image_size: 1-D with [height, width, channels] + // bounding_boxes: 3-D with shape [batch, N, 4] + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused)); + c->set_output(0, c->Vector(3)); c->set_output(1, c->Vector(3)); c->set_output(2, c->MakeShape({1, 1, 4})); @@ -622,6 +610,21 @@ REGISTER_OP("NonMaxSuppression") .Output("selected_indices: int32") .Attr("iou_threshold: float = 0.5") .SetShapeFn([](InferenceContext* c) { + // Get inputs and validate ranks. + ShapeHandle boxes; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); + ShapeHandle scores; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); + ShapeHandle max_output_size; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); + // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. + DimensionHandle unused; + // The boxes[0] and scores[0] are both num_boxes. + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); + // The boxes[1] is 4. + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); + c->set_output(0, c->Vector(c->UnknownDim())); return Status::OK(); }); @@ -633,6 +636,23 @@ REGISTER_OP("NonMaxSuppressionV2") .Input("iou_threshold: float") .Output("selected_indices: int32") .SetShapeFn([](InferenceContext* c) { + // Get inputs and validate ranks. + ShapeHandle boxes; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); + ShapeHandle scores; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); + ShapeHandle max_output_size; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); + ShapeHandle iou_threshold; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold)); + // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. + DimensionHandle unused; + // The boxes[0] and scores[0] are both num_boxes. + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); + // The boxes[1] is 4. + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); + c->set_output(0, c->Vector(c->UnknownDim())); return Status::OK(); }); diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index fa40f41bb949767f76ee9dae60a3f6312bd80186..3487c955cbb2b06bdb33000da549c0fc6e7f86e8 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -241,6 +241,7 @@ REGISTER_OP("TensorListSetItem") DataType t; TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t)); auto* handle_data = c->input_handle_shapes_and_types(0); + c->set_output(0, c->Scalar()); if (handle_data == nullptr) { c->set_output_handle_shapes_and_types(0, {{c->UnknownShape(), t}}); return Status::OK(); diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc index a67267418d608e7c824030225f906b010794a160..444aa8b9544c62d81f288f21e4eaaac23d8691cb 100644 --- a/tensorflow/core/ops/lookup_ops.cc +++ b/tensorflow/core/ops/lookup_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/shape_inference.h" @@ -102,6 +103,8 @@ REGISTER_OP("LookupTableFindV2") c->set_output(0, c->UnknownShape()); return Status::OK(); }); +WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableFindV2"); +// TODO(b/72710477): Update this. REGISTER_OP("LookupTableInsert") .Input("table_handle: Ref(string)") diff --git a/tensorflow/core/ops/manip_ops.cc b/tensorflow/core/ops/manip_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..95b4774fe6e230800e71d237c2cd027acf6e054b --- /dev/null +++ b/tensorflow/core/ops/manip_ops.cc @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +// -------------------------------------------------------------------------- +REGISTER_OP("Roll") + .Input("input: T") + .Input("shift: Tshift") + .Input("axis: Taxis") + .Output("output: T") + .Attr("T: type") + .Attr("Tshift: {int32,int64}") + .Attr("Taxis: {int32,int64}") + .SetShapeFn(shape_inference::UnchangedShape); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index dd484c3ee752b47f4a196cd45c6e26984b5ef0bd..8f33d51d5a20fc207102e4bf79e7605d9817eb9f 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1065,6 +1065,26 @@ REGISTER_OP("UnsortedSegmentMax") .Attr("Tnumsegments: {int32,int64} = DT_INT32") .SetShapeFn(UnsortedSegmentReductionShapeFn); +REGISTER_OP("UnsortedSegmentMin") + .Input("data: T") + .Input("segment_ids: Tindices") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(UnsortedSegmentReductionShapeFn); + +REGISTER_OP("UnsortedSegmentProd") + .Input("data: T") + .Input("segment_ids: Tindices") + .Input("num_segments: Tnumsegments") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("Tindices: {int32,int64}") + .Attr("Tnumsegments: {int32,int64} = DT_INT32") + .SetShapeFn(UnsortedSegmentReductionShapeFn); + REGISTER_OP("SparseSegmentSum") .Input("data: T") .Input("indices: Tidx") @@ -1172,12 +1192,12 @@ Status RangeSize(const Tensor* start_t, const Tensor* limit_t, T limit = limit_t->scalar()(); T delta = delta_t->scalar()(); if (start > limit && delta > 0) { - return errors::InvalidArgument("Requires start <= limit when delta > 0: ", - start, "/", limit); + return errors::InvalidArgument( + "Requires start <= limit when delta > 0: ", start, "/", limit); } if (start < limit && delta < 0) { - return errors::InvalidArgument("Requires start >= limit when delta < 0: ", - start, "/", limit); + return errors::InvalidArgument( + "Requires start >= limit when delta < 0: ", start, "/", limit); } if (delta == 0) { return errors::InvalidArgument("Requires delta != 0"); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 3f72b415699562a0d79fc1f41ff1b4a360bfc7db..67481fd202b3c3b35033b72e4c1c5fd294d98696 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1155,9 +1155,9 @@ Status TopKShapeFn(InferenceContext* c) { DimensionHandle last_dim = c->Dim(input, -1); if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) && c->Value(last_dim) < c->Value(k_dim)) { - return errors::InvalidArgument("input must have last dimension >= k = ", - c->Value(k_dim), " but is ", - c->Value(last_dim)); + return errors::InvalidArgument( + "input must have last dimension >= k = ", c->Value(k_dim), " but is ", + c->Value(last_dim)); } // Replace last_dim with k_dim. @@ -1211,9 +1211,9 @@ REGISTER_OP("NthElement") DimensionHandle last_dim = c->Dim(input, -1); if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) && c->Value(last_dim) <= c->Value(n_dim)) { - return errors::InvalidArgument("Input must have last dimension > n = ", - c->Value(n_dim), " but is ", - c->Value(last_dim)); + return errors::InvalidArgument( + "Input must have last dimension > n = ", c->Value(n_dim), + " but is ", c->Value(last_dim)); } // Reduce last_dim for output tensor @@ -1818,7 +1818,7 @@ REGISTER_OP("_MklMaxPool") .Input("input: T") .Input("mkl_input: uint8") .Output("output: T") -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML .Output("workspace: T") #else .Output("workspace: uint8") @@ -1844,7 +1844,7 @@ REGISTER_OP("_MklMaxPoolGrad") .Input("orig_input: T") .Input("orig_output: T") .Input("grad: T") -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML .Input("workspace: T") #else .Input("workspace: uint8") @@ -1916,7 +1916,7 @@ REGISTER_OP("_MklLRN") .Input("input: T") .Input("mkl_input: uint8") .Output("output: T") -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML .Output("workspace: T") #else .Output("workspace: uint8") @@ -1944,7 +1944,7 @@ REGISTER_OP("_MklLRNGrad") .Input("input_grads: T") .Input("input_image: T") .Input("output_image: T") -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML .Input("workspace: T") #else .Input("workspace: uint8") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index b57206c9c4f53fbf73537f466206f5c1b0caefcb..45ff08f38b134f963460d15f949411a7f1619d0c 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -2763,6 +2763,13 @@ op { name: "max_batch_size" type: "int" } + attr { + name: "max_enqueued_batches" + type: "int" + default_value { + i: 10 + } + } attr { name: "batch_timeout_micros" type: "int" @@ -7644,6 +7651,24 @@ op { type: DT_STRING } } +op { + name: "EnqueueInQueueDataset" + input_arg { + name: "queue" + type: DT_VARIANT + } + input_arg { + name: "components" + type_list_attr: "Tcomponents" + } + attr { + name: "Tcomponents" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Enter" input_arg { @@ -10219,29 +10244,6 @@ op { } } } -op { - name: "IgnoreErrorsDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} op { name: "Imag" input_arg { @@ -12330,11 +12332,11 @@ op { } input_arg { name: "num_lower" - type: DT_INT64 + type_attr: "Tindex" } input_arg { name: "num_upper" - type: DT_INT64 + type_attr: "Tindex" } output_arg { name: "band" @@ -12344,6 +12346,19 @@ op { name: "T" type: "type" } + attr { + name: "Tindex" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } } op { name: "MatrixDeterminant" @@ -15926,6 +15941,48 @@ op { minimum: 1 } } +op { + name: "PrependFromQueueAndPaddedBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "padded_shapes" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "padding_values" + type_list_attr: "Toutput_types" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "Toutput_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } +} op { name: "PreventGradient" input_arg { @@ -20925,27 +20982,6 @@ op { attr { name: "dtype" type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT32 - type: DT_UINT8 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_INT64 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_BFLOAT16 - type: DT_UINT16 - type: DT_COMPLEX128 - type: DT_HALF - type: DT_UINT32 - type: DT_UINT64 - } - } } attr { name: "Tindices" @@ -22082,6 +22118,49 @@ op { } } } +op { + name: "Roll" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "shift" + type_attr: "Tshift" + } + input_arg { + name: "axis" + type_attr: "Taxis" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshift" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Taxis" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} op { name: "Round" input_arg { @@ -30835,6 +30914,34 @@ op { } } } +op { + name: "UnravelIndex" + input_arg { + name: "indices" + type_attr: "Tidx" + } + input_arg { + name: "dims" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "Tidx" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} op { name: "UnsortedSegmentMax" input_arg { diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index f6cfbf873a024e3a035842468fc5ccca2d341ce7..8dae7e1ff5f872c33dd56509c0349180cec78593 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -193,7 +193,7 @@ REGISTER_OP("ResourceScatterUpdate") .Input("resource: resource") .Input("indices: Tindices") .Input("updates: dtype") - .Attr("dtype: numbertype") + .Attr("dtype: type") .Attr("Tindices: {int32, int64}") .SetShapeFn([](InferenceContext* c) { ShapeAndType handle_shape_and_type; diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc index e67d95fa8cb8466365bf12a46a123de174103d0f..4025070adb2b193edacdaf728f240961bf9d2530 100644 --- a/tensorflow/core/ops/sdca_ops.cc +++ b/tensorflow/core/ops/sdca_ops.cc @@ -19,8 +19,8 @@ limitations under the License. namespace tensorflow { -using shape_inference::ShapeHandle; using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; // -------------------------------------------------------------------------- static Status ApplySdcaOptimizerShapeFn(InferenceContext* c) { diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 8beb28de0a2e32832b2db60eeb8272a88536e91f..e4c5bcfb540660a609aca013b795d566e69f54a8 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -137,9 +137,9 @@ REGISTER_OP("Substr") DimensionHandle pos_dim = c->Dim(pos_shape, i); DimensionHandle len_dim = c->Dim(len_shape, i); if (c->Value(pos_dim) != c->Value(len_dim)) { - return errors::InvalidArgument("pos and len shapes must match: ", - c->DebugString(pos_shape), " vs. ", - c->DebugString(len_shape)); + return errors::InvalidArgument( + "pos and len shapes must match: ", c->DebugString(pos_shape), + " vs. ", c->DebugString(len_shape)); } } // c->input(0) is the ShapeHandle to input strings diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index e8d03877c91402394567b05df8b738de1c15c8c6..6ce9595fb60b78525bde19515077f7245a219d39 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -22,48 +22,6 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -const char kAddSignCommonDocStr[] = R"doc( -Update '*var' according to the AddSign update. - -m_t <- beta1 * m_{t-1} + (1 - beta1) * g -update <- (alpha + sign_decay * sign(g) *sign(m)) * g -variable <- variable - lr_t * update - -var: Should be from a Variable(). -m: Should be from a Variable(). -lr: Scaling factor. Must be a scalar. -sign_decay: Must be a scalar. -alpha: Must be a scalar. -beta: Must be a scalar. -grad: The gradient. -)doc"; - -const char kPowerSignCommonDocStr[] = R"doc( -Update '*var' according to the AddSign update. - -m_t <- beta1 * m_{t-1} + (1 - beta1) * g -update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g -variable <- variable - lr_t * update - -var: Should be from a Variable(). -m: Should be from a Variable(). -lr: Scaling factor. Must be a scalar. -logbase: Must be a scalar. -sign_decay: Must be a scalar. -beta: Must be a scalar. -grad: The gradient. -)doc"; - -const char kOutDocStr[] = R"doc( -out: Same as "var". -)doc"; - -const char kLockDocStr[] = R"doc( -use_locking: If `True`, updating of the var and m tensors is - protected by a lock; otherwise the behavior is undefined, but may exhibit less - contention. -)doc"; - static ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) { auto* handle_data = c->input_handle_shapes_and_types(input); if (handle_data != nullptr && !handle_data->empty() && diff --git a/tensorflow/core/ops/training_ops_test.cc b/tensorflow/core/ops/training_ops_test.cc index de4e3cd9e70014ea9b29d4d473d94c0abb52eabc..0f309c1f4e956c98b6f20fa3b6c810116a2b339c 100644 --- a/tensorflow/core/ops/training_ops_test.cc +++ b/tensorflow/core/ops/training_ops_test.cc @@ -24,7 +24,7 @@ static void TestGradAndIndicesErrorHandling(const ShapeInferenceTestOp& op, string shape_spec_middle, const string& shape_spec_end = "") { auto shape_spec = [&shape_spec_middle, shape_spec_end]( - const char* var_spec, const char* grad_indices_spec) { + const char* var_spec, const char* grad_indices_spec) { return strings::StrCat(var_spec, ";", shape_spec_middle, ";", grad_indices_spec, shape_spec_end); }; diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 07aecf848326b23b18b58ae60e896150ab7b4ef9..9ba25dea4fb278cbfaf4080e21beef8a3e9de769 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -57,6 +57,17 @@ cc_library( ], ) +cc_library( + name = "gcs_throttle", + srcs = ["gcs_throttle.cc"], + hdrs = ["gcs_throttle.h"], + copts = tf_copts(), + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/core:lib", + ], +) + cc_library( name = "gcs_file_system", srcs = ["gcs_file_system.cc"], @@ -69,6 +80,7 @@ cc_library( ":expiring_lru_cache", ":file_block_cache", ":gcs_dns_cache", + ":gcs_throttle", ":google_auth_provider", ":http_request", ":retrying_file_system", @@ -271,6 +283,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gcs_throttle_test", + size = "small", + srcs = ["gcs_throttle_test.cc"], + linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]), + deps = [ + ":gcs_throttle", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "curl_http_request_test", size = "small", diff --git a/tensorflow/core/platform/cloud/file_block_cache.cc b/tensorflow/core/platform/cloud/file_block_cache.cc index 0375af516b0504e8b527409ba22da0caa149ad9d..6add1142a15fb69044828bd82a6d6e838959de08 100644 --- a/tensorflow/core/platform/cloud/file_block_cache.cc +++ b/tensorflow/core/platform/cloud/file_block_cache.cc @@ -131,6 +131,7 @@ Status FileBlockCache::MaybeFetch(const Key& key, block->mu.lock(); // Reacquire the lock immediately afterwards if (status.ok()) { block->data.resize(bytes_transferred, 0); + block->data.shrink_to_fit(); downloaded_block = true; block->state = FetchState::FINISHED; } else { diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.cc b/tensorflow/core/platform/cloud/gcs_dns_cache.cc index 2b0e55bf371da9660f1422cef97e3ec1a25a9b61..4d9aff4d24f06c7bd1269ad590c9687092a5b132 100644 --- a/tensorflow/core/platform/cloud/gcs_dns_cache.cc +++ b/tensorflow/core/platform/cloud/gcs_dns_cache.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include #else +#include #include #include -#include #endif #include diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 520720372d9ff12556110967d2c47703ec4b5132..01ca0d76bab2720513775ef33ff8670bd148c241 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -50,7 +50,6 @@ limitations under the License. #endif namespace tensorflow { - namespace { constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/"; @@ -59,9 +58,6 @@ constexpr char kGcsUploadUriBase[] = constexpr char kStorageHost[] = "storage.googleapis.com"; constexpr size_t kReadAppendableFileBufferSize = 1024 * 1024; // In bytes. constexpr int kGetChildrenDefaultPageSize = 1000; -// Initial delay before retrying a GCS upload. -// Subsequent delays can be larger due to exponential back-off. -constexpr uint64 kUploadRetryDelayMicros = 1000000L; // The HTTP response code "308 Resume Incomplete". constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308; // The environment variable that overrides the size of the readahead buffer. @@ -120,6 +116,15 @@ constexpr char kWriteRequestTimeout[] = "GCS_WRITE_REQUEST_TIMEOUT_SECS"; // The environment variable to configure an additional header to send with // all requests to GCS (format HEADERNAME:HEADERCONTENT) constexpr char kAdditionalRequestHeader[] = "GCS_ADDITIONAL_REQUEST_HEADER"; +// The environment variable to configure the throttle (format: ) +constexpr char kThrottleRate[] = "GCS_THROTTLE_TOKEN_RATE"; +// The environment variable to configure the token bucket size (format: ) +constexpr char kThrottleBucket[] = "GCS_THROTTLE_BUCKET_SIZE"; +// The environment variable that controls the number of tokens per request. +// (format: ) +constexpr char kTokensPerRequest[] = "GCS_TOKENS_PER_REQUEST"; +// The environment variable to configure the initial tokens (format: ) +constexpr char kInitialTokens[] = "GCS_INITIAL_TOKENS"; // TODO: DO NOT use a hardcoded path Status GetTmpFilename(string* filename) { @@ -725,6 +730,26 @@ GcsFileSystem::GcsFileSystem() if (GetEnvVar(kWriteRequestTimeout, strings::safe_strtou32, &timeout_value)) { timeouts_.write = timeout_value; } + + int64 token_value; + if (GetEnvVar(kThrottleRate, strings::safe_strto64, &token_value)) { + GcsThrottleConfig config; + config.enabled = true; + config.token_rate = token_value; + + if (GetEnvVar(kThrottleBucket, strings::safe_strto64, &token_value)) { + config.bucket_size = token_value; + } + + if (GetEnvVar(kTokensPerRequest, strings::safe_strto64, &token_value)) { + config.tokens_per_request = token_value; + } + + if (GetEnvVar(kInitialTokens, strings::safe_strto64, &token_value)) { + config.initial_tokens = token_value; + } + throttle_.SetConfig(config); + } } GcsFileSystem::GcsFileSystem( @@ -778,7 +803,9 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object)); std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request), + "when reading gs://", bucket, "/", object); + request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket, "/", request->EscapeString(object))); request->SetRange(offset, offset + n - 1); @@ -793,6 +820,8 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ " << offset << " of size: " << bytes_read; + throttle_.RecordResponse(bytes_read); + if (bytes_read < block_size()) { // Check stat cache to see if we encountered an interrupted read. FileStatistics stat; @@ -930,41 +959,43 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket, "'object' must be a non-empty string. (File: %s)", fname.c_str())); } - StatCache::ComputeFunc compute_func = - [this, &bucket, &object](const string& fname, FileStatistics* stat) { - std::vector output_buffer; - std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); - request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/", - request->EscapeString(object), - "?fields=size%2Cupdated")); - request->SetResultBuffer(&output_buffer); - request->SetTimeouts(timeouts_.connect, timeouts_.idle, - timeouts_.metadata); + StatCache::ComputeFunc compute_func = [this, &bucket, &object]( + const string& fname, + FileStatistics* stat) { + std::vector output_buffer; + std::unique_ptr request; + TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request), + " when reading metadata of gs://", bucket, + "/", object); - TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), - " when reading metadata of gs://", - bucket, "/", object); + request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/", + request->EscapeString(object), + "?fields=size%2Cupdated")); + request->SetResultBuffer(&output_buffer); + request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata); - Json::Value root; - TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), + " when reading metadata of gs://", bucket, + "/", object); - // Parse file size. - TF_RETURN_IF_ERROR(GetInt64Value(root, "size", &stat->length)); + Json::Value root; + TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root)); - // Parse file modification time. - string updated; - TF_RETURN_IF_ERROR(GetStringValue(root, "updated", &updated)); - TF_RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->mtime_nsec))); + // Parse file size. + TF_RETURN_IF_ERROR(GetInt64Value(root, "size", &stat->length)); - VLOG(1) << "Stat of: gs://" << bucket << "/" << object << " -- " - << " length: " << stat->length - << "; mtime_nsec: " << stat->mtime_nsec - << "; updated: " << updated; + // Parse file modification time. + string updated; + TF_RETURN_IF_ERROR(GetStringValue(root, "updated", &updated)); + TF_RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->mtime_nsec))); - stat->is_directory = false; - return Status::OK(); - }; + VLOG(1) << "Stat of: gs://" << bucket << "/" << object << " -- " + << " length: " << stat->length + << "; mtime_nsec: " << stat->mtime_nsec << "; updated: " << updated; + + stat->is_directory = false; + return Status::OK(); + }; TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute(fname, stat, compute_func)); if (stat->is_directory) { @@ -1442,6 +1473,10 @@ Status GcsFileSystem::CreateHttpRequest(std::unique_ptr* request) { additional_header_->second); } + if (!throttle_.AdmitRequest()) { + return errors::Unavailable("Request throttled"); + } + *request = std::move(new_request); return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index 2eae39608e38184450290e86bc12d81494bb8302..e8edde8a445aad4c0310394d89480dc6ae445dfa 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/expiring_lru_cache.h" #include "tensorflow/core/platform/cloud/file_block_cache.h" #include "tensorflow/core/platform/cloud/gcs_dns_cache.h" +#include "tensorflow/core/platform/cloud/gcs_throttle.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/cloud/retrying_file_system.h" #include "tensorflow/core/platform/file_system.h" @@ -194,6 +195,7 @@ class GcsFileSystem : public FileSystem { std::unique_ptr http_request_factory_; std::unique_ptr file_block_cache_; std::unique_ptr dns_cache_; + GcsThrottle throttle_; using StatCache = ExpiringLRUCache; std::unique_ptr stat_cache_; diff --git a/tensorflow/core/platform/cloud/gcs_throttle.cc b/tensorflow/core/platform/cloud/gcs_throttle.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb5f8958a37f45aeac1a836ca037f91931bb34a6 --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_throttle.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/gcs_throttle.h" + +#include + +namespace tensorflow { + +GcsThrottle::GcsThrottle(EnvTime* env_time) + : last_updated_secs_(env_time->NowSeconds()), + available_tokens_(0), + env_time_(env_time) {} + +bool GcsThrottle::AdmitRequest() { + mutex_lock l(mu_); + if (!config_.enabled) return true; + UpdateState(); + if (available_tokens_ < config_.tokens_per_request) { + return false; + } + available_tokens_ -= config_.tokens_per_request; + return true; +} + +void GcsThrottle::RecordResponse(size_t num_bytes) { + mutex_lock l(mu_); + if (!config_.enabled) return; + UpdateState(); + available_tokens_ -= request_bytes_to_tokens(num_bytes); +} + +void GcsThrottle::SetConfig(GcsThrottleConfig config) { + mutex_lock l(mu_); + config_ = config; + available_tokens_ = config.initial_tokens; + last_updated_secs_ = env_time_->NowSeconds(); +} + +void GcsThrottle::UpdateState() { + // TODO(b/72643279): Switch to a monotonic clock. + int64 now = env_time_->NowSeconds(); + uint64 delta_secs = + std::max(0LL, now - static_cast(last_updated_secs_)); + available_tokens_ += delta_secs * config_.token_rate; + available_tokens_ = std::min(available_tokens_, config_.bucket_size); + last_updated_secs_ = now; +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_throttle.h b/tensorflow/core/platform/cloud/gcs_throttle.h new file mode 100644 index 0000000000000000000000000000000000000000..1a89daef084e921f1ad8bd856cefcc62d0d7aa1c --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_throttle.h @@ -0,0 +1,156 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_THROTTLE_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_THROTTLE_H_ + +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +/** + * GcsThrottleConfig is used to configure the GcsThrottle. + */ +struct GcsThrottleConfig { + /** + * enabled is true if GcsThrottle should throttle requests, false otherwise. + */ + bool enabled = false; + + /** + * token_rate is the number of tokens accrued every second that can be used + * for making requests to the GCS service. + */ + int64 token_rate = 100000; // Approximately 800 MBits/second bandwidth-only. + + /** + * bucket_size is the maximum number of available tokens the GcsThrottle can + * accrue. + */ + int64 bucket_size = 10000000; // 10 million tokens total + + /** + * tokens_per_request determines the number of tokens consumed for every + * request. + * + * Note: tokens are also consumed in proportion to the response size. + */ + int64 tokens_per_request = 100; + + /** + * initial_tokens determines how many tokens should be available immediately + * after the GcsThrottle is constructed. + */ + int64 initial_tokens = 0; +}; + +/** + * GcsThrottle is used to ensure fair use of the available GCS capacity. + * + * GcsThrottle operates around a concept of tokens. Tokens are consumed when + * making requests to the GCS service. Tokens are consumed both based on the + * number of requests made, as well as the bandwidth consumed (response sizes). + * + * GcsThrottle is thread safe and can be used from multiple threads. + */ +class GcsThrottle { + public: + /** + * Constructs a GcsThrottle. + */ + explicit GcsThrottle(EnvTime* env_time = EnvTime::Default()); + + /** + * AdmitRequest updates the GcsThrottle to record a request will be made. + * + * AdmitRequest should be called before any request is made. AdmitRequest + * returns false if the request should be denied. If AdmitRequest + * returns false, no tokens are consumed. If true is returned, the configured + * number of tokens are consumed. + */ + bool AdmitRequest(); + + /** + * RecordResponse updates the GcsThrottle to record a request has been made. + * + * RecordResponse should be called after the response has been received. + * RecordResponse will update the internal state based on the number of bytes + * in the response. + * + * Note: we split up the request and the response in this fashion in order to + * avoid penalizing consumers who are using large readahead buffers at higher + * layers of the I/O stack. + */ + void RecordResponse(size_t num_bytes); + + /** + * SetConfig sets the configuration for GcsThrottle and re-initializes state. + * + * After calling this, the token pool will be config.initial_tokens. + */ + void SetConfig(GcsThrottleConfig config); + + /** + * available_tokens gives a snapshot of how many tokens are available. + * + * The returned value should not be used to make admission decisions. The + * purpose of this function is to make available to monitoring or other + * instrumentation the number of available tokens in the pool. + */ + inline int64 available_tokens() { + mutex_lock l(mu_); + if (!config_.enabled) return 0; + UpdateState(); + return available_tokens_; + } + + private: + /** + * UpdateState updates the available_tokens_ and last_updated_secs_ variables. + * + * UpdateState should be called in order to mark the passage of time, and + * therefore add tokens to the availble_tokens_ pool. + */ + void UpdateState() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + inline uint64 request_bytes_to_tokens(size_t num_bytes) { + return num_bytes >> 10; + } + + mutex mu_; + + /** + * last_updated_secs_ records the number of seconds since the Unix epoch that + * the internal state of the GcsThrottle was updated. This is important when + * determining the number of tokens to add to the available_tokens_ pool. + */ + uint64 last_updated_secs_ GUARDED_BY(mu_) = 0; + + /** + * available_tokens_ records how many tokens are available to be consumed. + * + * Note: it is possible for available_tokens_ to become negative. If a + * response comes back that consumes more than the available tokens, the count + * will go negative, and block future requests until we have available tokens. + */ + int64 available_tokens_ GUARDED_BY(mu_) = 0; + + EnvTime* const env_time_; + GcsThrottleConfig config_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_THROTTLE_H_ diff --git a/tensorflow/core/platform/cloud/gcs_throttle_test.cc b/tensorflow/core/platform/cloud/gcs_throttle_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..694756022e37263a07f8215bf7496c9ca130fd58 --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_throttle_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/gcs_throttle.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +class TestTime : public EnvTime { + public: + uint64 NowMicros() override { return now_; } + + void SetTime(uint64 now_micros) { now_ = now_micros; } + + void AdvanceSeconds(int64 secs) { now_ += secs * 1000000L; } + + private: + uint64 now_ = 1234567890000000ULL; +}; + +class GcsThrottleTest : public ::testing::Test { + protected: + GcsThrottleTest() : throttle_(&time_) { + config_.enabled = true; + throttle_.SetConfig(config_); + } + + GcsThrottleConfig config_; + TestTime time_; + GcsThrottle throttle_; +}; + +TEST_F(GcsThrottleTest, ReplenishTokens) { + EXPECT_EQ(0, throttle_.available_tokens()); + time_.AdvanceSeconds(1); + EXPECT_EQ(100000, throttle_.available_tokens()); + time_.AdvanceSeconds(2); + EXPECT_EQ(300000, throttle_.available_tokens()); +} + +TEST_F(GcsThrottleTest, RejectRequest) { + EXPECT_EQ(0, throttle_.available_tokens()); + time_.AdvanceSeconds(1); + EXPECT_TRUE(throttle_.AdmitRequest()); + EXPECT_EQ(99900, throttle_.available_tokens()); + for (int i = 1; i < 1000; i++) { + EXPECT_TRUE(throttle_.AdmitRequest()); + } + EXPECT_FALSE(throttle_.AdmitRequest()); +} + +TEST_F(GcsThrottleTest, MarkResponses) { + time_.AdvanceSeconds(1); + EXPECT_TRUE(throttle_.AdmitRequest()); + throttle_.RecordResponse(128000000); // 128 MB response + EXPECT_EQ(-25100, throttle_.available_tokens()); + EXPECT_FALSE(throttle_.AdmitRequest()); + time_.AdvanceSeconds(1); + EXPECT_TRUE(throttle_.AdmitRequest()) + << "Available tokens: " << throttle_.available_tokens(); +} + +TEST_F(GcsThrottleTest, Skippingtime_) { + EXPECT_EQ(0, throttle_.available_tokens()); + time_.AdvanceSeconds(90); + EXPECT_EQ(9000000, throttle_.available_tokens()); +} + +TEST_F(GcsThrottleTest, BucketLimit) { + time_.AdvanceSeconds(120); + EXPECT_EQ(10000000, throttle_.available_tokens()); +} + +TEST_F(GcsThrottleTest, ReverseTime) { + time_.AdvanceSeconds(1); + EXPECT_EQ(100000, throttle_.available_tokens()); + time_.AdvanceSeconds(-3600); + EXPECT_EQ(100000, throttle_.available_tokens()); + time_.AdvanceSeconds(1); + EXPECT_EQ(200000, throttle_.available_tokens()); +} + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h index 682b97f6ec6d697bef2ef6301a39be35c95c5861..7711eaceb290fb21c54c9656c473d912ebbd84cf 100644 --- a/tensorflow/core/platform/cloud/http_request_fake.h +++ b/tensorflow/core/platform/cloud/http_request_fake.h @@ -38,8 +38,7 @@ class FakeHttpRequest : public CurlHttpRequest { public: /// Return the response for the given request. FakeHttpRequest(const string& request, const string& response) - : FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) { - } + : FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) {} /// Return the response with headers for the given request. FakeHttpRequest(const string& request, const string& response, diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc index 236259dbc16ffc806779bd100e1ec6ace2b7bb39..ad569758cc6ec11555a81a3bc7fbefbc580d6529 100644 --- a/tensorflow/core/platform/cloud/oauth_client_test.cc +++ b/tensorflow/core/platform/cloud/oauth_client_test.cc @@ -160,12 +160,12 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) { ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key)); ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(), header_dot_claim.size())); - ASSERT_EQ( - 1, - EVP_DigestVerifyFinal( - md_ctx, const_cast( - reinterpret_cast(signature.data())), - signature.size())); + ASSERT_EQ(1, + EVP_DigestVerifyFinal( + md_ctx, + const_cast( + reinterpret_cast(signature.data())), + signature.size())); EVP_MD_CTX_cleanup(md_ctx); // Free all the crypto-related resources. diff --git a/tensorflow/core/platform/cloud/retrying_file_system.cc b/tensorflow/core/platform/cloud/retrying_file_system.cc index c3b6831361305f69e8a9882dbff90ce139ca13c0..be9ebe67b18e7be76e95149258cb1fcce6047d85 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system.cc @@ -25,7 +25,6 @@ namespace tensorflow { namespace { - class RetryingRandomAccessFile : public RandomAccessFile { public: RetryingRandomAccessFile(std::unique_ptr base_file, @@ -203,4 +202,6 @@ Status RetryingFileSystem::DeleteRecursively(const string& dirname, initial_delay_microseconds_); } +void RetryingFileSystem::FlushCaches() { base_file_system_->FlushCaches(); } + } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h index d9d8ea6b004c3cf1d0d77ff65fa415e746310afd..a262a5fd940f9b269721790c80caaef38d79d690 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.h +++ b/tensorflow/core/platform/cloud/retrying_file_system.h @@ -69,6 +69,8 @@ class RetryingFileSystem : public FileSystem { Status DeleteRecursively(const string& dirname, int64* undeleted_files, int64* undeleted_dirs) override; + void FlushCaches() override; + private: std::unique_ptr base_file_system_; const int64 initial_delay_microseconds_; diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc index 232dcb3e71aa7c5b05b45e37332fe58970fc3fe8..d3f763bb3c845436e8458135a0a754d8cb002957 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc @@ -84,7 +84,8 @@ class MockWritableFile : public WritableFile { class MockFileSystem : public FileSystem { public: - explicit MockFileSystem(const ExpectedCalls& calls) : calls_(calls) {} + explicit MockFileSystem(const ExpectedCalls& calls, bool* flushed = nullptr) + : calls_(calls), flushed_(flushed) {} Status NewRandomAccessFile( const string& fname, std::unique_ptr* result) override { @@ -156,11 +157,18 @@ class MockFileSystem : public FileSystem { return calls_.ConsumeNextCall("DeleteRecursively"); } + void FlushCaches() override { + if (flushed_) { + *flushed_ = true; + } + } + std::unique_ptr writable_file_to_return; std::unique_ptr random_access_file_to_return; private: MockCallSequence calls_; + bool* flushed_ = nullptr; }; TEST(RetryingFileSystemTest, NewRandomAccessFile_ImmediateSuccess) { @@ -702,5 +710,14 @@ TEST(RetryingFileSystemTest, DeleteRecursively_AllRetriesFailed) { << status; } +TEST(RetryingFileSystemTest, FlushCaches) { + ExpectedCalls none; + bool flushed = false; + std::unique_ptr base_fs(new MockFileSystem(none, &flushed)); + RetryingFileSystem fs(std::move(base_fs), 0); + fs.FlushCaches(); + EXPECT_TRUE(flushed); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/platform/cpu_feature_guard.cc b/tensorflow/core/platform/cpu_feature_guard.cc index b0d7b3a67ae9f92d8e321978a3b899c243c22d1d..b5706581580ea00865b45cf50a4d92d22c647e53 100644 --- a/tensorflow/core/platform/cpu_feature_guard.cc +++ b/tensorflow/core/platform/cpu_feature_guard.cc @@ -97,14 +97,17 @@ std::once_flag g_cpu_feature_guard_warn_once_flag; void InfoAboutUnusedCPUFeatures() { std::call_once(g_cpu_feature_guard_warn_once_flag, [] { string missing_instructions; -#ifdef PLATFORM_WINDOWS +#if defined(_MSC_VER) && !defined(__clang__) + #ifndef __AVX__ CheckIfFeatureUnused(CPUFeature::AVX, "AVX", missing_instructions); #endif // __AVX__ #ifndef __AVX2__ CheckIfFeatureUnused(CPUFeature::AVX2, "AVX2", missing_instructions); #endif // __AVX2__ -#else // ifdef platform windows + +#else // if defined(_MSC_VER) && !defined(__clang__) + #ifndef __SSE__ CheckIfFeatureUnused(CPUFeature::SSE, "SSE", missing_instructions); #endif // __SSE__ @@ -132,7 +135,7 @@ void InfoAboutUnusedCPUFeatures() { #ifndef __FMA__ CheckIfFeatureUnused(CPUFeature::FMA, "FMA", missing_instructions); #endif // __FMA__ -#endif // else of ifdef platform windows +#endif // else of if defined(_MSC_VER) && !defined(__clang__) if (!missing_instructions.empty()) { LOG(INFO) << "Your CPU supports instructions that this TensorFlow " << "binary was not compiled to use:" << missing_instructions; diff --git a/tensorflow/core/platform/cuda_libdevice_path_test.cc b/tensorflow/core/platform/cuda_libdevice_path_test.cc index 639f6804ea236b86f458263091f371c1374e50ae..2d34239a9958d722a1cb84213657ca8229ebaf2c 100644 --- a/tensorflow/core/platform/cuda_libdevice_path_test.cc +++ b/tensorflow/core/platform/cuda_libdevice_path_test.cc @@ -27,8 +27,7 @@ TEST(CudaLibdevicePathTest, LibdevicePath) { VLOG(2) << "Libdevice root = " << LibdeviceRoot(); std::vector libdevice_files; TF_EXPECT_OK(Env::Default()->GetMatchingPaths( - io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"), - &libdevice_files)); + io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"), &libdevice_files)); EXPECT_LT(0, libdevice_files.size()); } #endif diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc index f4b0f16393d70521386ad49fbf010591e5afb08c..8e60a7f0910ff9cf77a33f9d72d680ec42847777 100644 --- a/tensorflow/core/platform/default/device_tracer.cc +++ b/tensorflow/core/platform/default/device_tracer.cc @@ -579,8 +579,10 @@ Status DeviceTracerImpl::Collect(StepStatsCollector *collector) { // TODO(pbar) Handle device IDs and prefix properly. const string prefix = ""; const int id = 0; - const string stream_device = strings::StrCat(prefix, "/device:GPU:", id, "/stream:"); - const string memcpy_device = strings::StrCat(prefix, "/device:GPU:", id, "/memcpy"); + const string stream_device = + strings::StrCat(prefix, "/device:GPU:", id, "/stream:"); + const string memcpy_device = + strings::StrCat(prefix, "/device:GPU:", id, "/memcpy"); mutex_lock l2(trace_mu_); for (const auto &rec : kernel_records_) { diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc index 82bd69f9ca46eb1b8dd586d18ed852a2e8c5084e..2b874da1981bed396330ca3c526d82779046bdf2 100644 --- a/tensorflow/core/platform/default/logging.cc +++ b/tensorflow/core/platform/default/logging.cc @@ -83,15 +83,14 @@ void LogMessage::GenerateLogMessage() { const size_t time_buffer_size = 30; char time_buffer[time_buffer_size]; strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", - localtime(&now_seconds)); + localtime(&now_seconds)); // TODO(jeff,sanjay): Replace this with something that logs through the env. fprintf(stderr, "%s.%06d: %c %s:%d] %s\n", time_buffer, micros_remainder, - "IWEF"[severity_], fname_, line_, str().c_str()); + "IWEF"[severity_], fname_, line_, str().c_str()); } #endif - namespace { // Parse log level (int64) from environment variable (char*) diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h index 40c260f236613e533e30dc006e77b02f393bdd48..f0efa31d5576393e9d9bba6e39a454b2a33cddc3 100644 --- a/tensorflow/core/platform/default/logging.h +++ b/tensorflow/core/platform/default/logging.h @@ -19,8 +19,8 @@ limitations under the License. // IWYU pragma: private, include "third_party/tensorflow/core/platform/logging.h" // IWYU pragma: friend third_party/tensorflow/core/platform/logging.h -#include #include +#include #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -205,16 +205,18 @@ string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) { inline string* name##Impl(int v1, int v2, const char* exprtext) { \ return name##Impl(v1, v2, exprtext); \ } \ - inline string* name##Impl(const size_t v1, const int v2, const char* exprtext) { \ + inline string* name##Impl(const size_t v1, const int v2, \ + const char* exprtext) { \ if (TF_PREDICT_FALSE(v2 < 0)) { \ - return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext);\ + return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext); \ } \ const size_t uval = (size_t)((unsigned)v1); \ return name##Impl(uval, v2, exprtext); \ } \ - inline string* name##Impl(const int v1, const size_t v2, const char* exprtext) { \ - if (TF_PREDICT_FALSE(v2 >= std::numeric_limits::max())) { \ - return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext);\ + inline string* name##Impl(const int v1, const size_t v2, \ + const char* exprtext) { \ + if (TF_PREDICT_FALSE(v2 >= std::numeric_limits::max())) { \ + return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext); \ } \ const size_t uval = (size_t)((unsigned)v2); \ return name##Impl(v1, uval, exprtext); \ @@ -225,12 +227,12 @@ string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) { // This happens if, for example, those are used as token names in a // yacc grammar. TF_DEFINE_CHECK_OP_IMPL(Check_EQ, - == ) // Compilation error with CHECK_EQ(NULL, x)? -TF_DEFINE_CHECK_OP_IMPL(Check_NE, != ) // Use CHECK(x == NULL) instead. -TF_DEFINE_CHECK_OP_IMPL(Check_LE, <= ) -TF_DEFINE_CHECK_OP_IMPL(Check_LT, < ) -TF_DEFINE_CHECK_OP_IMPL(Check_GE, >= ) -TF_DEFINE_CHECK_OP_IMPL(Check_GT, > ) + ==) // Compilation error with CHECK_EQ(NULL, x)? +TF_DEFINE_CHECK_OP_IMPL(Check_NE, !=) // Use CHECK(x == NULL) instead. +TF_DEFINE_CHECK_OP_IMPL(Check_LE, <=) +TF_DEFINE_CHECK_OP_IMPL(Check_LT, <) +TF_DEFINE_CHECK_OP_IMPL(Check_GE, >=) +TF_DEFINE_CHECK_OP_IMPL(Check_GT, >) #undef TF_DEFINE_CHECK_OP_IMPL // In optimized mode, use CheckOpString to hint to compiler that diff --git a/tensorflow/core/platform/denormal.cc b/tensorflow/core/platform/denormal.cc index f13b0af2a79bec4538c64cbc475681f6eb0ce127..3631d9ddf99430372c11403dba56c14331a3db24 100644 --- a/tensorflow/core/platform/denormal.cc +++ b/tensorflow/core/platform/denormal.cc @@ -40,36 +40,51 @@ limitations under the License. namespace tensorflow { namespace port { -ScopedFlushDenormal::ScopedFlushDenormal() { -// For now, we flush denormals only on SSE 3. Other architectures such as ARM -// can be added as needed. +static void SetDenormalState(bool flush_zero_mode, bool denormals_zero_mode) { + // For now, we flush denormals only on SSE 3. Other architectures such as ARM + // can be added as needed. #ifdef DENORM_USE_INTRINSICS if (TestCPUFeature(SSE3)) { - // Save existing flags - flush_zero_mode_ = _MM_GET_FLUSH_ZERO_MODE() == _MM_FLUSH_ZERO_ON; - denormals_zero_mode_ = - _MM_GET_DENORMALS_ZERO_MODE() == _MM_DENORMALS_ZERO_ON; - - // Flush denormals to zero (the FTZ flag). - _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); - - // Interpret denormal inputs as zero (the DAZ flag). - _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); + // Restore flags + _MM_SET_FLUSH_ZERO_MODE(flush_zero_mode ? _MM_FLUSH_ZERO_ON + : _MM_FLUSH_ZERO_OFF); + _MM_SET_DENORMALS_ZERO_MODE(denormals_zero_mode ? _MM_DENORMALS_ZERO_ON + : _MM_DENORMALS_ZERO_OFF); } #endif } -ScopedFlushDenormal::~ScopedFlushDenormal() { +static std::pair GetDernormalState() { + // For now, we flush denormals only on SSE 3. Other architectures such as ARM + // can be added as needed. + #ifdef DENORM_USE_INTRINSICS if (TestCPUFeature(SSE3)) { - // Restore flags - _MM_SET_FLUSH_ZERO_MODE(flush_zero_mode_ ? _MM_FLUSH_ZERO_ON - : _MM_FLUSH_ZERO_OFF); - _MM_SET_DENORMALS_ZERO_MODE(denormals_zero_mode_ ? _MM_DENORMALS_ZERO_ON - : _MM_DENORMALS_ZERO_OFF); + // Save existing flags + bool flush_zero_mode = _MM_GET_FLUSH_ZERO_MODE() == _MM_FLUSH_ZERO_ON; + bool denormals_zero_mode = + _MM_GET_DENORMALS_ZERO_MODE() == _MM_DENORMALS_ZERO_ON; + return {flush_zero_mode, denormals_zero_mode}; } #endif + return {false, false}; +} + +ScopedRestoreFlushDenormalState::ScopedRestoreFlushDenormalState() { + std::tie(flush_zero_mode_, denormals_zero_mode_) = GetDernormalState(); +} + +ScopedRestoreFlushDenormalState::~ScopedRestoreFlushDenormalState() { + SetDenormalState(flush_zero_mode_, denormals_zero_mode_); +} + +ScopedFlushDenormal::ScopedFlushDenormal() { + SetDenormalState(/*flush_zero_mode=*/true, /*denormals_zero_mode=*/true); +} + +ScopedDontFlushDenormal::ScopedDontFlushDenormal() { + SetDenormalState(/*flush_zero_mode=*/false, /*denormals_zero_mode=*/false); } } // namespace port diff --git a/tensorflow/core/platform/denormal.h b/tensorflow/core/platform/denormal.h index 5e34131a3b8d8ec5b74bf66add1567e4f5207a02..09bb0352a2f375fac73054ca516cee79905795c1 100644 --- a/tensorflow/core/platform/denormal.h +++ b/tensorflow/core/platform/denormal.h @@ -21,19 +21,41 @@ limitations under the License. namespace tensorflow { namespace port { +// Remembers the flush denormal state on construction and restores that same +// state on destruction. +class ScopedRestoreFlushDenormalState { + public: + ScopedRestoreFlushDenormalState(); + ~ScopedRestoreFlushDenormalState(); + + private: + bool flush_zero_mode_; + bool denormals_zero_mode_; + TF_DISALLOW_COPY_AND_ASSIGN(ScopedRestoreFlushDenormalState); +}; + // While this class is active, denormal floating point numbers are flushed // to zero. The destructor restores the original flags. class ScopedFlushDenormal { public: ScopedFlushDenormal(); - ~ScopedFlushDenormal(); private: - bool flush_zero_mode_; - bool denormals_zero_mode_; + ScopedRestoreFlushDenormalState restore_; TF_DISALLOW_COPY_AND_ASSIGN(ScopedFlushDenormal); }; +// While this class is active, denormal floating point numbers are not flushed +// to zero. The destructor restores the original flags. +class ScopedDontFlushDenormal { + public: + ScopedDontFlushDenormal(); + + private: + ScopedRestoreFlushDenormalState restore_; + TF_DISALLOW_COPY_AND_ASSIGN(ScopedDontFlushDenormal); +}; + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/device_tracer_test.cc b/tensorflow/core/platform/device_tracer_test.cc index c0c08dabacbcb9fdbbfd9bdbe16bcfaea7328507..89f14e905afa4e2c10055f59721fe4cabf082781 100644 --- a/tensorflow/core/platform/device_tracer_test.cc +++ b/tensorflow/core/platform/device_tracer_test.cc @@ -77,7 +77,8 @@ class DeviceTracerTest : public ::testing::Test { Node* y_neg = test::graph::Unary(&graph, "Neg", i); y_neg_ = y_neg->name(); - y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); + y_neg->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:GPU:0"); test::graph::ToGraphDef(&graph, &def_); } diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index 1bcca1243fb636b6cd75f2ec796f1f6c7ac364bb..12509c250eab9047b869694e930bf523a975a4f8 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -44,6 +44,9 @@ limitations under the License. namespace tensorflow { +// 128KB copy buffer +constexpr size_t kCopyFileBufferSize = 128 * 1024; + class FileSystemRegistryImpl : public FileSystemRegistry { public: Status Register(const string& scheme, Factory factory) override; @@ -278,6 +281,17 @@ Status Env::RenameFile(const string& src, const string& target) { return src_fs->RenameFile(src, target); } +Status Env::CopyFile(const string& src, const string& target) { + FileSystem* src_fs; + FileSystem* target_fs; + TF_RETURN_IF_ERROR(GetFileSystemForFile(src, &src_fs)); + TF_RETURN_IF_ERROR(GetFileSystemForFile(target, &target_fs)); + if (src_fs == target_fs) { + return src_fs->CopyFile(src, target); + } + return FileSystemCopyFile(src_fs, src, target_fs, target); +} + string Env::GetExecutablePath() { char exe_path[PATH_MAX] = {0}; #ifdef __APPLE__ @@ -406,6 +420,29 @@ Status WriteStringToFile(Env* env, const string& fname, return s; } +Status FileSystemCopyFile(FileSystem* src_fs, const string& src, + FileSystem* target_fs, const string& target) { + std::unique_ptr src_file; + TF_RETURN_IF_ERROR(src_fs->NewRandomAccessFile(src, &src_file)); + + std::unique_ptr target_file; + TF_RETURN_IF_ERROR(target_fs->NewWritableFile(target, &target_file)); + + uint64 offset = 0; + std::unique_ptr scratch(new char[kCopyFileBufferSize]); + Status s = Status::OK(); + while (s.ok()) { + StringPiece result; + s = src_file->Read(offset, kCopyFileBufferSize, &result, scratch.get()); + if (!(s.ok() || s.code() == error::OUT_OF_RANGE)) { + return s; + } + TF_RETURN_IF_ERROR(target_file->Append(result)); + offset += result.size(); + } + return target_file->Close(); +} + // A ZeroCopyInputStream on a RandomAccessFile. namespace { class FileStream : public ::tensorflow::protobuf::io::ZeroCopyInputStream { diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h index 557bfa87e50a85a6f9de86548931ea215d8ac7ff..4ce4e0b4e024d50ae2bd081ec7b8b155060d2a4a 100644 --- a/tensorflow/core/platform/env.h +++ b/tensorflow/core/platform/env.h @@ -214,6 +214,9 @@ class Env { /// replaced. Status RenameFile(const string& src, const string& target); + /// \brief Copy the src to target. + Status CopyFile(const string& src, const string& target); + /// \brief Returns the absolute path of the current executable. It resolves /// symlinks if there is any. string GetExecutablePath(); @@ -286,7 +289,7 @@ class Env { // "version" should be the version of the library or NULL // returns the name that LoadLibrary() can use virtual string FormatLibraryFileName(const string& name, - const string& version) = 0; + const string& version) = 0; private: // Returns a possible list of local temporary directories. @@ -353,6 +356,7 @@ class EnvWrapper : public Env { const string& version) override { return target_->FormatLibraryFileName(name, version); } + private: Env* target_; }; @@ -380,6 +384,11 @@ struct ThreadOptions { size_t guard_size = 0; // 0: use system default value }; +/// A utility routine: copy contents of `src` in file system `src_fs` +/// to `target` in file system `target_fs`. +Status FileSystemCopyFile(FileSystem* src_fs, const string& src, + FileSystem* target_fs, const string& target); + /// A utility routine: reads contents of named file into `*data` Status ReadFileToString(Env* env, const string& fname, string* data); diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc index 14755891fa2d3b916396c75c9647acafe66ec524..271d73f5f1a7bd3e1301520aed09cbafd89c8ebc 100644 --- a/tensorflow/core/platform/file_system.cc +++ b/tensorflow/core/platform/file_system.cc @@ -131,18 +131,19 @@ Status FileSystem::GetMatchingPaths(const string& pattern, if (children.empty()) continue; // This IsDirectory call can be expensive for some FS. Parallelizing it. children_dir_status.resize(children.size()); - ForEach(0, children.size(), [this, ¤t_dir, &children, &fixed_prefix, - &children_dir_status](int i) { - const string child_path = io::JoinPath(current_dir, children[i]); - // In case the child_path doesn't start with the fixed_prefix then - // we don't need to explore this path. - if (!StringPiece(child_path).starts_with(fixed_prefix)) { - children_dir_status[i] = - Status(tensorflow::error::CANCELLED, "Operation not needed"); - } else { - children_dir_status[i] = IsDirectory(child_path); - } - }); + ForEach(0, children.size(), + [this, ¤t_dir, &children, &fixed_prefix, + &children_dir_status](int i) { + const string child_path = io::JoinPath(current_dir, children[i]); + // In case the child_path doesn't start with the fixed_prefix then + // we don't need to explore this path. + if (!StringPiece(child_path).starts_with(fixed_prefix)) { + children_dir_status[i] = Status(tensorflow::error::CANCELLED, + "Operation not needed"); + } else { + children_dir_status[i] = IsDirectory(child_path); + } + }); for (int i = 0; i < children.size(); ++i) { const string child_path = io::JoinPath(current_dir, children[i]); // If the IsDirectory call was cancelled we bail. @@ -264,4 +265,8 @@ Status FileSystem::RecursivelyCreateDir(const string& dirname) { return Status::OK(); } +Status FileSystem::CopyFile(const string& src, const string& target) { + return FileSystemCopyFile(this, src, this, target); +} + } // namespace tensorflow diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h index d32efcea0967eea321d512f4d0f3218128f3d59b..3085b6958fd921ae124b885107e807f0a02e1d9d 100644 --- a/tensorflow/core/platform/file_system.h +++ b/tensorflow/core/platform/file_system.h @@ -189,6 +189,9 @@ class FileSystem { /// \brief Overwrites the target if it exists. virtual Status RenameFile(const string& src, const string& target) = 0; + /// \brief Copy the src to target. + virtual Status CopyFile(const string& src, const string& target); + /// \brief Translate an URI to a filename for the FileSystem implementation. /// /// The implementation in this class cleans up the path, removing diff --git a/tensorflow/core/platform/gif.h b/tensorflow/core/platform/gif.h index 9c72d34ff518abcabf773af607589fe8114beebf..ab095a35c93517c6527b55bd922dbeb46d695ca4 100644 --- a/tensorflow/core/platform/gif.h +++ b/tensorflow/core/platform/gif.h @@ -20,7 +20,8 @@ limitations under the License. #if defined(PLATFORM_GOOGLE) #include "tensorflow/core/platform/google/build_config/gif.h" -#elif defined(PLATFORM_POSIX)|| defined(PLATFORM_WINDOWS) ||defined(PLATFORM_POSIX_ANDROID) +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ + defined(PLATFORM_POSIX_ANDROID) #include #else #error Define the appropriate PLATFORM_ macro for this platform diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 0baeac09841073ad6013a4700646e82d5d97182f..74863293a32451e8881c93de468539b913169aaa 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -164,8 +164,9 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) { } else { hdfs_->hdfsBuilderSetNameNode(builder, nn.c_str()); } - // KERB_TICKET_CACHE_PATH will be deleted in the future, Because KRB5CCNAME is the build in - // environment variable of Kerberos, so KERB_TICKET_CACHE_PATH and related code are unnecessary. + // KERB_TICKET_CACHE_PATH will be deleted in the future, Because KRB5CCNAME is + // the build in environment variable of Kerberos, so KERB_TICKET_CACHE_PATH + // and related code are unnecessary. char* ticket_cache_path = getenv("KERB_TICKET_CACHE_PATH"); if (ticket_cache_path != nullptr) { hdfs_->hdfsBuilderSetKerbTicketCachePath(builder, ticket_cache_path); diff --git a/tensorflow/core/platform/jpeg.h b/tensorflow/core/platform/jpeg.h index edbcbd960a7d61970119bfb385f075e1d3ffb96f..1b5e633f0aad09850afa82bee59d45c7943bbd8a 100644 --- a/tensorflow/core/platform/jpeg.h +++ b/tensorflow/core/platform/jpeg.h @@ -20,7 +20,8 @@ limitations under the License. #if defined(PLATFORM_GOOGLE) #include "tensorflow/core/platform/google/build_config/jpeg.h" -#elif defined(PLATFORM_POSIX)|| defined(PLATFORM_WINDOWS) ||defined(PLATFORM_POSIX_ANDROID) +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ + defined(PLATFORM_POSIX_ANDROID) #include #include #include diff --git a/tensorflow/core/platform/mem.h b/tensorflow/core/platform/mem.h index dc389a8741501d27394ac559c95eaa73c2014afd..7bb9fc264fbf6ee3f20e9b2687c9ba52b6171ec4 100644 --- a/tensorflow/core/platform/mem.h +++ b/tensorflow/core/platform/mem.h @@ -59,6 +59,9 @@ void MallocExtension_ReleaseToSystem(std::size_t num_bytes); // routine, this routine returns 0. std::size_t MallocExtension_GetAllocatedSize(const void* p); +// Returns the amount of RAM available in kB, or INT64_MAX if unknown. +int64 AvailableRam(); + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/platform.h b/tensorflow/core/platform/platform.h index 12120c4ab96ae8327864c46a8e0dc434b900e67e..0481b3687137c8b00fa84d33eb317a1a4f5be9df 100644 --- a/tensorflow/core/platform/platform.h +++ b/tensorflow/core/platform/platform.h @@ -43,10 +43,11 @@ limitations under the License. #elif defined(__arm__) #define PLATFORM_POSIX -// Require an outside macro to tell us if we're building for Raspberry Pi. -#if !defined(RASPBERRY_PI) +// Require an outside macro to tell us if we're building for Raspberry Pi or +// another ARM device that's not a mobile platform. +#if !defined(RASPBERRY_PI) && !defined(ARM_NON_MOBILE) #define IS_MOBILE_PLATFORM -#endif // !defined(RASPBERRY_PI) +#endif // !defined(RASPBERRY_PI) && !defined(ARM_NON_MOBILE) #else // If no platform specified, use: diff --git a/tensorflow/core/platform/png.h b/tensorflow/core/platform/png.h index 5b0203c343e6b1764a9cc8a7908919422d826bcb..dad18d72195953e78c6a169a19b9182ae6571485 100644 --- a/tensorflow/core/platform/png.h +++ b/tensorflow/core/platform/png.h @@ -20,7 +20,8 @@ limitations under the License. #if defined(PLATFORM_GOOGLE) #include "tensorflow/core/platform/google/build_config/png.h" -#elif defined(PLATFORM_POSIX)|| defined(PLATFORM_WINDOWS) ||defined(PLATFORM_POSIX_ANDROID) +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ + defined(PLATFORM_POSIX_ANDROID) #include #else #error Define the appropriate PLATFORM_ macro for this platform diff --git a/tensorflow/core/platform/posix/error.cc b/tensorflow/core/platform/posix/error.cc index cda6d7d8f9d6ad3e7f2c8fa56cc99a8dbe07fa00..2bb9443fb3c45e0cd4bb31a48539355747684b5f 100644 --- a/tensorflow/core/platform/posix/error.cc +++ b/tensorflow/core/platform/posix/error.cc @@ -73,19 +73,19 @@ error::Code ErrnoToCode(int err_number) { case ECHILD: // No child processes case EISCONN: // Socket is connected #if !defined(_WIN32) && !defined(__HAIKU__) - case ENOTBLK: // Block device required + case ENOTBLK: // Block device required #endif - case ENOTCONN: // The socket is not connected - case EPIPE: // Broken pipe + case ENOTCONN: // The socket is not connected + case EPIPE: // Broken pipe #if !defined(_WIN32) - case ESHUTDOWN: // Cannot send after transport endpoint shutdown + case ESHUTDOWN: // Cannot send after transport endpoint shutdown #endif - case ETXTBSY: // Text file busy + case ETXTBSY: // Text file busy code = error::FAILED_PRECONDITION; break; - case ENOSPC: // No space left on device + case ENOSPC: // No space left on device #if !defined(_WIN32) - case EDQUOT: // Disk quota exceeded + case EDQUOT: // Disk quota exceeded #endif case EMFILE: // Too many open files case EMLINK: // Too many links @@ -95,7 +95,7 @@ error::Code ErrnoToCode(int err_number) { case ENOMEM: // Not enough space case ENOSR: // No STREAM resources #if !defined(_WIN32) && !defined(__HAIKU__) - case EUSERS: // Too many users + case EUSERS: // Too many users #endif code = error::RESOURCE_EXHAUSTED; break; @@ -104,17 +104,17 @@ error::Code ErrnoToCode(int err_number) { case ERANGE: // Result too large code = error::OUT_OF_RANGE; break; - case ENOSYS: // Function not implemented - case ENOTSUP: // Operation not supported - case EAFNOSUPPORT: // Address family not supported + case ENOSYS: // Function not implemented + case ENOTSUP: // Operation not supported + case EAFNOSUPPORT: // Address family not supported #if !defined(_WIN32) - case EPFNOSUPPORT: // Protocol family not supported + case EPFNOSUPPORT: // Protocol family not supported #endif case EPROTONOSUPPORT: // Protocol not supported #if !defined(_WIN32) && !defined(__HAIKU__) case ESOCKTNOSUPPORT: // Socket type not supported #endif - case EXDEV: // Improper link + case EXDEV: // Improper link code = error::UNIMPLEMENTED; break; case EAGAIN: // Resource temporarily unavailable @@ -123,7 +123,7 @@ error::Code ErrnoToCode(int err_number) { case ECONNRESET: // Connection reset case EINTR: // Interrupted function call #if !defined(_WIN32) - case EHOSTDOWN: // Host is down + case EHOSTDOWN: // Host is down #endif case EHOSTUNREACH: // Host is unreachable case ENETDOWN: // Network is down @@ -139,7 +139,7 @@ error::Code ErrnoToCode(int err_number) { break; case EDEADLK: // Resource deadlock avoided #if !defined(_WIN32) - case ESTALE: // Stale file handle + case ESTALE: // Stale file handle #endif code = error::ABORTED; break; @@ -158,7 +158,7 @@ error::Code ErrnoToCode(int err_number) { case ENOMSG: // No message of the desired type case EPROTO: // Protocol error #if !defined(_WIN32) && !defined(__HAIKU__) - case EREMOTE: // Object is remote + case EREMOTE: // Object is remote #endif code = error::UNKNOWN; break; diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc index 614ee00b0133976e9fe49caf7c75a01194e10237..494acde803a778fb839a7444e4d5ac2fd094eb09 100644 --- a/tensorflow/core/platform/posix/port.cc +++ b/tensorflow/core/platform/posix/port.cc @@ -29,6 +29,7 @@ limitations under the License. #if defined(__linux__) && !defined(__ANDROID__) #include +#include #endif #include #include @@ -171,5 +172,16 @@ double NominalCPUFrequency() { #endif } +int64 AvailableRam() { +#if defined(__linux__) && !defined(__ANDROID__) + struct sysinfo info; + int err = sysinfo(&info); + if (err == 0) { + return info.freeram / 1024; + } +#endif + return INT64_MAX; +} + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc index fb7a5a9995985fd09936472d4b6f8a45254f7312..9a8021565cbcc2a172a23439d2a7139108c0df39 100644 --- a/tensorflow/core/platform/posix/posix_file_system.cc +++ b/tensorflow/core/platform/posix/posix_file_system.cc @@ -18,6 +18,9 @@ limitations under the License. #include #include #include +#if !defined(__APPLE__) +#include +#endif #include #include #include @@ -34,6 +37,9 @@ limitations under the License. namespace tensorflow { +// 128KB of copy buffer +constexpr size_t kPosixCopyFileBufferSize = 128 * 1024; + // pread() based random-access class PosixRandomAccessFile : public RandomAccessFile { private: @@ -276,4 +282,70 @@ Status PosixFileSystem::RenameFile(const string& src, const string& target) { return result; } +Status PosixFileSystem::CopyFile(const string& src, const string& target) { + string translated_src = TranslateName(src); + struct stat sbuf; + if (stat(translated_src.c_str(), &sbuf) != 0) { + return IOError(src, errno); + } + int src_fd = open(translated_src.c_str(), O_RDONLY); + if (src_fd < 0) { + return IOError(src, errno); + } + string translated_target = TranslateName(target); + // O_WRONLY | O_CREAT: + // Open file for write and if file does not exist, create the file. + // S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH: + // Create the file with permission of 0644 + int target_fd = open(translated_target.c_str(), O_WRONLY | O_CREAT, + S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); + if (target_fd < 0) { + close(src_fd); + return IOError(target, errno); + } + int rc = 0; + off_t offset = 0; + std::unique_ptr buffer(new char[kPosixCopyFileBufferSize]); + while (offset < sbuf.st_size) { + // Use uint64 for safe compare SSIZE_MAX + uint64 chunk = sbuf.st_size - offset; + if (chunk > SSIZE_MAX) { + chunk = SSIZE_MAX; + } +#if defined(__linux__) && !defined(__ANDROID__) + rc = sendfile(target_fd, src_fd, &offset, static_cast(chunk)); +#else + if (chunk > kPosixCopyFileBufferSize) { + chunk = kPosixCopyFileBufferSize; + } + rc = read(src_fd, buffer.get(), static_cast(chunk)); + if (rc <= 0) { + break; + } + rc = write(target_fd, buffer.get(), static_cast(chunk)); + offset += chunk; +#endif + if (rc <= 0) { + break; + } + } + + Status result = Status::OK(); + if (rc < 0) { + result = IOError(target, errno); + } + + // Keep the error code + rc = close(target_fd); + if (rc < 0 && result == Status::OK()) { + result = IOError(target, errno); + } + rc = close(src_fd); + if (rc < 0 && result == Status::OK()) { + result = IOError(target, errno); + } + + return result; +} + } // namespace tensorflow diff --git a/tensorflow/core/platform/posix/posix_file_system.h b/tensorflow/core/platform/posix/posix_file_system.h index fe050fd5a0ee87efb339c9ee2b9e447fce803615..98ffa43b8acf8a10a4ace1bf11cc7d6f5e8a95a7 100644 --- a/tensorflow/core/platform/posix/posix_file_system.h +++ b/tensorflow/core/platform/posix/posix_file_system.h @@ -56,6 +56,8 @@ class PosixFileSystem : public FileSystem { Status GetFileSize(const string& fname, uint64* size) override; Status RenameFile(const string& src, const string& target) override; + + Status CopyFile(const string& src, const string& target) override; }; Status IOError(const string& context, int err_number); diff --git a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h index 8604b01c53ef69040a919dadda73df897e98b0e1..ce2069b004473a684a1882068d3479ed049c58d6 100644 --- a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h +++ b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h @@ -58,8 +58,8 @@ class AndroidArmV7ACpuUtilsHelper : public ICpuUtilsHelper { TF_DISALLOW_COPY_AND_ASSIGN(AndroidArmV7ACpuUtilsHelper); }; -} // profile_utils -} // tensorflow +} // namespace profile_utils +} // namespace tensorflow #endif // defined(__ANDROID__) && (__ANDROID_API__ >= 21) && // (defined(__ARM_ARCH_7A__) || defined(__aarch64__)) diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.cc b/tensorflow/core/platform/profile_utils/cpu_utils.cc index d3362690d7e08c8e88e8168b62c8134b6af5a319..02de7d1362bbfca645d07ee72165283351944b9b 100644 --- a/tensorflow/core/platform/profile_utils/cpu_utils.cc +++ b/tensorflow/core/platform/profile_utils/cpu_utils.cc @@ -28,15 +28,17 @@ namespace profile_utils { static ICpuUtilsHelper* cpu_utils_helper_instance_ = nullptr; -#if (defined(__powerpc__) || defined(__ppc__) && ( __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || (defined(__s390x__)) - /* static */ uint64 CpuUtils::GetCycleCounterFrequency() { - static const uint64 cpu_frequency = GetCycleCounterFrequencyImpl(); - return cpu_frequency; +#if (defined(__powerpc__) || \ + defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || \ + (defined(__s390x__)) +/* static */ uint64 CpuUtils::GetCycleCounterFrequency() { + static const uint64 cpu_frequency = GetCycleCounterFrequencyImpl(); + return cpu_frequency; } #else - /* static */ int64 CpuUtils::GetCycleCounterFrequency() { - static const int64 cpu_frequency = GetCycleCounterFrequencyImpl(); - return cpu_frequency; +/* static */ int64 CpuUtils::GetCycleCounterFrequency() { + static const int64 cpu_frequency = GetCycleCounterFrequencyImpl(); + return cpu_frequency; } #endif diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.h b/tensorflow/core/platform/profile_utils/cpu_utils.h index 5d215b4804dbee8cb785c99b09ec725101bacb4e..7b580c8bf606cdd9acf998fa21cb1d946e5e6ada 100644 --- a/tensorflow/core/platform/profile_utils/cpu_utils.h +++ b/tensorflow/core/platform/profile_utils/cpu_utils.h @@ -42,7 +42,7 @@ namespace profile_utils { class CpuUtils { public: // Constant for invalid frequency. - // This value is returned when the furequency is not obtained somehow. + // This value is returned when the frequency is not obtained somehow. static constexpr int64 INVALID_FREQUENCY = -1; static constexpr uint64 DUMMY_CYCLE_CLOCK = 1; @@ -94,16 +94,18 @@ class CpuUtils { #endif } - // Return cycle counter frequency. - // As this method caches the cpu frequency internally, - // the first call will incur overhead, but not subsequent calls. - #if (defined(__powerpc__) || defined(__ppc__) && ( __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || (defined(__s390x__)) - static uint64 GetCycleCounterFrequency(); - #else - static int64 GetCycleCounterFrequency(); - #endif +// Return cycle counter frequency. +// As this method caches the cpu frequency internally, +// the first call will incur overhead, but not subsequent calls. +#if (defined(__powerpc__) || \ + defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || \ + (defined(__s390x__)) + static uint64 GetCycleCounterFrequency(); +#else + static int64 GetCycleCounterFrequency(); +#endif - // Return micro secound per each clock + // Return micro second per each clock // As this method caches the cpu frequency internally, // the first call will incur overhead, but not subsequent calls. static double GetMicroSecPerClock(); diff --git a/tensorflow/core/platform/profile_utils/cpu_utils_test.cc b/tensorflow/core/platform/profile_utils/cpu_utils_test.cc index 5b11b684dd9833bf742faaeaa3e79d2b49a78c6d..eb8161fbfd5ddfc796edd66a9119ad70c3c1de8e 100644 --- a/tensorflow/core/platform/profile_utils/cpu_utils_test.cc +++ b/tensorflow/core/platform/profile_utils/cpu_utils_test.cc @@ -53,15 +53,17 @@ TEST_F(CpuUtilsTest, CheckGetCurrentClockCycle) { } TEST_F(CpuUtilsTest, CheckCycleCounterFrequency) { - #if (defined(__powerpc__) || defined(__ppc__) && ( __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || (defined(__s390x__)) - const uint64 cpu_frequency = CpuUtils::GetCycleCounterFrequency(); - CHECK_GT(cpu_frequency, 0); - CHECK_NE(cpu_frequency, unsigned(CpuUtils::INVALID_FREQUENCY)); - #else - const int64 cpu_frequency = CpuUtils::GetCycleCounterFrequency(); - CHECK_GT(cpu_frequency, 0); - CHECK_NE(cpu_frequency, CpuUtils::INVALID_FREQUENCY); - #endif +#if (defined(__powerpc__) || \ + defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) || \ + (defined(__s390x__)) + const uint64 cpu_frequency = CpuUtils::GetCycleCounterFrequency(); + CHECK_GT(cpu_frequency, 0); + CHECK_NE(cpu_frequency, unsigned(CpuUtils::INVALID_FREQUENCY)); +#else + const int64 cpu_frequency = CpuUtils::GetCycleCounterFrequency(); + CHECK_GT(cpu_frequency, 0); + CHECK_NE(cpu_frequency, CpuUtils::INVALID_FREQUENCY); +#endif if (DBG) { LOG(INFO) << "Cpu frequency = " << cpu_frequency; } diff --git a/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h b/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h index 51c54d50d1dadcf78e8263ce44b07c998b68c05c..11b739c0096b5b5fd498bb5c753a54c8b1628208 100644 --- a/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h +++ b/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h @@ -47,7 +47,7 @@ class ICpuUtilsHelper { TF_DISALLOW_COPY_AND_ASSIGN(ICpuUtilsHelper); }; -} // profile_utils -} // tensorflow +} // namespace profile_utils +} // namespace tensorflow #endif // TENSORFLOW_PLATFORM_PROFILEUTILS_I_CPU_UTILS_HELPER_H__ diff --git a/tensorflow/core/platform/protobuf_internal.h b/tensorflow/core/platform/protobuf_internal.h index 7d6e8f57a62e08a7897bdccdeb7033363b282bd4..2f151a5aee6af067e4536bb569b4c0799c831b98 100644 --- a/tensorflow/core/platform/protobuf_internal.h +++ b/tensorflow/core/platform/protobuf_internal.h @@ -45,8 +45,8 @@ Status ParseAny(const google::protobuf::Any& any, T* message, #ifdef TENSORFLOW_LITE_PROTOS if (any.type_url() != strings::StrCat("type.googleapis.com/", type_name)) { return errors::FailedPrecondition( - "Expected Any type_url for: ", type_name, ". Got: ", - string(any.type_url().data(), any.type_url().size()), "."); + "Expected Any type_url for: ", type_name, + ". Got: ", string(any.type_url().data(), any.type_url().size()), "."); } if (!message->ParseFromString(any.value())) { return errors::FailedPrecondition("Failed to unpack: ", diff --git a/tensorflow/core/platform/s3/aws_logging.cc b/tensorflow/core/platform/s3/aws_logging.cc index fbca0acc36b01fa91dece4bdd0d19b7059dc114e..44317f1a3e41831b903bd0044d53d1eba80168df 100644 --- a/tensorflow/core/platform/s3/aws_logging.cc +++ b/tensorflow/core/platform/s3/aws_logging.cc @@ -96,7 +96,7 @@ Aws::Utils::Logging::LogLevel ParseLogLevelFromEnv() { return log_level; } -} +} // namespace static bool initialized = false; static mutex s3_logging_mutex(LINKER_INITIALIZED); diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc index 2c0babe098f2e7a066338e5cb2a25aedf16db8d9..301fcb9dbf653d29f6ac5321332c8764adaad681 100644 --- a/tensorflow/core/platform/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -22,8 +22,10 @@ limitations under the License. #include #include #include +#include #include #include +#include #include #include #include @@ -128,8 +130,7 @@ Aws::Client::ClientConfiguration& GetDefaultClientConfig() { return cfg; }; - -void ShutdownClient(Aws::S3::S3Client *s3_client) { +void ShutdownClient(Aws::S3::S3Client* s3_client) { if (s3_client != nullptr) { delete s3_client; Aws::SDKOptions options; @@ -166,7 +167,7 @@ Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket, class S3RandomAccessFile : public RandomAccessFile { public: S3RandomAccessFile(const string& bucket, const string& object, - std::shared_ptr s3_client) + std::shared_ptr s3_client) : bucket_(bucket), object_(object), s3_client_(s3_client) {} Status Read(uint64 offset, size_t n, StringPiece* result, @@ -202,7 +203,7 @@ class S3RandomAccessFile : public RandomAccessFile { class S3WritableFile : public WritableFile { public: S3WritableFile(const string& bucket, const string& object, - std::shared_ptr s3_client) + std::shared_ptr s3_client) : bucket_(bucket), object_(object), s3_client_(s3_client), @@ -284,8 +285,8 @@ class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { } // namespace -S3FileSystem::S3FileSystem() : - s3_client_(nullptr, ShutdownClient), client_lock_() {} +S3FileSystem::S3FileSystem() + : s3_client_(nullptr, ShutdownClient), client_lock_() {} S3FileSystem::~S3FileSystem() {} @@ -305,8 +306,15 @@ std::shared_ptr S3FileSystem::GetS3Client() { }; Aws::InitAPI(options); - this->s3_client_ = std::shared_ptr( - new Aws::S3::S3Client(GetDefaultClientConfig())); + // The creation of S3Client disables virtual addressing: + // S3Client(clientConfiguration, signPayloads, useVirtualAdressing = true) + // The purpose is to address the issue encountered when there is an `.` + // in the bucket name. Due to TLS hostname validation or DNS rules, + // the bucket may not be resolved. Disabling of virtual addressing + // should address the issue. See GitHub issue 16397 for details. + this->s3_client_ = std::shared_ptr(new Aws::S3::S3Client( + GetDefaultClientConfig(), + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, false)); } return this->s3_client_; @@ -400,7 +408,8 @@ Status S3FileSystem::GetChildren(const string& dir, Aws::S3::Model::ListObjectsResult listObjectsResult; do { - auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); + auto listObjectsOutcome = + this->GetS3Client()->ListObjects(listObjectsRequest); if (!listObjectsOutcome.IsSuccess()) { string error = strings::StrCat( listObjectsOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -473,7 +482,8 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) { .WithMaxKeys(1); listObjectsRequest.SetResponseStreamFactory( []() { return Aws::New(kS3FileSystemAllocationTag); }); - auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); + auto listObjectsOutcome = + this->GetS3Client()->ListObjects(listObjectsRequest); if (listObjectsOutcome.IsSuccess()) { if (listObjectsOutcome.GetResult().GetContents().size() > 0) { stats->length = 0; @@ -495,7 +505,7 @@ Status S3FileSystem::DeleteFile(const string& fname) { deleteObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str()); auto deleteObjectOutcome = - this->GetS3Client()->DeleteObject(deleteObjectRequest); + this->GetS3Client()->DeleteObject(deleteObjectRequest); if (!deleteObjectOutcome.IsSuccess()) { string error = strings::StrCat( deleteObjectOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -542,7 +552,8 @@ Status S3FileSystem::DeleteDir(const string& dirname) { .WithMaxKeys(2); listObjectsRequest.SetResponseStreamFactory( []() { return Aws::New(kS3FileSystemAllocationTag); }); - auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); + auto listObjectsOutcome = + this->GetS3Client()->ListObjects(listObjectsRequest); if (listObjectsOutcome.IsSuccess()) { auto contents = listObjectsOutcome.GetResult().GetContents(); if (contents.size() > 1 || @@ -594,7 +605,8 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { Aws::S3::Model::ListObjectsResult listObjectsResult; do { - auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); + auto listObjectsOutcome = + this->GetS3Client()->ListObjects(listObjectsRequest); if (!listObjectsOutcome.IsSuccess()) { string error = strings::StrCat( listObjectsOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -607,13 +619,15 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { Aws::String src_key = object.GetKey(); Aws::String target_key = src_key; target_key.replace(0, src_object.length(), target_object.c_str()); - Aws::String source = Aws::String(src_bucket.c_str()) + "/" + src_key; + Aws::String source = Aws::String(src_bucket.c_str()) + "/" + + Aws::Utils::StringUtils::URLEncode(src_key.c_str()); copyObjectRequest.SetBucket(target_bucket.c_str()); copyObjectRequest.SetKey(target_key); copyObjectRequest.SetCopySource(source); - auto copyObjectOutcome = this->GetS3Client()->CopyObject(copyObjectRequest); + auto copyObjectOutcome = + this->GetS3Client()->CopyObject(copyObjectRequest); if (!copyObjectOutcome.IsSuccess()) { string error = strings::StrCat( copyObjectOutcome.GetError().GetExceptionName().c_str(), ": ", diff --git a/tensorflow/core/platform/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h index 168b8007f3b60c60724682dd7fc4e95f8d15a413..31264be621d93c1efb68f7b0b49e28cb65b05de1 100644 --- a/tensorflow/core/platform/s3/s3_file_system.h +++ b/tensorflow/core/platform/s3/s3_file_system.h @@ -55,8 +55,21 @@ class S3FileSystem : public FileSystem { Status GetFileSize(const string& fname, uint64* size) override; Status RenameFile(const string& src, const string& target) override; + private: // Returns the member S3 client, initializing as-needed. + // When the client tries to access the object in S3, e.g., + // s3://bucket-name/path/to/object + // the behavior could be controlled by various environmental + // variables. + // By default S3 access regional endpoint, with region + // controlled by `AWS_REGION`. The endpoint could be overridden + // explicitly with `S3_ENDPOINT`. S3 uses HTTPS by default. + // If S3_USE_HTTPS=0 is specified, HTTP is used. Also, + // S3_VERIFY_SSL=0 could disable SSL verification in case + // HTTPS is used. + // This S3 Client does not support Virtual Hosted–Style Method + // for a bucket. std::shared_ptr GetS3Client(); std::shared_ptr s3_client_; diff --git a/tensorflow/core/platform/setround.cc b/tensorflow/core/platform/setround.cc index 0c66da09bb9aa1c892063be11c66aedaf75d7eb6..592626bfa17e691d1b10ddce5c7f0f31ed825861 100644 --- a/tensorflow/core/platform/setround.cc +++ b/tensorflow/core/platform/setround.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/core/platform/setround.h" - namespace tensorflow { namespace port { diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h index a6636225ccbbc8154e290cd7f1aa6cafe3d2027a..327237dba933230cb313dd06091d2ff2ca3cc4b2 100644 --- a/tensorflow/core/platform/test_benchmark.h +++ b/tensorflow/core/platform/test_benchmark.h @@ -60,7 +60,7 @@ class Benchmark { private: string name_; int num_args_; - std::vector> args_; + std::vector > args_; void (*fn0_)(int) = nullptr; void (*fn1_)(int, int) = nullptr; void (*fn2_)(int, int, int) = nullptr; diff --git a/tensorflow/core/platform/windows/cpu_info.h b/tensorflow/core/platform/windows/cpu_info.h index d6e78dbc8f9f25070d94141e46d35dcb8d727ef7..f20939d3c0ff02be30f19be170644fab44b6f45e 100644 --- a/tensorflow/core/platform/windows/cpu_info.h +++ b/tensorflow/core/platform/windows/cpu_info.h @@ -22,8 +22,10 @@ limitations under the License. // Byte order defines provided by gcc. MSVC doesn't define those so // we define them here. // We assume that all windows platform out there are little endian. +#if defined(_MSC_VER) && !defined(__clang__) #define __ORDER_LITTLE_ENDIAN__ 0x4d2 #define __ORDER_BIG_ENDIAN__ 0x10e1 #define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__ +#endif #endif // TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_ diff --git a/tensorflow/core/platform/windows/env.cc b/tensorflow/core/platform/windows/env.cc index 788a4bf4b1af74393099d1b590a1e589d9a07f25..41b264417071cadb5f70806b458ee2b46ebb2feb 100644 --- a/tensorflow/core/platform/windows/env.cc +++ b/tensorflow/core/platform/windows/env.cc @@ -24,9 +24,9 @@ limitations under the License. #undef LoadLibrary #undef ERROR +#include #include #include -#include #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/platform/load_library.h" @@ -53,8 +53,7 @@ class StdThread : public Thread { class WindowsEnv : public Env { public: - WindowsEnv() - : GetSystemTimePreciseAsFileTime_(NULL) { + WindowsEnv() : GetSystemTimePreciseAsFileTime_(NULL) { // GetSystemTimePreciseAsFileTime function is only available in the latest // versions of Windows. For that reason, we try to look it up in // kernel32.dll at runtime and use an alternative option if the function @@ -72,8 +71,8 @@ class WindowsEnv : public Env { } bool MatchPath(const string& path, const string& pattern) override { - std::wstring ws_path(WindowsFileSystem::Utf8ToWideChar(path)); - std::wstring ws_pattern(WindowsFileSystem::Utf8ToWideChar(pattern)); + std::wstring ws_path(WindowsFileSystem::Utf8ToWideChar(path)); + std::wstring ws_pattern(WindowsFileSystem::Utf8ToWideChar(pattern)); return PathMatchSpecW(ws_path.c_str(), ws_pattern.c_str()) == TRUE; } @@ -122,14 +121,14 @@ class WindowsEnv : public Env { SetThreadpoolTimer(timer, &FileDueTime, 0, 0); } - Status LoadLibrary(const char *library_filename, void** handle) override { + Status LoadLibrary(const char* library_filename, void** handle) override { std::string file_name = library_filename; std::replace(file_name.begin(), file_name.end(), '/', '\\'); std::wstring ws_file_name(WindowsFileSystem::Utf8ToWideChar(file_name)); HMODULE hModule = LoadLibraryExW(ws_file_name.c_str(), NULL, - LOAD_WITH_ALTERED_SEARCH_PATH); + LOAD_WITH_ALTERED_SEARCH_PATH); if (!hModule) { return errors::NotFound(file_name + " not found"); } @@ -138,31 +137,30 @@ class WindowsEnv : public Env { } Status GetSymbolFromLibrary(void* handle, const char* symbol_name, - void** symbol) override { + void** symbol) override { FARPROC found_symbol; found_symbol = GetProcAddress((HMODULE)handle, symbol_name); if (found_symbol == NULL) { return errors::NotFound(std::string(symbol_name) + " not found"); } - *symbol = (void **)found_symbol; + *symbol = (void**)found_symbol; return Status::OK(); } - string FormatLibraryFileName(const string& name, const string& version) - override { + string FormatLibraryFileName(const string& name, + const string& version) override { string filename; if (version.size() == 0) { filename = name + ".dll"; - } - else { + } else { filename = name + version + ".dll"; } return filename; } private: - typedef VOID(WINAPI * FnGetSystemTimePreciseAsFileTime)(LPFILETIME); + typedef VOID(WINAPI* FnGetSystemTimePreciseAsFileTime)(LPFILETIME); FnGetSystemTimePreciseAsFileTime GetSystemTimePreciseAsFileTime_; }; diff --git a/tensorflow/core/platform/windows/error.cc b/tensorflow/core/platform/windows/error.cc index 39e941a3834f7f7cd03e7791d43d56f190dc1fd6..291fc5003fb6bbc07274cdea72d73e92a453f363 100644 --- a/tensorflow/core/platform/windows/error.cc +++ b/tensorflow/core/platform/windows/error.cc @@ -21,7 +21,7 @@ namespace internal { std::string GetWindowsErrorMessage(DWORD err) { LPSTR buffer = NULL; DWORD flags = FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | - FORMAT_MESSAGE_IGNORE_INSERTS; + FORMAT_MESSAGE_IGNORE_INSERTS; FormatMessageA(flags, NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), reinterpret_cast(&buffer), 0, NULL); std::string message = buffer; diff --git a/tensorflow/core/platform/windows/error.h b/tensorflow/core/platform/windows/error.h index 026e0d5aa946f7c851dacc05a3306631e06886aa..ba643a0fa8f92f58fbd88ac00fba3f663bb7e0f2 100644 --- a/tensorflow/core/platform/windows/error.h +++ b/tensorflow/core/platform/windows/error.h @@ -24,9 +24,7 @@ namespace tensorflow { namespace internal { std::string GetWindowsErrorMessage(DWORD err); - -} } +} // namespace tensorflow #endif // TENSORFLOW_CORE_PLATFORM_WINDOWS_ERROR_H_ - diff --git a/tensorflow/core/platform/windows/integral_types.h b/tensorflow/core/platform/windows/integral_types.h index 4970b8ca6a1673dd24d2d445348fe5b337ae13be..46338a536dbc3541763e62954fee74b2a5a0700b 100644 --- a/tensorflow/core/platform/windows/integral_types.h +++ b/tensorflow/core/platform/windows/integral_types.h @@ -1,18 +1,18 @@ - /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_ #define TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_ diff --git a/tensorflow/core/platform/windows/net.cc b/tensorflow/core/platform/windows/net.cc index 46eb072d42592028859122a4cad3d9478a96476e..2ab558ab95cafd15b10f7b887c846b32ab7e4c47 100644 --- a/tensorflow/core/platform/windows/net.cc +++ b/tensorflow/core/platform/windows/net.cc @@ -26,7 +26,7 @@ limitations under the License. #undef ERROR -#pragma comment(lib,"Ws2_32.lib") +#pragma comment(lib, "Ws2_32.lib") namespace tensorflow { namespace internal { @@ -44,8 +44,8 @@ bool IsPortAvailable(int* port, bool is_tcp) { CHECK_GE(*port, 0); CHECK_LE(*port, 65535); if (sock == INVALID_SOCKET) { - LOG(ERROR) << "socket() failed: " << - GetWindowsErrorMessage(WSAGetLastError()); + LOG(ERROR) << "socket() failed: " + << GetWindowsErrorMessage(WSAGetLastError()); return false; } @@ -54,8 +54,8 @@ bool IsPortAvailable(int* port, bool is_tcp) { int result = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&one), sizeof(one)); if (result == SOCKET_ERROR) { - LOG(ERROR) << "setsockopt() failed: " << - GetWindowsErrorMessage(WSAGetLastError()); + LOG(ERROR) << "setsockopt() failed: " + << GetWindowsErrorMessage(WSAGetLastError()); closesocket(sock); return false; } @@ -66,8 +66,8 @@ bool IsPortAvailable(int* port, bool is_tcp) { addr.sin_port = htons((uint16_t)*port); result = bind(sock, (struct sockaddr*)&addr, sizeof(addr)); if (result == SOCKET_ERROR) { - LOG(WARNING) << "bind(port=" << *port << ") failed: " << - GetWindowsErrorMessage(WSAGetLastError()); + LOG(WARNING) << "bind(port=" << *port + << ") failed: " << GetWindowsErrorMessage(WSAGetLastError()); closesocket(sock); return false; } @@ -75,8 +75,8 @@ bool IsPortAvailable(int* port, bool is_tcp) { // Get the bound port number. result = getsockname(sock, (struct sockaddr*)&addr, &addr_len); if (result == SOCKET_ERROR) { - LOG(WARNING) << "getsockname() failed: " << - GetWindowsErrorMessage(WSAGetLastError()); + LOG(WARNING) << "getsockname() failed: " + << GetWindowsErrorMessage(WSAGetLastError()); closesocket(sock); return false; } diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index e327d53949caf7e2d30e6deba0be2848f010afc2..582b232054b850a2ef5ab8f47c089eb35a7bb3cf 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -149,8 +149,20 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { - // TODO(yuefengz): implement it for this platform. +#ifdef TENSORFLOW_USE_ABSL + return absl::base_internal::NominalCPUFrequency(); +#else return 1.0; +#endif +} + +int64 AvailableRam() { + MEMORYSTATUSEX statex; + statex.dwLength = sizeof(statex); + if (GlobalMemoryStatusEx(&statex)) { + return statex.ullAvailPhys / 1024; + } + return INT64_MAX; } } // namespace port diff --git a/tensorflow/core/platform/windows/subprocess.h b/tensorflow/core/platform/windows/subprocess.h index b65313363ed79ab327414179a9923ba2d436dd0b..66ec44885d52195b807f4957aec6d590324b2975 100644 --- a/tensorflow/core/platform/windows/subprocess.h +++ b/tensorflow/core/platform/windows/subprocess.h @@ -19,8 +19,7 @@ limitations under the License. namespace tensorflow { // SubProcess is not yet implemented for Windows. -class SubProcess { -}; +class SubProcess {}; } // namespace tensorflow diff --git a/tensorflow/core/platform/windows/test.cc b/tensorflow/core/platform/windows/test.cc index 0ffd02ff14849d77761e85c30388dc49a53c84db..584acad91b24fc6be9b93f71b7d44b0fba3cb2e8 100644 --- a/tensorflow/core/platform/windows/test.cc +++ b/tensorflow/core/platform/windows/test.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/net.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc index 604348fe03a01d44195ba8a8ff427ae3ef3a4137..b6b3722caae4dc0cdc0ddff91be479ab91a744b2 100644 --- a/tensorflow/core/platform/windows/windows_file_system.cc +++ b/tensorflow/core/platform/windows/windows_file_system.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include #include -#include #undef StrCat #include #include @@ -75,16 +75,16 @@ SSIZE_T pread(HANDLE hfile, char* src, size_t num_bytes, uint64_t offset) { if (TRUE == read_result) { result = bytes_read; } else if ((FALSE == read_result) && - ((last_error = GetLastError()) != ERROR_IO_PENDING)) { + ((last_error = GetLastError()) != ERROR_IO_PENDING)) { result = (last_error == ERROR_HANDLE_EOF) ? 0 : -1; } else { - if (ERROR_IO_PENDING == last_error) { // Otherwise bytes_read already has the result. - BOOL overlapped_result = ::GetOverlappedResult(hfile, &overlapped, - &bytes_read, TRUE); + if (ERROR_IO_PENDING == + last_error) { // Otherwise bytes_read already has the result. + BOOL overlapped_result = + ::GetOverlappedResult(hfile, &overlapped, &bytes_read, TRUE); if (FALSE == overlapped_result) { result = (::GetLastError() == ERROR_HANDLE_EOF) ? 0 : -1; - } - else { + } else { result = bytes_read; } } @@ -151,11 +151,11 @@ class WindowsWritableFile : public WritableFile { Status Append(const StringPiece& data) override { DWORD bytes_written = 0; DWORD data_size = static_cast(data.size()); - BOOL write_result = ::WriteFile(hfile_, data.data(), data_size, - &bytes_written, NULL); + BOOL write_result = + ::WriteFile(hfile_, data.data(), data_size, &bytes_written, NULL); if (FALSE == write_result) { - return IOErrorFromWindowsError( - "Failed to WriteFile: " + filename_, ::GetLastError()); + return IOErrorFromWindowsError("Failed to WriteFile: " + filename_, + ::GetLastError()); } assert(size_t(bytes_written) == data.size()); @@ -171,8 +171,8 @@ class WindowsWritableFile : public WritableFile { } if (FALSE == ::CloseHandle(hfile_)) { - return IOErrorFromWindowsError( - "CloseHandle failed for: " + filename_, ::GetLastError()); + return IOErrorFromWindowsError("CloseHandle failed for: " + filename_, + ::GetLastError()); } hfile_ = INVALID_HANDLE_VALUE; @@ -187,9 +187,7 @@ class WindowsWritableFile : public WritableFile { return Status::OK(); } - Status Sync() override { - return Flush(); - } + Status Sync() override { return Flush(); } }; class WinReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { @@ -204,7 +202,10 @@ class WinReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { public: WinReadOnlyMemoryRegion(const std::string& filename, HANDLE hfile, HANDLE hmap, const void* address, uint64 length) - : filename_(filename), hfile_(hfile), hmap_(hmap), address_(address), + : filename_(filename), + hfile_(hfile), + hmap_(hmap), + address_(address), length_(length) {} ~WinReadOnlyMemoryRegion() { @@ -238,9 +239,9 @@ Status WindowsFileSystem::NewRandomAccessFile( // almost all tests would work with a possible exception of fault_injection. DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; - HANDLE hfile = ::CreateFileW(ws_translated_fname.c_str(), GENERIC_READ, - share_mode, NULL, OPEN_EXISTING, file_flags, - NULL); + HANDLE hfile = + ::CreateFileW(ws_translated_fname.c_str(), GENERIC_READ, share_mode, NULL, + OPEN_EXISTING, file_flags, NULL); if (INVALID_HANDLE_VALUE == hfile) { string context = "NewRandomAccessFile failed to Create/Open: " + fname; @@ -258,9 +259,9 @@ Status WindowsFileSystem::NewWritableFile( result->reset(); DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; - HANDLE hfile = ::CreateFileW(ws_translated_fname.c_str(), GENERIC_WRITE, - share_mode, NULL, CREATE_ALWAYS, - FILE_ATTRIBUTE_NORMAL, NULL); + HANDLE hfile = + ::CreateFileW(ws_translated_fname.c_str(), GENERIC_WRITE, share_mode, + NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL); if (INVALID_HANDLE_VALUE == hfile) { string context = "Failed to create a NewWriteableFile: " + fname; @@ -278,9 +279,9 @@ Status WindowsFileSystem::NewAppendableFile( result->reset(); DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; - HANDLE hfile = ::CreateFileW(ws_translated_fname.c_str(), GENERIC_WRITE, - share_mode, NULL, OPEN_ALWAYS, - FILE_ATTRIBUTE_NORMAL, NULL); + HANDLE hfile = + ::CreateFileW(ws_translated_fname.c_str(), GENERIC_WRITE, share_mode, + NULL, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL); if (INVALID_HANDLE_VALUE == hfile) { string context = "Failed to create a NewAppendableFile: " + fname; @@ -316,9 +317,9 @@ Status WindowsFileSystem::NewReadOnlyMemoryRegionFromFile( file_flags |= FILE_FLAG_OVERLAPPED; DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; - HANDLE hfile = ::CreateFileW(ws_translated_fname.c_str(), GENERIC_READ, - share_mode, NULL, OPEN_EXISTING, file_flags, - NULL); + HANDLE hfile = + ::CreateFileW(ws_translated_fname.c_str(), GENERIC_READ, share_mode, NULL, + OPEN_EXISTING, file_flags, NULL); if (INVALID_HANDLE_VALUE == hfile) { return IOErrorFromWindowsError( @@ -345,28 +346,32 @@ Status WindowsFileSystem::NewReadOnlyMemoryRegionFromFile( NULL); // Mapping name if (!hmap) { - string context = "Failed to create file mapping for " - "NewReadOnlyMemoryRegionFromFile: " + fname; + string context = + "Failed to create file mapping for " + "NewReadOnlyMemoryRegionFromFile: " + + fname; return IOErrorFromWindowsError(context, ::GetLastError()); } UniqueCloseHandlePtr map_guard(hmap, CloseHandleFunc); - const void* mapped_region = ::MapViewOfFileEx( - hmap, FILE_MAP_READ, - 0, // High DWORD of access start - 0, // Low DWORD - file_size, - NULL); // Let the OS choose the mapping + const void* mapped_region = + ::MapViewOfFileEx(hmap, FILE_MAP_READ, + 0, // High DWORD of access start + 0, // Low DWORD + file_size, + NULL); // Let the OS choose the mapping if (!mapped_region) { - string context = "Failed to MapViewOfFile for " - "NewReadOnlyMemoryRegionFromFile: " + fname; + string context = + "Failed to MapViewOfFile for " + "NewReadOnlyMemoryRegionFromFile: " + + fname; return IOErrorFromWindowsError(context, ::GetLastError()); } - result->reset(new WinReadOnlyMemoryRegion(fname, hfile, hmap, - mapped_region, file_size)); + result->reset(new WinReadOnlyMemoryRegion(fname, hfile, hmap, mapped_region, + file_size)); map_guard.release(); file_guard.release(); @@ -404,8 +409,8 @@ Status WindowsFileSystem::GetChildren(const string& dir, } do { - string file_name = WideCharToUtf8(find_data.cFileName); - const StringPiece basename = file_name; + string file_name = WideCharToUtf8(find_data.cFileName); + const StringPiece basename = file_name; if (basename != "." && basename != "..") { result->push_back(file_name); } @@ -457,8 +462,7 @@ Status WindowsFileSystem::GetFileSize(const string& fname, uint64* size) { file_size.HighPart = attrs.nFileSizeHigh; file_size.LowPart = attrs.nFileSizeLow; *size = file_size.QuadPart; - } - else { + } else { string context = "Can not get size for: " + fname; result = IOErrorFromWindowsError(context, ::GetLastError()); } @@ -472,7 +476,7 @@ Status WindowsFileSystem::RenameFile(const string& src, const string& target) { std::wstring ws_translated_src = Utf8ToWideChar(TranslateName(src)); std::wstring ws_translated_target = Utf8ToWideChar(TranslateName(target)); if (!::MoveFileExW(ws_translated_src.c_str(), ws_translated_target.c_str(), - MOVEFILE_REPLACE_EXISTING)) { + MOVEFILE_REPLACE_EXISTING)) { string context(strings::StrCat("Failed to rename: ", src, " to: ", target)); result = IOErrorFromWindowsError(context, ::GetLastError()); } diff --git a/tensorflow/core/platform/windows/windows_file_system.h b/tensorflow/core/platform/windows/windows_file_system.h index 8dcc1530370f0615ec45785a1f3d10ce828d11a3..ba0302f0fd8b56dabaf9271a725bebdac4716102 100644 --- a/tensorflow/core/platform/windows/windows_file_system.h +++ b/tensorflow/core/platform/windows/windows_file_system.h @@ -63,33 +63,35 @@ class WindowsFileSystem : public FileSystem { Status RenameFile(const string& src, const string& target) override; - string TranslateName(const string& name) const override { - return name; - } + string TranslateName(const string& name) const override { return name; } static std::wstring Utf8ToWideChar(const string& utf8str) { - int size_required = MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), (int)utf8str.size(), NULL, 0); - std::wstring ws_translated_str(size_required, 0); - MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), (int)utf8str.size(), &ws_translated_str[0], size_required); - return ws_translated_str; + int size_required = MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), + (int)utf8str.size(), NULL, 0); + std::wstring ws_translated_str(size_required, 0); + MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), (int)utf8str.size(), + &ws_translated_str[0], size_required); + return ws_translated_str; } - static string WideCharToUtf8(const std::wstring &wstr) { - if (wstr.empty()) return std::string(); - int size_required = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), (int)wstr.size(), NULL, 0, NULL, NULL); - string utf8_translated_str(size_required, 0); - WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), (int)wstr.size(), &utf8_translated_str[0], size_required, NULL, NULL); - return utf8_translated_str; + static string WideCharToUtf8(const std::wstring& wstr) { + if (wstr.empty()) return std::string(); + int size_required = WideCharToMultiByte( + CP_UTF8, 0, wstr.c_str(), (int)wstr.size(), NULL, 0, NULL, NULL); + string utf8_translated_str(size_required, 0); + WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), (int)wstr.size(), + &utf8_translated_str[0], size_required, NULL, NULL); + return utf8_translated_str; } }; class LocalWinFileSystem : public WindowsFileSystem { -public: - string TranslateName(const string& name) const override { - StringPiece scheme, host, path; - io::ParseURI(name, &scheme, &host, &path); - return path.ToString(); - } + public: + string TranslateName(const string& name) const override { + StringPiece scheme, host, path; + io::ParseURI(name, &scheme, &host, &path); + return path.ToString(); + } }; } // namespace tensorflow diff --git a/tensorflow/core/profiler/README.md b/tensorflow/core/profiler/README.md index 9e628b10651423a7ce05392e675453c87f8b6c8c..57d76eb4cb9382790c80a0d55ee94b64e7b9dcdc 100644 --- a/tensorflow/core/profiler/README.md +++ b/tensorflow/core/profiler/README.md @@ -240,8 +240,9 @@ Open a Chrome browser, enter URL chrome://tracing and load the timeline file. # can also generate memory profile using `-select bytes` tfprof> code -select accelerator_micros -max_depth 100000 -output pprof:outfile= -trim_name_regexes .*apply_op.* -# Use pprof to visualize the generated file. -pprof -png --nodecount=100 --sample_index=1 +# Use google-pprof, from the google-perftools package to visualize the generated file. +# On Ubuntu you can install it with `apt-get install it google-perftools`. +google-pprof --pdf --nodecount=100 ``` ![PprofGraph](g3doc/pprof.jpg) @@ -256,7 +257,7 @@ bug fix. `OpLogProto` is a good plus if it is used. #### Teams -* Xin Pan (xpan@google.com, github: panyx0718) +* Xin Pan * Chris Antaki * Yao Zhang * Jon Shlens diff --git a/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc b/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc index d05143aff9b8cc0b9a0e9af9445ba79345e4bf62..e968b9c97e28eeae22954102d5f0e07e09d75f7f 100644 --- a/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc +++ b/tensorflow/core/profiler/internal/advisor/tfprof_advisor_test.cc @@ -53,10 +53,13 @@ class TFProfAdvisorTest : public ::testing::Test { NodeExecStats node_stat; node_stat.set_all_start_micros(start_miros); node_stat.set_op_end_rel_micros(end_rel_micros); - node->AddStepStat(step, "/job:localhost/replica:0/task:0/device:GPU:0", node_stat); - node->AddStepStat(step, "/job:localhost/replica:0/task:0/device:GPU:0:stream:all", + node->AddStepStat(step, "/job:localhost/replica:0/task:0/device:GPU:0", node_stat); - node->AddStepStat(step, "/job:localhost/replica:0/task:0/device:GPU:0:stream:0", + node->AddStepStat(step, + "/job:localhost/replica:0/task:0/device:GPU:0:stream:all", + node_stat); + node->AddStepStat(step, + "/job:localhost/replica:0/task:0/device:GPU:0:stream:0", node_stat); return node; } diff --git a/tensorflow/core/profiler/internal/tfprof_op.cc b/tensorflow/core/profiler/internal/tfprof_op.cc index 5a8429d4893effc8bbfa0bf69e18b4a182e9a5df..3dce1d85db35436d162e73bf0946b320b899d5eb 100644 --- a/tensorflow/core/profiler/internal/tfprof_op.cc +++ b/tensorflow/core/profiler/internal/tfprof_op.cc @@ -113,8 +113,9 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, root_->formatted_str = FormatNode(root_.get(), root_.get(), opts); } if (timeline) { - fprintf(stderr, "op view doesn't support timeline yet. " - "Consider graph/scope/code view.\n"); + fprintf(stderr, + "op view doesn't support timeline yet. " + "Consider graph/scope/code view.\n"); return root_.get(); } if (cnodes_map_.empty()) { @@ -265,9 +266,9 @@ string TFOp::FormatNode(OpNode* node, OpNode* root, const Options& opts) const { double pct = 0.0; if (node->proto().total_parameters() > 0) { accu_pct = 100.0 * node->proto().total_parameters() / - root->proto().total_parameters(); - pct = 100.0 * node->proto().parameters() / - root->proto().total_parameters(); + root->proto().total_parameters(); + pct = + 100.0 * node->proto().parameters() / root->proto().total_parameters(); } attrs.push_back(strings::Printf( "%30s", @@ -282,9 +283,8 @@ string TFOp::FormatNode(OpNode* node, OpNode* root, const Options& opts) const { double pct = 0.0; if (node->proto().total_float_ops() > 0) { accu_pct = 100.0 * node->proto().total_float_ops() / - root->proto().total_float_ops(); - pct = 100.0 * node->proto().float_ops() / - root->proto().total_float_ops(); + root->proto().total_float_ops(); + pct = 100.0 * node->proto().float_ops() / root->proto().total_float_ops(); } attrs.push_back(strings::Printf( diff --git a/tensorflow/core/profiler/internal/tfprof_op.h b/tensorflow/core/profiler/internal/tfprof_op.h index fe1c3b2ae826783c1405b6151b82f153c05d2901..aa22182d36cac8d7e1f9fb3143beadfdfe0efce6 100644 --- a/tensorflow/core/profiler/internal/tfprof_op.h +++ b/tensorflow/core/profiler/internal/tfprof_op.h @@ -41,8 +41,7 @@ namespace tfprof { // to input ops. class TFOp : public TFMultiShow { public: - explicit TFOp() - : TFMultiShow() {} + explicit TFOp() : TFMultiShow() {} ~TFOp() override {} void AddNode(TFGraphNode* node) override; @@ -51,7 +50,7 @@ class TFOp : public TFMultiShow { private: const ShowMultiNode* ShowInternal(const Options& opts, - Timeline* timeline) override; + Timeline* timeline) override; int64 SearchRoot(const std::vector nodes, const std::vector& regexes); diff --git a/tensorflow/core/profiler/internal/tfprof_show.h b/tensorflow/core/profiler/internal/tfprof_show.h index 4d6de060705435c5346f6f49810b7dfc05d4530e..81b021549a49625cd5ba4a6ba8130f12cc7cf5f7 100644 --- a/tensorflow/core/profiler/internal/tfprof_show.h +++ b/tensorflow/core/profiler/internal/tfprof_show.h @@ -78,40 +78,43 @@ class TFShow { return nodes; } std::vector sorted_nodes = nodes; - std::sort(sorted_nodes.begin(), sorted_nodes.end(), [&opts](const T* n1, - const T* n2) { - if (n1->name() == kTFProfRoot) return true; - if (n2->name() == kTFProfRoot) return false; - bool name_cmp = n1->name() < n2->name(); - if (opts.order_by == kOrderBy[0]) { - return name_cmp; - } else if (opts.order_by == kOrderBy[1]) { - return n1->proto().total_requested_bytes() > - n2->proto().total_requested_bytes(); - } else if (opts.order_by == kOrderBy[2]) { - return n1->proto().total_peak_bytes() > n2->proto().total_peak_bytes(); - } else if (opts.order_by == kOrderBy[3]) { - return n1->proto().total_residual_bytes() > - n2->proto().total_residual_bytes(); - } else if (opts.order_by == kOrderBy[4]) { - return n1->proto().total_output_bytes() > - n2->proto().total_output_bytes(); - } else if (opts.order_by == kOrderBy[5]) { - return n1->proto().total_exec_micros() > - n2->proto().total_exec_micros(); - } else if (opts.order_by == kOrderBy[6]) { - return n1->proto().total_accelerator_exec_micros() > - n2->proto().total_accelerator_exec_micros(); - } else if (opts.order_by == kOrderBy[7]) { - return n1->proto().total_cpu_exec_micros() > - n2->proto().total_cpu_exec_micros(); - } else if (opts.order_by == kOrderBy[8]) { - return n1->proto().total_parameters() > n2->proto().total_parameters(); - } else if (opts.order_by == kOrderBy[9]) { - return n1->proto().total_float_ops() > n2->proto().total_float_ops(); - } - return name_cmp; - }); + std::sort(sorted_nodes.begin(), sorted_nodes.end(), + [&opts](const T* n1, const T* n2) { + if (n1->name() == kTFProfRoot) return true; + if (n2->name() == kTFProfRoot) return false; + bool name_cmp = n1->name() < n2->name(); + if (opts.order_by == kOrderBy[0]) { + return name_cmp; + } else if (opts.order_by == kOrderBy[1]) { + return n1->proto().total_requested_bytes() > + n2->proto().total_requested_bytes(); + } else if (opts.order_by == kOrderBy[2]) { + return n1->proto().total_peak_bytes() > + n2->proto().total_peak_bytes(); + } else if (opts.order_by == kOrderBy[3]) { + return n1->proto().total_residual_bytes() > + n2->proto().total_residual_bytes(); + } else if (opts.order_by == kOrderBy[4]) { + return n1->proto().total_output_bytes() > + n2->proto().total_output_bytes(); + } else if (opts.order_by == kOrderBy[5]) { + return n1->proto().total_exec_micros() > + n2->proto().total_exec_micros(); + } else if (opts.order_by == kOrderBy[6]) { + return n1->proto().total_accelerator_exec_micros() > + n2->proto().total_accelerator_exec_micros(); + } else if (opts.order_by == kOrderBy[7]) { + return n1->proto().total_cpu_exec_micros() > + n2->proto().total_cpu_exec_micros(); + } else if (opts.order_by == kOrderBy[8]) { + return n1->proto().total_parameters() > + n2->proto().total_parameters(); + } else if (opts.order_by == kOrderBy[9]) { + return n1->proto().total_float_ops() > + n2->proto().total_float_ops(); + } + return name_cmp; + }); return sorted_nodes; } diff --git a/tensorflow/core/profiler/internal/tfprof_show_multi.h b/tensorflow/core/profiler/internal/tfprof_show_multi.h index 2a2208d8e78efd5bc20d0db23e5fdaabbb3e8d5a..711d35f9753cf85f7f318a9ac3de40d6d2bf786e 100644 --- a/tensorflow/core/profiler/internal/tfprof_show_multi.h +++ b/tensorflow/core/profiler/internal/tfprof_show_multi.h @@ -50,7 +50,7 @@ class TFMultiShow { protected: virtual const ShowMultiNode* ShowInternal(const Options& opts, - Timeline* timeline) = 0; + Timeline* timeline) = 0; bool LookUpCheckPoint(const string& name, std::unique_ptr* tensor); diff --git a/tensorflow/core/profiler/internal/tfprof_timeline.h b/tensorflow/core/profiler/internal/tfprof_timeline.h index 4428ab571f84ff75499f24d78af2547d512a8c1c..baf3fb2bedb13e13b21940485ec439c19a97dd02 100644 --- a/tensorflow/core/profiler/internal/tfprof_timeline.h +++ b/tensorflow/core/profiler/internal/tfprof_timeline.h @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/profiler/internal/tfprof_node_show.h" +#include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { namespace tfprof { @@ -178,7 +178,6 @@ class Timeline { int64 step_; const string outfile_; int64 next_pid_ = 0; - int64 allocator_pid_ = -1; MemoryTracker mem_tracker_; ChromeTraceFormatter chrome_formatter_; std::map device_pids_; diff --git a/tensorflow/core/profiler/internal/tfprof_utils.cc b/tensorflow/core/profiler/internal/tfprof_utils.cc index 2813bb46fa44bc1ed04e7e8f5cd02737a81abad4..7712ebd926f1df2d65b7f7d732b55846654ed218 100644 --- a/tensorflow/core/profiler/internal/tfprof_utils.cc +++ b/tensorflow/core/profiler/internal/tfprof_utils.cc @@ -355,9 +355,6 @@ static const char* const kOpTypes = static const char* const kScope = "scope: The nodes in the model graph are organized by their names, which " "is hierarchical like filesystem."; -static const char* const kGraph = - "graph: The nodes in the model graph are organized by their operation " - "input and output."; static const char* const kCode = "code: When python trace is available, the nodes are python lines and " "their are organized by the python call stack."; diff --git a/tensorflow/core/profiler/profiler.cc b/tensorflow/core/profiler/profiler.cc index 2cc212d5898c15c0d066a477068f7c68fa244b54..808e3c853bec0efb9523ee413f3d5272a833358d 100644 --- a/tensorflow/core/profiler/profiler.cc +++ b/tensorflow/core/profiler/profiler.cc @@ -206,8 +206,12 @@ int Run(int argc, char** argv) { "graph_path,op_log_path,run_meta_path\n"); std::unique_ptr graph(new GraphDef()); if (!FLAGS_graph_path.empty()) { - TF_CHECK_OK( - ReadProtoFile(Env::Default(), FLAGS_graph_path, graph.get(), false)); + s = ReadProtoFile(Env::Default(), FLAGS_graph_path, graph.get(), false); + if (!s.ok()) { + fprintf(stderr, "Failed to read graph_path: %s\n", + s.ToString().c_str()); + return 1; + } } std::unique_ptr op_log(new OpLogProto()); diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index ccab69b9c04cad1fdd95f7ff4304fc60e2f459da..3606c5f127ce1f533d018e645b0a48c20e79cd8d 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -387,7 +387,7 @@ message RunOptions { // EXPERIMENTAL. Options used to initialize DebuggerState, if enabled. DebugOptions debug_options = 6; - // When enabled, causes tensor alllocation information to be included in + // When enabled, causes tensor allocation information to be included in // the error message when the Run() call fails because the allocator ran // out of memory (OOM). // diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index d3c3d432a3739a28a5703f3c1a2aa4cbc95f461c..0e9e202bc9a2d2368772c7fede9eb877d9d99023 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -29,7 +29,7 @@ message RewriterConfig { AGGRESSIVE = 3; } - // Optimize tensor layouts + // Optimize tensor layouts (default is ON) Toggle layout_optimizer = 1; // Fold constants (default is ON) Toggle constant_folding = 3; @@ -37,11 +37,13 @@ message RewriterConfig { Toggle arithmetic_optimization = 7; // Control dependency optimizations (default is ON). Toggle dependency_optimization = 8; + // Loop optimizations (default is OFF). + Toggle loop_optimization = 9; // If true, don't remove unnecessary ops from the graph bool disable_model_pruning = 2; enum MemOptType { - // The default setting (currently disabled) + // The default setting (SCHEDULING_HEURISTICS only) DEFAULT_MEM_OPT = 0; // Disabled in the meta-optimizer. NO_MEM_OPT = 1; diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 67da7bf4526235ae51eb172f8da9fc267cc12b98..7405e01e14494fb6e4e241f1a2b8bc33a4200fa7 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -19,7 +19,7 @@ limitations under the License. // TensorFlow uses semantic versioning, see http://semver.org/. #define TF_MAJOR_VERSION 1 -#define TF_MINOR_VERSION 5 +#define TF_MINOR_VERSION 6 #define TF_PATCH_VERSION 0 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc index 1eab7e3d024c181f260500686b9127dd76dbe206..3a5f1f83af8d2d2324f3139568aa69f204cf1248 100644 --- a/tensorflow/core/util/bcast.cc +++ b/tensorflow/core/util/bcast.cc @@ -69,9 +69,9 @@ BCast::BCast(const Vec& sx, const Vec& sy, const bool fewer_dims_optimization) { State curr = UNKNOWN; const int64 x_i = x[i]; // i-th dimension of x. const int64 y_i = y[i]; // i-th dimension of y. - int64 o_i; // i-th dimension of the output. - int64 bx_i; // i-th broadcast for x. - int64 by_i; // i-th broadcast for y. + int64 o_i; // i-th dimension of the output. + int64 bx_i; // i-th broadcast for x. + int64 by_i; // i-th broadcast for y. // Invariant: // o_i = x_i * bx_i = y_i * by_i if (x_i == y_i) { diff --git a/tensorflow/core/util/ctc/ctc_loss_calculator.h b/tensorflow/core/util/ctc/ctc_loss_calculator.h index be00895b0d3517fe06a852685f79f32e5a0b5167..dd1163310bf406b66bdd450ac6bf840272f7c592 100644 --- a/tensorflow/core/util/ctc/ctc_loss_calculator.h +++ b/tensorflow/core/util/ctc/ctc_loss_calculator.h @@ -130,13 +130,13 @@ Status CTCLossCalculator::CalculateLoss( for (int t = 1; t < num_time_steps; ++t) { if (inputs[t].rows() != batch_size) { return errors::InvalidArgument("Expected batch size at t: ", t, - " to be: ", batch_size, " but got: ", - inputs[t].rows()); + " to be: ", batch_size, + " but got: ", inputs[t].rows()); } if (inputs[t].cols() != num_classes) { return errors::InvalidArgument("Expected class count at t: ", t, - " to be: ", num_classes, " but got: ", - inputs[t].cols()); + " to be: ", num_classes, + " but got: ", inputs[t].cols()); } } @@ -282,8 +282,8 @@ Status CTCLossCalculator::PopulateLPrimes( LabelSequences* l_primes) const { // labels is a Label array of size batch_size if (labels.size() != batch_size) { - return errors::InvalidArgument("labels.size() != batch_size: ", - labels.size(), " vs. ", batch_size); + return errors::InvalidArgument( + "labels.size() != batch_size: ", labels.size(), " vs. ", batch_size); } *max_u_prime = 0; // keep track of longest l' modified label sequence. @@ -325,12 +325,13 @@ Status CTCLossCalculator::PopulateLPrimes( for (int l_i : l) { if (l_i < 0) { return errors::InvalidArgument( - "All labels must be nonnegative integers, batch: ", b, " labels: ", - str_util::Join(l, ",")); + "All labels must be nonnegative integers, batch: ", b, + " labels: ", str_util::Join(l, ",")); } else if (l_i >= num_classes) { return errors::InvalidArgument( - "No label may be greater than num_classes. ", "num_classes: ", - num_classes, ", batch: ", b, " labels: ", str_util::Join(l, ",")); + "No label may be greater than num_classes. ", + "num_classes: ", num_classes, ", batch: ", b, + " labels: ", str_util::Join(l, ",")); } } if (!ignore_longer_outputs_than_inputs) { diff --git a/tensorflow/core/util/cuda_device_functions.h b/tensorflow/core/util/cuda_device_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..f2d4e470c82d9a1480ac1bf7726a7a7a9ae08715 --- /dev/null +++ b/tensorflow/core/util/cuda_device_functions.h @@ -0,0 +1,635 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_ +#define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_ + +/** + * Wrappers and helpers for CUDA device code. + * + * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide + * backwards compatibility, see go/volta-porting for details. + * Provides atomic operations on types that aren't natively supported. + */ + +#if GOOGLE_CUDA + +#include +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "cuda/include/cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace detail { + +// Helper for range-based for loop using 'delta' increments. +// Usage: see CudaGridRange?() functions below. +template +class CudaGridRange { + struct Iterator { + __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {} + __device__ T operator*() const { return index_; } + __device__ Iterator& operator++() { + index_ += delta_; + return *this; + } + __device__ bool operator!=(const Iterator& other) const { + bool greater = index_ > other.index_; + bool less = index_ < other.index_; + // Anything past an end iterator (delta_ == 0) is equal. + // In range-based for loops, this optimizes to 'return less'. + if (!other.delta_) { + return less; + } + if (!delta_) { + return greater; + } + return less || greater; + } + + private: + T index_; + const T delta_; + }; + + public: + __device__ CudaGridRange(T begin, T delta, T end) + : begin_(begin), delta_(delta), end_(end) {} + + __device__ Iterator begin() const { return Iterator{begin_, delta_}; } + __device__ Iterator end() const { return Iterator{end_, 0}; } + + private: + T begin_; + T delta_; + T end_; +}; + +} // namespace detail + +// Helper to visit indices in the range 0 <= i < count, using the x-coordinate +// of the global thread index. That is, each index i is visited by all threads +// with the same x-coordinate. +// Usage: for(int i : CudaGridRangeX(count)) { visit(i); } +template +__device__ detail::CudaGridRange CudaGridRangeX(T count) { + return detail::CudaGridRange(blockIdx.x * blockDim.x + threadIdx.x, + gridDim.x * blockDim.x, count); +} + +// Helper to visit indices in the range 0 <= i < count using the y-coordinate. +// Usage: for(int i : CudaGridRangeY(count)) { visit(i); } +template +__device__ detail::CudaGridRange CudaGridRangeY(T count) { + return detail::CudaGridRange(blockIdx.y * blockDim.y + threadIdx.y, + gridDim.y * blockDim.y, count); +} + +// Helper to visit indices in the range 0 <= i < count using the z-coordinate. +// Usage: for(int i : CudaGridRangeZ(count)) { visit(i); } +template +__device__ detail::CudaGridRange CudaGridRangeZ(T count) { + return detail::CudaGridRange(blockIdx.z * blockDim.z + threadIdx.z, + gridDim.z * blockDim.z, count); +} + +// Mask for all 32 threads in a warp. +const unsigned kCudaWarpAll = 0xffffffff; + +// Returns the warp lane ID of the calling thread +__device__ inline unsigned CudaLaneId() { + unsigned int lane_id; + asm("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +} + +namespace detail { +// Returns true if mask is a valid parameter for __shfl*sync to return a well +// defined value, assuming the calling lane will read from src_lane as part of +// the shuffle operation. +// +// Specifically, returns true iff mask has the calling lane bit and the src_lane +// bit set, and the src_lane calls this function with the same mask value +// (required for the two threads to wait for each other). +// +// On Volta, for some invalid masks, this function hangs or returns false +// positives, because the implementation shuffles with the same mask that +// we are validating. Run on Pascal if you suspect that the mask is incorrect. +__device__ inline bool CudaValidateShuffleSyncMask(unsigned mask, + unsigned src_lane) { + unsigned src_dst_mask = 1u << CudaLaneId() | 1u << src_lane; +#if CUDA_VERSION >= 9000 + unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane); +#else + unsigned src_lane_mask = __shfl(mask, src_lane); +#endif + return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask; +} + +// Returns the actual source lane for shuffle. +__device__ inline unsigned CudaShuffleGetSrcLane(int src_lane, int width) { + int lane_id = CudaLaneId(); + int lane_base = lane_id & ~width + 1; + int lane_offset = src_lane & width - 1; + return lane_base + lane_offset; +} + +// Returns the source lane for shuffle up. +__device__ inline unsigned CudaShuffleUpGetSrcLane(unsigned delta, int width) { + unsigned lane_id = CudaLaneId(); + if ((lane_id & width - 1) < delta) { + return lane_id; + } + return lane_id - delta; +} + +// Returns the source lane for shuffle down. +__device__ inline unsigned CudaShuffleDownGetSrcLane(unsigned delta, + int width) { + unsigned lane_id = CudaLaneId(); + if ((lane_id & width - 1) + delta >= width) { + return lane_id; + } + return lane_id + delta; +} + +// Returns the source lane for shuffle xor. +__device__ inline unsigned CudaShuffleXorGetSrcLane(int lane_mask, int width) { + int lane_id = CudaLaneId(); + int src_lane = lane_id ^ lane_mask; + if (src_lane > (lane_id | width - 1)) { + return lane_id; + } + return src_lane; +} +} // namespace detail + +// For all *_sync wrappers below, it is illegal to synchronize threads from +// different program locations, because that is not supported before sm_70. +// In other words, all threads in 'mask' must call the functions in convergence. +// Code that requires sm_70 (and CUDA 9) may use the intrinsic directly. +// +// It is also illegal to shuffle with a mask that produces an undefined result +// for any of the threads. Specifically, all source threads of the shuffle +// must have their corresponding bit in 'mask' set. + +// Wrapper for __syncwarp. No-op for CUDA 8 and earlier. +__device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) { + assert(mask & 1u << CudaLaneId()); +#if CUDA_VERSION >= 9000 + __syncwarp(mask); +#endif +} + +// Wrapper for __ballot_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +__device__ inline unsigned CudaBallotSync(unsigned mask, int pred) { + assert(mask & 1u << CudaLaneId()); +#if CUDA_VERSION >= 9000 + return __ballot_sync(mask, pred); +#else + return __ballot(pred) & mask; // Apply mask to match __ballot_sync's spec. +#endif +} + +// Wrapper for __any_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +__device__ inline int CudaAnySync(unsigned mask, int pred) { + assert(mask & 1u << CudaLaneId()); +#if CUDA_VERSION >= 9000 + return __any_sync(mask, pred); +#else + return __any(pred); +#endif +} + +// Wrapper for __all_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +__device__ inline int CudaAllSync(unsigned mask, int pred) { + assert(mask & 1u << CudaLaneId()); +#if CUDA_VERSION >= 9000 + return __all_sync(mask, pred); +#else + return __all(pred); +#endif +} + +// Wrapper for __shfl_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +template +__device__ T CudaShuffleSync(unsigned mask, T value, int src_lane, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::CudaValidateShuffleSyncMask( + mask, detail::CudaShuffleGetSrcLane(src_lane, width))); +#if CUDA_VERSION >= 9000 + return __shfl_sync(mask, value, src_lane, width); +#else + return __shfl(value, src_lane, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double CudaShuffleSync(unsigned mask, double value, + int src_lane, int width = warpSize) { + unsigned lo, hi; + asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); + hi = CudaShuffleSync(mask, hi, src_lane, width); + lo = CudaShuffleSync(mask, lo, src_lane, width); + asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); + return value; +} + +// Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +template +__device__ inline T CudaShuffleUpSync(unsigned mask, T value, unsigned delta, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::CudaValidateShuffleSyncMask( + mask, detail::CudaShuffleUpGetSrcLane(delta, width))); +#if CUDA_VERSION >= 9000 + return __shfl_up_sync(mask, value, delta, width); +#else + return __shfl_up(value, delta, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double CudaShuffleUpSync(unsigned mask, double value, + unsigned delta, + int width = warpSize) { + unsigned lo, hi; + asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); + hi = CudaShuffleUpSync(mask, hi, delta, width); + lo = CudaShuffleUpSync(mask, lo, delta, width); + asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); + return value; +} + +// Wrapper for __shfl_down_sync. All threads in 'mask' must call this function +// in convergence, see comment above for details. +template +__device__ inline T CudaShuffleDownSync(unsigned mask, T value, unsigned delta, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::CudaValidateShuffleSyncMask( + mask, detail::CudaShuffleDownGetSrcLane(delta, width))); +#if CUDA_VERSION >= 9000 + return __shfl_down_sync(mask, value, delta, width); +#else + return __shfl_down(value, delta, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double CudaShuffleDownSync(unsigned mask, double value, + unsigned delta, + int width = warpSize) { + unsigned lo, hi; + asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); + hi = CudaShuffleDownSync(mask, hi, delta, width); + lo = CudaShuffleDownSync(mask, lo, delta, width); + asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); + return value; +} + +// Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +template +__device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::CudaValidateShuffleSyncMask( + mask, detail::CudaShuffleXorGetSrcLane(lane_mask, width))); +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, lane_mask, width); +#else + return __shfl_xor(value, lane_mask, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double CudaShuffleXorSync(unsigned mask, double value, + int lane_mask, + int width = warpSize) { + unsigned lo, hi; + asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); + hi = CudaShuffleXorSync(mask, hi, lane_mask, width); + lo = CudaShuffleXorSync(mask, lo, lane_mask, width); + asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); + return value; +} + +// Wrapper for __ldg. +template +__host__ __device__ T CudaLdg(const T* address) { +#if __CUDA_ARCH__ >= 350 + return __ldg(address); +#else + return *address; +#endif +} + +__host__ __device__ inline bool CudaLdg(const bool* address) { + return CudaLdg(reinterpret_cast(address)) != 0; +} + +__host__ __device__ inline std::complex CudaLdg( + const std::complex* address) { +#if __CUDA_ARCH__ >= 350 + float2 mem = __ldg(reinterpret_cast(address)); + return std::complex(mem.x, mem.y); +#else + return *address; +#endif +} + +__host__ __device__ inline std::complex CudaLdg( + const std::complex* address) { +#if __CUDA_ARCH__ >= 350 + double2 mem = __ldg(reinterpret_cast(address)); + return std::complex(mem.x, mem.y); +#else + return *address; +#endif +} + +// Zeroes count elements starting at ptr using all threads of a 1-D grid. +// Note: this function does not synchronize, and therefore the memory range is +// not guaranteed to be zero until the next kernel launch. +template +__global__ void SetZero(const int count, T* ptr) { + // Check that the grid is one dimensional and index doesn't overflow. + assert(blockDim.y == 1 && blockDim.z == 1); + assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x); + for (int i : CudaGridRangeX(count)) { + ptr[i] = T(0); + } +} + +// Helper to set all tensor entries to a specific value. +template +__global__ void SetToValue(const int count, T* ptr, T value) { + // Check that the grid is one dimensional and index doesn't overflow. + assert(blockDim.y == 1 && blockDim.z == 1); + assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x); + for (int i : CudaGridRangeX(count)) { + ptr[i] = value; + } +} + +namespace detail { +// Helper function for atomic accumulation implemented as CAS. +template +__device__ T CudaAtomicCasHelper(T* ptr, F accumulate) { + T old = *ptr; + T assumed; + do { + assumed = old; + old = atomicCAS(ptr, assumed, accumulate(assumed)); + } while (assumed != old); + return old; +} + +// Overload for floating point (using integer comparison to handle NaN +// correctly). +template +__device__ float CudaAtomicCasHelper(float* ptr, F accumulate) { + return __float_as_int( + CudaAtomicCasHelper(reinterpret_cast(ptr), [accumulate](int32 a) { + return __float_as_int(accumulate(__int_as_float(a))); + })); +} +template +__device__ double CudaAtomicCasHelper(double* ptr, F accumulate) { + return __longlong_as_double(CudaAtomicCasHelper( + reinterpret_cast(ptr), + [accumulate](tensorflow::uint64 a) { + return __double_as_longlong(accumulate(__longlong_as_double(a))); + })); +} + +// Overload of above function for half. Note that we don't have +// atomicCAS() for anything less than 32 bits, so we need to include the +// other 16 bits in the operation. +// +// This version is going to be very slow +// under high concurrency, since most threads will be spinning on failing +// their compare-and-swap tests. (The fact that we get false sharing on the +// neighboring fp16 makes this even worse.) If you are doing a large reduction, +// you are much better off with doing the intermediate steps in fp32 and then +// switching to fp16 as late as you can in the calculations. +// +// Note: Assumes little endian. +template +__device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) { +#if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) + static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian"); +#endif + namespace half_impl = Eigen::half_impl; + intptr_t intptr = reinterpret_cast(ptr); + assert(!(intptr & 0x1)); // should be 2-aligned. + if (intptr & 0x2) { + // The half is in the second part of the uint32 (upper 16 bits). + uint32* address = reinterpret_cast(intptr - 2); + uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) { + unsigned short high = static_cast(arg >> 16); + Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high)); + return (static_cast(acc.x) << 16) | (arg & 0xffff); + }); + return half_impl::raw_uint16_to_half(static_cast(result >> 16)); + } else { + // The half is in the first part of the uint32 (lower 16 bits). + uint32* address = reinterpret_cast(intptr); + uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) { + unsigned short low = static_cast(arg & 0xffff); + Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low)); + return (arg & 0xffff0000) | static_cast(acc.x); + }); + return half_impl::raw_uint16_to_half(static_cast(result & 0xffff)); + } +} + +template +using ToTypeIfConvertible = + typename std::enable_if::value, To>::type; + +} // namespace detail + +// CUDA provides atomic ops, but not for all types. We provide wrappers +// for some ops and provide implementation for all reasonable types. + +template +__device__ detail::ToTypeIfConvertible CudaAtomicAdd(T* ptr, U value) { + return atomicAdd(ptr, value); +} + +__device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr, + Eigen::half value) { + return detail::CudaAtomicCasHelper( + ptr, [value](Eigen::half a) { return a + value; }); +} + + +#if __CUDA_ARCH__ < 600 +__device__ inline double CudaAtomicAdd(double* ptr, double value) { + return detail::CudaAtomicCasHelper(ptr, + [value](double a) { return a + value; }); +} +#elif __clang__ +// Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX. +// see https://reviews.llvm.org/D39638 +__device__ inline double CudaAtomicAdd(double* ptr, double value) { + double result; + asm volatile("atom.add.f64 %0, [%1], %2;" + : "=d"(result) + : "l"(ptr), "d"(value) + : "memory"); + return result; +} +#endif +// CudaAtomicAdd +// Specializations of CudaAtomicAdd for complex types, which CudaAtomicAdd does +// not support. We treat a std::complex* as a T* (the C++ standard section +// 26.4.4 allows this explicitly) and atomic add the real and imaginary +// components individually. The operation as a whole is not atomic, but we can +// safely treat the components independently for the purpose of accumulating. +__device__ inline std::complex CudaAtomicAdd(std::complex* ptr, + std::complex value) { + auto ptr_scalar = reinterpret_cast(ptr); + return std::complex(CudaAtomicAdd(ptr_scalar, value.real()), + CudaAtomicAdd(ptr_scalar + 1, value.imag())); +} + +__device__ inline std::complex CudaAtomicAdd( + std::complex* ptr, std::complex value) { + auto ptr_scalar = reinterpret_cast(ptr); + return std::complex(CudaAtomicAdd(ptr_scalar, value.real()), + CudaAtomicAdd(ptr_scalar + 1, value.imag())); +} + +// CudaAtomicSub +template +__device__ detail::ToTypeIfConvertible CudaAtomicSub(T* ptr, U value) { + return atomicSub(ptr, value); +} + +// Specializations of substraction which add the negative value. +__device__ inline float CudaAtomicSub(float* ptr, float value) { + return CudaAtomicAdd(ptr, -value); +} + +__device__ inline double CudaAtomicSub(double* ptr, double value) { + return CudaAtomicAdd(ptr, -value); +} + +__device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr, + tensorflow::uint64 value) { + return CudaAtomicAdd(ptr, -value); +} + +__device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr, + Eigen::half value) { + return detail::CudaAtomicCasHelper( + ptr, [value](Eigen::half a) { return a - value; }); +} + +// CudaAtomicMax +template +__device__ detail::ToTypeIfConvertible CudaAtomicMax(T* ptr, U value) { + return atomicMax(ptr, value); +} + +__device__ inline float CudaAtomicMax(float* ptr, float value) { + return detail::CudaAtomicCasHelper( + ptr, [value](float a) { return max(a, value); }); +} + +__device__ inline double CudaAtomicMax(double* ptr, double value) { + return detail::CudaAtomicCasHelper( + ptr, [value](double a) { return max(a, value); }); +} + +__device__ inline Eigen::half CudaAtomicMax(Eigen::half* ptr, + Eigen::half value) { + return detail::CudaAtomicCasHelper( + ptr, [value](Eigen::half a) { return max(a, value); }); +} + +#if __CUDA_ARCH__ < 320 +__device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr, + tensorflow::uint64 value) { + return detail::CudaAtomicCasHelper( + ptr, [value](tensorflow::uint64 a) { return max(a, value); }); +} +#endif + +// CudaAtomicMin +template +__device__ detail::ToTypeIfConvertible CudaAtomicMin(T* ptr, U value) { + return atomicMin(ptr, value); +} + +__device__ inline float CudaAtomicMin(float* ptr, float value) { + return detail::CudaAtomicCasHelper( + ptr, [value](float a) { return min(a, value); }); +} + +__device__ inline double CudaAtomicMin(double* ptr, double value) { + return detail::CudaAtomicCasHelper( + ptr, [value](double a) { return min(a, value); }); +} + +__device__ inline Eigen::half CudaAtomicMin(Eigen::half* ptr, + Eigen::half value) { + return detail::CudaAtomicCasHelper( + ptr, [value](Eigen::half a) { return min(a, value); }); +} + +#if __CUDA_ARCH__ < 320 +__device__ inline tensorflow::uint64 CudaAtomicMin(tensorflow::uint64* ptr, + tensorflow::uint64 value) { + return detail::CudaAtomicCasHelper( + ptr, [value](tensorflow::uint64 a) { return min(a, value); }); +} +#endif + +// CudaAtomicMul +template +__device__ detail::ToTypeIfConvertible CudaAtomicMul(T* ptr, U value) { + return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; }); +} + +// CudaAtomicDiv +template +__device__ detail::ToTypeIfConvertible CudaAtomicDiv(T* ptr, U value) { + return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; }); +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index 3e32ec79731e1529affb49cf6e1aff3f23b84262..3c59524cb6f85911544b8f2d7d3339e19af7f5b4 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -18,299 +18,79 @@ limitations under the License. #if GOOGLE_CUDA -#include +#include "tensorflow/core/util/cuda_device_functions.h" +#include "tensorflow/core/util/cuda_launch_config.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "cuda/include/cuda.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor.h" -#include "tensorflow/core/platform/types.h" +// Deprecated, use 'for(int i : CudaGridRangeX(n))' instead. +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i : ::tensorflow::CudaGridRangeX(n)) +// Deprecated, use 'for(int i : CudaGridRange?(n))' instead. +#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \ + for (int i : ::tensorflow::CudaGridRange##axis(n)) -// Mask for all 32 threads in a warp. -#define CUDA_WARP_ALL 0xFFFFFFFF - -#if defined(CUDA_VERSION) && CUDA_VERSION < 9000 -// CUDA 9.0 introduces a new, light-weight barrier synchronization primitive -// that operates at the warp-scope. This is required to ensure visibility of -// reads/writes among threads that can make indepenent progress on Volta. -// For previous CUDA versions these synchronizations not necessary, and we -// define an empty function as a convenience for backward compatibility. -__device__ inline void __syncwarp(unsigned mask = CUDA_WARP_ALL) {} - -// CUDA 9.0 deprecates the warp-intrinsic functions (shfl, ballot, etc.) in -// favor of synchronizing versions. These ensure that all warp lanes specified -// in mask execute the intrinsic in convergence. Here we provide legacy mappings -// to the less-verbose routines provided in previous versions of CUDA. -#define __ballot_sync(mask, predicate) __ballot(predicate) -#define __shfl_sync(mask, val, srcLane, width) __shfl(val, srcLane, width) -#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width) -#define __shfl_up_sync(mask, val, delta, width) __shfl_up(val, delta, width) -#define __shfl_xor_sync(mask, val, laneMask, width) \ - __shfl_xor(val, laneMask, width) -#endif - -// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and -// GetCuda3DLaunchConfig: -// -// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one -// version uses heuristics without any knowledge of the device kernel, the other -// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical -// launch parameters that maximize occupancy. Currently, only the maximum -// occupancy version of GetCuda3DLaunchConfig is available. -// -// For large number of work elements, the convention is that each kernel would -// iterate through its assigned range. The return value of GetCudaLaunchConfig -// is struct CudaLaunchConfig, which contains all the information needed for the -// kernel launch, including: virtual number of threads, the number of threads -// per block and number of threads per block used inside <<< >>> of a kernel -// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing -// as CudaLaunchConfig. The only difference is the dimension. The macros -// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop. -// -/* Sample code: - -__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) { - CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) { - do_your_job_here; - } +namespace tensorflow { +__host__ __device__ inline tensorflow::bfloat16 CudaLdg( + const tensorflow::bfloat16* address) { + tensorflow::bfloat16 return_value; + return_value.value = CudaLdg(reinterpret_cast(address)); + return return_value; } -__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { - do_your_job_here; - } - } +template +__host__ __device__ inline T ldg(const T* ptr) { + return CudaLdg(ptr); } -__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { - CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { - do_your_job_here; - } - } - } +template +__host__ __device__ inline const T& tf_min(const T& x, const T& y) { + return x < y ? x : y; } -void MyDriverFunc(const GPUDevice &d) { - // use heuristics - CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d); - MyKernel1D <<>> (cfg1, other_args...); - Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d); - MyKernel2D <<>> (cfg2, other_args...); - Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d); - MyKernel3D <<>> (cfg3, other_args...); - - // maximize occupancy - CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 ); - MyKernel1D <<>> (cfg4, other_args...); - Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d, - MyKernel1D, 0, 0); - MyKernel2D <<>> (cfg5, other_args...); - Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d, - MyKernel1D, 0, 0); - MyKernel3D <<>> (cfg6, other_args...); +template +__host__ __device__ inline const T& tf_max(const T& x, const T& y) { + return x < y ? y : x; } -// See the test for this for more example: -// -https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc - -*/ - -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) - -#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \ - for (int i = blockIdx.axis * blockDim.axis + threadIdx.axis; i < n.axis; \ - i += blockDim.axis * gridDim.axis) - -#define DIV_UP(a, b) (((a) + (b)-1) / (b)) - -namespace tensorflow { - -typedef Eigen::GpuDevice GPUDevice; - -struct CudaLaunchConfig { - // Logical number of thread that works on the elements. If each logical - // thread works on exactly a single element, this is the same as the working - // element count. - int virtual_thread_count = -1; - // Number of threads per block. - int thread_per_block = -1; - // Number of blocks for Cuda kernel launch. - int block_count = -1; -}; - -// Calculate the Cuda launch config we should use for a kernel launch. -// This is assuming the kernel is quite simple and will largely be -// memory-limited. -// REQUIRES: work_element_count > 0. -inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, - const GPUDevice& d) { - CHECK_GT(work_element_count, 0); - CudaLaunchConfig config; - const int virtual_thread_count = work_element_count; - const int physical_thread_count = std::min( - d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(), - virtual_thread_count); - const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock()); - const int block_count = - std::min(DIV_UP(physical_thread_count, thread_per_block), - d.getNumCudaMultiProcessors()); - - config.virtual_thread_count = virtual_thread_count; - config.thread_per_block = thread_per_block; - config.block_count = block_count; - return config; +// Overloads of the above functions for float and double. +__host__ __device__ inline float tf_min(float x, float y) { + return fminf(x, y); } - -// Calculate the Cuda launch config we should use for a kernel launch. This -// variant takes the resource limits of func into account to maximize occupancy. -// REQUIRES: work_element_count > 0. -template -inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, - const GPUDevice& d, DeviceFunc func, - size_t dynamic_shared_memory_size, - int block_size_limit) { - CHECK_GT(work_element_count, 0); - CudaLaunchConfig config; - int block_count = 0; - int thread_per_block = 0; - - cudaError_t err = cudaOccupancyMaxPotentialBlockSize( - &block_count, &thread_per_block, func, dynamic_shared_memory_size, - block_size_limit); - CHECK_EQ(err, cudaSuccess); - - block_count = - std::min(block_count, DIV_UP(work_element_count, thread_per_block)); - - config.virtual_thread_count = work_element_count; - config.thread_per_block = thread_per_block; - config.block_count = block_count; - return config; +__host__ __device__ inline double tf_min(double x, double y) { + return fmin(x, y); } - -struct Cuda2DLaunchConfig { - dim3 virtual_thread_count = dim3(0, 0, 0); - dim3 thread_per_block = dim3(0, 0, 0); - dim3 block_count = dim3(0, 0, 0); -}; - -inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim, - const GPUDevice& d) { - Cuda2DLaunchConfig config; - - if (xdim <= 0 || ydim <= 0) { - return config; - } - - const int kThreadsPerBlock = 256; - int block_cols = std::min(xdim, kThreadsPerBlock); - // ok to round down here and just do more loops in the kernel - int block_rows = std::max(kThreadsPerBlock / block_cols, 1); - - const int physical_thread_count = - d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(); - - const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1); - - config.virtual_thread_count = dim3(xdim, ydim, 1); - config.thread_per_block = dim3(block_cols, block_rows, 1); - - int grid_x = std::min(DIV_UP(xdim, block_cols), max_blocks); - - config.block_count = dim3( - grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1); - return config; +__host__ __device__ inline float tf_max(float x, float y) { + return fmaxf(x, y); +} +__host__ __device__ inline double tf_max(double x, double y) { + return fmax(x, y); } -// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch. -// This variant takes the resource limits of func into account to maximize -// occupancy. -using Cuda3DLaunchConfig = Cuda2DLaunchConfig; - -template -inline Cuda3DLaunchConfig GetCuda3DLaunchConfig( - int xdim, int ydim, int zdim, const GPUDevice& d, DeviceFunc func, - size_t dynamic_shared_memory_size, int block_size_limit) { - Cuda3DLaunchConfig config; - - if (xdim <= 0 || ydim <= 0 || zdim <= 0) { - return config; - } - - int dev; - cudaGetDevice(&dev); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, dev); - int xthreadlimit = deviceProp.maxThreadsDim[0]; - int ythreadlimit = deviceProp.maxThreadsDim[1]; - int zthreadlimit = deviceProp.maxThreadsDim[2]; - int xgridlimit = deviceProp.maxGridSize[0]; - int ygridlimit = deviceProp.maxGridSize[1]; - int zgridlimit = deviceProp.maxGridSize[2]; - - int block_count = 0; - int thread_per_block = 0; - cudaError_t err = cudaOccupancyMaxPotentialBlockSize( - &block_count, &thread_per_block, func, dynamic_shared_memory_size, - block_size_limit); - CHECK_EQ(err, cudaSuccess); - -#define MIN3(a, b, c) std::min((a), std::min((b), (c))) - int threadsx = MIN3(xdim, thread_per_block, xthreadlimit); - int threadsy = - MIN3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit); - int threadsz = - MIN3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1), - zthreadlimit); - - int blocksx = MIN3(block_count, DIV_UP(xdim, threadsx), xgridlimit); - int blocksy = - MIN3(DIV_UP(block_count, blocksx), DIV_UP(ydim, threadsy), ygridlimit); - int blocksz = MIN3(DIV_UP(block_count, (blocksx * blocksy)), - DIV_UP(zdim, threadsz), zgridlimit); -#undef MIN3 +__device__ inline Eigen::half CudaShuffleSync(unsigned mask, Eigen::half value, + int src_lane, + int width = warpSize) { + return Eigen::half( + CudaShuffleSync(mask, static_cast(value), src_lane, width)); +} - config.virtual_thread_count = dim3(xdim, ydim, zdim); - config.thread_per_block = dim3(threadsx, threadsy, threadsz); - config.block_count = dim3(blocksx, blocksy, blocksz); - return config; +__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleUpSync( + unsigned mask, Eigen::half value, int delta, int width = warpSize) { + return Eigen::half( + CudaShuffleUpSync(mask, static_cast(value), delta, width)); } -template -inline Cuda2DLaunchConfig GetCuda2DLaunchConfig( - int xdim, int ydim, const GPUDevice& d, DeviceFunc func, - size_t dynamic_shared_memory_size, int block_size_limit) { - return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func, - dynamic_shared_memory_size, block_size_limit); +__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDownSync( + unsigned mask, Eigen::half value, int delta, int width = warpSize) { + return Eigen::half( + CudaShuffleDownSync(mask, static_cast(value), delta, width)); } -// Returns a raw reference to the current cuda stream. Required by a -// number of kernel calls (for which StreamInterface* does not work), i.e. -// CUB and certain cublas primitives. -inline const cudaStream_t& GetCudaStream(OpKernelContext* context) { - const cudaStream_t* ptr = CHECK_NOTNULL( - reinterpret_cast(context->op_device_context() - ->stream() - ->implementation() - ->CudaStreamMemberHack())); - return *ptr; +__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync( + unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) { + return Eigen::half( + CudaShuffleXorSync(mask, static_cast(value), lane_mask, width)); } namespace cuda_helper { - template __device__ IntType upper_bound(IntType* first, IntType count, IntType val) { IntType* orig = first; @@ -330,495 +110,8 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) { return first - orig; } - } // namespace cuda_helper - -template -__device__ __host__ inline T ldg(const T* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - return __ldg(address); -#else - return *address; -#endif -} - -template <> -__device__ __host__ inline std::complex ldg( - const std::complex* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - float2 mem = __ldg(reinterpret_cast(address)); - return std::complex(mem.x, mem.y); -#else - return *address; -#endif -} - -template <> -__device__ __host__ inline std::complex ldg( - const std::complex* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - double2 mem = __ldg(reinterpret_cast(address)); - return std::complex(mem.x, mem.y); -#else - return *address; -#endif -} - -template <> -__device__ __host__ inline Eigen::half ldg(const Eigen::half* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - return Eigen::half_impl::raw_uint16_to_half( - __ldg(reinterpret_cast(address))); -#else - return *address; -#endif -} - -template <> -__device__ __host__ inline tensorflow::bfloat16 ldg( - const tensorflow::bfloat16* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - tensorflow::bfloat16 return_value; - asm volatile("ld.global.nc.u16 %0, [%1];" - : "=h"(return_value.value) - : "l"(address)); - return return_value; -#else - return *address; -#endif -} - -template <> -__device__ __host__ inline bool ldg(const bool* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - return *reinterpret_cast( - __ldg(reinterpret_cast(address))); -#else - return *address; -#endif -} - -// CUDA provides atomic ops, but not for all types. We provide wrappers -// for some ops and provide implementation for all reasonable types. -#define CUDA_ATOMIC_WRAPPER(op, T) \ - __device__ __forceinline__ T CudaAtomic##op(T* address, T val) - -#define USE_CUDA_ATOMIC(op, T) \ - CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } - -// For atomicAdd. -USE_CUDA_ATOMIC(Add, int32); -USE_CUDA_ATOMIC(Add, uint32); -USE_CUDA_ATOMIC(Add, uint64); -USE_CUDA_ATOMIC(Add, float); - -// For atomicMax. -USE_CUDA_ATOMIC(Max, int32); -USE_CUDA_ATOMIC(Max, uint32); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 -USE_CUDA_ATOMIC(Max, uint64); -#else -// The uint64 overload of atomicMax() is only available for __CUDA_ARCH__ >= -// 350. If not satisfied, we provide a custom implementation using atomicCAS(). -CUDA_ATOMIC_WRAPPER(Max, uint64) { - uint64* address_as_ull = reinterpret_cast(address); - uint64 old = *address_as_ull, assumed; - - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, max(val, assumed)); - } while (assumed != old); - - return old; -} -#endif - -// Custom implementation of atomicAdd for double. -// This implementation is copied from CUDA manual. -CUDA_ATOMIC_WRAPPER(Add, double) { - uint64* address_as_ull = reinterpret_cast(address); - uint64 old = *address_as_ull, assumed; - - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val + __longlong_as_double(assumed))); - - // Note: uses integer comparison to avoid hang in case of NaN - } while (assumed != old); - - return __longlong_as_double(old); -} - -// Custom implementation of atomicAdd for std::complex. -// This implementation performs to atomic additions on the components. -CUDA_ATOMIC_WRAPPER(Add, std::complex) { -#if defined(__CUDA_ARCH__) -#if __CUDA_ARCH__ >= 350 - float2* addr_as_float2 = reinterpret_cast(address); - float2* val_as_float2 = reinterpret_cast(&val); - CudaAtomicAdd(&(addr_as_float2->x), val_as_float2->x); - CudaAtomicAdd(&(addr_as_float2->y), val_as_float2->y); -#else - static_assert(sizeof(std::complex) == 2 * sizeof(float), - "Unable to compile CudaAtomicAdd for complex64 because " - "sizeof(complex64) != 2*sizeof(float32)"); - float* addr_as_float = reinterpret_cast(address); - float* val_as_float = reinterpret_cast(&val); - CudaAtomicAdd(addr_as_float, *val_as_float); - CudaAtomicAdd(addr_as_float + 1, *(val_as_float + 1)); -#endif -#endif - return *address; -} - -// Custom implementation of atomicAdd for std::complex. -// This implementation performs to atomic additions on the components -// using the double atomic wrapper above. -CUDA_ATOMIC_WRAPPER(Add, complex128) { -#if defined(__CUDA_ARCH__) -#if __CUDA_ARCH__ >= 350 - double2* addr_as_double2 = reinterpret_cast(address); - double2* val_as_double2 = reinterpret_cast(&val); - CudaAtomicAdd(&(addr_as_double2->x), val_as_double2->x); - CudaAtomicAdd(&(addr_as_double2->y), val_as_double2->y); -#else - static_assert(sizeof(std::complex) == 2 * sizeof(double), - "Unable to compile CudaAtomicAdd for complex128 because " - "sizeof(complex128) != 2*sizeof(float64)"); - double* addr_as_double = reinterpret_cast(address); - double* val_as_double = reinterpret_cast(&val); - CudaAtomicAdd(addr_as_double, *val_as_double); - CudaAtomicAdd(addr_as_double + 1, *(val_as_double + 1)); -#endif -#endif - return *address; -} - -// Helper functions for CudaAtomicAdd(half*, half), below. -// -// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2() -// for a more efficient implementation, assuming that adding -0.0 -// will never harm the neighboring value. In this version, we take special -// care to guarantee the bits of the untouched value are unchanged. -inline __device__ uint32 add_to_low_half(uint32 val, float x) { - Eigen::half low_half; - low_half.x = static_cast(val & 0xffffu); - low_half = static_cast(static_cast(low_half) + x); - return (val & 0xffff0000u) | low_half.x; -} - -inline __device__ uint32 add_to_high_half(uint32 val, float x) { - Eigen::half high_half; - high_half.x = static_cast(val >> 16); - high_half = static_cast(static_cast(high_half) + x); - return (val & 0xffffu) | (high_half.x << 16); -} - -// Custom implementation of atomicAdd for half. Note that we don't have -// atomicCAS() for anything less than 32 bits, so we need to include the -// other 16 bits in the operation. -// -// Unlike the other atomic adds, this version is going to be very slow -// under high concurrency, since most threads will be spinning on failing -// their compare-and-swap tests. (The fact that we get false sharing on the -// neighboring fp16 makes this even worse.) If you are doing a large reduction, -// you are much better off with doing the intermediate steps in fp32 and then -// switching to fp16 as late as you can in the calculations. -// -// Note: Assumes little endian. -CUDA_ATOMIC_WRAPPER(Add, Eigen::half) { - float val_as_float(val); - intptr_t address_int = reinterpret_cast(address); - if ((address_int & 0x2) == 0) { - // The half is in the first part of the uint32 (lower 16 bits). - uint32* address_as_uint32 = reinterpret_cast(address); - assert(((intptr_t)address_as_uint32 & 0x3) == 0); - uint32 old = *address_as_uint32, assumed; - - do { - assumed = old; - old = atomicCAS(address_as_uint32, assumed, - add_to_low_half(assumed, val_as_float)); - - // Note: uses integer comparison to avoid hang in case of NaN - } while (assumed != old); - - Eigen::half ret; - ret.x = old & 0xffffu; - return ret; - } else { - // The half is in the second part of the uint32 (upper 16 bits). - uint32* address_as_uint32 = reinterpret_cast(address_int - 2); - assert(((intptr_t)address_as_uint32 & 0x3) == 0); - uint32 old = *address_as_uint32, assumed; - - do { - assumed = old; - old = atomicCAS(address_as_uint32, assumed, - add_to_high_half(assumed, val_as_float)); - - // Note: uses integer comparison to avoid hang in case of NaN - } while (assumed != old); - - Eigen::half ret; - ret.x = old >> 16; - return ret; - } -} - -template -__global__ void SetZero(const int nthreads, T* bottom_diff) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); } -} - -// For atomicSub. - -// Custom implementation for sub by just negating the value. -#define WRAPPED_ATOMIC_SUB(T) \ - CUDA_ATOMIC_WRAPPER(Sub, T) { return CudaAtomicAdd(address, -val); } - -WRAPPED_ATOMIC_SUB(uint64); -WRAPPED_ATOMIC_SUB(int32); -WRAPPED_ATOMIC_SUB(uint32); -WRAPPED_ATOMIC_SUB(Eigen::half); -WRAPPED_ATOMIC_SUB(float); -WRAPPED_ATOMIC_SUB(double); - -CUDA_ATOMIC_WRAPPER(Sub, complex64) { - const std::complex Tneg(-val.real(), -val.imag()); - return CudaAtomicAdd(address, Tneg); -} - -CUDA_ATOMIC_WRAPPER(Sub, complex128) { - const std::complex Tneg(-val.real(), -val.imag()); - return CudaAtomicAdd(address, Tneg); -} - -#undef WRAPPED_ATOMIC_SUB - -// For atomicMul. -CUDA_ATOMIC_WRAPPER(Mul, int32) { - int32 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, val * assumed); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Mul, uint32) { - uint32 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, val * assumed); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Mul, uint64) { - uint64 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, val * assumed); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Mul, float) { - int32* address_as_int = reinterpret_cast(address); - int32 old = *address_as_int, assumed; - do { - assumed = old; - old = atomicCAS(address_as_int, assumed, - __float_as_int(val * __int_as_float(assumed))); - } while (assumed != old); - return __int_as_float(old); -} - -CUDA_ATOMIC_WRAPPER(Mul, double) { - uint64* address_as_ull = reinterpret_cast(address); - uint64 old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val * __longlong_as_double(assumed))); - } while (assumed != old); - return __longlong_as_double(old); -} - -// For atomicDiv. -CUDA_ATOMIC_WRAPPER(Div, int32) { - int32 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, assumed / val); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Div, uint32) { - uint32 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, assumed / val); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Div, uint64) { - uint64 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, assumed / val); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Div, float) { - int32* address_as_int = reinterpret_cast(address); - int32 old = *address_as_int, assumed; - do { - assumed = old; - old = atomicCAS(address_as_int, assumed, - __float_as_int(__int_as_float(assumed) / val)); - } while (assumed != old); - return __int_as_float(old); -} - -CUDA_ATOMIC_WRAPPER(Div, double) { - uint64* address_as_ull = reinterpret_cast(address); - uint64 old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(__longlong_as_double(assumed) / val)); - } while (assumed != old); - return __longlong_as_double(old); -} - -#undef USE_CUDA_ATOMIC -#undef CUDA_ATOMIC_WRAPPER - -template -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_min(const T& x, const T& y) { - return x > y ? y : x; -} - -template -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) { - return x < y ? y : x; -} - -__device__ EIGEN_ALWAYS_INLINE unsigned CudaBallot(unsigned mask, - int predicate) { - return __ballot_sync(mask, predicate); -} - -template -__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(unsigned mask, T value, - int srcLane, - int width = warpSize) { - return __shfl_sync(mask, value, srcLane, width); -} - -// Variant of the (undocumented) version from the CUDA SDK, but using unsigned -// instead of float for lo and hi (which is incorrect with ftz, for example). -// A bug has been filed with NVIDIA and will be fixed in the next CUDA release. -// TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(unsigned mask, double value, - int srcLane, - int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_sync(mask, hi, srcLane, width); - lo = __shfl_sync(mask, lo, srcLane, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; -} - -template -__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(unsigned mask, T value, - int delta, - int width = warpSize) { - return __shfl_up_sync(mask, value, delta, width); -} - -// Variant of the (undocumented) version from the CUDA SDK, but using unsigned -// instead of float for lo and hi (which is incorrect with ftz, for example). -// A bug has been filed with NVIDIA and will be fixed in the next CUDA release. -// TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(unsigned mask, double value, - int delta, - int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_up_sync(mask, hi, delta, width); - lo = __shfl_up_sync(mask, lo, delta, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; -} - -template -__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask, T value, - int delta, - int width = warpSize) { - return __shfl_down_sync(mask, value, delta, width); -} - -__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDown( - unsigned mask, Eigen::half value, int delta, int width = warpSize) { - return Eigen::half( - __shfl_down_sync(mask, static_cast(value), delta, width)); -} - -// Variant of the (undocumented) version from the CUDA SDK, but using unsigned -// instead of float for lo and hi (which is incorrect with ftz, for example). -// A bug has been filed with NVIDIA and will be fixed in the next CUDA release. -// TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(unsigned mask, - double value, int delta, - int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_down_sync(mask, hi, delta, width); - lo = __shfl_down_sync(mask, lo, delta, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; -} - -template -__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask, T value, - int laneMask, - int width = warpSize) { - return __shfl_xor_sync(mask, value, laneMask, width); -} - -__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXor( - unsigned mask, Eigen::half value, int laneMask, int width = warpSize) { - return Eigen::half( - __shfl_xor_sync(mask, static_cast(value), laneMask, width)); -} - -// Variant of the (undocumented) version from the CUDA SDK, but using unsigned -// instead of float for lo and hi (which is incorrect with ftz, for example). -// A bug has been filed with NVIDIA and will be fixed in the next CUDA release. -// TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(unsigned mask, - double value, int laneMask, - int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_xor_sync(mask, hi, laneMask, width); - lo = __shfl_xor_sync(mask, lo, laneMask, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; -} - } // namespace tensorflow -#undef DIV_UP - #endif // GOOGLE_CUDA - #endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ diff --git a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc index 6991554effd9088c04bfcb71f274b82408507463..732ed33ede17bc90d3301d3f1eee6302a96028d7 100644 --- a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc +++ b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc @@ -52,11 +52,11 @@ __global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) { } } __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { if (x < 0) { // x might overflow when testing extreme case break; } - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { if (y < 0) { // y might overflow when testing extreme case break; } @@ -66,15 +66,15 @@ __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) { } } __global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { if (x < 0) { // x might overflow when testing extreme case break; } - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { if (y < 0) { // y might overflow when testing extreme case break; } - CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { + CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { if (z < 0) { // z might overflow when testing extreme case break; } @@ -87,6 +87,44 @@ __global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) { } } +__global__ void CudaShuffleGetSrcLaneTest(unsigned* failure_count) { + unsigned lane_id = CudaLaneId(); + for (int width = warpSize; width > 1; width /= 2) { + auto check_result = [&](const char* op_name, int param, unsigned actual, + unsigned expected) { + if (actual != expected) { + printf("Cuda%sGetSrcLane(%d, %d) for lane %d returned %d, not %d\n", + op_name, param, width, lane_id, actual, expected); + CudaAtomicAdd(failure_count, 1); + } + }; + for (int src_lane = -warpSize; src_lane <= warpSize; ++src_lane) { + unsigned actual_lane = detail::CudaShuffleGetSrcLane(src_lane, width); + unsigned expect_lane = + CudaShuffleSync(kCudaWarpAll, lane_id, src_lane, width); + check_result("Shuffle", src_lane, actual_lane, expect_lane); + } + for (unsigned delta = 0; delta <= warpSize; ++delta) { + unsigned actual_lane = detail::CudaShuffleUpGetSrcLane(delta, width); + unsigned expect_lane = + CudaShuffleUpSync(kCudaWarpAll, lane_id, delta, width); + check_result("ShuffleUp", delta, actual_lane, expect_lane); + } + for (unsigned delta = 0; delta <= warpSize; ++delta) { + unsigned actual_lane = detail::CudaShuffleDownGetSrcLane(delta, width); + unsigned expect_lane = + CudaShuffleDownSync(kCudaWarpAll, lane_id, delta, width); + check_result("ShuffleDown", delta, actual_lane, expect_lane); + } + for (int lane_lane = warpSize; lane_lane > 0; lane_lane /= 2) { + unsigned actual_lane = detail::CudaShuffleXorGetSrcLane(lane_lane, width); + unsigned expect_lane = + CudaShuffleXorSync(kCudaWarpAll, lane_id, lane_lane, width); + check_result("ShuffleXor", lane_lane, actual_lane, expect_lane); + } + } +} + } // namespace class CudaLaunchConfigTest : public ::testing::Test { @@ -94,7 +132,7 @@ class CudaLaunchConfigTest : public ::testing::Test { const int bufsize = 1024; int* outbuf = nullptr; Eigen::CudaStreamDevice stream; - GPUDevice d = GPUDevice(&stream); + Eigen::GpuDevice d = Eigen::GpuDevice(&stream); virtual void SetUp() { cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize); @@ -111,27 +149,27 @@ class CudaLaunchConfigTest : public ::testing::Test { TEST_F(CudaLaunchConfigTest, GetCudaLaunchConfig) { CudaLaunchConfig cfg; - // test valid inputs - #define TEST_LAUNCH_PARAMETER(work_element_count) \ - cfg = GetCudaLaunchConfig(bufsize, d); \ - SetOutbufZero<<>> \ - (cfg, outbuf); \ - CUDA_ASSERT_SUCCESS \ - cfg = GetCudaLaunchConfig(work_element_count, d); \ - Count1D<<>> ( \ - cfg, bufsize, outbuf); \ - CUDA_EXPECT_SUCCESS \ - EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0));\ - \ - cfg = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ - SetOutbufZero<<>> \ - (cfg, outbuf); \ - CUDA_ASSERT_SUCCESS \ - cfg = GetCudaLaunchConfig(work_element_count, d, Count1D, 0, 0); \ - Count1D<<>> ( \ - cfg, bufsize, outbuf); \ - CUDA_EXPECT_SUCCESS \ - EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0)) +// test valid inputs +#define TEST_LAUNCH_PARAMETER(work_element_count) \ + cfg = GetCudaLaunchConfig(bufsize, d); \ + SetOutbufZero<<>>( \ + cfg, outbuf); \ + CUDA_ASSERT_SUCCESS \ + cfg = GetCudaLaunchConfig(work_element_count, d); \ + Count1D<<>>( \ + cfg, bufsize, outbuf); \ + CUDA_EXPECT_SUCCESS \ + EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0)); \ + \ + cfg = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ + SetOutbufZero<<>>( \ + cfg, outbuf); \ + CUDA_ASSERT_SUCCESS \ + cfg = GetCudaLaunchConfig(work_element_count, d, Count1D, 0, 0); \ + Count1D<<>>( \ + cfg, bufsize, outbuf); \ + CUDA_EXPECT_SUCCESS \ + EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0)) TEST_LAUNCH_PARAMETER(128); TEST_LAUNCH_PARAMETER(129); @@ -143,7 +181,7 @@ TEST_F(CudaLaunchConfigTest, GetCudaLaunchConfig) { TEST_LAUNCH_PARAMETER(8192); TEST_LAUNCH_PARAMETER(123456); TEST_LAUNCH_PARAMETER(1 << 30); - #undef TEST_LAUNCH_PARAMETER +#undef TEST_LAUNCH_PARAMETER } bool operator==(const Cuda2DLaunchConfig& a, const Cuda2DLaunchConfig& b) { @@ -162,27 +200,27 @@ TEST_F(CudaLaunchConfigTest, GetCuda2DLaunchConfig) { Cuda2DLaunchConfig cfg; CudaLaunchConfig cfg1d; - // test valid inputs - #define TEST_LAUNCH_PARAMETER(dimx, dimy) \ - cfg1d = GetCudaLaunchConfig(bufsize, d); \ - SetOutbufZero<<>> \ - (cfg1d, outbuf);\ - CUDA_ASSERT_SUCCESS \ - cfg = GetCuda2DLaunchConfig(dimx, dimy, d); \ - Count2D<<>> ( \ - cfg, bufsize, outbuf); \ - CUDA_EXPECT_SUCCESS \ - EXPECT_EQ(dimx * dimy, std::accumulate(outbuf, outbuf + bufsize, 0)); \ - \ - cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ - SetOutbufZero<<>> \ - (cfg1d, outbuf);\ - CUDA_ASSERT_SUCCESS \ - cfg = GetCuda2DLaunchConfig(dimx, dimy, d, Count2D, 0, 0); \ - Count2D<<>> ( \ - cfg, bufsize, outbuf); \ - CUDA_EXPECT_SUCCESS \ - EXPECT_EQ(dimx * dimy, std::accumulate(outbuf, outbuf + bufsize, 0)) +// test valid inputs +#define TEST_LAUNCH_PARAMETER(dimx, dimy) \ + cfg1d = GetCudaLaunchConfig(bufsize, d); \ + SetOutbufZero<<>>( \ + cfg1d, outbuf); \ + CUDA_ASSERT_SUCCESS \ + cfg = GetCuda2DLaunchConfig(dimx, dimy, d); \ + Count2D<<>>( \ + cfg, bufsize, outbuf); \ + CUDA_EXPECT_SUCCESS \ + EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0)); \ + \ + cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ + SetOutbufZero<<>>( \ + cfg1d, outbuf); \ + CUDA_ASSERT_SUCCESS \ + cfg = GetCuda2DLaunchConfig(dimx, dimy, d, Count2D, 0, 0); \ + Count2D<<>>( \ + cfg, bufsize, outbuf); \ + CUDA_EXPECT_SUCCESS \ + EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0)) TEST_LAUNCH_PARAMETER(128, 128); TEST_LAUNCH_PARAMETER(129, 64); @@ -195,24 +233,24 @@ TEST_F(CudaLaunchConfigTest, GetCuda2DLaunchConfig) { TEST_LAUNCH_PARAMETER(123456, 12); TEST_LAUNCH_PARAMETER(1, 1 << 30); TEST_LAUNCH_PARAMETER(1 << 30, 1); - #undef TEST_LAUNCH_PARAMETER +#undef TEST_LAUNCH_PARAMETER } TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) { Cuda3DLaunchConfig cfg; CudaLaunchConfig cfg1d; - // test valid inputs - #define TEST_LAUNCH_PARAMETER(dimx, dimy, dimz) \ - cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ - SetOutbufZero<<>> \ - (cfg1d, outbuf);\ - CUDA_ASSERT_SUCCESS \ - cfg = GetCuda3DLaunchConfig(dimx, dimy, dimz, d, Count3D, 0, 0); \ - Count3D<<>> ( \ - cfg, bufsize, outbuf); \ - CUDA_EXPECT_SUCCESS \ - EXPECT_EQ(dimx * dimy * dimz, std::accumulate(outbuf, outbuf + bufsize, 0)) +// test valid inputs +#define TEST_LAUNCH_PARAMETER(dimx, dimy, dimz) \ + cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ + SetOutbufZero<<>>( \ + cfg1d, outbuf); \ + CUDA_ASSERT_SUCCESS \ + cfg = GetCuda3DLaunchConfig(dimx, dimy, dimz, d, Count3D, 0, 0); \ + Count3D<<>>( \ + cfg, bufsize, outbuf); \ + CUDA_EXPECT_SUCCESS \ + EXPECT_EQ(dimx* dimy* dimz, std::accumulate(outbuf, outbuf + bufsize, 0)) TEST_LAUNCH_PARAMETER(128, 128, 128); TEST_LAUNCH_PARAMETER(129, 64, 1024); @@ -226,7 +264,17 @@ TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) { TEST_LAUNCH_PARAMETER(1, 1, 1 << 30); TEST_LAUNCH_PARAMETER(1, 1 << 30, 1); TEST_LAUNCH_PARAMETER(1 << 30, 1, 1); - #undef TEST_LAUNCH_PARAMETER +#undef TEST_LAUNCH_PARAMETER +} + +TEST(CudaDeviceFunctionsTest, ShuffleGetSrcLane) { + unsigned* failure_count; + ASSERT_EQ(cudaMallocManaged(&failure_count, sizeof(unsigned)), cudaSuccess); + *failure_count = 0; + CudaShuffleGetSrcLaneTest<<<1, 32>>>(failure_count); + ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess); + ASSERT_EQ(*failure_count, 0); + cudaFree(failure_count); } } // namespace tensorflow diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/cuda_launch_config.h new file mode 100644 index 0000000000000000000000000000000000000000..81df7a51d703986b040b5d15e128139ae56c24fb --- /dev/null +++ b/tensorflow/core/util/cuda_launch_config.h @@ -0,0 +1,306 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_ +#define TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_ + +#if GOOGLE_CUDA + +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "cuda/include/cuda.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/types.h" + +// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and +// GetCuda3DLaunchConfig: +// +// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one +// version uses heuristics without any knowledge of the device kernel, the other +// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical +// launch parameters that maximize occupancy. Currently, only the maximum +// occupancy version of GetCuda3DLaunchConfig is available. +// +// For large number of work elements, the convention is that each kernel would +// iterate through its assigned range. The return value of GetCudaLaunchConfig +// is struct CudaLaunchConfig, which contains all the information needed for the +// kernel launch, including: virtual number of threads, the number of threads +// per block and number of threads per block used inside <<< >>> of a kernel +// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing +// as CudaLaunchConfig. The only difference is the dimension. The macros +// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop. +// +/* Sample code: + +__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) { + CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) { + do_your_job_here; + } +} + +__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) { + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + do_your_job_here; + } + } +} + +__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) { + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { + do_your_job_here; + } + } + } +} + +void MyDriverFunc(const Eigen::GpuDevice &d) { + // use heuristics + CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d); + MyKernel1D <<>> (cfg1, other_args...); + Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d); + MyKernel2D <<>> (cfg2, other_args...); + Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d); + MyKernel3D <<>> (cfg3, other_args...); + + // maximize occupancy + CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 ); + MyKernel1D <<>> (cfg4, other_args...); + Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d, + MyKernel1D, 0, 0); + MyKernel2D <<>> (cfg5, other_args...); + Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d, + MyKernel1D, 0, 0); + MyKernel3D <<>> (cfg6, other_args...); +} + +// See the test for this for more example: +// +https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc + +*/ + +namespace tensorflow { + +inline int DivUp(int a, int b) { return (a + b - 1) / b; } + +struct CudaLaunchConfig { + // Logical number of thread that works on the elements. If each logical + // thread works on exactly a single element, this is the same as the working + // element count. + int virtual_thread_count = -1; + // Number of threads per block. + int thread_per_block = -1; + // Number of blocks for Cuda kernel launch. + int block_count = -1; +}; + +// Calculate the Cuda launch config we should use for a kernel launch. +// This is assuming the kernel is quite simple and will largely be +// memory-limited. +// REQUIRES: work_element_count > 0. +inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, + const Eigen::GpuDevice& d) { + CHECK_GT(work_element_count, 0); + CudaLaunchConfig config; + const int virtual_thread_count = work_element_count; + const int physical_thread_count = std::min( + d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(), + virtual_thread_count); + const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock()); + const int block_count = + std::min(DivUp(physical_thread_count, thread_per_block), + d.getNumCudaMultiProcessors()); + + config.virtual_thread_count = virtual_thread_count; + config.thread_per_block = thread_per_block; + config.block_count = block_count; + return config; +} + +// Calculate the Cuda launch config we should use for a kernel launch. This +// variant takes the resource limits of func into account to maximize occupancy. +// REQUIRES: work_element_count > 0. +template +inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, + const Eigen::GpuDevice& d, + DeviceFunc func, + size_t dynamic_shared_memory_size, + int block_size_limit) { + CHECK_GT(work_element_count, 0); + CudaLaunchConfig config; + int block_count = 0; + int thread_per_block = 0; + + cudaError_t err = cudaOccupancyMaxPotentialBlockSize( + &block_count, &thread_per_block, func, dynamic_shared_memory_size, + block_size_limit); + CHECK_EQ(err, cudaSuccess); + + block_count = + std::min(block_count, DivUp(work_element_count, thread_per_block)); + + config.virtual_thread_count = work_element_count; + config.thread_per_block = thread_per_block; + config.block_count = block_count; + return config; +} + +// Calculate the Cuda launch config we should use for a kernel launch. This +// variant takes the resource limits of func into account to maximize occupancy. +// The returned launch config has thread_per_block set to fixed_block_size. +// REQUIRES: work_element_count > 0. +template +inline CudaLaunchConfig GetCudaLaunchConfigFixedBlockSize( + int work_element_count, const Eigen::GpuDevice& d, DeviceFunc func, + size_t dynamic_shared_memory_size, int fixed_block_size) { + CHECK_GT(work_element_count, 0); + CudaLaunchConfig config; + int block_count = 0; + + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &block_count, func, fixed_block_size, dynamic_shared_memory_size); + CHECK_EQ(err, cudaSuccess); + block_count = std::min(block_count * d.getNumCudaMultiProcessors(), + DivUp(work_element_count, fixed_block_size)); + + config.virtual_thread_count = work_element_count; + config.thread_per_block = fixed_block_size; + config.block_count = block_count; + return config; +} + +struct Cuda2DLaunchConfig { + dim3 virtual_thread_count = dim3(0, 0, 0); + dim3 thread_per_block = dim3(0, 0, 0); + dim3 block_count = dim3(0, 0, 0); +}; + +inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim, + const Eigen::GpuDevice& d) { + Cuda2DLaunchConfig config; + + if (xdim <= 0 || ydim <= 0) { + return config; + } + + const int kThreadsPerBlock = 256; + int block_cols = std::min(xdim, kThreadsPerBlock); + // ok to round down here and just do more loops in the kernel + int block_rows = std::max(kThreadsPerBlock / block_cols, 1); + + const int physical_thread_count = + d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(); + + const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1); + + config.virtual_thread_count = dim3(xdim, ydim, 1); + config.thread_per_block = dim3(block_cols, block_rows, 1); + + int grid_x = std::min(DivUp(xdim, block_cols), max_blocks); + + config.block_count = dim3( + grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1); + return config; +} + +// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch. +// This variant takes the resource limits of func into account to maximize +// occupancy. +using Cuda3DLaunchConfig = Cuda2DLaunchConfig; + +template +inline Cuda3DLaunchConfig GetCuda3DLaunchConfig( + int xdim, int ydim, int zdim, const Eigen::GpuDevice& d, DeviceFunc func, + size_t dynamic_shared_memory_size, int block_size_limit) { + Cuda3DLaunchConfig config; + + if (xdim <= 0 || ydim <= 0 || zdim <= 0) { + return config; + } + + int dev; + cudaGetDevice(&dev); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, dev); + int xthreadlimit = deviceProp.maxThreadsDim[0]; + int ythreadlimit = deviceProp.maxThreadsDim[1]; + int zthreadlimit = deviceProp.maxThreadsDim[2]; + int xgridlimit = deviceProp.maxGridSize[0]; + int ygridlimit = deviceProp.maxGridSize[1]; + int zgridlimit = deviceProp.maxGridSize[2]; + + int block_count = 0; + int thread_per_block = 0; + cudaError_t err = cudaOccupancyMaxPotentialBlockSize( + &block_count, &thread_per_block, func, dynamic_shared_memory_size, + block_size_limit); + CHECK_EQ(err, cudaSuccess); + + int threadsx = std::min({xdim, thread_per_block, xthreadlimit}); + int threadsy = + std::min({ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit}); + int threadsz = + std::min({zdim, std::max(thread_per_block / (threadsx * threadsy), 1), + zthreadlimit}); + + int blocksx = std::min({block_count, DivUp(xdim, threadsx), xgridlimit}); + int blocksy = std::min( + {DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit}); + int blocksz = std::min({DivUp(block_count, (blocksx * blocksy)), + DivUp(zdim, threadsz), zgridlimit}); + + config.virtual_thread_count = dim3(xdim, ydim, zdim); + config.thread_per_block = dim3(threadsx, threadsy, threadsz); + config.block_count = dim3(blocksx, blocksy, blocksz); + return config; +} + +template +inline Cuda2DLaunchConfig GetCuda2DLaunchConfig( + int xdim, int ydim, const Eigen::GpuDevice& d, DeviceFunc func, + size_t dynamic_shared_memory_size, int block_size_limit) { + return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func, + dynamic_shared_memory_size, block_size_limit); +} + +// Returns a raw reference to the current cuda stream. Required by a +// number of kernel calls (for which StreamInterface* does not work), i.e. +// CUB and certain cublas primitives. +inline const cudaStream_t& GetCudaStream(OpKernelContext* context) { + const cudaStream_t* ptr = CHECK_NOTNULL( + reinterpret_cast(context->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + return *ptr; +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ diff --git a/tensorflow/core/util/event.proto b/tensorflow/core/util/event.proto index 5c3799c13228142fcd8b81e3db85332f6e618d4f..65d2c5a09c5c98a70e834e182d5751350506a1a1 100644 --- a/tensorflow/core/util/event.proto +++ b/tensorflow/core/util/event.proto @@ -80,3 +80,8 @@ message TaggedRunMetadata { // deserialization. bytes run_metadata = 2; } + +// For communicating live events back to a coordinator +message SessionStatus { + repeated Event event = 1; +} diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc index 23b00e23dd0e7054aaf0e4e442c60f1372ce2d5b..49507616ed8c6461f8d59d8899d93abb4ba58cd2 100644 --- a/tensorflow/core/util/events_writer.cc +++ b/tensorflow/core/util/events_writer.cc @@ -17,6 +17,7 @@ limitations under the License. #include // for NULL +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -35,10 +36,21 @@ EventsWriter::EventsWriter(const string& file_prefix) file_prefix_(file_prefix), num_outstanding_events_(0) {} -bool EventsWriter::InitIfNeeded() { +EventsWriter::~EventsWriter() { + Close().IgnoreError(); // Autoclose in destructor. +} + +Status EventsWriter::Init() { return InitWithSuffix(""); } + +Status EventsWriter::InitWithSuffix(const string& suffix) { + file_suffix_ = suffix; + return InitIfNeeded(); +} + +Status EventsWriter::InitIfNeeded() { if (recordio_writer_ != nullptr) { CHECK(!filename_.empty()); - if (FileHasDisappeared()) { + if (!FileStillExists().ok()) { // Warn user of data loss and let .reset() below do basic cleanup. if (num_outstanding_events_ > 0) { LOG(WARNING) << "Re-initialization, attempting to open a new file, " @@ -46,7 +58,7 @@ bool EventsWriter::InitIfNeeded() { } } else { // No-op: File is present and writer is initialized. - return true; + return Status::OK(); } } @@ -57,15 +69,12 @@ bool EventsWriter::InitIfNeeded() { static_cast(time_in_seconds), port::Hostname().c_str(), file_suffix_.c_str()); - Status s = env_->NewWritableFile(filename_, &recordio_file_); - if (!s.ok()) { - LOG(ERROR) << "Could not open events file: " << filename_ << ": " << s; - return false; - } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + env_->NewWritableFile(filename_, &recordio_file_), + "Creating writable file ", filename_); recordio_writer_.reset(new io::RecordWriter(recordio_file_.get())); if (recordio_writer_ == nullptr) { - LOG(ERROR) << "Could not create record writer"; - return false; + return errors::Unknown("Could not create record writer"); } num_outstanding_events_ = 0; VLOG(1) << "Successfully opened events file: " << filename_; @@ -77,21 +86,21 @@ bool EventsWriter::InitIfNeeded() { event.set_wall_time(time_in_seconds); event.set_file_version(strings::StrCat(kVersionPrefix, kCurrentVersion)); WriteEvent(event); - Flush(); + TF_RETURN_WITH_CONTEXT_IF_ERROR(Flush(), "Flushing first event."); } - return true; + return Status::OK(); } string EventsWriter::FileName() { if (filename_.empty()) { - InitIfNeeded(); + InitIfNeeded().IgnoreError(); } return filename_; } void EventsWriter::WriteSerializedEvent(StringPiece event_str) { if (recordio_writer_ == nullptr) { - if (!InitIfNeeded()) { + if (!InitIfNeeded().ok()) { LOG(ERROR) << "Write failed because file could not be opened."; return; } @@ -108,60 +117,51 @@ void EventsWriter::WriteEvent(const Event& event) { WriteSerializedEvent(record); } -bool EventsWriter::Flush() { - if (num_outstanding_events_ == 0) return true; +Status EventsWriter::Flush() { + if (num_outstanding_events_ == 0) return Status::OK(); CHECK(recordio_file_ != nullptr) << "Unexpected NULL file"; - if (!recordio_writer_->Flush().ok()) { - LOG(ERROR) << "Failed to flush " << num_outstanding_events_ << " events to " - << filename_; - return false; - } + TF_RETURN_WITH_CONTEXT_IF_ERROR(recordio_writer_->Flush(), "Failed to flush ", + num_outstanding_events_, " to ", filename_); + TF_RETURN_WITH_CONTEXT_IF_ERROR(recordio_file_->Sync(), "Failed to sync ", + num_outstanding_events_, " to ", filename_); - // The FileHasDisappeared() condition is necessary because - // recordio_writer_->Sync() can return true even if the underlying + // The FileStillExists() condition is necessary because + // recordio_writer_->Sync() can return OK even if the underlying // file has been deleted. EventWriter.FileDeletionBeforeWriting // demonstrates this and will fail if the FileHasDisappeared() // condition is removed. // Also, we deliberately attempt to Sync() before checking for a // disappearing file, in case for some file system File::Exists() is // false after File::Open() but before File::Sync(). - if (!recordio_file_->Flush().ok() || !recordio_file_->Sync().ok() || - FileHasDisappeared()) { - LOG(ERROR) << "Failed to flush " << num_outstanding_events_ << " events to " - << filename_; - return false; - } + TF_RETURN_WITH_CONTEXT_IF_ERROR(FileStillExists(), "Failed to flush ", + num_outstanding_events_, " to ", filename_); VLOG(1) << "Wrote " << num_outstanding_events_ << " events to disk."; num_outstanding_events_ = 0; - return true; + return Status::OK(); } -bool EventsWriter::Close() { - bool return_value = Flush(); +Status EventsWriter::Close() { + Status status = Flush(); if (recordio_file_ != nullptr) { - Status s = recordio_file_->Close(); - if (!s.ok()) { - LOG(ERROR) << "Error when closing previous event file: " << filename_ - << ": " << s; - return_value = false; + Status close_status = recordio_file_->Close(); + if (!close_status.ok()) { + status = close_status; } recordio_writer_.reset(nullptr); recordio_file_.reset(nullptr); } num_outstanding_events_ = 0; - return return_value; + return status; } -bool EventsWriter::FileHasDisappeared() { +Status EventsWriter::FileStillExists() { if (env_->FileExists(filename_).ok()) { - return false; - } else { - // This can happen even with non-null recordio_writer_ if some other - // process has removed the file. - LOG(ERROR) << "The events file " << filename_ << " has disappeared."; - return true; + return Status::OK(); } + // This can happen even with non-null recordio_writer_ if some other + // process has removed the file. + return errors::Unknown("The events file ", filename_, " has disappeared."); } } // namespace tensorflow diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h index a1a8cf790d4e2735d705cc2050c14970e5bfab4a..5dbaf97af4ad145cb09009b44d6f93d1c270d17d 100644 --- a/tensorflow/core/util/events_writer.h +++ b/tensorflow/core/util/events_writer.h @@ -18,6 +18,8 @@ limitations under the License. #include #include + +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" @@ -43,7 +45,7 @@ class EventsWriter { // Note that it is not recommended to simultaneously have two // EventWriters writing to the same file_prefix. explicit EventsWriter(const string& file_prefix); - ~EventsWriter() { Close(); } // Autoclose in destructor. + ~EventsWriter(); // Sets the event file filename and opens file for writing. If not called by // user, will be invoked automatically by a call to FileName() or Write*(). @@ -51,11 +53,8 @@ class EventsWriter { // and is open this is a no-op. If on the other hand the file was opened, // but has since disappeared (e.g. deleted by another process), this will open // a new file with a new timestamp in its filename. - bool Init() { return InitWithSuffix(""); } - bool InitWithSuffix(const string& suffix) { - file_suffix_ = suffix; - return InitIfNeeded(); - } + Status Init(); + Status InitWithSuffix(const string& suffix); // Returns the filename for the current events file: // filename_ = [file_prefix_].out.events.[timestamp].[hostname][suffix] @@ -77,12 +76,12 @@ class EventsWriter { // be written too. // Close() calls Flush() and then closes the current events file. // Returns true only if both the flush and the closure were successful. - bool Flush(); - bool Close(); + Status Flush(); + Status Close(); private: - bool FileHasDisappeared(); // True if event_file_path_ does not exist. - bool InitIfNeeded(); + Status FileStillExists(); // OK if event_file_path_ exists. + Status InitIfNeeded(); Env* env_; const string file_prefix_; diff --git a/tensorflow/core/util/events_writer_test.cc b/tensorflow/core/util/events_writer_test.cc index a6286ea701f09b94fe18cb373a42b5a83aab893a..a75b26abc631eb782ba527f9d15ac25ce9f72b2b 100644 --- a/tensorflow/core/util/events_writer_test.cc +++ b/tensorflow/core/util/events_writer_test.cc @@ -112,7 +112,7 @@ TEST(EventWriter, WriteFlush) { string file_prefix = GetDirName("/writeflush_test"); EventsWriter writer(file_prefix); WriteFile(&writer); - EXPECT_TRUE(writer.Flush()); + TF_EXPECT_OK(writer.Flush()); string filename = writer.FileName(); VerifyFile(filename); } @@ -121,7 +121,7 @@ TEST(EventWriter, WriteClose) { string file_prefix = GetDirName("/writeclose_test"); EventsWriter writer(file_prefix); WriteFile(&writer); - EXPECT_TRUE(writer.Close()); + TF_EXPECT_OK(writer.Close()); string filename = writer.FileName(); VerifyFile(filename); } @@ -143,7 +143,7 @@ TEST(EventWriter, FailFlush) { TF_EXPECT_OK(env()->FileExists(filename)); TF_ASSERT_OK(env()->DeleteFile(filename)); EXPECT_EQ(errors::Code::NOT_FOUND, env()->FileExists(filename).code()); - EXPECT_FALSE(writer.Flush()); + EXPECT_FALSE(writer.Flush().ok()); EXPECT_EQ(errors::Code::NOT_FOUND, env()->FileExists(filename).code()); } @@ -155,18 +155,18 @@ TEST(EventWriter, FailClose) { TF_EXPECT_OK(env()->FileExists(filename)); TF_ASSERT_OK(env()->DeleteFile(filename)); EXPECT_EQ(errors::Code::NOT_FOUND, env()->FileExists(filename).code()); - EXPECT_FALSE(writer.Close()); + EXPECT_FALSE(writer.Close().ok()); EXPECT_EQ(errors::Code::NOT_FOUND, env()->FileExists(filename).code()); } TEST(EventWriter, InitWriteClose) { string file_prefix = GetDirName("/initwriteclose_test"); EventsWriter writer(file_prefix); - EXPECT_TRUE(writer.Init()); + TF_EXPECT_OK(writer.Init()); string filename0 = writer.FileName(); TF_EXPECT_OK(env()->FileExists(filename0)); WriteFile(&writer); - EXPECT_TRUE(writer.Close()); + TF_EXPECT_OK(writer.Close()); string filename1 = writer.FileName(); EXPECT_EQ(filename0, filename1); VerifyFile(filename1); @@ -178,7 +178,7 @@ TEST(EventWriter, NameWriteClose) { string filename = writer.FileName(); TF_EXPECT_OK(env()->FileExists(filename)); WriteFile(&writer); - EXPECT_TRUE(writer.Close()); + TF_EXPECT_OK(writer.Close()); VerifyFile(filename); } @@ -186,7 +186,7 @@ TEST(EventWriter, NameClose) { string file_prefix = GetDirName("/nameclose_test"); EventsWriter writer(file_prefix); string filename = writer.FileName(); - EXPECT_TRUE(writer.Close()); + TF_EXPECT_OK(writer.Close()); TF_EXPECT_OK(env()->FileExists(filename)); TF_ASSERT_OK(env()->DeleteFile(filename)); } @@ -199,9 +199,9 @@ TEST(EventWriter, FileDeletionBeforeWriting) { env()->SleepForMicroseconds( 2000000); // To make sure timestamp part of filename will differ. TF_ASSERT_OK(env()->DeleteFile(filename0)); - EXPECT_TRUE(writer.Init()); // Init should reopen file. + TF_EXPECT_OK(writer.Init()); // Init should reopen file. WriteFile(&writer); - EXPECT_TRUE(writer.Flush()); + TF_EXPECT_OK(writer.Flush()); string filename1 = writer.FileName(); EXPECT_NE(filename0, filename1); VerifyFile(filename1); diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc index 9b6a8e12511448b72e17a0b20a4418c4a5cd2c7a..13e41c17f7c7df5ad581bd3f6a39051641139258 100644 --- a/tensorflow/core/util/example_proto_fast_parsing_test.cc +++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc @@ -57,6 +57,7 @@ void TestCorrectness(const string& serialized) { Example example; Example fast_example; EXPECT_TRUE(example.ParseFromString(serialized)); + example.DiscardUnknownFields(); EXPECT_TRUE(TestFastParse(serialized, &fast_example)); EXPECT_EQ(example.DebugString(), fast_example.DebugString()); if (example.DebugString() != fast_example.DebugString()) { diff --git a/tensorflow/core/util/example_proto_helper.cc b/tensorflow/core/util/example_proto_helper.cc index 41f56d2daa48e651f5ac4051deae9c05ef1ed859..e156a3bc8f0f01acc543e9b385bd9782870be52a 100644 --- a/tensorflow/core/util/example_proto_helper.cc +++ b/tensorflow/core/util/example_proto_helper.cc @@ -247,8 +247,9 @@ Status SingleExampleProtoToTensors( bool types_match; TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match)); if (!types_match) { - return errors::InvalidArgument("Name: ", example_name, ", Feature: ", - key, ". Data types don't match. ", + return errors::InvalidArgument("Name: ", example_name, + ", Feature: ", key, + ". Data types don't match. ", "Expected type: ", DataTypeString(dtype), " Feature is: ", ProtoDebugString(f)); } @@ -278,8 +279,9 @@ Status SingleExampleProtoToTensors( bool types_match; TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match)); if (!types_match) { - return errors::InvalidArgument("Name: ", example_name, ", Feature: ", - key, ". Data types don't match. ", + return errors::InvalidArgument("Name: ", example_name, + ", Feature: ", key, + ". Data types don't match. ", "Expected type: ", DataTypeString(dtype), " Feature is: ", ProtoDebugString(f)); } diff --git a/tensorflow/core/util/memmapped_file_system_test.cc b/tensorflow/core/util/memmapped_file_system_test.cc index 616eb5dac32188688ac01cf49ff583dc1623d5ad..504d2d353f8f76f77e4efd3e4a6a6edcaa200711 100644 --- a/tensorflow/core/util/memmapped_file_system_test.cc +++ b/tensorflow/core/util/memmapped_file_system_test.cc @@ -144,8 +144,8 @@ TEST(MemmappedFileSystemTest, ProxyToDefault) { TF_ASSERT_OK(memmapped_env.NewAppendableFile(filename, &writable_file_temp)); // Making sure to clean up after the test finishes. const auto adh = [&memmapped_env, &filename](WritableFile* f) { - delete f; - TF_CHECK_OK(memmapped_env.DeleteFile(filename)); + delete f; + TF_CHECK_OK(memmapped_env.DeleteFile(filename)); }; std::unique_ptr writable_file( writable_file_temp.release(), adh); diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 2caf5fc56dafb5a8879db8026a78bc7bf46346a4..eda966bc3342912334f90a7beddf8ccd3aefa68d 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML #include "mkldnn.hpp" using mkldnn::engine; @@ -210,31 +210,32 @@ class MklShape { CHECK_EQ(dnnDelete_F32(convert), E_SUCCESS); } -// The following methods are used for serializing and de-serializing the -// contents of the mklshape object. -// The data is serialized in this order -// isMklTensor_ -// dimension_ -// sizes_ -// strides_ -// mklLayout_ -// tfLayout_ -// tf_to_mkl_dim_map_ + // The following methods are used for serializing and de-serializing the + // contents of the mklshape object. + // The data is serialized in this order + // isMklTensor_ + // dimension_ + // sizes_ + // strides_ + // mklLayout_ + // tfLayout_ + // tf_to_mkl_dim_map_ #define SIZE_OF_MKL_DNN_BUF \ (dnnLayoutSerializationBufferSize_F32()) // Size of buffer needed to // serialize dnn_layout pointer -// Size of buffer to hold the serialized object, the size is computed as follows -// sizeof(isMklTensor_) + sizeof(dimension_) + sizeof(sizes_) + sizeof(strides_) -// + sizeof(mklLayout_ buffer) + sizeof(tfLayout_ buffer) -// + sizeof(tf_to_mkl_dim_map_) + // Size of buffer to hold the serialized object, the size is computed as + // follows sizeof(isMklTensor_) + sizeof(dimension_) + sizeof(sizes_) + + // sizeof(strides_) + // + sizeof(mklLayout_ buffer) + sizeof(tfLayout_ buffer) + // + sizeof(tf_to_mkl_dim_map_) #define SIZE_OF_MKL_SERIAL_DATA(dims) \ (2 * sizeof(size_t) + 3 * dims * sizeof(size_t) + 2 * SIZE_OF_MKL_DNN_BUF) -// First we need to define some macro for offsets into the serial buffer where -// different elements of Mklshape is written/read from + // First we need to define some macro for offsets into the serial buffer where + // different elements of Mklshape is written/read from #define IS_MKL_TENSOR_OFFSET 0 // Location from start of buffer where isMklTensor_ is serialized @@ -324,7 +325,7 @@ class MklShape { nullptr; // TF dimension corresponding to this MKL dimension }; -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML // Forward decl TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format); @@ -388,7 +389,7 @@ class MklDnnShape { /// Equality function for MklDnnShape objects /// @return true if both are equal; false otherwise. - inline bool operator == (const MklDnnShape& input_shape) const { + inline bool operator==(const MklDnnShape& input_shape) const { if (this->IsMklTensor() != input_shape.IsMklTensor()) { return false; } @@ -406,7 +407,7 @@ class MklDnnShape { /// Equality operator for MklDnnShape and TFShape. /// Returns: true if TF shapes for both are the same, false otherwise - inline bool operator == (const TensorShape& input_shape) const { + inline bool operator==(const TensorShape& input_shape) const { if (!this->IsMklTensor()) { return false; } @@ -425,7 +426,7 @@ class MklDnnShape { inline size_t GetDimension(char dimension) const { int index = GetMklDnnTensorDimIndex(dimension); CHECK(index >= 0 && index < this->GetDimension()) - << "Invalid index from the dimension: " << index << ", " << dimension; + << "Invalid index from the dimension: " << index << ", " << dimension; return this->DimSize(index); } @@ -659,7 +660,7 @@ class MklDnnShape { typedef std::vector MklShapeList; -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML typedef std::vector MklDnnShapeList; #endif @@ -673,7 +674,7 @@ inline bool AreAllMklTensors(const MklShapeList& shapes) { return true; } -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML template inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor, const MklShape& mkl_shape) { @@ -705,8 +706,8 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor, Tensor output_tensor; TensorShape output_shape; - TF_CHECK_OK(Status(error::Code::UNIMPLEMENTED, - "Unimplemented conversion function")); + TF_CHECK_OK( + Status(error::Code::UNIMPLEMENTED, "Unimplemented conversion function")); return output_tensor; } @@ -724,7 +725,7 @@ inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) { sizeof(uint8)); } -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) { mklshape->DeSerializeMklDnnShape( ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs())) @@ -748,8 +749,7 @@ inline void GetMklInputList(OpKernelContext* ctext, StringPiece name, ctext->input_list(name, input_tensors); } - -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name, MklShapeList* mkl_shapes) { @@ -779,7 +779,7 @@ inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name, #endif -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML /// Get shape of input tensor pointed by 'input_idx' in TensorShape format. /// If the input tensor is in MKL layout, then obtains TensorShape from /// MklShape. @@ -814,7 +814,7 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n, second_tensor->flat().size() * sizeof(uint8)); } -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML // Allocate the second output tensor that will contain // the MKL shape serialized inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n, @@ -851,7 +851,7 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n, second_tensor->flat().size() * sizeof(uint8)); } -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML // Allocate the output tensor, create a second output tensor that will contain // the MKL shape serialized inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n, @@ -875,7 +875,7 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n, // Allocates a temp tensor and returns the data buffer for temporary storage. // Currently -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML template inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out, const memory::primitive_desc& pd, void** buf_out) { @@ -973,8 +973,8 @@ inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) { return mkl_shape.dim_size(index); } -inline void CopyMklTensorInToOut(OpKernelContext* context, - int idx_in, int idx_out) { +inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in, + int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -994,9 +994,9 @@ inline void CopyMklTensorInToOut(OpKernelContext* context, context->set_output(idx_meta_out, meta_output); } -#ifndef INTEL_MKL_DNN -inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, - int idx_in, int idx_out, +#ifdef INTEL_MKL_ML +inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in, + int idx_out, const TensorShape& shape) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); @@ -1013,8 +1013,8 @@ inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, context->set_output(idx_data_out, output); } #else -inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, - int idx_in, int idx_out, +inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in, + int idx_out, const TensorShape& shape) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); @@ -1032,10 +1032,10 @@ inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, } #endif -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML -inline void ForwardTfTensorInToOut(OpKernelContext* context, - int idx_in, int idx_out) { +inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in, + int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -1053,8 +1053,8 @@ inline void ForwardTfTensorInToOut(OpKernelContext* context, #else -inline void ForwardTfTensorInToOut(OpKernelContext* context, - int idx_in, int idx_out) { +inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in, + int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -1072,8 +1072,8 @@ inline void ForwardTfTensorInToOut(OpKernelContext* context, #endif -inline void ForwardMklTensorInToOut(OpKernelContext* context, - int idx_in, int idx_out) { +inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in, + int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -1090,10 +1090,10 @@ inline void ForwardMklTensorInToOut(OpKernelContext* context, } } -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context, - int idx_in, int idx_out, - const MklDnnShape& mkl_shape) { + int idx_in, int idx_out, + const MklDnnShape& mkl_shape) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); @@ -1112,9 +1112,9 @@ inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context, // Forward the MKL shape ONLY (used in elementwise and other ops where // we call the eigen implementation and MKL shape is not used) inline void ForwardMklMetaDataInToOut(OpKernelContext* context, - uint idx_data_in, uint idx_data_out) { - uint idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs()); - uint idx_meta_out = + uint32 idx_data_in, uint32_t idx_data_out) { + uint32 idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs()); + uint32 idx_meta_out = GetTensorMetaDataIndex(idx_data_out, context->num_outputs()); if (IsRefType(context->input_dtype(idx_data_in))) { @@ -1126,13 +1126,13 @@ inline void ForwardMklMetaDataInToOut(OpKernelContext* context, // Set a dummy MKL shape (called when the output is in TF format) inline void SetDummyMklShapeOutput(OpKernelContext* context, - uint idx_data_out) { + uint32 idx_data_out) { MklShape mkl_shape_output; mkl_shape_output.SetMklTensor(false); AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output); } -#ifndef INTEL_MKL_DNN +#ifdef INTEL_MKL_ML // We don't need these functions in MKLDNN. We have defined equality operator // on MklDnnShape class directly. @@ -1216,11 +1216,11 @@ inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) { int64 H = input.dim_size(1); int64 W = input.dim_size(2); int64 C = input.dim_size(3); - int64 stride_n = H*W*C; -# pragma omp parallel for num_threads(16) + int64 stride_n = H * W * C; +#pragma omp parallel for num_threads(16) for (int64 n = 0; n < N; ++n) { - mkl_somatcopy('R', 'T', H*W, C, 1, buf_in + n*stride_n, C, - buf_out + n*stride_n, H*W); + mkl_somatcopy('R', 'T', H * W, C, 1, buf_in + n * stride_n, C, + buf_out + n * stride_n, H * W); } } @@ -1232,17 +1232,17 @@ inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) { int64 H = (*output)->dim_size(1); int64 W = (*output)->dim_size(2); int64 C = (*output)->dim_size(3); - int64 stride_n = H*W*C; -# pragma omp parallel for num_threads(16) + int64 stride_n = H * W * C; +#pragma omp parallel for num_threads(16) for (int64 n = 0; n < N; ++n) { - mkl_somatcopy('R', 'T', C, H*W, 1, buf_in + n*stride_n, H*W, - buf_out + n*stride_n, C); + mkl_somatcopy('R', 'T', C, H * W, 1, buf_in + n * stride_n, H * W, + buf_out + n * stride_n, C); } } // ------------------------------------------------------------------- -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML /// Return MKL-DNN data type (memory::data_type) for input type T /// @@ -1279,10 +1279,11 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) { /// @return: Tensorflow data format corresponding to memory::format /// Fails with an error if invalid data format. inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) { - if (format == memory::format::nhwc) return FORMAT_NHWC; - else if (format == memory::format::nchw) return FORMAT_NCHW; - TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, - "Unsupported data format")); + if (format == memory::format::nhwc) + return FORMAT_NHWC; + else if (format == memory::format::nchw) + return FORMAT_NCHW; + TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure // that we don't come here. @@ -1425,7 +1426,6 @@ inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim, return memory::desc(md); } - /* * Class to represent all the resources corresponding to a tensor in TensorFlow * that are required to execute an operation (such as Convolution). @@ -1494,7 +1494,7 @@ class MklDnnData { /// @return: memory::desc object corresponding to blocked memory format /// for given dimensions and strides. static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim, - const memory::dims& strides) { + const memory::dims& strides) { return CreateBlockedMemDescHelper(dim, strides, MklDnnType()); } @@ -1563,7 +1563,6 @@ class MklDnnData { return user_memory_->get_primitive_desc(); } - /// Get function for descriptor of user memory. inline memory::desc GetUsrMemDesc() { // This is ugly. Why MKL-DNN does not provide desc() method of const type?? @@ -1634,7 +1633,8 @@ class MklDnnData { /// @return: true in case reorder of input is needed; false, otherwise. inline bool IsReorderNeeded(const memory::format& target_format) const { CHECK_NOTNULL(user_memory_); - return target_format != user_memory_->get_primitive_desc().desc().data.format; + return target_format != + user_memory_->get_primitive_desc().desc().data.format; } /// Function to create a reorder from memory pointed by from to memory pointed @@ -1753,7 +1753,7 @@ class MklDnnData { } }; -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/util/mkl_util_test.cc b/tensorflow/core/util/mkl_util_test.cc index 8b73eadb40046518179fcaaa5c244aa7f3d52ebe..cd1d0713ad58b594005847f48943a228743e530d 100644 --- a/tensorflow/core/util/mkl_util_test.cc +++ b/tensorflow/core/util/mkl_util_test.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { namespace { -#ifdef INTEL_MKL_DNN +#ifndef INTEL_MKL_ML TEST(MklUtilTest, MklDnnTfShape) { auto cpu_engine = engine(engine::cpu, 0); @@ -84,7 +84,7 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) { EXPECT_EQ(b_md2.data.format, mkldnn_blocked); } -#endif // INTEL_MKL_DNN +#endif // INTEL_MKL_ML } // namespace } // namespace tensorflow diff --git a/tensorflow/core/util/presized_cuckoo_map.h b/tensorflow/core/util/presized_cuckoo_map.h index e7dab830f0ec9e3401d621f04358d3ee62cb0b63..f88ad2faaff344832d65b04357c3d8c2665ebad5 100644 --- a/tensorflow/core/util/presized_cuckoo_map.h +++ b/tensorflow/core/util/presized_cuckoo_map.h @@ -67,7 +67,7 @@ inline uint64 multiply_high_u64(uint64 x, uint64 y) { return prod_hi + (prod_mid1 >> 32) + (prod_mid2 >> 32) + carry; #endif } -} +} // namespace presized_cuckoo_map template class PresizedCuckooMap { diff --git a/tensorflow/core/util/reporter_test.cc b/tensorflow/core/util/reporter_test.cc index 1cb07718feee820c334d8f5183cafb2de0cb009b..575c27d4ef72ec33c4b9352de59fc806b12d6385 100644 --- a/tensorflow/core/util/reporter_test.cc +++ b/tensorflow/core/util/reporter_test.cc @@ -29,8 +29,8 @@ namespace { // Tests of all the error paths in log_reader.cc follow: static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain " - << expected; + EXPECT_TRUE(StringPiece(s).contains(expected)) + << s << " does not contain " << expected; } TEST(TestReporter, NoLogging) { diff --git a/tensorflow/core/util/session_message.cc b/tensorflow/core/util/session_message.cc new file mode 100644 index 0000000000000000000000000000000000000000..28a6517a1a3c584b896c0b51f9937bc786283b16 --- /dev/null +++ b/tensorflow/core/util/session_message.cc @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/session_message.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/event.pb.h" + +static const int kMaxLogEvents = 1000; + +namespace tensorflow { + +SessionLogger::SessionLogger() : status_(new SessionStatus) {} + +SessionLogger::~SessionLogger() {} + +string SessionLogger::DebugString() { return "SessionLogger"; } + +void SessionLogger::Log(StringPiece message) { + mutex_lock lock(mu_); + + Event* event = status_->add_event(); + event->set_wall_time(Env::Default()->NowMicros()); + event->set_step(0); + LogMessage* log = event->mutable_log_message(); + log->set_message(message.ToString()); + log->set_level(LogMessage::INFO); + + // Clip log events by 10% if we overflow + if (status_->event_size() > kMaxLogEvents) { + auto events = status_->mutable_event(); + events->DeleteSubrange(0, kMaxLogEvents / 10); + } +} + +SessionLogger* GetSessionLogger(ResourceMgr* rm) { + SessionLogger* logger; + + std::function status_creator = + [](SessionLogger** result) { + *result = new SessionLogger(); + return Status::OK(); + }; + + if (!rm->LookupOrCreate("session", "status", &logger, + status_creator) + .ok()) { + return nullptr; + } + + return logger; +} + +void LogSessionMessage(ResourceMgr* rm, StringPiece message) { + return GetSessionLogger(rm)->Log(message); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/session_message.h b/tensorflow/core/util/session_message.h new file mode 100644 index 0000000000000000000000000000000000000000..c0f3d78b46a50386403c453fcc92d56a456206de --- /dev/null +++ b/tensorflow/core/util/session_message.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H_ +#define TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H_ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +class ResourceMgr; +class SessionStatus; + +class SessionLogger : public ResourceBase { + public: + SessionLogger(); + ~SessionLogger(); + + void Log(StringPiece message); + string DebugString() override; + + const SessionStatus& status() { return *status_; } + + private: + std::unique_ptr status_; + mutex mu_; +}; + +// Return a SessionLogger instance for the current session. If the logger +// will be used across multiple computations, you must explicitly acquire +// and release references using Ref()/Unref(). +// +// Returns nullptr if a logger cannot be created. +SessionLogger* GetSessionLogger(ResourceMgr* rm); + +// Attach `message` to the logger for the current session. +void LogSessionMessage(ResourceMgr* rm, StringPiece message); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h index f2401a0af4e60f66c606e86e90a37bcf09eb6308..258ee418c145bae161c7603d4249875fb687c94a 100644 --- a/tensorflow/core/util/sparse/sparse_tensor.h +++ b/tensorflow/core/util/sparse/sparse_tensor.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/sparse/dim_comparator.h" #include "tensorflow/core/util/sparse/group_iterator.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace sparse { @@ -59,8 +59,8 @@ class SparseTensor { shape_(shape.begin(), shape.end()), order_(order.begin(), order.end()), dims_(GetDimsFromIx(ix)) { - CHECK_EQ(ix.dtype(), DT_INT64) << "indices must be type int64 but got: " - << ix.dtype(); + CHECK_EQ(ix.dtype(), DT_INT64) + << "indices must be type int64 but got: " << ix.dtype(); CHECK(TensorShapeUtils::IsVector(vals.shape())) << "vals must be a vec, but got: " << vals.shape().DebugString(); CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0)) diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc index efdd97fd3d6ffa5c1f66f2a0950d7bd44ba01eb1..85de0320857e307ea54594c2eff611b9e413945b 100644 --- a/tensorflow/core/util/sparse/sparse_tensor_test.cc +++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace sparse { diff --git a/tensorflow/core/util/stream_executor_util.h b/tensorflow/core/util/stream_executor_util.h index 6a5ddec04c9d6c2f723e0caa7343103f09c63183..f7767ace716782e53a2023bea7acc7b2f3c6604c 100644 --- a/tensorflow/core/util/stream_executor_util.h +++ b/tensorflow/core/util/stream_executor_util.h @@ -41,9 +41,10 @@ class StreamExecutorUtil { // This assumes that the error codes between the two implementations // match. static Status ConvertStatus(const perftools::gputools::port::Status& s) { - return s.ok() ? Status::OK() : Status(static_cast( - static_cast(s.code())), - s.error_message()); + return s.ok() ? Status::OK() + : Status(static_cast( + static_cast(s.code())), + s.error_message()); } }; diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 579b70ab5149f05749205f24a0c6e64c95f12dfd..0426fee0e2679718a80cfb46bcb78a668c6b6e83 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -286,7 +286,7 @@ Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out, TF_RETURN_IF_ERROR(out->Append(len)); *crc32c = crc32c::Extend(*crc32c, reinterpret_cast(&elem_size), sizeof(uint64)); - *bytes_written += sizeof(uint64); + *bytes_written += len.size(); // Write the serialized variant. TF_RETURN_IF_ERROR(out->Append(elem)); @@ -913,8 +913,8 @@ Status BundleReader::LookupSlice(StringPiece full_tensor_key, Status BundleReader::GetSliceValue(StringPiece full_tensor_key, const BundleEntryProto& full_tensor_entry, const TensorSlice& slice_spec, Tensor* val) { - using checkpoint::TensorSliceSet; using checkpoint::RegisterTensorSlice; + using checkpoint::TensorSliceSet; DCHECK_GE(full_tensor_entry.slices_size(), 0); const TensorShape full_shape(TensorShape(full_tensor_entry.shape())); diff --git a/tensorflow/core/util/tensor_slice_reader_cache.cc b/tensorflow/core/util/tensor_slice_reader_cache.cc index 0f009d7de57a3cf1471c1ba694d3a771bc00635c..424f8098a9c1e3cec3851be06d04d49bed93e9af 100644 --- a/tensorflow/core/util/tensor_slice_reader_cache.cc +++ b/tensorflow/core/util/tensor_slice_reader_cache.cc @@ -55,7 +55,7 @@ const TensorSliceReader* TensorSliceReaderCache::GetReader( TensorSliceReader::OpenTableFunction open_function, int preferred_shard) { mutex_lock l(mu_); -#if defined(__GXX_RTTI) || defined(_CPPRTTI) +#if defined(__GXX_RTTI) || defined(_CPPRTTI) // Get the function pointer from the open_function value. TensorSliceReaderCache::OpenFuncType* func_ptr = open_function.target(); diff --git a/tensorflow/core/util/tensor_slice_set.cc b/tensorflow/core/util/tensor_slice_set.cc index 4217df90ca147ccc17cadf6c46c6e4ef4524f12b..7c1d325c0a54e7ba5261f645a2962970fa2d3630 100644 --- a/tensorflow/core/util/tensor_slice_set.cc +++ b/tensorflow/core/util/tensor_slice_set.cc @@ -188,9 +188,9 @@ Status RegisterTensorSlice( } if (type != tss->type()) { return errors::Internal("Incompatible tensor types detected for tensor ", - name, ": existing = ", - DataTypeString(tss->type()), ", new = ", - DataTypeString(type)); + name, + ": existing = ", DataTypeString(tss->type()), + ", new = ", DataTypeString(type)); } } // Register the tensor slices without the actual data. diff --git a/tensorflow/core/util/tensor_slice_util.h b/tensorflow/core/util/tensor_slice_util.h index c7edae66b267d4cbd88d497c745b4d81802ab3a9..8f5a6f1d93591e94ec759d343ec26146c67552c0 100644 --- a/tensorflow/core/util/tensor_slice_util.h +++ b/tensorflow/core/util/tensor_slice_util.h @@ -139,9 +139,9 @@ static bool CopyDataFromTensorSliceToTensorSlice(const TensorShape& shape, const TensorSlice& slice_d, const SrcT* ptr_s, DstT* ptr_d) { - CHECK_LE(shape.dims(), kTensorSliceMaxRank) << "Only tensors of size up to " - << kTensorSliceMaxRank - << " are supported"; + CHECK_LE(shape.dims(), kTensorSliceMaxRank) + << "Only tensors of size up to " << kTensorSliceMaxRank + << " are supported"; // We need to compute the intersection of the two slices. TensorSlice inter; if (!slice_s.Intersect(slice_d, &inter)) { diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h index bdb4921e1bbf8611d84420c1e52d01fa39c25264..2888c66d10fa3c2ab0eaf755a23da3eb3fcd6b09 100644 --- a/tensorflow/core/util/tensor_slice_writer.h +++ b/tensorflow/core/util/tensor_slice_writer.h @@ -101,8 +101,8 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, // The tensor and the slice have to be compatible if (shape.dims() != slice.dims()) { return errors::Internal("Incompatible tensor shape and slice: ", "shape = ", - shape.DebugString(), ", slice = ", - slice.DebugString()); + shape.DebugString(), + ", slice = ", slice.DebugString()); } DataType dt = DataTypeToEnum::value; // We need to add an entry for "name" if there isn't an entry already. @@ -114,9 +114,9 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, CHECK_EQ(name, ssm.name()) << ProtoShortDebugString(ssm); TensorShape ssm_shape(ssm.shape()); if (!shape.IsSameSize(ssm_shape)) { - return errors::Internal("Mismatching shapes: existing tensor = ", - ssm_shape.DebugString(), ", trying to add name ", - name, ", shape = ", shape.DebugString()); + return errors::Internal( + "Mismatching shapes: existing tensor = ", ssm_shape.DebugString(), + ", trying to add name ", name, ", shape = ", shape.DebugString()); } if (dt != ssm.type()) { return errors::Internal( diff --git a/tensorflow/docs_src/about/bib.md b/tensorflow/docs_src/about/bib.md index c9f0c532c62791a9fcf854f11fd2f330955ee7d6..5593a3d95c435df38174fde5db37f4dd3437acd4 100644 --- a/tensorflow/docs_src/about/bib.md +++ b/tensorflow/docs_src/about/bib.md @@ -60,7 +60,7 @@ author={ Lukasz~Kaiser and Manjunath~Kudlur and Josh~Levenberg and - Dan~Man\'{e} and + Dandelion~Man\'{e} and Rajat~Monga and Sherry~Moore and Derek~Murray and diff --git a/tensorflow/docs_src/about/roadmap.md b/tensorflow/docs_src/about/roadmap.md index 3ee825ed400de93553bf69fee065fcf8ef13be4d..1f934acab69276d4c32393bb73632d978e0d15c3 100644 --- a/tensorflow/docs_src/about/roadmap.md +++ b/tensorflow/docs_src/about/roadmap.md @@ -1,37 +1,86 @@ # Roadmap -**Last updated: January 23, 2017** +**Last updated: Feb 15, 2018** -TensorFlow is a fast moving project. In order for the community to better -understand what the near future will bring, this document shares what we are -working on internally. Many of these features were requested by the community, -and we welcome -[contributions](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome). +TensorFlow is a rapidly moving, community supported project. This document is intended +to provide guidance about priorities and focus areas of the core set of TensorFlow +developers and about functionality that can be expected in the upcoming releases of +TensorFlow. Many of these areas are driven by community use cases, and we welcome +further +[contributions](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md) +to TensorFlow. -The features on this list are targeted for the next few months. At this point, -we do not have timelines for these features. +The features below do not have concrete release dates. However, the majority can be +expected in the next one to two releases. -### Improve non-Python language support +### APIs +#### High Level APIs: +* Easy multi-GPU utilization with Estimators +* Easy-to-use high-level pre-made estimators for Gradient Boosted Trees, Time Series, and other models -* Support for adding gradient computation for graphs constructed in other - languages (C++, Java, Go etc.) +#### Eager Execution: +* Efficient utilization of multiple GPUs +* Distributed training (multi-machine) +* Performance improvements +* Simpler export to a GraphDef/SavedModel -### Making TensorFlow easier to use -* High-level APIs -* Well-maintained models showing best practices +#### Keras API: +* Better integration with tf.data (ability to call `model.fit` with data tensors) +* Full support for Eager Execution (both Eager support for the regular Keras API, and ability +to create Keras models Eager- style via Model subclassing) +* Better distribution/multi-GPU support and TPU support (including a smoother model-to-estimator workflow) -### Performance -* Speed and memory benchmarks -* Distributed full model benchmarks -* Performance and memory usage improvements +#### Official Models: +* A set of +[reference models](https://github.com/tensorflow/models/tree/master/official) +across image recognition, speech, object detection, and + translation that demonstrate best practices and serve as a starting point for + high-performance model development. + +#### Contrib: +* Deprecation notices added to parts of tf.contrib where preferred implementations exist outside of tf.contrib. +* As much as possible, large projects inside tf.contrib moved to separate repositories. +* The tf.contrib module will eventually be discontinued in its current form, experimental development will in future happen in other repositories. -### Core Features -* Automatic op placement ([#2126](https://github.com/tensorflow/tensorflow/issues/2126)) -* Support for graph-level functions + +#### Probabilistic Reasoning and Statistical Analysis: +* Rich set of tools for probabilistic and statistical analysis in tf.distributions + and tf.probability. These include new samplers, layers, optimizers, losses, and structured models +* Statistical tools for hypothesis testing, convergence diagnostics, and sample statistics +* Edward 2.0: High-level API for probabilistic programming ### Platforms -* OpenCL support ([#22](https://github.com/tensorflow/tensorflow/issues/22)) +#### TensorFlow Lite: +* Increased coverage of supported ops in TensorFlow Lite +* Easier conversion of a trained TensorFlow graph for use on TensorFlow Lite +* Support for GPU acceleration in TensorFlow Lite (iOS and Android) +* Support for hardware accelerators via Android NeuralNets API +* Improved CPU performance by quantization and other network optimizations (eg. pruning, distillation) +* Increased support for devices beyond Android and iOS (eg. RPi, Cortex-M) + +### Performance +#### Distributed TensorFlow: +* Multi-GPU support optimized for a variety of GPU topologies +* Improved mechanisms for distributing computations on several machines + +#### Optimizations: +* Mixed precision training support with initial example model and guide +* Native TensorRT support +* Int8 support for SkyLake via MKL +* Dynamic loading of SIMD-optimized kernels + +### Documentation and Usability: +* Updated documentation, tutorials and Getting Started guides +* Process to enable external contributions to tutorials, documentation, and blogs showcasing best practice use-cases of TensorFlow and high-impact applications + +### Community and Partner Engagement +#### Special Interest Groups: +* Mobilizing the community to work together in focused domains +* [tf-distribute](https://groups.google.com/a/tensorflow.org/forum/#!forum/tf-distribute) +: build and packaging of TensorFlow +* More to be identified and launched -### Community -* More educational resources -* Better integration of TensorFlow into the opensource big data ecosystem (e.g. -[#2655](https://github.com/tensorflow/tensorflow/issues/2655)) +#### Community: +* Incorporate public feedback on significant design decisions via a Request-for-Comment (RFC) process +* Formalize process for external contributions to land in TensorFlow and associated projects +* Grow global TensorFlow communities and user groups +* Collaborate with partners to co-develop and publish research papers diff --git a/tensorflow/docs_src/about/uses.md b/tensorflow/docs_src/about/uses.md index 8818177a288ef16ac1907a20ab563ee3d871f7fd..d646880bd350c42e463680a5c7eb0903f2c0a497 100644 --- a/tensorflow/docs_src/about/uses.md +++ b/tensorflow/docs_src/about/uses.md @@ -22,6 +22,14 @@ This section describes some of the current uses of the TensorFlow system. > TensorFlow, or even better, send us a pull request to add an entry to this > file. +* **Deep Speech** +
    +
  • **Organization**: Mozilla
  • +
  • **Domain**: Speech Recognition
  • +
  • **Description**: A TensorFlow implementation motivated by Baidu's Deep Speech architecture.
  • +
  • **More info**: [GitHub Repo](https://github.com/mozilla/deepspeech)
  • +
+ * **RankBrain**
  • **Organization**: Google
  • diff --git a/tensorflow/docs_src/api_guides/python/contrib.distributions.md b/tensorflow/docs_src/api_guides/python/contrib.distributions.md index 7a3d509b75198461430195aa70a336f94b7f8cfa..533d7dac1373f61ca92dba288a7d29e07e0f37d3 100644 --- a/tensorflow/docs_src/api_guides/python/contrib.distributions.md +++ b/tensorflow/docs_src/api_guides/python/contrib.distributions.md @@ -17,7 +17,6 @@ initialized with parameters that define the distributions. * @{tf.contrib.distributions.Binomial} * @{tf.contrib.distributions.Bernoulli} -* @{tf.contrib.distributions.BernoulliWithSigmoidProbs} * @{tf.contrib.distributions.Beta} * @{tf.contrib.distributions.Categorical} * @{tf.contrib.distributions.Chi2} diff --git a/tensorflow/docs_src/api_guides/python/regression_examples.md b/tensorflow/docs_src/api_guides/python/regression_examples.md index 45cb9d829cfbc1b1efb735cc1ea27e33159db724..7de2be05521d9293e33664cdbbd7bf16b9ad7c52 100644 --- a/tensorflow/docs_src/api_guides/python/regression_examples.md +++ b/tensorflow/docs_src/api_guides/python/regression_examples.md @@ -38,7 +38,7 @@ The preceding examples rely on the following data set utility: Utility Description - imports85.py + imports85.py This program provides utility functions that load the imports85 data set into formats that other TensorFlow programs (for example, linear_regression.py and @@ -229,4 +229,4 @@ passed through to the `model_fn` when the `model_fn` is called. The `model_fn` returns an @{tf.estimator.EstimatorSpec$`EstimatorSpec`} which is a simple structure indicating to the `Estimator` which operations should be run to accomplish -varions tasks. +various tasks. diff --git a/tensorflow/docs_src/community/welcome.md b/tensorflow/docs_src/community/welcome.md index a3abf2550757e825ae2d023018def919de1bcd8f..9f6fe91b1490ef4ffe43acc877ecb83cc9121118 100644 --- a/tensorflow/docs_src/community/welcome.md +++ b/tensorflow/docs_src/community/welcome.md @@ -12,7 +12,6 @@ The source code for TensorFlow is on Before contributing to TensorFlow source code, please review the [Contribution guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md). - ### Projects developed by the TensorFlow community The TensorFlow community has created many great projects around TensorFlow, including: @@ -65,5 +64,6 @@ please read the following list carefully: [TensorFlow issues tracker](https://github.com/tensorflow/tensorflow/issues) on GitHub. For example, use the issue tracker to request a new operation in TensorFlow. - + * To report vulnerabilities, please follow our + [vulnerability disclosure guidelines](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/SECURITY.md). diff --git a/tensorflow/docs_src/deploy/index.md b/tensorflow/docs_src/deploy/index.md index 5831960b4f6e383a6babb0823893a5d9ec5017f0..07b1bc9257ff7b132d22ac186a2f462e9c784867 100644 --- a/tensorflow/docs_src/deploy/index.md +++ b/tensorflow/docs_src/deploy/index.md @@ -7,6 +7,8 @@ the following documents: a cluster of TensorFlow servers. * @{$hadoop$How to run TensorFlow on Hadoop}, which has a highly self-explanatory title. + * @{$s3$How to run TensorFlow with the S3 filesystem}, which explains how + to run TensorFlow with the S3 file system. * The entire document set for [TensorFlow serving](/serving), an open-source, flexible, high-performance serving system for machine-learned models designed for production environments. TensorFlow Serving provides diff --git a/tensorflow/docs_src/deploy/leftnav_files b/tensorflow/docs_src/deploy/leftnav_files index f8f8d578e602cac8dd814326e318ebe0e85ec700..c682e7add16c741279aedb40c1b12f4ca8f0286a 100644 --- a/tensorflow/docs_src/deploy/leftnav_files +++ b/tensorflow/docs_src/deploy/leftnav_files @@ -1,3 +1,4 @@ index.md distributed.md hadoop.md +s3.md diff --git a/tensorflow/docs_src/deploy/s3.md b/tensorflow/docs_src/deploy/s3.md new file mode 100644 index 0000000000000000000000000000000000000000..38f84286347622d1de0646cdc621d5fb1447e588 --- /dev/null +++ b/tensorflow/docs_src/deploy/s3.md @@ -0,0 +1,40 @@ +# How to run TensorFlow on S3 + +This document describes how to run TensorFlow on S3 file system. + +## S3 + +We assume that you are familiar with @{$reading_data$reading data}. + +To use S3 with TensorFlow, change the file paths you use to read and write +data to an S3 path. For example: + +```python +filenames = ["s3://bucketname/path/to/file1.tfrecord", + "s3://bucketname/path/to/file2.tfrecord"] +dataset = tf.data.TFRecordDataset(filenames) +``` + +When reading or writing data on S3 with your TensorFlow program, the behavior +could be controlled by various environmental variables: + +* **AWS_REGION**: By default, regional endpoint is used for S3, with region + controlled by `AWS_REGION`. If `AWS_REGION` is not specified, then + `us-east-1` is used. +* **S3_ENDPOINT**: The endpoint could be overridden explicitly with + `S3_ENDPOINT` specified. +* **S3_USE_HTTPS**: HTTPS is used to access S3 by default, unless + `S3_USE_HTTPS=0`. +* **S3_VERIFY_SSL**: If HTTPS is used, SSL verification could be disabled + with `S3_VERIFY_SSL=0`. + +To read or write objects in a bucket that is no publicly accessible, +AWS credentials must be provided through one of the following methods: + +* Set credentials in the AWS credentials profile file on the local system, + located at: `~/.aws/credentials` on Linux, macOS, or Unix, or + `C:\Users\USERNAME\.aws\credentials` on Windows. +* Set the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment + variables. +* If TensorFlow is deployed on an EC2 instance, specify an IAM role and then + give the EC2 instance access to that role. diff --git a/tensorflow/docs_src/extend/add_filesys.md b/tensorflow/docs_src/extend/add_filesys.md index f0591b7b7d8af478db067ecd3bdd949e75d813c9..06f11de4eb0ea7878b01cd37d994c5a40ec400be 100644 --- a/tensorflow/docs_src/extend/add_filesys.md +++ b/tensorflow/docs_src/extend/add_filesys.md @@ -81,6 +81,8 @@ filesystem implementations call their existing libraries. Examples include: plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hadoop/hadoop_file_system.h) * [GCS plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/cloud/gcs_file_system.h) +* [S3 + plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/s3/s3_file_system.h) #### The File interfaces diff --git a/tensorflow/docs_src/get_started/checkpoints.md b/tensorflow/docs_src/get_started/checkpoints.md index 680e1c0d3f58166a4f6b352816914f5220d84996..dfa2110e691167f54e6ea8b7a832f0a88d0ec41a 100644 --- a/tensorflow/docs_src/get_started/checkpoints.md +++ b/tensorflow/docs_src/get_started/checkpoints.md @@ -16,7 +16,7 @@ This document focuses on checkpoints. For details on SavedModel, see the ## Sample code This document relies on the same -[https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py](Iris classification example) detailed in @{$premade_estimators$Getting Started with TensorFlow}. +[Iris classification example](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py) detailed in @{$premade_estimators$Getting Started with TensorFlow}. To download and access the example, invoke the following two commands: ```shell diff --git a/tensorflow/docs_src/get_started/custom_estimators.md b/tensorflow/docs_src/get_started/custom_estimators.md index 6343cc4ee454c7242b98497a37e9852b4e9873ae..42a246678a054d637fea5a82a03ecb84ff412bd9 100644 --- a/tensorflow/docs_src/get_started/custom_estimators.md +++ b/tensorflow/docs_src/get_started/custom_estimators.md @@ -15,7 +15,7 @@ git clone https://github.com/tensorflow/models/ cd models/samples/core/get_started ``` -In this document we wil be looking at +In this document we will be looking at [`custom_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py). You can run it with the following command: @@ -161,7 +161,7 @@ classifier = tf.estimator.Estimator( To implement a typical model function, you must do the following: -* (Define the model)[#define_the_model]. +* [Define the model](#define_the_model). * Specify additional calculations for each of the [three different modes](#modes): * [Predict](#predict) diff --git a/tensorflow/docs_src/get_started/datasets_quickstart.md b/tensorflow/docs_src/get_started/datasets_quickstart.md index ecfbf160f0de2414f6cffa07d159a3e26733e3a6..bc69773d2138f5bf280015b61f1f243fd874bdac 100644 --- a/tensorflow/docs_src/get_started/datasets_quickstart.md +++ b/tensorflow/docs_src/get_started/datasets_quickstart.md @@ -28,8 +28,8 @@ def train_input_fn(features, labels, batch_size): # Shuffle, repeat, and batch the examples. dataset = dataset.shuffle(1000).repeat().batch(batch_size) - # Build the Iterator, and return the read end of the pipeline. - return dataset.make_one_shot_iterator().get_next() + # Return the dataset. + return dataset ``` Let's look at this more closely. @@ -40,7 +40,7 @@ This function expects three arguments. Arguments expecting an "array" can accept nearly anything that can be converted to an array with `numpy.array`. One exception is [`tuple`](https://docs.python.org/3/tutorial/datastructures.html#tuples-and-sequences) -which has special meaning for `Datasets`. +which, as we will see, has special meaning for `Datasets`. * `features`: A `{'feature_name':array}` dictionary (or [`DataFrame`](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html)) @@ -73,11 +73,12 @@ Let's walk through the `train_input_fn()`. ### Slices -In the simplest cases, @{tf.data.Dataset.from_tensor_slices} function takes an -array and returns a @{tf.data.Dataset} representing slices of the array. For -example, an array containing the @{$tutorials/layers$mnist training data} -has a shape of `(60000, 28, 28)`. Passing this to `from_tensor_slices` returns -a `Dataset` object containing 60000 slices, each one a 28x28 image. +The function starts by using the @{tf.data.Dataset.from_tensor_slices} function +to create a @{tf.data.Dataset} representing slices of the array. The array is +sliced across the first dimension. For example, an array containing the +@{$tutorials/layers$mnist training data} has a shape of `(60000, 28, 28)`. +Passing this to `from_tensor_slices` returns a `Dataset` object containing +60000 slices, each one a 28x28 image. The code that returns this `Dataset` is as follows: @@ -89,18 +90,24 @@ mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x) print(mnist_ds) ``` -This will print the following line, showing the @{$programmers_guide/tensors#shapes$shapes} and @{$programmers_guide/tensors#data_types$types} of the items in -the dataset. Note that the dataset does not know how many items it contains. +This will print the following line, showing the +@{$programmers_guide/tensors#shapes$shapes} and +@{$programmers_guide/tensors#data_types$types} of the items in +the dataset. Note that a `Dataset` does not know how many items it contains. ``` None ``` -The dataset above represents a collection of simple arrays, but datasets are -much more powerful than this. Datasets transparently handle any nested -combination of dictionaries or tuples. For example, ensuring that `features` -is a standard dictionary, you can then convert the dictionary of arrays to -a `Dataset` of dictionaries as follows: +The `Dataset` above represents a simple collection of arrays, but datasets are +much more powerful than this. A `Dataset` can transparently handle any nested +combination of dictionaries or tuples (or +[`namedtuple`](https://docs.python.org/2/library/collections.html#collections.namedtuple) +). + +For example after converting the iris `features` +to a standard python dictionary, you can then convert the dictionary of arrays +to a `Dataset` of dictionaries as follows: ``` python dataset = tf.data.Dataset.from_tensor_slices(dict(features)) @@ -124,9 +131,9 @@ and `types` of the `Dataset` take on the same structure. This dataset contains dictionaries of @{$programmers_guide/tensors#rank$scalars}, all of type `tf.float64`. -The first line of `train_input_fn` uses the same functionality, but adds -another level of structure. It creates a dataset containing -`(features, labels)` pairs. +The first line of the iris `train_input_fn` uses the same functionality, but +adds another level of structure. It creates a dataset containing +`(features_dict, label)` pairs. The following code shows that the label is a scalar with type `int64`: @@ -164,14 +171,14 @@ dataset = dataset.shuffle(1000).repeat().batch(batch_size) ``` The @{tf.data.Dataset.shuffle$`shuffle`} method uses a fixed-size buffer to -shuffle the items as they pass through. Setting a `buffer_size` greater than -the number of examples in the `Dataset` ensures that the data is completely -shuffled. The Iris data set only contains 150 examples. +shuffle the items as they pass through. In this case the `buffer_size` is +greater than the number of examples in the `Dataset`, ensuring that the data is +completely shuffled (The Iris data set only contains 150 examples). -The @{tf.data.Dataset.repeat$`repeat`} method has the `Dataset` restart when -it reaches the end. To limit the number of epochss, set the `count` argument. +The @{tf.data.Dataset.repeat$`repeat`} method restarts the `Dataset` when +it reaches the end. To limit the number of epochs, set the `count` argument. -The @{tf.data.Dataset.repeat$`batch`} method collects a number of examples and +The @{tf.data.Dataset.batch$`batch`} method collects a number of examples and stacks them, to create batches. This adds a dimension to their shape. The new dimension is added as the first dimension. The following code uses the `batch` method on the MNIST `Dataset`, from earlier. This results in a @@ -213,35 +220,16 @@ print(dataset) ### Return - - -The `train`, `evaluate`, and `predict` methods of every Estimator require -input functions to return a `(features, label)` pair containing -@{$programmers_guide/tensors$tensorflow tensors}. The `train_input_fn` uses -the following line to convert the Dataset into the expected format: - -```python -# Build the Iterator, and return the read end of the pipeline. -features_result, labels_result = dataset.make_one_shot_iterator().get_next() -``` +At this point the `Dataset` contains `(features_dict, labels)` pairs. +This is the format expected by the `train` and `evaluate` methods, so the +`input_fn` returns the dataset. -The result is a structure of @{$programmers_guide/tensors$TensorFlow tensors}, -matching the layout of the items in the `Dataset`. -For an introduction to what these objects are and how to work with them, -see @{$programmers_guide/low_level_intro}. +The `labels` can/should be omitted when using the `predict` method. -``` python -print((features_result, labels_result)) -``` + -```None -({ - 'SepalLength': , - 'PetalWidth': , - 'PetalLength': , - 'SepalWidth': }, -Tensor("IteratorGetNext_1:4", shape=(?,), dtype=int64)) -``` ## Reading a CSV File @@ -282,7 +270,7 @@ produce the necessary `(features, label)` pairs. We will start by building a function to parse a single line. -The following `iris_data.parse_line` function acomplishes this taks using the +The following `iris_data.parse_line` function accomplishes this task using the @{tf.decode_csv} function, and some simple python code: We must parse each of the lines in the dataset in order to generate the diff --git a/tensorflow/docs_src/get_started/feature_columns.md b/tensorflow/docs_src/get_started/feature_columns.md index e3308ed716d63f10bf0e9dda858c23eef30709a6..ad3e1fe3e3a4e3f5278e76bcaa0fc8eee2faf374 100644 --- a/tensorflow/docs_src/get_started/feature_columns.md +++ b/tensorflow/docs_src/get_started/feature_columns.md @@ -461,8 +461,8 @@ permitting a richer palette of numbers for every cell, an embedding column contains far fewer cells than an indicator column. Let's look at an example comparing indicator and embedding columns. Suppose our -input examples consists of different words from a limited palette of only 81 -words. Further suppose that the data set provides provides the following input +input examples consist of different words from a limited palette of only 81 +words. Further suppose that the data set provides the following input words in 4 separate examples: * `"dog"` diff --git a/tensorflow/docs_src/get_started/get_started_for_beginners.md b/tensorflow/docs_src/get_started/get_started_for_beginners.md index ea1c2fb3f473b9e39567c7607d3b3ad10d2de6b5..b88483be699630d2275850cbc7c461eeb90f5943 100644 --- a/tensorflow/docs_src/get_started/get_started_for_beginners.md +++ b/tensorflow/docs_src/get_started/get_started_for_beginners.md @@ -36,6 +36,7 @@ the following three: alt="Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor" src="../images/iris_three_species.jpg">
+ **From left to right, [*Iris setosa*](https://commons.wikimedia.org/w/index.php?curid=170298) (by [Radomil](https://commons.wikimedia.org/wiki/User:Radomil), CC BY-SA 3.0), @@ -90,11 +91,10 @@ a number. Here's the representation scheme: A **model** is the relationship between features and the label. For the Iris problem, the model defines the relationship -between the sepal and petal measurements and the Iris species. -Some simple models can be described with a few lines of algebra; -more complex machine learning models -contain such a large number of interlacing mathematical functions and -parameters that they become hard to summarize mathematically. +between the sepal and petal measurements and the predicted Iris species. Some +simple models can be described with a few lines of algebra, but complex machine +learning models have a large number of parameters that are difficult to +summarize. Could you determine the relationship between the four features and the Iris species *without* using machine learning? That is, could you use @@ -188,6 +188,7 @@ provides a programming stack consisting of multiple API layers:
+ **The TensorFlow Programming Environment.**

 

@@ -331,7 +332,7 @@ interpret data is such a rich topic that we devote an entire From a code perspective, you build a list of `feature_column` objects by calling functions from the @{tf.feature_column} module. Each object describes an input to the model. To tell the model to interpret data as a floating-point value, -call @{tf.feature_column.numeric_column). In `premade_estimator.py`, all +call @{tf.feature_column.numeric_column}. In `premade_estimator.py`, all four features should be interpreted as literal floating-point values, so the code to create a feature column looks as follows: @@ -357,7 +358,7 @@ my_feature_columns = [ ### Select the type of model -We need the select the kind of model that will be trained. +We need to select the kind of model that will be trained. Lots of model types exist; picking the ideal type takes experience. We've selected a neural network to solve the Iris problem. [**Neural networks**](https://developers.google.com/machine-learning/glossary/#neural_network) @@ -380,6 +381,7 @@ fully connected neural network consisting of three hidden layers:
+ **A neural network with three hidden layers.**

 

@@ -568,6 +570,7 @@ of 0.5. The following suggests a more effective model: 5.5 2.5 4.0 1.3 1 1 + **A model that is 80% accurate.**

 

@@ -655,7 +658,9 @@ calls as follows: ```python predictions = classifier.predict( - input_fn=lambda:eval_input_fn(predict_x, batch_size=args.batch_size)) + input_fn=lambda:eval_input_fn(predict_x, + labels=None, + batch_size=args.batch_size)) ``` As with the `evaluate` method, our `predict` method also gathers examples @@ -700,7 +705,7 @@ for pred_dict, expec in zip(predictions, expected): class_id = pred_dict['class_ids'][0] probability = pred_dict['probabilities'][class_id] - print(template.format(SPECIES[class_id], 100 * probability, expec)) + print(template.format(iris_data.SPECIES[class_id], 100 * probability, expec)) ``` Running the program yields the following output: diff --git a/tensorflow/docs_src/get_started/premade_estimators.md b/tensorflow/docs_src/get_started/premade_estimators.md index dbc35065abf22c88c325c4edc370b6da91c4df5b..6bffd2e065548a42eb726df34542ecc7480ad38d 100644 --- a/tensorflow/docs_src/get_started/premade_estimators.md +++ b/tensorflow/docs_src/get_started/premade_estimators.md @@ -2,37 +2,39 @@ # Getting Started with TensorFlow This document introduces the TensorFlow programming environment and shows you -how to write the Iris classification problem in TensorFlow. +how to solve the Iris classification problem in TensorFlow. -Prior to reading this document, do the following: +## Prerequisites + +Prior to using the sample code in this document, you'll need to do the +following: * @{$install$Install TensorFlow}. * If you installed TensorFlow with virtualenv or Anaconda, activate your TensorFlow environment. -* To keep the data import simple, our Iris example uses Pandas. You can - install Pandas with: +* Install or upgrade pandas by issuing the following command: - `pip install pandas` + pip install pandas ## Getting the sample code -Take the following steps to get the sample code for this program: +Take the following steps to get the sample code we'll be going through: -1. Clone the TensorFlow Models repository from github by entering the following +1. Clone the TensorFlow Models repository from GitHub by entering the following command: - `git clone https://github.com/tensorflow/models` + git clone https://github.com/tensorflow/models 1. Change directory within that branch to the location containing the examples used in this document: - `cd models/samples/core/get_started/` + cd models/samples/core/get_started/ The program described in this document is [`premade_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py). This program uses [`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py) -To fetch its training data. +to fetch its training data. ### Running the program @@ -45,7 +47,7 @@ python premade_estimator.py The program should output training logs followed by some predictions against the test set. For example, the first line in the following output shows that the model thinks there is a 99.6% chance that the first example in the test -set is a Setosa. Since the test set `expected "Setosa"`, this appears to be +set is a Setosa. Since the test set expected Setosa, this appears to be a good prediction. ``` None @@ -61,9 +63,9 @@ If the program generates errors instead of answers, ask yourself the following questions: * Did you install TensorFlow properly? -* Are you using the correct version of tensorflow? +* Are you using the correct version of TensorFlow? * Did you activate the environment you installed TensorFlow in? (This is - only relevant in certain installation environments.) + only relevant in certain installation mechanisms.) ## The programming stack @@ -74,18 +76,15 @@ provides a programming stack consisting of multiple API layers:
-
-The TensorFlow Programming Environment -
We strongly recommend writing TensorFlow programs with the following APIs: -* @{tf.estimator$Estimators}, which represent a complete model. +* @{$programmers_guide/estimators$Estimators}, which represent a complete model. The Estimator API provides methods to train the model, to judge the model's accuracy, and to generate predictions. * @{$get_started/datasets_quickstart$Datasets}, which build a data input pipeline. The Dataset API has methods to load and manipulate data, and feed - it into your model. The Datasets API meshes well with the Estimators API. + it into your model. The Dataset API meshes well with the Estimators API. ## Classifying irises: an overview @@ -99,6 +98,7 @@ classifies Iris flowers into three different species based on the size of their alt="Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor" src="../images/iris_three_species.jpg">
+ **From left to right, [*Iris setosa*](https://commons.wikimedia.org/w/index.php?curid=170298) (by [Radomil](https://commons.wikimedia.org/wiki/User:Radomil), CC BY-SA 3.0), @@ -120,7 +120,7 @@ individual Iris flowers: * petal length * petal width -Our model will represent these features as float32 numerical data. +Our model will represent these features as `float32` numerical data. The label identifies the Iris species, which must be one of the following: @@ -154,9 +154,6 @@ The following figure illustrates the features, hidden layers, and predictions alt="A diagram of the network architecture: Inputs, 2 hidden layers, and outputs" src="../images/custom_estimators/full_network.png"> -
-The Model. -
### Inference @@ -174,12 +171,12 @@ example is an Iris Versicolor. ## Overview of programming with Estimators -An Estimator is TensorFlow's high level representation of a complete model. It +An Estimator is TensorFlow's high-level representation of a complete model. It handles the details of initialization, logging, saving and restoring, and many other features so you can concentrate on your model. For more details see @{$programmers_guide/estimators}. -An "Estimator" is any class derived from @{tf.estimator.Estimator}. TensorFlow +An Estimator is any class derived from @{tf.estimator.Estimator}. TensorFlow provides a collection of [pre-made Estimators](https://developers.google.com/machine-learning/glossary/#pre-made_Estimator) (for example, `LinearRegressor`) to implement common ML algorithms. Beyond @@ -199,7 +196,7 @@ following tasks: * Call one or more methods on the Estimator object, passing the appropriate input function as the source of the data. -Let's see how those tasks are implemented in Iris. +Let's see how those tasks are implemented for Iris classification. ## Create input functions @@ -209,17 +206,30 @@ evaluating, and prediction. An **input function** is a function that returns a @{tf.data.Dataset} object which outputs the following two-element tuple: -* "features" - A Python dictionary in which: +* [`features`](https://developers.google.com/machine-learning/glossary/#feature) - A Python dictionary in which: * Each key is the name of a feature. * Each value is an array containing all of that feature's values. -* "label" - An array containing the values of the +* `label` - An array containing the values of the [label](https://developers.google.com/machine-learning/glossary/#label) for every example. -Your input function may generate the "features" dictionary and "label" list any -way you like. However, we recommend using TensorFlow's @{tf.data.Dataset} API, -which can deftly parse all sorts of data. At a high-level, -the @{tf.data.Dataset} API consists of the following classes: +Just to demonstrate the format of the input function, here's a simple +implementation: + +```python +def input_evaluation_set(): + features = {'SepalLength': np.array([6.4, 5.0]), + 'SepalWidth': np.array([2.8, 2.3]), + 'PetalLength': np.array([5.6, 3.3]), + 'PetalWidth': np.array([2.2, 1.0])} + labels = np.array([2, 1]) + return features, labels +``` + +Your input function may generate the `features` dictionary and `label` list any +way you like. However, we recommend using TensorFlow's Dataset API, which can +parse all sorts of data. At a high level, the Dataset API consists of the +following classes:
+Where the individual members are: -Where: - -* Dataset: Base class containing methods to create and transform datasets. Also - allows you to initialize a dataset from data in memory, or from a Python - generator. -* TextLineDataset: Reads lines from text files. -* TFRecordDataset: Reads records from TFRecord files. -* FixedLengthRecordDataset: Reads fixed size records from binary files. -* Iterator: Provides a way to access one data set element at a time. +* `Dataset` - Base class containing methods to create and transform + datasets. Also allows you to initialize a dataset from data in memory, or from + a Python generator. +* `TextLineDataset` - Reads lines from text files. +* `TFRecordDataset` - Reads records from TFRecord files. +* `FixedLengthRecordDataset` - Reads fixed size records from binary files. +* `Iterator` - Provides a way to access one data set element at a time. The Dataset API can handle a lot of common cases for you. For example, using the Dataset API, you can easily read in records from a large collection of files in parallel and join them into a single stream. -To keep things simple in this example we are going to load the data with pandas, -and build our input pipeline from this in-memory data. +To keep things simple in this example we are going to load the data with +[pandas](https://pandas.pydata.org/), and build our input pipeline from this +in-memory data. Here is the input function used for training in this program, which is available in [`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py): @@ -258,9 +268,9 @@ def train_input_fn(features, labels, batch_size): return dataset.shuffle(1000).repeat().batch(batch_size) ``` -## Define the Feature Columns +## Define the feature columns -A [**Feature Column**](https://developers.google.com/machine-learning/glossary/#feature_columns) +A [**feature column**](https://developers.google.com/machine-learning/glossary/#feature_columns) is an object describing how the model should use raw input data from the features dictionary. When you build an Estimator model, you pass it a list of feature columns that describes each of the features you want the model to use. @@ -270,7 +280,7 @@ to the model. For Iris, the 4 raw features are numeric values, so we'll build a list of feature columns to tell the Estimator model to represent each of the four features as 32-bit floating-point values. Therefore, the code to create the -Feature Column is simply: +feature column is: ```python # Feature columns describe how to use the input. @@ -279,29 +289,29 @@ for key in train_x.keys(): my_feature_columns.append(tf.feature_column.numeric_column(key=key)) ``` -Feature Columns can be far more sophisticated than those we're showing here. -We detail feature columns @{$get_started/feature_columns$later on} in -getting started. +Feature columns can be far more sophisticated than those we're showing here. We +detail feature columns @{$get_started/feature_columns$later on} in our Getting +Started guide. Now that we have the description of how we want the model to represent the raw features, we can build the estimator. -## Instantiate an Estimator +## Instantiate an estimator The Iris problem is a classic classification problem. Fortunately, TensorFlow provides several pre-made classifier Estimators, including: -* @{tf.estimator.DNNClassifier}—for deep models that perform multi-class +* @{tf.estimator.DNNClassifier} for deep models that perform multi-class classification. -* @{tf.estimator.DNNLinearCombinedClassifier}—for wide-n-deep models. -* @{tf.estimator.LinearClassifier}— for classifiers based on linear models. +* @{tf.estimator.DNNLinearCombinedClassifier} for wide & deep models. +* @{tf.estimator.LinearClassifier} for classifiers based on linear models. For the Iris problem, `tf.estimator.DNNClassifier` seems like the best choice. Here's how we instantiated this Estimator: ```python -# Build 2 hidden layer DNN with 10, 10 units respectively. +# Build a DNN with 2 hidden layers and 10 nodes in each hidden layer. classifier = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, # Two hidden layers of 10 nodes each. @@ -363,7 +373,7 @@ Test set accuracy: 0.967 We now have a trained model that produces good evaluation results. We can now use the trained model to predict the species of an Iris flower -based on some unlabeled measurments. As with training and evaluation, we make +based on some unlabeled measurements. As with training and evaluation, we make predictions using a single function call: ```python diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md index ba1a4118aece1f42822f7cd084feed50c5cf6ebb..f3620cf687359ebc4abfc3365beb3da694ec7baf 100644 --- a/tensorflow/docs_src/install/install_c.md +++ b/tensorflow/docs_src/install/install_c.md @@ -38,7 +38,7 @@ enable TensorFlow for C: OS="linux" # Change to "darwin" for macOS TARGET_DIRECTORY="/usr/local" curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.5.0-rc1.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.6.0-rc1.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md index 87cc647317a11fab0d9d0219dd5764af3dcb2ecc..4bf4bacaecb88c9335cbe5ccbd7e6557cd21aca6 100644 --- a/tensorflow/docs_src/install/install_go.md +++ b/tensorflow/docs_src/install/install_go.md @@ -38,7 +38,7 @@ steps to install this library and enable TensorFlow for Go: TF_TYPE="cpu" # Change to "gpu" for GPU support TARGET_DIRECTORY='/usr/local' curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.5.0-rc1.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.6.0-rc1.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 37e109a6e4bdee97ad02bc7aceb2c0c24e1ec7ec..a63a6c7ebe29c70c547336831ee12e51d49f851e 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -36,7 +36,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs: org.tensorflow tensorflow - 1.5.0-rc1 + 1.6.0-rc1 ``` @@ -65,7 +65,11 @@ As an example, these steps will create a Maven project that uses TensorFlow: org.tensorflow tensorflow - 1.5.0-rc1 +<<<<<<< HEAD + 1.6.0-rc1 +======= + 1.6.0-rc0 +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d @@ -123,12 +127,20 @@ instead: org.tensorflow libtensorflow - 1.5.0-rc1 +<<<<<<< HEAD + 1.6.0-rc1 +======= + 1.6.0-rc0 +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d org.tensorflow libtensorflow_jni_gpu - 1.5.0-rc1 +<<<<<<< HEAD + 1.6.0-rc1 +======= + 1.6.0-rc0 +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d ``` @@ -147,7 +159,11 @@ refer to the simpler instructions above instead. Take the following steps to install TensorFlow for Java on Linux or macOS: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.5.0-rc1.jar), +<<<<<<< HEAD + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc1.jar), +======= + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc0.jar), +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d which is the TensorFlow Java Archive (JAR). 2. Decide whether you will run TensorFlow for Java on CPU(s) only or with @@ -166,7 +182,11 @@ Take the following steps to install TensorFlow for Java on Linux or macOS: OS=$(uname -s | tr '[:upper:]' '[:lower:]') mkdir -p ./jni curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.5.0-rc1.tar.gz" | +<<<<<<< HEAD + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.6.0-rc1.tar.gz" | +======= + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.6.0-rc0.tar.gz" | +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d tar -xz -C ./jni ### Install on Windows @@ -174,10 +194,17 @@ Take the following steps to install TensorFlow for Java on Linux or macOS: Take the following steps to install TensorFlow for Java on Windows: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.5.0-rc1.jar), +<<<<<<< HEAD + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc1.jar), which is the TensorFlow Java Archive (JAR). 2. Download the following Java Native Interface (JNI) file appropriate for - [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.5.0-rc1.zip). + [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.6.0-rc1.zip). +======= + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc0.jar), + which is the TensorFlow Java Archive (JAR). + 2. Download the following Java Native Interface (JNI) file appropriate for + [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.6.0-rc0.zip). +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d 3. Extract this .zip file. @@ -225,7 +252,11 @@ must be part of your `classpath`. For example, you can include the downloaded `.jar` in your `classpath` by using the `-cp` compilation flag as follows: -
javac -cp libtensorflow-1.5.0-rc1.jar HelloTF.java
+<<<<<<< HEAD +
javac -cp libtensorflow-1.6.0-rc1.jar HelloTF.java
+======= +
javac -cp libtensorflow-1.6.0-rc0.jar HelloTF.java
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d ### Running @@ -239,11 +270,19 @@ two files are available to the JVM: For example, the following command line executes the `HelloTF` program on Linux and macOS X: -
java -cp libtensorflow-1.5.0-rc1.jar:. -Djava.library.path=./jni HelloTF
+<<<<<<< HEAD +
java -cp libtensorflow-1.6.0-rc1.jar:. -Djava.library.path=./jni HelloTF
+ +And the following command line executes the `HelloTF` program on Windows: + +
java -cp libtensorflow-1.6.0-rc1.jar;. -Djava.library.path=jni HelloTF
+======= +
java -cp libtensorflow-1.6.0-rc0.jar:. -Djava.library.path=./jni HelloTF
And the following command line executes the `HelloTF` program on Windows: -
java -cp libtensorflow-1.5.0-rc1.jar;. -Djava.library.path=jni HelloTF
+
java -cp libtensorflow-1.6.0-rc0.jar;. -Djava.library.path=jni HelloTF
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d If the program prints Hello from version, you've successfully installed TensorFlow for Java and are ready to use the API. If the program diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index 03f12dff08cb3483666df4b8553b97fc1c4f34f9..681b45423f29f9f43da69e2d4d516a11f66636a3 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -31,13 +31,13 @@ If you are installing TensorFlow with GPU support using one of the mechanisms described in this guide, then the following NVIDIA software must be installed on your system: - * CUDA® Toolkit 8.0. For details, see + * CUDA® Toolkit 9.0. For details, see [NVIDIA's documentation](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/#axzz4VZnqTJ2A). Ensure that you append the relevant Cuda pathnames to the `LD_LIBRARY_PATH` environment variable as described in the NVIDIA documentation. - * The NVIDIA drivers associated with CUDA Toolkit 8.0. - * cuDNN v6.0. For details, see + * The NVIDIA drivers associated with CUDA Toolkit 9.0. + * cuDNN v7.0. For details, see [NVIDIA's documentation](https://developer.nvidia.com/cudnn). Ensure that you create the `CUDA_HOME` environment variable as described in the NVIDIA documentation. @@ -188,7 +188,7 @@ Take the following steps to install TensorFlow with Virtualenv: Virtualenv environment:
(tensorflow)$ pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp34-cp34m-linux_x86_64.whl
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl If you encounter installation problems, see [Common Installation Problems](#common_installation_problems). @@ -293,7 +293,11 @@ take the following steps:
      $ sudo pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp34-cp34m-linux_x86_64.whl
+<<<<<<< HEAD
+     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
+=======
+     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
      
If this step fails, see @@ -480,7 +484,11 @@ Take the following steps to install TensorFlow in an Anaconda environment:
      (tensorflow)$ pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp34-cp34m-linux_x86_64.whl
+<<<<<<< HEAD + https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl +======= + https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d @@ -648,14 +656,22 @@ This section documents the relevant values for Linux installations. CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp27-none-linux_x86_64.whl
+<<<<<<< HEAD
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp27-none-linux_x86_64.whl
+=======
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp27-none-linux_x86_64.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.5.0rc1-cp27-none-linux_x86_64.whl
+<<<<<<< HEAD
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp27-none-linux_x86_64.whl
+=======
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp27-none-linux_x86_64.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
 
Note that GPU support requires the NVIDIA hardware and software described in @@ -667,14 +683,22 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp34-cp34m-linux_x86_64.whl
+<<<<<<< HEAD
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
+=======
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.5.0rc1-cp34-cp34m-linux_x86_64.whl
+<<<<<<< HEAD
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
+=======
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
 
Note that GPU support requires the NVIDIA hardware and software described in @@ -686,14 +710,22 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp35-cp35m-linux_x86_64.whl
+<<<<<<< HEAD
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp35-cp35m-linux_x86_64.whl
+=======
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp35-cp35m-linux_x86_64.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.5.0rc1-cp35-cp35m-linux_x86_64.whl
+<<<<<<< HEAD
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp35-cp35m-linux_x86_64.whl
+=======
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp35-cp35m-linux_x86_64.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
 
@@ -705,14 +737,22 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0rc1-cp36-cp36m-linux_x86_64.whl
+<<<<<<< HEAD
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp36-cp36m-linux_x86_64.whl
+=======
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp36-cp36m-linux_x86_64.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.5.0rc1-cp36-cp36m-linux_x86_64.whl
+<<<<<<< HEAD
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp36-cp36m-linux_x86_64.whl
+=======
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp36-cp36m-linux_x86_64.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
 
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index e13ddadab7b0e2ed96bdaf5600b3479a4b5eec55..a2f484ebf8867620630fa815fee55c06585f4fa1 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -115,7 +115,7 @@ Take the following steps to install TensorFlow with Virtualenv: TensorFlow in the active Virtualenv is as follows:
 $ pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py2-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl If you encounter installation problems, see [Common Installation Problems](#common-installation-problems). @@ -238,7 +238,11 @@ take the following steps: issue the following command:
 $ sudo pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py2-none-any.whl 
+<<<<<<< HEAD + https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl +======= + https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d If the preceding command fails, see [installation problems](#common-installation-problems). @@ -347,7 +351,11 @@ Take the following steps to install TensorFlow in an Anaconda environment: TensorFlow for Python 2.7:
 (targetDirectory)$ pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py2-none-any.whl
+<<<<<<< HEAD + https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl +======= + https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py2-none-any.whl +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d @@ -520,7 +528,11 @@ This section documents the relevant values for Mac OS installations.
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py2-none-any.whl
+<<<<<<< HEAD
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl
+=======
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py2-none-any.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
 
@@ -528,5 +540,9 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py2-none-a
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py3-none-any.whl
+<<<<<<< HEAD
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl
+=======
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl
+>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d
 
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md index f494cc7a7c0575fd7950b6fe28d7671e1f25725f..d01be212601d5bd2941fe4852498a4fb4d4403d6 100644 --- a/tensorflow/docs_src/install/install_sources.md +++ b/tensorflow/docs_src/install/install_sources.md @@ -133,7 +133,7 @@ The following NVIDIA hardware must be installed on your system: The following NVIDIA software must be installed on your system: - * NVIDIA's Cuda Toolkit (>= 7.0). We recommend version 8.0. + * NVIDIA's Cuda Toolkit (>= 7.0). We recommend version 9.0. For details, see [NVIDIA's documentation](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/#axzz4VZnqTJ2A). Ensure that you append the relevant Cuda pathnames to the @@ -221,7 +221,7 @@ problem, do either of the following: * Download Xcode 7.2 and select it as your default by issuing the following command: -
 $ sudo xcode-select -s /Application/Xcode-7.2/Xcode.app
+
 $ sudo xcode-select -s /Applications/Xcode-7.2/Xcode.app
**NOTE:** Your system must fulfill the NVIDIA software requirements described in one of the following documents: @@ -272,8 +272,6 @@ Found possible Python library paths: Please input the desired Python library path to use. Default is [/usr/lib/python2.7/dist-packages] Using python library path: /usr/local/lib/python2.7/dist-packages -Do you wish to build TensorFlow with MKL support? [y/N] -No MKL support will be enabled for TensorFlow Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]: Do you wish to use jemalloc as the malloc implementation? [Y/n] jemalloc enabled @@ -291,11 +289,11 @@ Do you wish to build TensorFlow with CUDA support? [y/N] Y CUDA support will be enabled for TensorFlow Do you want to use clang as CUDA compiler? [y/N] nvcc will be used as CUDA compiler -Please specify the Cuda SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 8.0]: 8.0 -Please specify the location where CUDA 8.0 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: +Please specify the Cuda SDK version you want to use, e.g. 7.0. [Leave empty to default to CUDA 9.0]: 9.0 +Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: Please specify which gcc should be used by nvcc as the host compiler. [Default is /usr/bin/gcc]: -Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 6.0]: 6 -Please specify the location where cuDNN 6 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: +Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: 7 +Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: Please specify a list of comma-separated Cuda compute capabilities you want to build with. You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus. Please note that each additional compute capability significantly increases your build time and binary size. @@ -361,10 +359,10 @@ Invoke `pip install` to install that pip package. The filename of the `.whl` file depends on your platform. For example, the following command will install the pip package -for TensorFlow 1.5.0rc1 on Linux: +for TensorFlow 1.6.0rc1 on Linux:
-$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.5.0rc1-py2-none-any.whl
+$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.6.0rc1-py2-none-any.whl
 
## Validate your installation @@ -395,7 +393,7 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started/get_started$Getting Started with +If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}. If the system outputs an error message instead of a greeting, see [Common @@ -462,9 +460,15 @@ Stack Overflow and specify the `tensorflow` tag. **Linux** - - - +<<<<<<< HEAD + + +======= + + +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d + + @@ -480,7 +484,12 @@ Stack Overflow and specify the `tensorflow` tag. **Mac**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.5.0-rc1CPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.0N/AN/A
tensorflow_gpu-1.5.0-rc1GPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.079
tensorflow-1.6.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.0N/AN/A
tensorflow_gpu-1.6.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.079
tensorflow-1.6.0rc0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.0N/AN/A
tensorflow_gpu-1.6.0rc0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.079
tensorflow-1.5.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.0N/AN/A
tensorflow_gpu-1.5.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.079
tensorflow-1.4.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.5.4N/AN/A
tensorflow_gpu-1.4.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.5.468
tensorflow-1.3.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.4.5N/AN/A
- +<<<<<<< HEAD + +======= + +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d + @@ -493,8 +502,15 @@ Stack Overflow and specify the `tensorflow` tag. **Windows**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.5.0-rc1CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
tensorflow-1.6.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
tensorflow-1.6.0rc0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
tensorflow-1.5.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
tensorflow-1.4.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.5.4N/AN/A
tensorflow-1.3.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.4.5N/AN/A
tensorflow-1.2.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.4.5N/AN/A
- - +<<<<<<< HEAD + + +======= + + +>>>>>>> 943a21fcdc1c48c8e95d872911ad52b13f0c037d + + diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md index 8d0eb7966fdf17be1c259627a64803f0a392943a..dedf485f93d6fd6a8ce7b4465548cc998d307daa 100644 --- a/tensorflow/docs_src/install/install_windows.md +++ b/tensorflow/docs_src/install/install_windows.md @@ -30,13 +30,13 @@ If you are installing TensorFlow with GPU support using one of the mechanisms described in this guide, then the following NVIDIA software must be installed on your system: - * CUDA® Toolkit 8.0. For details, see + * CUDA® Toolkit 9.0. For details, see [NVIDIA's documentation](http://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/) Ensure that you append the relevant Cuda pathnames to the `%PATH%` environment variable as described in the NVIDIA documentation. - * The NVIDIA drivers associated with CUDA Toolkit 8.0. - * cuDNN v6.0. For details, see + * The NVIDIA drivers associated with CUDA Toolkit 9.0. + * cuDNN v7.0. For details, see [NVIDIA's documentation](https://developer.nvidia.com/cudnn). Note that cuDNN is typically installed in a different location from the other CUDA DLLs. Ensure that you add the directory where you installed @@ -47,7 +47,7 @@ installed on your system: If you have a different version of one of the preceding packages, please change to the specified versions. In particular, the cuDNN version -must match exactly: TensorFlow will not load if it cannot find `cuDNN64_6.dll`. +must match exactly: TensorFlow will not load if it cannot find `cuDNN64_7.dll`. To use a different version of cuDNN, you must build from source. ## Determine how to install TensorFlow @@ -153,7 +153,7 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started/get_started$Getting Started with +If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}. If the system outputs an error message instead of a greeting, see [Common diff --git a/tensorflow/docs_src/install/leftnav_files b/tensorflow/docs_src/install/leftnav_files index 0e8b5ae7a17eb43cffc76d40692c4f0042de44af..e523e06f67aad508238ee0965f34ebe16c77bf90 100644 --- a/tensorflow/docs_src/install/leftnav_files +++ b/tensorflow/docs_src/install/leftnav_files @@ -1,16 +1,16 @@ index.md ### Python -install_linux.md -install_mac.md -install_windows.md -install_sources.md +install_linux.md: Ubuntu +install_mac.md: MacOS +install_windows.md: Windows +install_sources.md: From source >>> migration.md ### Other Languages -install_java.md -install_go.md -install_c.md +install_java.md: Java +install_go.md: Go +install_c.md: C diff --git a/tensorflow/docs_src/mobile/mobile_intro.md b/tensorflow/docs_src/mobile/mobile_intro.md index 17dbf1c3e6ad89768529864ba884274a51b3dfb2..69b63ae7d22ced9fd0299f17d1ae2d614c9a6be7 100644 --- a/tensorflow/docs_src/mobile/mobile_intro.md +++ b/tensorflow/docs_src/mobile/mobile_intro.md @@ -235,7 +235,7 @@ TensorFlow [on Github](https://github.com/tensorflow/models) that you can look through. Lean towards the simplest model you can find, and try to get started as soon as you have even a small amount of labelled data, since you’ll get the best results when you’re able to iterate quickly. The shorter the time it takes to -try training a model and running it in s real application, the better overall +try training a model and running it in its real application, the better overall results you’ll see. It’s common for an algorithm to get great training accuracy numbers but then fail to be useful within a real application because there’s a mismatch between the dataset and real usage. Prototype end-to-end usage as soon diff --git a/tensorflow/docs_src/performance/datasets_performance.md b/tensorflow/docs_src/performance/datasets_performance.md index 4f95e17c3598c23645fad07441c267266e5ef34e..46b43b7673c561679e89fff0ae738b0e751fcff5 100644 --- a/tensorflow/docs_src/performance/datasets_performance.md +++ b/tensorflow/docs_src/performance/datasets_performance.md @@ -92,11 +92,11 @@ transform the data. Without pipelining, the CPU and the GPU/TPU sit idle much of the time: -![without pipelining](https://www.tensorflow.org/images/datasets_without_pipelining.png) +![without pipelining](/images/datasets_without_pipelining.png) With pipelining, idle time diminishes significantly: -![with pipelining](https://www.tensorflow.org/images/datasets_with_pipelining.png) +![with pipelining](/images/datasets_with_pipelining.png) The `tf.data` API provides a software pipelining mechanism through the @{tf.data.Dataset.prefetch} transformation, which can be used to decouple the @@ -139,7 +139,7 @@ multiple CPU cores. To make this possible, the `map` transformation provides the the following diagram illustrates the effect of setting `num_parallel_calls=2` to the `map` transformation: -![parallel map](https://www.tensorflow.org/images/datasets_parallel_map.png) +![parallel map](/images/datasets_parallel_map.png) Choosing the best value for the `num_parallel_calls` argument depends on your hardware, characteristics of your training data (such as its size and shape), @@ -213,7 +213,7 @@ number of datasets to overlap can be specified by the `cycle_length` argument. The following diagram illustrates the effect of supplying `cycle_length=2` to the `parallel_interleave` transformation: -![parallel io](https://www.tensorflow.org/images/datasets_parallel_io.png) +![parallel io](/images/datasets_parallel_io.png) To apply this change to our running example, change: diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md index 10e7ad7ada533c8da5e5b871b38809b90604685e..cd47fc2803bc1429d28bd0ae4c2ad68e632a6f03 100644 --- a/tensorflow/docs_src/performance/performance_guide.md +++ b/tensorflow/docs_src/performance/performance_guide.md @@ -498,7 +498,7 @@ For TensorFlow source versions after 1.3.0: ```bash ./configure # Pick the desired options -bazel build --config=mkl -c opt //tensorflow/tools/pip_package:build_pip_package +bazel build --config=mkl --config=opt //tensorflow/tools/pip_package:build_pip_package ``` diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 1e9b8b35db65ef19a4bcb607b98af1e1de4e6d5b..daa2d4767c20d30594ecd90911bfd1771053c2e9 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -252,7 +252,7 @@ Clamps an operand to within the range between a minimum and maximum value. Given an operand and minimum and maximum values, returns the operand if it is in the range between the minimum and maximum, else returns the minimum value if the operand is below this range or the maximum value if the operand is above this -range. That is, `clamp(a, x, b) = max(min(a, x), b)`. +range. That is, `clamp(a, x, b) = min(max(a, x), b)`. All three arrays must be the same shape. Alternately, as a restricted form of [broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`. @@ -717,6 +717,7 @@ in 'dimension_numbers'. Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need to be the same, but must be listed in the same order in both 'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes. +There must be exactly one contracting dimension on both 'lhs' and 'rhs'. Example with contracting dimension numbers: @@ -736,8 +737,9 @@ DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, ``` Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same -dimension number, must be listed in the same order in both arrays, and must -have the same dimension sizes. +dimension number, must be listed in the same order in both arrays, must +have the same dimension sizes, and must be ordered before contracting and +non-contracting/non-batch dimension numbers. Example with batch dimension numbers (batch size 2, 2x2 matrices): @@ -769,6 +771,10 @@ DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0}, | [b0, m, k] `dot` [b0, k, n] | [b0, m, n] | batch matmul | | [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n] | batch matmul | +It follows that the resulting dimension number starts with the batch dimension, +then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' +non-contracting/non-batch dimension. + ## DynamicSlice See also @@ -1021,6 +1027,194 @@ Arguments | Type | Semantics The function is applied to each element in the `operand` array, resulting in an array with the same shape. It is allowed for `operand` to be a scalar (rank 0). +## Gather + +The XLA gather operation stitches together several slices (each slice at a +potentially different runtime offset) of an input tensor into an output tensor. + +### General Semantics + +See also +[`ComputationBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h). +For a more intuitive description, see the "Informal Description" section below. + + `gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)` + +|Arguments | Type | Semantics | +|----------------- | ----------------------- | --------------------------------| +|`operand` | `ComputationDataHandle` | The tensor we’re gathering | +: : : from. : +|`gather_indices` | `ComputationDataHandle` | Tensor containing the starting | +: : : indices of the slices we're : +: : : we're stitching together into : +: : : the output tensor. : +|`output_window_dims` | `ArraySlice` | The set of dimensions in the | +: : : output shape that are _window : +: : : dimensions_ (defined below). : +: : : Not all window dimensions may : +: : : be present in the output shape. : +|`elided_window_dims` | `ArraySlice` | The set of _window dimensions_ | +: : : that are not present in the output shape. : +: : : `window_bounds[i]` must be `1` for all `i` : +: : : in `elided_window_dims`. : +|`window_bounds` | `ArraySlice` | `window_bounds[i]` is the bounds | +: : : for window dimension `i`. This includes : +: : : both the window dimensions that are : +: : : explicitly part of the output shape (via : +: : : `output_window_dims`) and the window : +: : : dimensions that are elided (via : +: : : `elided_window_dims`). : +|`gather_dims_to_operand_dims` | `ArraySlice` | A dimension map (the | +: : : array is interpreted as mapping `i` to : +: : : `gather_dims_to_operand_dims[i]`) from : +: : : the gather indices in `gather_indices` to : +: : : the operand index space. It has to be : +: : : one-to-one and total. : + +If `gather_indices` is a vector with `N` elements then we implicitly reshape it +to a tensor of shape `[N,1]` before proceeding. + +For every index `Out` in the output tensor, we compute two things (more +precisely described later): + + - An index into the first `gather_indices.rank` - `1` dimensions of + `gather_indices`, which gives us a starting index of a slice, _operand + slice_, in the operand tensor. + + - A _window index_ that has the same rank as the operand. This index is + composed of the values in `Out` at dimensions `output_window_dims`, embedded + with zeroes according to `elided_window_dims`. + +The _window index_ is the relative index of the element in _operand slice_ that +should be present in the output at index `Out`. + +The output is a tensor of rank `output_window_dims.size` + `gather_indices.rank` +- `1`. Additionally, as a shorthand, we define `output_gather_dims` of type +`ArraySlice` as the set of dimensions in the output shape but not in +`output_window_dims`, in ascending order. E.g. if the output tensor has rank 5, +`output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`, `3`} + +The bounds for the output tensor along dimension `i` is computed as follows: + + 1. If `i` is present in `output_gather_dims` (i.e. is equal to + `output_gather_dims[k]` for some `k`) then we pick the corresponding + dimension bounds out of `gather_indices.shape` (i.e. pick + `gather_indices.shape.dims[k]`). + 2. If `i` is present in `output_window_dims` (i.e. equal to + `output_window_dims[k]` for some `k`) then we pick the corresponding bound + out of `window_bounds` after accounting for `elided_window_dims` (i.e. we + pick `adjusted_window_bounds[k]` where `adjusted_window_bounds` is + `window_bounds` with the bounds at indices `elided_window_dims` removed). + +The operand index `In` corresponding to an output index `Out` is computed as +follows: + + 1. Let `G` = { `Out`[`k`] for `k` in `output_gather_dims` }. Use `G` to slice + out vector `S` such that `S`[`i`] = `gather_indices`[`G`, `i`]. + 2. Create an index, `S``in`, into `operand` using `S` by scattering + `S` using the `gather_dims_to_operand_dims` map (`S``in` is the + starting indices for _operand slice_ mentioned above.). More precisely: + 1. `S``in`[`gather_dims_to_operand_dims`[`k`]] = `S`[`k`] if `k` < + `gather_dims_to_operand_dims.size`. + 2. `S``in`[`_`] = `0` otherwise. + 3. Create an index `W``in` into `operand` by scattering the indices + at the output window dimensions in `Out` according to + the `elided_window_dims` set (`W``in` is the _window index_ + mentioned above). More precisely: + 1. `W``in`[`window_dims_to_operand_dims`(`k`)] = `Out`[`k`] if + `k` < `output_window_dims.size` (`window_dims_to_operand_dims` is + defined below). + 2. `W``in`[`_`] = `0` otherwise. + 4. `In` is `W``in` + `S``in` where + is element-wise + addition. + +`window_dims_to_operand_dims` is the monotonic function with domain [`0`, +`output_window_dims.size`) and range [`0`, `operand.rank`) \ +`elided_window_dims`. So if, e.g., `output_window_dims.size` is `4`, +`operand.rank` is `6` and `elided_window_dims` is {`0`, `2`} then +`window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}. + +### Informal Description + +To get an intuition on how all of the above fits together, let's look at an +example that gathers 5 slices of shape `[8,6]` from a `[16,11]` tensor. The +position of a slice into the `[16,11]` tensor can be represented as an index +vector of shape `S64[2]`, so the set of 5 positions can be represented as a +`S64[5,2]` tensor. + +The behavior of the gather operation can then be depicted as an index +transformation that takes [`G`,`W``0`,`W``1`], an index in +the output shape, and maps it to an element in the input tensor in the following +way: + +
+ +
+ +We first select an (`X`,`Y`) vector from the gather indices tensor using `G`. +The element in the output tensor at index +[`G`,`W``0`,`W``1`] is then the element in the input +tensor at index [`X`+`W``0`,`Y`+`W``1`]. + +`window_bounds` is `[8,6]`, which decides the range of W`0` and +W`1`, and this in turn decides the bounds of the slice. + +This gather operation acts as a batch dynamic slice with `G` as the batch +dimension. + +The gather indices may be multidimensional. For instance, a more general +version of the example above using a "gather indices" tensor of shape `[4,5,2]` +would translate indices like this: + +
+ +
+ +Again, this acts as a batch dynamic slice `G``0` and +`G``1` as the batch dimensions. The window bounds are still `[8,6]`. + +The gather operation in XLA generalizes the informal semantics outlined above in +the following ways: + + 1. We can configure which dimensions in the output shape are the window + dimensions (dimensions containing `W``0`, `W``1` in + the last example). The output gather dimensions (dimensions containing + `G``0`, `G``1` in the last example) are defined to be + the output dimensions that are not window dimensions. + + 2. The number of output window dimensions explicitly present in the output + shape may be smaller than the input rank. These "missing" dimensions, which + are listed explicitly as `elided_window_dims`, must have a window bound of + `1`. Since they have a window bound of `1` the only valid index for them is + `0` and eliding them does not introduce ambiguity. + + 3. The slice extracted from the "Gather Indices" tensor ((`X`, `Y`) in the last + example) may have fewer elements than the input tensor rank, and an explicit + mapping dictates how the index should be expanded to have the same rank as + the input. + +As a final example, we use (2) and (3) to implement `tf.gather_nd`: + +
+ +
+ +`G``0` and `G``1` are used to slice out a starting index +from the gather indices tensor as usual, except the starting index has only one +element, `X`. Similarly, there is only one output window index with the value +`W``0`. However, before being used as indices into the input tensor, +these are expanded in accordance to "Gather Index Mapping" +(`gather_dims_to_operand_dims` in the formal description) and "Window Mapping" +(`window_dims_to_operand_dims` in the formal description) into +[`0`,`W``0`] and [`X`,`0`] respectively, adding up to +[`X`,`W``0`]. In other words, the output index +[`G``0`,`G``1`,`W``0`] maps to the input index +[`GatherIndices`[`G``0`,`G``1`,`0`],`X`] which gives us +the semantics for `tf.gather_nd`. + +`window_bounds` for this case is `[1,11]`. Intuitively this means that every +index `X` in the gather indices tensor picks an entire row and the result is the +concatenation of all these rows. ## GetTupleElement diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md index 9ede4ab83c1dcdb7370e83dfb9227fbb235d0689..d19200e80cdfe6620789ddd273647660c10b2a60 100644 --- a/tensorflow/docs_src/programmers_guide/datasets.md +++ b/tensorflow/docs_src/programmers_guide/datasets.md @@ -322,9 +322,10 @@ sess.run(iterator.initializer) next1, (next2, next3) = iterator.get_next() ``` -Note that evaluating *any* of `next1`, `next2`, or `next3` will advance the -iterator for all components. A typical consumer of an iterator will include all -components in a single expression. +Note that `next1`, `next2`, and `next3` are tensors produced by the +same op/node (created by `Iterator.get_next()`). Therefore, evaluating *any* of +these tensors will advance the iterator for all components. A typical consumer +of an iterator will include all components in a single expression. ## Reading input data diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md index 9eaee2702829cbfd96cd56e832003724eba5bb1b..c8fdae6f60c33776b6d9a8c1a33666ce4ddb1cb2 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/programmers_guide/debugger.md @@ -1,4 +1,4 @@ -# Debugging TensorFlow Programs +# TensorFlow Debugger @@ -214,7 +214,7 @@ navigate between these screens by clicking the `<--` and ### Other Features of the tfdbg CLI In addition to the commands listed above, the tfdbg CLI provides the following -addditional features: +additional features: * To navigate through previous tfdbg commands, type in a few characters followed by the Up or Down arrow keys. tfdbg will show you the history of diff --git a/tensorflow/docs_src/programmers_guide/graphs.md b/tensorflow/docs_src/programmers_guide/graphs.md index 2b4896c381052b5a3fb97385a18dbff82c2c0d89..9049a5a9f3d44e255188c6c41cdb12a619464379 100644 --- a/tensorflow/docs_src/programmers_guide/graphs.md +++ b/tensorflow/docs_src/programmers_guide/graphs.md @@ -125,14 +125,14 @@ an operation: @{tf.Tensor} accepts an optional `name` argument. For example, `tf.constant(42.0, name="answer")` creates a new @{tf.Operation} named `"answer"` and returns a @{tf.Tensor} named `"answer:0"`. If the default graph - already contained an operation named `"answer"`, the TensorFlow would append + already contains an operation named `"answer"`, then TensorFlow would append `"_1"`, `"_2"`, and so on to the name, in order to make it unique. * The @{tf.name_scope} function makes it possible to add a **name scope** prefix to all operations created in a particular context. The current name scope prefix is a `"/"`-delimited list of the names of all active @{tf.name_scope} context managers. If a name scope has already been used in the current - context, TensorFlow appens `"_1"`, `"_2"`, and so on. For example: + context, TensorFlow appends `"_1"`, `"_2"`, and so on. For example: ```python c_0 = tf.constant(0, name="c") # => operation named "c" diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md index d45e666ce7b440bae20ba32d894526372af7e17b..7a5e90081d9145ca934929f0af11f2a40cb2dcae 100644 --- a/tensorflow/docs_src/programmers_guide/index.md +++ b/tensorflow/docs_src/programmers_guide/index.md @@ -13,7 +13,7 @@ works. The units are as follows: ## Low Level APIs * @{$programmers_guide/low_level_intro}, which introduces the - basics of how you can to use TensorFlow outside of the high Level APIs. + basics of how you can use TensorFlow outside of the high Level APIs. * @{$programmers_guide/tensors}, which explains how to create, manipulate, and access Tensors--the fundamental object in TensorFlow. * @{$programmers_guide/variables}, which details how diff --git a/tensorflow/docs_src/programmers_guide/leftnav_files b/tensorflow/docs_src/programmers_guide/leftnav_files index 38de3ccc3e474e6051976c810519212da8f5051e..3fe4cb2ddaee40d9d6c6470bee171dedb27ad890 100644 --- a/tensorflow/docs_src/programmers_guide/leftnav_files +++ b/tensorflow/docs_src/programmers_guide/leftnav_files @@ -10,7 +10,10 @@ tensors.md variables.md graphs.md saved_model.md + +### Accelerators using_gpu.md +using_tpu.md ### ML Concepts embedding.md @@ -19,9 +22,9 @@ embedding.md debugger.md ### TensorBoard -summaries_and_tensorboard.md -graph_viz.md -tensorboard_histograms.md +summaries_and_tensorboard.md: Visualizing Learning +graph_viz.md: Graphs +tensorboard_histograms.md: Histograms ### Misc version_compat.md diff --git a/tensorflow/docs_src/programmers_guide/low_level_intro.md b/tensorflow/docs_src/programmers_guide/low_level_intro.md index 8f6d3fbd46d8b76d6033d95fd51c1df45733f5a3..05709ad10a9275953d351e4a62cbf6d7fbffbbe3 100644 --- a/tensorflow/docs_src/programmers_guide/low_level_intro.md +++ b/tensorflow/docs_src/programmers_guide/low_level_intro.md @@ -286,6 +286,23 @@ while True: break ``` +If the `Dataset` depends on stateful operations you may need to +initialize the iterator before using it, as shown below: + +``` python +r = tf.random_normal([10,3]) +dataset = tf.data.Dataset.from_tensor_slices(r) +iterator = dataset.make_initializable_iterator() +next_row = iterator.get_next() + +sess.run(iterator.initializer) +while True: + try: + print(sess.run(next_row)) + except tf.errors.OutOfRangeError: + break +``` + For more details on Datasets and Iterators see: @{$programmers_guide/datasets}. ## Layers @@ -295,7 +312,7 @@ the same input. @{tf.layers$Layers} are the preferred way to add trainable parameters to a graph. Layers package together both the variables and the operations that act -on them, . For example a +on them. For example a [densely-connected layer](https://developers.google.com/machine-learning/glossary/#fully_connected_layer) performs a weighted sum across all inputs for each output and applies an optional @@ -478,7 +495,7 @@ good. Here's what we got; your own output will almost certainly differ: [ 0.10527515]] ``` -### loss +### Loss To optimize a model, you first need to define the loss. We'll use the mean square error, a standard loss for regression problems. @@ -504,7 +521,7 @@ TensorFlow provides [**optimizers**](https://developers.google.com/machine-learning/glossary/#optimizer) implementing standard optimization algorithms. These are implemented as sub-classes of @{tf.train.Optimizer}. They incrementally change each -variable in order to minimizethe loss. The simplest optimization algorithm is +variable in order to minimize the loss. The simplest optimization algorithm is [**gradient descent**](https://developers.google.com/machine-learning/glossary/#gradient_descent), implemented by @{tf.train.GradientDescentOptimizer}. It modifies each variable according to the magnitude of the derivative of loss with respect to diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/programmers_guide/saved_model.md index 9f50be5b31cd8b61b81426f50aa9ef9beb3138f2..f27a658342b8d33407e1c6ed5799a10c2305a74c 100644 --- a/tensorflow/docs_src/programmers_guide/saved_model.md +++ b/tensorflow/docs_src/programmers_guide/saved_model.md @@ -285,7 +285,7 @@ with tf.Session(graph=tf.Graph()) as sess: ``` -### Loading a Savedmodel in C++ +### Loading a SavedModel in C++ The C++ version of the SavedModel [loader](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/loader.h) @@ -303,6 +303,30 @@ LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain}, &bundle); ``` +### Loading and Serving a SavedModel in TensorFlow Serving + +You can easily load and serve a SavedModel with the TensorFlow Serving Model +Server binary. See [instructions](https://www.tensorflow.org/serving/setup#installing_using_apt-get) +on how to install the server, or build it if you wish. + +Once you have the Model Server, run it with: +``` +tensorflow_model_server --port=port-numbers --model_name=your-model-name --model_base_path=your_model_base_path +``` +Set the port and model_name flags to values of your choosing. The +model_base_path flag expects to be to a base directory, with each version of +your model residing in a numerically named subdirectory. If you only have a +single version of your model, simply place it in a subdirectory like so: +* Place the model in /tmp/model/0001 +* Set model_base_path to /tmp/model + +Store different versions of your model in numerically named subdirectories of a +common base directory. For example, suppose the base directory is `/tmp/model`. +If you have only one version of your model, store it in `/tmp/model/0001`. If +you have two versions of your model, store the second version in +`/tmp/model/0002`, and so on. Set the `--model-base_path` flag to the base +directory (`/tmp/model`, in this example). TensorFlow Model Server will serve +the model in the highest numbered subdirectory of that base directory. ### Standard constants diff --git a/tensorflow/docs_src/programmers_guide/using_tpu.md b/tensorflow/docs_src/programmers_guide/using_tpu.md new file mode 100644 index 0000000000000000000000000000000000000000..d74d7f3181c9cf44e6c97e13742db682858f4694 --- /dev/null +++ b/tensorflow/docs_src/programmers_guide/using_tpu.md @@ -0,0 +1,396 @@ +# Using TPUs + +This document walks through the principal TensorFlow APIs necessary to make +effective use of a [Cloud TPU](https://cloud.google.com/tpu/), and highlights +the differences between regular TensorFlow usage, and usage on a TPU. + +This doc is aimed at users who: + +* Are familiar with TensorFlow's `Estimator` and `Dataset` APIs +* Have maybe [tried out a Cloud TPU](https://cloud.google.com/tpu/docs/quickstart) + using an existing model. +* Have, perhaps, skimmed the code of an example TPU model + [[1]](https://github.com/tensorflow/models/blob/master/official/mnist/mnist_tpu.py) + [[2]](https://github.com/tensorflow/tpu-demos/tree/master/cloud_tpu/models). +* Are interested in porting an existing `Estimator` model to + run on Cloud TPUs + +## TPUEstimator + +@{tf.estimator.Estimator$Estimators} are TensorFlow's model-level abstraction. +Standard `Estimators` can drive models on CPU and GPUs. You must use +@{tf.contrib.tpu.TPUEstimator} to drive a model on TPUs. + +Refer to TensorFlow's Getting Started section for an introduction to the basics +of using a @{$get_started/premade_estimators$pre-made `Estimator`}, and +@{$get_started/custom_estimators$custom `Estimator`s}. + +The `TPUEstimator` class differs somewhat from the `Estimator` class. + +The simplest way to maintain a model that can be run both on CPU/GPU or on a +Cloud TPU is to define the model's inference phase (from inputs to predictions) +outside of the `model_fn`. Then maintain separate implementations of the +`Estimator` setup and `model_fn`, both wrapping this inference step. For an +example of this pattern compare the `mnist.py` and `mnist_tpu.py` implementation in +[tensorflow/models](https://github.com/tensorflow/models/tree/master/official/mnist). + +### Running a `TPUEstimator` locally + +To create a standard `Estimator` you call the constructor, and pass it a +`model_fn`, for example: + +``` +my_estimator = tf.estimator.Estimator( + model_fn=my_model_fn) +``` + +The changes required to use a @{tf.contrib.tpu.TPUEstimator} on your local +machine are relatively minor. The constructor requires two additional arguments. +You should set the `use_tpu` argument to `False`, and pass a +@{tf.contrib.tpu.RunConfig} as the `config` argument, as shown below: + +``` python +my_tpu_estimator = tf.contrib.tpu.TPUEstimator( + model_fn=my_model_fn, + config=tf.contrib.tpu.RunConfig() + use_tpu=False) +``` + +Just this simple change will allow you to run a `TPUEstimator` locally. +The majority of example TPU models can be run in this local mode, +by setting the command line flags as follows: + + +``` +$> python mnist_tpu.py --use_tpu=false --master='' +``` + +Note: This `use_tpu=False` argument is useful for trying out the `TPUEstimator` +API. It is not meant to be a complete TPU compatibility test. Successfully +running a model locally in a `TPUEstimator` does not guarantee that it will +work on a TPU. + + +### Building a `tpu.RunConfig` + +While the default `RunConfig` is sufficient for local training, these settings +cannot be ignored in real usage. + +A more typical setup for a `RunConfig`, that can be switched to use a Cloud +TPU, might be as follows: + +``` python +import tempfile +import subprocess + +class FLAGS(object): + use_tpu=False + tpu_name=None + # Use a local temporary path for the `model_dir` + model_dir = tempfile.mkdtemp() + # Number of training steps to run on the Cloud TPU before returning control. + iterations = 50 + # A single Cloud TPU has 8 shards. + num_shards = 8 + +if FLAGS.use_tpu: + my_project_name = subprocess.check_output([ + 'gcloud','config','get-value','project']) + my_zone = subprocess.check_output([ + 'gcloud','config','get-value','compute/zone']) + cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( + tpu_names=[FLAGS.tpu_name], + zone=my_zone, + project=my_project) + master = tpu_cluster_resolver.get_master() +else: + master = '' + +my_tpu_run_config = tf.contrib.tpu.RunConfig( + master=master, + evaluation_master=master, + model_dir=FLAGS.model_dir, + session_config=tf.ConfigProto( + allow_soft_placement=True, log_device_placement=True), + tpu_config=tf.contrib.tpu.TPUConfig(FLAGS.iterations, + FLAGS.num_shards), +) +``` + +Then you must pass the @{tf.contrib.tpu.RunConfig} to the constructor: + +``` python +my_tpu_estimator = tf.contrib.tpu.TPUEstimator( + model_fn=my_model_fn, + config = my_tpu_run_config, + use_tpu=FLAGS.use_tpu) +``` + +Typically the `FLAGS` would be set by command line arguments. To switch from +training locally to training on a cloud TPU you would need to: + + 1) Set `FLAGS.use_tpu` to `True` + 1) Set `FLAGS.tpu_name` so the + `tf.contrib.cluster_resolver.TPUClusterResolver` can find it + 1) Set `FLAGS.model_dir` to a Google Cloud Storage bucket url (`gs://`). + + +## Optimizer + +When training on a cloud TPU you **must** wrap the optimizer in a +@{tf.contrib.tpu.CrossShardOptimizer}, which uses an `allreduce` to aggregate +gradients and broadcast the result to each shard (each TPU core). + +The `CrossShardOptimizer` is not compatible with local training. So, to have +the same code run both locally and on a Cloud TPU, add lines like the following: + +``` python +optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) +if FLAGS.use_tpu: + optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) +``` + +If you prefer to avoid a global `FLAGS` variable in your model code, one +approach is to set the optimizer as one of the `Estimator`'s params, +as follows: + +``` python +my_tpu_estimator = tf.contrib.tpu.TPUEstimator( + model_fn=my_model_fn, + config = my_tpu_run_config, + use_tpu=FLAGS.use_tpu, + params={'optimizer':optimizer}) +``` + +## Model Function + +This section details the changes you must make to the model function +(`model_fn()`) to make it `TPUEstimator` compatible. + +### Static shapes + +During regular usage TensorFlow attempts to determine the shapes of each +`tf.Tensor` during graph construction. During execution any unknown shape +dimensions are determined dynamically, +see @{$programmers_guide/tensors#shape$Tensor Shapes} for more details. + +To run on Cloud TPUs TensorFlow models are compiled using @{$xla$XLA}. +XLA uses a similar system for determining shapes at compile time. XLA requires +that all tensor dimensions be statically defined at compile time. All shapes +must evaluate to a constant, and not depend on external data, or stateful +operations like variables or a random number generator. + + +### Summaries + +Remove any use of `tf.summary` from your model. + +@{$summaries_and_tensorboard$TensorBoard summaries} are a great way see inside +your model. A minimal set of basic summaries are automatically recorded by the +`TPUEstimator`, to `event` files in the `model_dir`. Custom summaries, however, +are currently unsupported when training on a Cloud TPU. So while the +`TPUEstimator` will still run locally with summaries, it will fail if used on a +TPU. + +### Metrics + +Build your evaluation metrics dictionary in a stand-alone `metric_fn`. + + + +Evaluation metrics are an essential part of training a model. These are fully +supported on Cloud TPUs, but with a slightly different syntax. + +A standard @{tf.metrics} returns two tensors. The first returns the running +average of the metric value, while the second updates the running average and +returns the value for this batch: + +``` +running_average, current_batch = tf.metrics.accuracy(labels, predictions) +``` + +In a standard `Estimator` you create a dictionary of these pairs, and return it +as part of the `EstimatorSpec`. + +```python +my_metrics = {'accuracy': tf.metrics.accuracy(labels, predictions)} + +return tf.estimator.EstimatorSpec( + ... + eval_metric_ops=my_metrics +) +``` + +In a `TPUEstimator` you instead pass a function (which returns a metrics +dictionary) and a list of argument tensors, as shown below: + +```python +def my_metric_fn(labels, predictions): + return {'accuracy': tf.metrics.accuracy(labels, predictions)} + +return tf.contrib.tpu.TPUEstimatorSpec( + ... + eval_metrics=(my_metric_fn, [labels, predictions]) +) +``` + +### Use `TPUEstimatorSpec` + +`TPUEstimatorSpec` do not support hooks, and require function wrappers for +some fields. + +An `Estimator`'s `model_fn` must return an `EstimatorSpec`. An `EstimatorSpec` +is a simple structure of named fields containing all the `tf.Tensors` of the +model that the `Estimator` may need to interact with. + +`TPUEstimators` use a @{tf.contrib.tpu.TPUEstimatorSpec}. There are a few +differences between it and a standard @{tf.estimator.EstimatorSpec}: + + +* The `eval_metric_ops` must be wrapped into a `metrics_fn`, this field is + renamed `eval_metrics` ([see above](#metrics)). +* The @{tf.train.SessionRunHook$hooks} are unsupported, so these fields are + omitted. +* The @{tf.train.Scaffold$`scaffold`}, if used, must also be wrapped in a + function. This field is renamed to `scaffold_fn`. + +`Scaffold` and `Hooks` are for advanced usage, and can typically be omitted. + +## Input functions + +Input functions work mainly unchanged as they run on the host computer, not the +Cloud TPU itself. This section explains the two necessary adjustments. + +### Params argument + + + +The `input_fn` for a standard `Estimator` _can_ include a +`params` argument; the `input_fn` for a `TPUEstimator` *must* include a +`params` argument. This is necessary to allow the estimator to set the batch +size for each replica of the input stream. So the minimum signature for an +`input_fn` for a `TPUEstimator` is: + +``` +def my_input_fn(params): + pass +``` + +Where `params['batch-size']` will contain the batch size. + +### Static shapes and batch size + +The input pipeline generated by your `input_fn` is run on CPU. So it is mostly +free strict static shape requirements imposed by the XLA/TPU environment. The +one requirement is that the batches of data fed from your input pipeline to +the TPU have a static shape, as determined by the standard TensorFlow shape +inference algorithm. Intermediate tensors are free to have a dynamic shapes. +If shape inference has failed, but the shape is known it is possible to +impose the correct shape using `tf.set_shape()`. + +In the example below the shape +inference algorithm fails, but it is corrected using `set_shape`: + +``` +>>> x = tf.zeros(tf.constant([1,2,3])+1) +>>> x.shape + +TensorShape([Dimension(None), Dimension(None), Dimension(None)]) + +>>> x.set_shape([2,3,4]) +``` + +In many cases the batch size is the only unknown dimension. + +A typical input pipeline, using `tf.data`, will usually produce batches of a +fixed size. The last batch of a finite `Dataset`, however, is typically smaller, +containing just the remaining elements. Since a `Dataset` does not know its own +length or finiteness, the standard @{tf.data.Dataset.batch$`batch`} method +cannot determine if all batches will have a fixed size batch on its own: + +``` +>>> params = {'batch_size':32} +>>> ds = tf.data.Dataset.from_tensors([0, 1, 2]) +>>> ds = ds.repeat().batch(params['batch-size']) +>>> ds + + +``` + +The most straightforward fix is to +@{tf.data.Dataset.apply$apply} @{tf.contrib.data.batch_and_drop_remainder} +as follows: + +``` +>>> params = {'batch_size':32} +>>> ds = tf.data.Dataset.from_tensors([0, 1, 2]) +>>> ds = ds.repeat().apply( +... tf.contrib.data.batch_and_drop_remainder(params['batch-size'])) +>>> ds + + <_RestructuredDataset shapes: (32, 3), types: tf.int32> +``` + +The one downside to this approach is that, as the name implies, this batching +method throws out any fractional batch at the end of the dataset. This is fine +for an infinitely repeating dataset being used for training, but could be a +problem if you want to train for an exact number of epochs. + +To do an exact 1-epoch of _evaluation_ you can work around this by manually +padding the length of the batches, and setting the padding entries to have zero +weight when creating your `tf.metrics`. + +## Datasets + +Efficient use of the `tf.data.Dataset` API is critical when using a Cloud +TPU, as it is impossible to use the Cloud TPU's unless you can feed it data +quickly enough. See @{$datasets_performance} for details on dataset performance. + +For all but the simplest experimentation (using +@{tf.data.Dataset.from_tensor_slices} or other in-graph data) you will need to +store all data files read by the `TPUEstimator`'s `Dataset` in Google Cloud +Storage Buckets. + + + +For most use-cases, we recommend converting your data into `TFRecord` +format and using a @{tf.data.TFRecordDataset} to read it. This, however, is not +a hard requirement and you can use other dataset readers +(`FixedLengthRecordDataset` or `TextLineDataset`) if you prefer. + +Small datasets can be loaded entirely into memory using +@{tf.data.Dataset.cache}. + +Regardless of the data format used, it is strongly recommended that you +@{$performance_guide#use_large_files$use large files}, on the order of +100MB. This is especially important in this networked setting as the overhead +of opening a file is significantly higher. + +It is also important, regardless of the type of reader used, to enable buffering +using the `buffer_size` argument to the constructor. This argument is specified +in bytes. A minimum of a few MB (`buffer_size=8*1024*1024`) is recommended so +that data is available when needed. + +The TPU-demos repo includes +[a script](https://github.com/tensorflow/tpu-demos/blob/master/cloud_tpu/datasets/imagenet_to_gcs.py) +for downloading the imagenet dataset and converting it to an appropriate format. +This together with the imagenet +[models](https://github.com/tensorflow/tpu-demos/tree/master/cloud_tpu/models) +included in the repo demonstrate all of these best-practices. + + +## What Next + +For details on how to actually set up and run a Cloud TPU see: + + * [Google Cloud TPU Documentation](https://cloud.google.com/tpu/docs/) + +This document is by no means exhaustive. The best source of more detail on how +to make a Cloud TPU compatible model are the example models published in: + + * The [TPU Demos Repository.](https://github.com/tensorflow/tpu-demos/) + +For more information about tuning TensorFlow code for performance see: + + * The @{$performance$Performance Section.} + diff --git a/tensorflow/docs_src/programmers_guide/variables.md b/tensorflow/docs_src/programmers_guide/variables.md index 64250738056043e236b5eb236bcbf29375655260..e8cf7711552f4c83ed1e03e0753b580cc7505ddc 100644 --- a/tensorflow/docs_src/programmers_guide/variables.md +++ b/tensorflow/docs_src/programmers_guide/variables.md @@ -62,9 +62,10 @@ them. For this reason TensorFlow provides **collections**, which are named lists of tensors or other objects, such as `tf.Variable` instances. By default every `tf.Variable` gets placed in the following two collections: + * `tf.GraphKeys.GLOBAL_VARIABLES` --- variables that can be shared across -multiple devices, - * `tf.GraphKeys.TRAINABLE_VARIABLES`--- variables for which TensorFlow will + multiple devices, + * `tf.GraphKeys.TRAINABLE_VARIABLES` --- variables for which TensorFlow will calculate gradients. If you don't want a variable to be trainable, add it to the diff --git a/tensorflow/docs_src/tutorials/layers.md b/tensorflow/docs_src/tutorials/layers.md index b898cbe29c2bac9ade341fe3b3566e42e133fc5b..5111b16247e2b5c3410e69dcdf08318a35b18c2f 100644 --- a/tensorflow/docs_src/tutorials/layers.md +++ b/tensorflow/docs_src/tutorials/layers.md @@ -635,7 +635,7 @@ should be logged after every 50 steps of training. ### Train the Model Now we're ready to train our model, which we can do by creating `train_input_fn` -ans calling `train()` on `mnist_classifier`. Add the following to `main()`: +and calling `train()` on `mnist_classifier`. Add the following to `main()`: ```python # Train the model diff --git a/tensorflow/docs_src/tutorials/leftnav_files b/tensorflow/docs_src/tutorials/leftnav_files index 41ffdc86010fb8407889df26eefa5fa59952c5da..888052428f951fa1a7cbd9c6d35497a056387097 100644 --- a/tensorflow/docs_src/tutorials/leftnav_files +++ b/tensorflow/docs_src/tutorials/leftnav_files @@ -1,22 +1,22 @@ index.md ### Images -layers.md -image_recognition.md -image_retraining.md +layers.md: MNIST +image_recognition.md: Image Recognition +image_retraining.md: Image Retraining deep_cnn.md ### Sequences recurrent.md -seq2seq.md -recurrent_quickdraw.md +seq2seq.md: Neural Machine Translation +recurrent_quickdraw.md: Drawing Classification audio_recognition.md ### Data Representation -wide.md -wide_and_deep.md +wide.md: Linear Models +wide_and_deep.md: Wide & Deep Learning word2vec.md -kernel_methods.md +kernel_methods.md: Kernel Methods ### Non-ML mandelbrot.md diff --git a/tensorflow/docs_src/tutorials/wide.md b/tensorflow/docs_src/tutorials/wide.md index dba6f54c52ca5bf2569c66ad055329708de3991c..005dc020f94f666da295f4ff0342fae858121012 100644 --- a/tensorflow/docs_src/tutorials/wide.md +++ b/tensorflow/docs_src/tutorials/wide.md @@ -82,7 +82,7 @@ Here's a list of columns available in the Census Income dataset: | hours_per_week | Continuous | Hours worked per week. | | native_country | Categorical | Country of origin of the | : : : individual. : -| income | Categorical | ">50K" or "<=50K", meaning | +| income_bracket | Categorical | ">50K" or "<=50K", meaning | : : : whether the person makes more : : : : than $50,000 annually. : diff --git a/tensorflow/examples/android/build.gradle b/tensorflow/examples/android/build.gradle index f7bdf8b816a8191770bc1ad59b890041b8e39912..0767726aa9a248fb073fbd4114f47d1b4ed6901b 100644 --- a/tensorflow/examples/android/build.gradle +++ b/tensorflow/examples/android/build.gradle @@ -56,10 +56,12 @@ def nativeOutDir = 'libs/' + cpuType def nativeBuildRule = 'buildNativeBazel' def demoLibPath = '../../../bazel-bin/tensorflow/examples/android/libtensorflow_demo.so' def inferenceLibPath = '../../../bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so' + +// Override for Makefile builds. if (nativeBuildSystem == 'makefile') { nativeBuildRule = 'buildNativeMake' - demoLibPath = '../../../tensorflow/contrib/makefile/gen/lib/libtensorflow_demo.so' - inferenceLibPath = '../../../tensorflow/contrib/makefile/gen/lib/libtensorflow_inference.so' + demoLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_demo.so' + inferenceLibPath = '../../../tensorflow/contrib/makefile/gen/lib/android_' + cpuType + '/libtensorflow_inference.so' } // If building with Bazel, this is the location of the bazel binary. @@ -154,7 +156,8 @@ task buildNativeMake(type: Exec) { '-s', \ 'tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in', \ '-t', \ - 'libtensorflow_inference.so libtensorflow_demo.so' \ + 'libtensorflow_inference.so libtensorflow_demo.so all' \ + , '-a', cpuType \ //, '-T' // Uncomment to skip protobuf and speed up subsequent builds. } diff --git a/tensorflow/examples/android/res/animator/color_animation.xml b/tensorflow/examples/android/res/animator/color_animation.xml new file mode 100644 index 0000000000000000000000000000000000000000..891d8cc1d4f3e59d0371030fd763c5ad468e7887 --- /dev/null +++ b/tensorflow/examples/android/res/animator/color_animation.xml @@ -0,0 +1,30 @@ + + + + + diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java index a317273acdff016c824031e06c413ecc01f82ec8..068c7b0d945669b8207097e81c03ade07bc7ca73 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java @@ -81,8 +81,11 @@ public class LegacyCameraConnectionFragment extends Fragment { try { Camera.Parameters parameters = camera.getParameters(); - parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE); - + List focusModes = parameters.getSupportedFocusModes(); + if (focusModes != null + && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) { + parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE); + } List cameraSizes = parameters.getSupportedPreviewSizes(); Size[] sizes = new Size[cameraSizes.size()]; int i = 0; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java index 184df1bdb42802bfe50b15429f09baeb5600e34f..8a1d86d9eedf3a1e1aa80e998ff150ad0c2447a1 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java @@ -31,7 +31,8 @@ the RecognizeCommands helper class. package org.tensorflow.demo; -import android.animation.ValueAnimator; +import android.animation.AnimatorInflater; +import android.animation.AnimatorSet; import android.app.Activity; import android.content.pm.PackageManager; import android.media.AudioFormat; @@ -329,17 +330,11 @@ public class SpeechActivity extends Activity { labelIndex = i; } } - final View labelView = (View) labelsListView.getChildAt(labelIndex - 2); - ValueAnimator colorAnimation = - ValueAnimator.ofArgb(0x00b3ccff, 0xffb3ccff, 0x00b3ccff); - colorAnimation.setDuration(750); - colorAnimation.addUpdateListener( - new ValueAnimator.AnimatorUpdateListener() { - @Override - public void onAnimationUpdate(ValueAnimator animator) { - labelView.setBackgroundColor((int) animator.getAnimatedValue()); - } - }); + final View labelView = labelsListView.getChildAt(labelIndex - 2); + + AnimatorSet colorAnimation = (AnimatorSet) AnimatorInflater.loadAnimator( + SpeechActivity.this, R.animator.color_animation); + colorAnimation.setTarget(labelView); colorAnimation.start(); } } diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java index 2fe2ba539edc84e80baf36b6d1ac1e192bc92163..af6af2bc8f508a70aa7e44a7236f0e7ea5e3d71c 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java @@ -199,7 +199,7 @@ public class MultiBoxTracker { final int w, final int h, final int rowStride, - final int sensorOrienation, + final int sensorOrientation, final byte[] frame, final long timestamp) { if (objectTracker == null && !initialized) { @@ -209,7 +209,7 @@ public class MultiBoxTracker { objectTracker = ObjectTracker.getInstance(w, h, rowStride, true); frameWidth = w; frameHeight = h; - this.sensorOrientation = sensorOrienation; + this.sensorOrientation = sensorOrientation; initialized = true; if (objectTracker == null) { diff --git a/tensorflow/examples/get_started/regression/imports85.py b/tensorflow/examples/get_started/regression/imports85.py index 6bee556eb887a643b3a81691324736427ecc2707..a8e4c782b3f7b5d01a91f38a48e5edb2202108de 100644 --- a/tensorflow/examples/get_started/regression/imports85.py +++ b/tensorflow/examples/get_started/regression/imports85.py @@ -131,7 +131,7 @@ def dataset(y_name="price", train_fraction=0.7): # booleans but we are dealing with symbolic tensors. return ~in_training_set(line) - base_dataset = (tf.contrib.data + base_dataset = (tf.data # Get the lines from the file. .TextLineDataset(path) # drop lines with question marks. diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py index fa4c1c0da5f31863aa4d99b6ec84e1e50e1a1551..461fb1c5173f66278eb585d30bd8749a58fb6245 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Train and Eval the MNIST network. This version is like fully_connected_feed.py but uses data converted @@ -65,6 +64,7 @@ def decode(serialized_example): return image, label + def augment(image, label): # OPTIONAL: Could reshape into a 28x28 image and apply distortions # here. Since we are not applying any distortions in this @@ -72,12 +72,14 @@ def augment(image, label): # into a vector, we don't bother. return image, label + def normalize(image, label): # Convert from [0, 255] -> [-0.5, 0.5] floats. image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 return image, label + def inputs(train, batch_size, num_epochs): """Reads input data num_epochs times. @@ -98,9 +100,10 @@ def inputs(train, batch_size, num_epochs): over the dataset once. On the other hand there is no special initialization required. """ - if not num_epochs: num_epochs = None - filename = os.path.join(FLAGS.train_dir, - TRAIN_FILE if train else VALIDATION_FILE) + if not num_epochs: + num_epochs = None + filename = os.path.join(FLAGS.train_dir, TRAIN_FILE + if train else VALIDATION_FILE) with tf.name_scope('input'): # TFRecordDataset opens a protobuf and reads entries line by line @@ -127,13 +130,11 @@ def run_training(): # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Input images and labels. - image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size, - num_epochs=FLAGS.num_epochs) + image_batch, label_batch = inputs( + train=True, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs) # Build a Graph that computes predictions from the inference model. - logits = mnist.inference(image_batch, - FLAGS.hidden1, - FLAGS.hidden2) + logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2) # Add to the Graph the loss calculation. loss = mnist.loss(logits, label_batch) @@ -152,7 +153,7 @@ def run_training(): sess.run(init_op) try: step = 0 - while True: #train until OutOfRangeError + while True: #train until OutOfRangeError start_time = time.time() # Run one step of the model. The return values are @@ -168,10 +169,12 @@ def run_training(): # Print an overview fairly often. if step % 100 == 0: print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, - duration)) + duration)) step += 1 except tf.errors.OutOfRangeError: - print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) + print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, + step)) + def main(_): run_training() @@ -183,37 +186,27 @@ if __name__ == '__main__': '--learning_rate', type=float, default=0.01, - help='Initial learning rate.' - ) + help='Initial learning rate.') parser.add_argument( '--num_epochs', type=int, default=2, - help='Number of epochs to run trainer.' - ) + help='Number of epochs to run trainer.') parser.add_argument( '--hidden1', type=int, default=128, - help='Number of units in hidden layer 1.' - ) + help='Number of units in hidden layer 1.') parser.add_argument( '--hidden2', type=int, default=32, - help='Number of units in hidden layer 2.' - ) - parser.add_argument( - '--batch_size', - type=int, - default=100, - help='Batch size.' - ) + help='Number of units in hidden layer 2.') + parser.add_argument('--batch_size', type=int, default=100, help='Batch size.') parser.add_argument( '--train_dir', type=str, default='/tmp/data', - help='Directory with the training data.' - ) + help='Directory with the training data.') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index ec22684eaf63700c608c6ce45f22941555246b99..c49e7e7ee2e397e353b468c727263ff3eb931401 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -41,7 +41,6 @@ The subfolder names are important, since they define what label is applied to each image, but the filenames themselves don't matter. Once your images are prepared, you can run the training with a command like this: - ```bash bazel build tensorflow/examples/image_retraining:retrain && \ bazel-bin/tensorflow/examples/image_retraining/retrain \ @@ -70,12 +69,14 @@ on resource-limited platforms, you can try the `--architecture` flag with a Mobilenet model. For example: Run floating-point version of mobilenet: + ```bash python tensorflow/examples/image_retraining/retrain.py \ --image_dir ~/flower_photos --architecture mobilenet_1.0_224 ``` Run quantized version of mobilenet: + ```bash python tensorflow/examples/image_retraining/retrain.py \ --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized @@ -96,6 +97,12 @@ Visualize the summaries with this command: tensorboard --logdir /tmp/retrain_logs +To use with Tensorflow Serving: + +```bash +tensorflow_model_server --port=9000 --model_name=inception \ + --model_base_path=/tmp/saved_models/ +``` """ from __future__ import absolute_import from __future__ import division @@ -344,8 +351,8 @@ def maybe_download_and_extract(data_url): filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress) print() statinfo = os.stat(filepath) - tf.logging.info('Successfully downloaded', filename, statinfo.st_size, - 'bytes.') + tf.logging.info('Successfully downloaded %s %d bytes.', + filename, statinfo.st_size) print('Extracting file from ', filepath) tarfile.open(filepath, 'r:gz').extractall(dest_directory) else: @@ -1004,6 +1011,46 @@ def add_jpeg_decoding(input_width, input_height, input_depth, input_mean, return jpeg_data, mul_image +def export_model(sess, architecture, saved_model_dir): + """Exports model for serving. + + Args: + sess: Current active TensorFlow Session. + architecture: Model architecture. + saved_model_dir: Directory in which to save exported model and variables. + """ + if architecture == 'inception_v3': + input_tensor = 'DecodeJpeg/contents:0' + elif architecture.startswith('mobilenet_'): + input_tensor = 'input:0' + else: + raise ValueError('Unknown architecture', architecture) + in_image = sess.graph.get_tensor_by_name(input_tensor) + inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)} + + out_classes = sess.graph.get_tensor_by_name('final_result:0') + outputs = {'prediction': + tf.saved_model.utils.build_tensor_info(out_classes)} + + signature = tf.saved_model.signature_def_utils.build_signature_def( + inputs=inputs, + outputs=outputs, + method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) + + legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') + + # Save out the SavedModel. + builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) + builder.add_meta_graph_and_variables( + sess, [tf.saved_model.tag_constants.SERVING], + signature_def_map={ + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + signature + }, + legacy_init_op=legacy_init_op) + builder.save() + + def main(_): # Needed to make sure the logging output is visible. # See https://github.com/tensorflow/tensorflow/issues/3047 @@ -1179,6 +1226,8 @@ def main(_): with gfile.FastGFile(FLAGS.output_labels, 'w') as f: f.write('\n'.join(image_lists.keys()) + '\n') + export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir) + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -1362,5 +1411,11 @@ if __name__ == '__main__': takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html for more information on Mobilenet.\ """) + parser.add_argument( + '--saved_model_dir', + type=str, + default='/tmp/saved_models/1/', + help='Where to save the exported graph.' + ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/label_image/label_image.py b/tensorflow/examples/label_image/label_image.py index d62b73384c4969dc56a2f91d89719ba02a8f9431..fe5e0fc684abce08d3d7b7f3fa22bb5ba701c64a 100644 --- a/tensorflow/examples/label_image/label_image.py +++ b/tensorflow/examples/label_image/label_image.py @@ -18,11 +18,11 @@ from __future__ import division from __future__ import print_function import argparse -import sys import numpy as np import tensorflow as tf + def load_graph(model_file): graph = tf.Graph() graph_def = tf.GraphDef() @@ -34,22 +34,26 @@ def load_graph(model_file): return graph -def read_tensor_from_image_file(file_name, input_height=299, input_width=299, - input_mean=0, input_std=255): + +def read_tensor_from_image_file(file_name, + input_height=299, + input_width=299, + input_mean=0, + input_std=255): input_name = "file_reader" output_name = "normalized" file_reader = tf.read_file(file_name, input_name) if file_name.endswith(".png"): - image_reader = tf.image.decode_png(file_reader, channels = 3, - name='png_reader') + image_reader = tf.image.decode_png( + file_reader, channels=3, name="png_reader") elif file_name.endswith(".gif"): - image_reader = tf.squeeze(tf.image.decode_gif(file_reader, - name='gif_reader')) + image_reader = tf.squeeze( + tf.image.decode_gif(file_reader, name="gif_reader")) elif file_name.endswith(".bmp"): - image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader') + image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader") else: - image_reader = tf.image.decode_jpeg(file_reader, channels = 3, - name='jpeg_reader') + image_reader = tf.image.decode_jpeg( + file_reader, channels=3, name="jpeg_reader") float_caster = tf.cast(image_reader, tf.float32) dims_expander = tf.expand_dims(float_caster, 0) resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) @@ -59,6 +63,7 @@ def read_tensor_from_image_file(file_name, input_height=299, input_width=299, return result + def load_labels(label_file): label = [] proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() @@ -66,6 +71,7 @@ def load_labels(label_file): label.append(l.rstrip()) return label + if __name__ == "__main__": file_name = "tensorflow/examples/label_image/data/grace_hopper.jpg" model_file = \ @@ -110,11 +116,12 @@ if __name__ == "__main__": output_layer = args.output_layer graph = load_graph(model_file) - t = read_tensor_from_image_file(file_name, - input_height=input_height, - input_width=input_width, - input_mean=input_mean, - input_std=input_std) + t = read_tensor_from_image_file( + file_name, + input_height=input_height, + input_width=input_width, + input_mean=input_mean, + input_std=input_std) input_name = "import/" + input_layer output_name = "import/" + output_layer @@ -122,8 +129,9 @@ if __name__ == "__main__": output_operation = graph.get_operation_by_name(output_name) with tf.Session(graph=graph) as sess: - results = sess.run(output_operation.outputs[0], - {input_operation.outputs[0]: t}) + results = sess.run(output_operation.outputs[0], { + input_operation.outputs[0]: t + }) results = np.squeeze(results) top_k = results.argsort()[-5:][::-1] diff --git a/tensorflow/examples/learn/text_classification.py b/tensorflow/examples/learn/text_classification.py index eb117c39a122f4f6c108dd18f8f8035edf05eaa1..e4e61862b02f9827f42c8d0052a7be8a57502dd8 100644 --- a/tensorflow/examples/learn/text_classification.py +++ b/tensorflow/examples/learn/text_classification.py @@ -34,8 +34,7 @@ MAX_LABEL = 15 WORDS_FEATURE = 'words' # Name of the input words feature. -def estimator_spec_for_softmax_classification( - logits, labels, mode): +def estimator_spec_for_softmax_classification(logits, labels, mode): """Returns EstimatorSpec instance for softmax classification.""" predicted_classes = tf.argmax(logits, 1) if mode == tf.estimator.ModeKeys.PREDICT: @@ -53,8 +52,8 @@ def estimator_spec_for_softmax_classification( return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) eval_metric_ops = { - 'accuracy': tf.metrics.accuracy( - labels=labels, predictions=predicted_classes) + 'accuracy': + tf.metrics.accuracy(labels=labels, predictions=predicted_classes) } return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) @@ -67,8 +66,7 @@ def bag_of_words_model(features, labels, mode): bow_embedding_column = tf.feature_column.embedding_column( bow_column, dimension=EMBEDDING_SIZE) bow = tf.feature_column.input_layer( - features, - feature_columns=[bow_embedding_column]) + features, feature_columns=[bow_embedding_column]) logits = tf.layers.dense(bow, MAX_LABEL, activation=None) return estimator_spec_for_softmax_classification( @@ -110,9 +108,9 @@ def main(unused_argv): # Prepare training and testing data dbpedia = tf.contrib.learn.datasets.load_dataset( 'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data) - x_train = pandas.Series(dbpedia.train.data[:,1]) + x_train = pandas.Series(dbpedia.train.data[:, 1]) y_train = pandas.Series(dbpedia.train.target) - x_test = pandas.Series(dbpedia.test.data[:,1]) + x_test = pandas.Series(dbpedia.test.data[:, 1]) y_test = pandas.Series(dbpedia.test.target) # Process vocabulary @@ -152,10 +150,7 @@ def main(unused_argv): # Predict. test_input_fn = tf.estimator.inputs.numpy_input_fn( - x={WORDS_FEATURE: x_test}, - y=y_test, - num_epochs=1, - shuffle=False) + x={WORDS_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) predictions = classifier.predict(input_fn=test_input_fn) y_predicted = np.array(list(p['class'] for p in predictions)) y_predicted = y_predicted.reshape(np.array(y_test).shape) diff --git a/tensorflow/examples/speech_commands/label_wav_dir.py b/tensorflow/examples/speech_commands/label_wav_dir.py new file mode 100644 index 0000000000000000000000000000000000000000..2f305359e380e7192795851112c8261ea896c290 --- /dev/null +++ b/tensorflow/examples/speech_commands/label_wav_dir.py @@ -0,0 +1,136 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Runs a trained audio graph against WAVE files and reports the results. + +The model, labels and .wav files specified in the arguments will be loaded, and +then the predictions from running the model against the audio data will be +printed to the console. This is a useful script for sanity checking trained +models, and as an example of how to use an audio model from Python. + +Here's an example of running it: + +python tensorflow/examples/speech_commands/label_wav_dir.py \ +--graph=/tmp/my_frozen_graph.pb \ +--labels=/tmp/speech_commands_train/conv_labels.txt \ +--wav_dir=/tmp/speech_dataset/left + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +import glob + +import tensorflow as tf + +# pylint: disable=unused-import +from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio +# pylint: enable=unused-import + +FLAGS = None + + +def load_graph(filename): + """Unpersists graph from file as default graph.""" + with tf.gfile.FastGFile(filename, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + +def load_labels(filename): + """Read in labels, one label per line.""" + return [line.rstrip() for line in tf.gfile.GFile(filename)] + + +def run_graph(wav_dir, labels, input_layer_name, output_layer_name, + num_top_predictions): + """Runs the audio data through the graph and prints predictions.""" + with tf.Session() as sess: + # Feed the audio data as input to the graph. + # predictions will contain a two-dimensional array, where one + # dimension represents the input image count, and the other has + # predictions per class + for wav_path in glob.glob(wav_dir + "/*.wav"): + if not wav_path or not tf.gfile.Exists(wav_path): + tf.logging.fatal('Audio file does not exist %s', wav_path) + + with open(wav_path, 'rb') as wav_file: + wav_data = wav_file.read() + + softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name) + predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data}) + + # Sort to show labels in order of confidence + print('\n%s' % (wav_path.split('/')[-1])) + top_k = predictions.argsort()[-num_top_predictions:][::-1] + for node_id in top_k: + human_string = labels[node_id] + score = predictions[node_id] + print('%s (score = %.5f)' % (human_string, score)) + + return 0 + + +def label_wav(wav_dir, labels, graph, input_name, output_name, how_many_labels): + """Loads the model and labels, and runs the inference to print predictions.""" + if not labels or not tf.gfile.Exists(labels): + tf.logging.fatal('Labels file does not exist %s', labels) + + if not graph or not tf.gfile.Exists(graph): + tf.logging.fatal('Graph file does not exist %s', graph) + + labels_list = load_labels(labels) + + # load graph, which is stored in the default session + load_graph(graph) + + run_graph(wav_dir, labels_list, input_name, output_name, how_many_labels) + + +def main(_): + """Entry point for script, converts flags to arguments.""" + label_wav(FLAGS.wav_dir, FLAGS.labels, FLAGS.graph, FLAGS.input_name, + FLAGS.output_name, FLAGS.how_many_labels) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--wav_dir', type=str, default='', help='Audio file to be identified.') + parser.add_argument( + '--graph', type=str, default='', help='Model to use for identification.') + parser.add_argument( + '--labels', type=str, default='', help='Path to file containing labels.') + parser.add_argument( + '--input_name', + type=str, + default='wav_data:0', + help='Name of WAVE data input node in model.') + parser.add_argument( + '--output_name', + type=str, + default='labels_softmax:0', + help='Name of node outputting a prediction in the model.') + parser.add_argument( + '--how_many_labels', + type=int, + default=3, + help='Number of results to show.') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py index a4e80041f82191d7c58a3e52c929340eb604ec9d..f084931215261f183f1ecfc5517ea9a5126db039 100644 --- a/tensorflow/examples/speech_commands/train.py +++ b/tensorflow/examples/speech_commands/train.py @@ -357,12 +357,12 @@ if __name__ == '__main__': '--window_size_ms', type=float, default=30.0, - help='How long each spectrogram timeslice is',) + help='How long each spectrogram timeslice is.',) parser.add_argument( '--window_stride_ms', type=float, default=10.0, - help='How long each spectrogram timeslice is',) + help='How far to move in time between spectogram timeslices.',) parser.add_argument( '--dct_coefficient_count', type=int, diff --git a/tensorflow/examples/tutorials/mnist/input_data.py b/tensorflow/examples/tutorials/mnist/input_data.py index f1a7e1c4af57dba4f06326eb8b03c7eddae86b51..fa148ae3e6f44e140e3b4fb6a4204a601b6c0a24 100644 --- a/tensorflow/examples/tutorials/mnist/input_data.py +++ b/tensorflow/examples/tutorials/mnist/input_data.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=unused-import import gzip import os import tempfile @@ -27,3 +28,4 @@ from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets +# pylint: enable=unused-import diff --git a/tensorflow/examples/tutorials/mnist/mnist_softmax.py b/tensorflow/examples/tutorials/mnist/mnist_softmax.py index fb3ac942039e670fb5ca975c5d9835ba065190a2..47dd6a1947811765101529826c2b24d9798fef1f 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_softmax.py +++ b/tensorflow/examples/tutorials/mnist/mnist_softmax.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """A very simple MNIST classifier. See extensive documentation at @@ -67,12 +66,19 @@ def main(_): # Test trained model correct_prediction = tf.equal(tf.argmax(y, 1), y_) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - print(sess.run(accuracy, feed_dict={x: mnist.test.images, - y_: mnist.test.labels})) + print(sess.run( + accuracy, feed_dict={ + x: mnist.test.images, + y_: mnist.test.labels + })) + if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') + parser.add_argument( + '--data_dir', + type=str, + default='/tmp/tensorflow/mnist/input_data', + help='Directory for storing input data') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py index d055d157454d4cb351e8db59eec484f212893fe5..14ae7fbf35836ad7f5d56101ae0fc33a3f3fb9ba 100644 --- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py @@ -131,7 +131,7 @@ def generate_batch(batch_size, num_skips, skip_window): batch = np.ndarray(shape=(batch_size), dtype=np.int32) labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) span = 2 * skip_window + 1 # [ skip_window target skip_window ] - buffer = collections.deque(maxlen=span) + buffer = collections.deque(maxlen=span) # pylint: disable=redefined-builtin if data_index + span > len(data): data_index = 0 buffer.extend(data[data_index:data_index + span]) @@ -270,12 +270,6 @@ with tf.Session(graph=graph) as session: run_metadata=run_metadata) average_loss += loss_val - # Add returned summaries to writer in each step. - writer.add_summary(summary, step) - # Add metadata to visualize the graph for the last run. - if step == (num_steps - 1): - writer.add_run_metadata(run_metadata, 'step%d' % step) - # Add returned summaries to writer in each step. writer.add_summary(summary, step) # Add metadata to visualize the graph for the last run. diff --git a/tensorflow/examples/udacity/5_word2vec.ipynb b/tensorflow/examples/udacity/5_word2vec.ipynb index 18c456cad787b2ed5b39d5791de649874bbe7ae3..3b43d1fb55ee5d7f6a91754a221962755f04190c 100644 --- a/tensorflow/examples/udacity/5_word2vec.ipynb +++ b/tensorflow/examples/udacity/5_word2vec.ipynb @@ -455,7 +455,7 @@ " \n", " # Compute the similarity between minibatch examples and all embeddings.\n", " # We use the cosine distance:\n", - " norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))\n", + " norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keepdims=True))\n", " normalized_embeddings = embeddings / norm\n", " valid_embeddings = tf.nn.embedding_lookup(\n", " normalized_embeddings, valid_dataset)\n", diff --git a/tensorflow/examples/udacity/Dockerfile b/tensorflow/examples/udacity/Dockerfile index 3ca58566c1ddb4c2446f7d9b19ee31fb8b603909..00eb853e527c922121fae6dc5eab42c589b0b238 100644 --- a/tensorflow/examples/udacity/Dockerfile +++ b/tensorflow/examples/udacity/Dockerfile @@ -8,7 +8,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ apt-get clean && \ rm -rf /var/lib/apt/lists/* -RUN pip install scikit-learn pyreadline Pillow +RUN pip install scikit-learn pyreadline Pillow imageio RUN rm -rf /notebooks/* ADD *.ipynb /notebooks/ WORKDIR /notebooks diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index fc087d9d995dfe031e61fd0fa15d649c2ee35cc9..08943a527cbdc072b12b066240c213be45ffd54c 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -173,7 +173,11 @@ type OpSpec struct { // operation. Attrs map[string]interface{} - // Other possible fields: Device, ColocateWith, ControlInputs. + // Operations that must be executed before executing the operation + // being added. + ControlDependencies []*Operation + + // Other possible fields: Device, ColocateWith. } // AddOperation adds an operation to g. @@ -204,6 +208,9 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { } } } + for _, in := range args.ControlDependencies { + C.TF_AddControlInput(cdesc, in.c) + } status := newStatus() for name, value := range args.Attrs { if err := setAttr(cdesc, status, name, value); err != nil { diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go index a9ec79463a00022bf85bf00032df9004648525ae..13de4294dc2ebdfff9bb68d277c09239d0bc8593 100644 --- a/tensorflow/go/op/scope.go +++ b/tensorflow/go/op/scope.go @@ -33,10 +33,11 @@ import ( // A Scope object and all its derivates (e.g., obtained from Scope.SubScope) // are not safe for concurrent use by multiple goroutines. type Scope struct { - graph *tf.Graph - namemap map[string]int - namespace string - err *scopeErr + graph *tf.Graph + namemap map[string]int + namespace string + controlDependencies []*tf.Operation + err *scopeErr } // scopeErr is used to share errors between all derivatives of a root scope. @@ -80,6 +81,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation { if s.namespace != "" { args.Name = s.namespace + "/" + args.Name } + args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...) op, err := s.graph.AddOperation(args) if err != nil { s.UpdateErr(args.Type, err) @@ -103,6 +105,28 @@ func (s *Scope) SubScope(namespace string) *Scope { } } +// WithControlDependencies returns a new Scope which will cause all operations +// added to the graph to execute only after all the provided operations have +// executed first (in addition to any other control dependencies in s). +func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope { + // Force a copy of the control dependencies into a new underlying array on + // every call. We cannot alias the same underlying array as `ops`, otherwise + // the user could modify that array after calling s.WithControlDependencies, + // which would be confusing. We cannot alias the same underlying array as the + // original `s.controlDependencies`, since Scopes form a logical tree, and + // other calls to s.WithControlDependencies could stomp on each other. + deps := make([]*tf.Operation, 0, len(s.controlDependencies)+len(ops)) + deps = append(deps, s.controlDependencies...) + deps = append(deps, ops...) + return &Scope{ + graph: s.graph, + namemap: s.namemap, + namespace: s.namespace, + controlDependencies: deps, + err: s.err, + } +} + // Err returns the error, if any, encountered during the construction // of the Graph managed by s. // diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go index 6fb5d32e503c7c9a5a48747844da15be81b1de2d..b58a61de98b0f5b04959e1eca35c6b6c4d77e42b 100644 --- a/tensorflow/go/op/scope_test.go +++ b/tensorflow/go/op/scope_test.go @@ -69,6 +69,49 @@ func TestScopeSubScopeErrors(t *testing.T) { } } +func TestControlDependencies(t *testing.T) { + var ( + s = NewScope() + zero = Const(s.SubScope("zero"), int32(0)) + one = Const(s.SubScope("one"), int32(1)) + variable = VarHandleOp(s, tf.Int32, tf.ScalarShape()) + init = AssignVariableOp(s, variable, zero) + update = AssignAddVariableOp(s, variable, one) + readDeps = []*tf.Operation{update} + ) + // We intend for `read` to have a control dependency on `update`. + s = s.WithControlDependencies(readDeps...) + // Ensure that Scope.WithControlDependencies makes a copy of the underlying + // array, rather than just holding a slice reference to the same user-supplied + // underlying array. If the copy is correctly performed, overwriting + // readDeps[0] should have no effect on control dependencies for `read`. + readDeps[0] = init + read := ReadVariableOp(s, variable, tf.Int32) + + graph, err := s.Finalize() + if err != nil { + t.Fatal(err) + } + sess, err := tf.NewSession(graph, nil) + if err != nil { + t.Fatal(err) + } + if _, err = sess.Run(nil, nil, []*tf.Operation{init}); err != nil { + t.Fatal(err) + } + // Without the control dependency, the read operation may not see the + // update. + for i := int32(0); i < 10; i++ { + out, err := sess.Run(nil, []tf.Output{read}, nil) + if err != nil { + t.Fatal(err) + } + if got, want := out[0].Value().(int32), i+1; got != want { + t.Errorf("Got %d, want %d", got, want) + } + } +} + func TestScopeFinalize(t *testing.T) { var ( root = NewScope() diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 5b19c90238ef3bb1361a5e2476e94dd06e76d128..13f38dfb32a476477d306093bad6b56e1744a640 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -220,6 +220,64 @@ func CreateSummaryFileWriter(scope *Scope, writer tf.Output, logdir tf.Output, m return scope.AddOperation(opspec) } +// FakeQuantWithMinMaxVarsPerChannelGradientAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannelGradient. +type FakeQuantWithMinMaxVarsPerChannelGradientAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsPerChannelGradientNumBits sets the optional num_bits attribute to value. +// +// value: The bitwidth of the quantization; between 2 and 8, inclusive. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsPerChannelGradientNumBits(value int64) FakeQuantWithMinMaxVarsPerChannelGradientAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange sets the optional narrow_range attribute to value. +// +// value: Whether to quantize into 2^num_bits - 1 distinct values. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation. +// +// Arguments: +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation, +// shape one of: `[d]`, `[b, d]`, `[b, h, w, d]`. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape +// same as `gradients`. +// min, max: Quantization interval, floats of shape `[d]`. +// +// +// +// Returns Backpropagated gradients w.r.t. inputs, shape same as +// `inputs`: +// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter, shape `[d]`: +// `sum_per_d(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter, shape `[d]`: +// `sum_per_d(gradients * (inputs > max))`. +func FakeQuantWithMinMaxVarsPerChannelGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FakeQuantWithMinMaxVarsPerChannelGradient", + Input: []tf.Input{ + gradients, inputs, min, max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // Partitions `data` into `num_partitions` tensors using indices from `partitions`. // // For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` @@ -629,6 +687,77 @@ func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, v return scope.AddOperation(opspec) } +// MapPeekAttr is an optional argument to MapPeek. +type MapPeekAttr func(optionalAttr) + +// MapPeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapPeekCapacity(value int64) MapPeekAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapPeekMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapPeekMemoryLimit(value int64) MapPeekAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapPeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapPeekContainer(value string) MapPeekAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapPeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapPeekSharedName(value string) MapPeekAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op peeks at the values at the specified key. If the +// +// underlying container does not contain this key +// this op will block until it does. +func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MapPeek", + Input: []tf.Input{ + key, indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapPeek", err) + return + } + return values +} + // Returns (x - y)(x - y) element-wise. // // *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting @@ -4509,6 +4638,68 @@ func CriticalSectionOp(scope *Scope, optional ...CriticalSectionOpAttr) (resourc return op.Output(0) } +// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. +type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) + +// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value. +// If not specified, defaults to -6 +func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["min"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value. +// If not specified, defaults to 6 +func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["max"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Compute gradients for a FakeQuantWithMinMaxArgs operation. +// +// Arguments: +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. +// +// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: +// `gradients * (inputs >= min && inputs <= max)`. +func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FakeQuantWithMinMaxArgsGradient", + Input: []tf.Input{ + gradients, inputs, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // AvgPool3DAttr is an optional argument to AvgPool3D. type AvgPool3DAttr func(optionalAttr) @@ -8729,31 +8920,6 @@ func IRFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Out return op.Output(0) } -// Compute the pairwise cross product. -// -// `a` and `b` must be the same shape; they can either be simple 3-element vectors, -// or any shape where the innermost dimension is 3. In the latter case, each pair -// of corresponding 3-element vectors is cross-multiplied independently. -// -// Arguments: -// a: A tensor containing 3-element vectors. -// b: Another tensor, of same type and shape as `a`. -// -// Returns Pairwise cross product of the vectors in `a` and `b`. -func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Cross", - Input: []tf.Input{ - a, b, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Transforms a vector of brain.Example protos (as strings) into typed tensors. // // Arguments: @@ -19872,16 +20038,42 @@ func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset t return op.Output(0) } -// Creates a dataset that contains the elements of `input_dataset` ignoring errors. -func IgnoreErrorsDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Adds a value to the current value of a variable. +// +// Any ReadVariableOp which depends directly or indirectly on this assign is +// guaranteed to see the incremented value or a subsequent newer one. +// +// Outputs the incremented value, which can be used to totally order the +// increments to this variable. +// +// Arguments: +// resource: handle to the resource in which to store the variable. +// value: the value by which the variable will be incremented. +// +// Returns the created operation. +func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AssignAddVariableOp", + Input: []tf.Input{ + resource, value, + }, + } + return scope.AddOperation(opspec) +} + +// Records the latency of producing `input_dataset` elements in a StatsAggregator. +func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "IgnoreErrorsDataset", + Type: "LatencyStatsDataset", Input: []tf.Input{ - input_dataset, + input_dataset, tag, }, Attrs: attrs, } @@ -19889,145 +20081,7 @@ func IgnoreErrorsDataset(scope *Scope, input_dataset tf.Output, output_types []t return op.Output(0) } -// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. -type CropAndResizeGradImageAttr func(optionalAttr) - -// CropAndResizeGradImageMethod sets the optional method attribute to value. -// -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { - return func(m optionalAttr) { - m["method"] = value - } -} - -// Computes the gradient of the crop_and_resize op wrt the input image tensor. -// -// Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` -// containing the original image size. Both `image_height` and `image_width` need -// to be positive. -// -// -// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"T": T} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CropAndResizeGradImage", - Input: []tf.Input{ - grads, boxes, box_ind, image_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Reads and outputs the entire contents of the input filename. -func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReadFile", - Input: []tf.Input{ - filename, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Concatenates tensors along one dimension. -// -// Arguments: -// values: List of `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// axis: 0-D. The dimension along which to concatenate. Must be in the -// range [-rank(values), rank(values)). -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ConcatV2", - Input: []tf.Input{ - tf.OutputList(values), axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Adds a value to the current value of a variable. -// -// Any ReadVariableOp which depends directly or indirectly on this assign is -// guaranteed to see the incremented value or a subsequent newer one. -// -// Outputs the incremented value, which can be used to totally order the -// increments to this variable. -// -// Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value by which the variable will be incremented. -// -// Returns the created operation. -func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AssignAddVariableOp", - Input: []tf.Input{ - resource, value, - }, - } - return scope.AddOperation(opspec) -} - -// Records the latency of producing `input_dataset` elements in a StatsAggregator. -func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "LatencyStatsDataset", - Input: []tf.Input{ - input_dataset, tag, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Convert JSON-encoded Example records to binary protocol buffer strings. +// Convert JSON-encoded Example records to binary protocol buffer strings. // // This op translates a tensor containing Example records, encoded using // the [standard JSON @@ -20705,68 +20759,6 @@ func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Outp return op.Output(0) } -// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. -type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) - -// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value. -// If not specified, defaults to -6 -func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["min"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value. -// If not specified, defaults to 6 -func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["max"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Compute gradients for a FakeQuantWithMinMaxArgs operation. -// -// Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. -// -// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: -// `gradients * (inputs >= min && inputs <= max)`. -func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxArgsGradient", - Input: []tf.Input{ - gradients, inputs, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // BatchToSpace for 4-D tensors of type T. // // This is a legacy version of the more general BatchToSpaceND. @@ -21290,6 +21282,31 @@ func StatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output return op.Output(0) } +// Compute the pairwise cross product. +// +// `a` and `b` must be the same shape; they can either be simple 3-element vectors, +// or any shape where the innermost dimension is 3. In the latter case, each pair +// of corresponding 3-element vectors is cross-multiplied independently. +// +// Arguments: +// a: A tensor containing 3-element vectors. +// b: Another tensor, of same type and shape as `a`. +// +// Returns Pairwise cross product of the vectors in `a` and `b`. +func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cross", + Input: []tf.Input{ + a, b, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Performs a padding as a preprocess during a convolution. // // Similar to FusedResizeAndPadConv2d, this op allows for an optimized @@ -22211,45 +22228,210 @@ func EncodeBase64Pad(value bool) EncodeBase64Attr { // end so that the encoded has length multiple of 4. See Padding section of the // link above. // -// Web-safe means that the encoder uses - and _ instead of + and /. +// Web-safe means that the encoder uses - and _ instead of + and /. +// +// Arguments: +// input: Strings to be encoded. +// +// Returns Input strings encoded in base64. +func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EncodeBase64", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated. Use TensorArrayCloseV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 +// +// Returns the created operation. +func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArrayCloseV2", + Input: []tf.Input{ + handle, + }, + } + return scope.AddOperation(opspec) +} + +// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. +type CropAndResizeGradImageAttr func(optionalAttr) + +// CropAndResizeGradImageMethod sets the optional method attribute to value. +// +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { + return func(m optionalAttr) { + m["method"] = value + } +} + +// Computes the gradient of the crop_and_resize op wrt the input image tensor. +// +// Arguments: +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` +// containing the original image size. Both `image_height` and `image_width` need +// to be positive. +// +// +// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CropAndResizeGradImage", + Input: []tf.Input{ + grads, boxes, box_ind, image_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reads and outputs the entire contents of the input filename. +func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReadFile", + Input: []tf.Input{ + filename, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Concatenates tensors along one dimension. +// +// Arguments: +// values: List of `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// axis: 0-D. The dimension along which to concatenate. Must be in the +// range [-rank(values), rank(values)). +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ConcatV2", + Input: []tf.Input{ + tf.OutputList(values), axis, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Forwards the value of an available tensor from `inputs` to `output`. +// +// `Merge` waits for at least one of the tensors in `inputs` to become available. +// It is usually combined with `Switch` to implement branching. +// +// `Merge` forwards the first tensor to become available to `output`, and sets +// `value_index` to its index in `inputs`. // // Arguments: -// input: Strings to be encoded. +// inputs: The input tensors, exactly one of which will become available. // -// Returns Input strings encoded in base64. -func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { +// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. +func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "EncodeBase64", + Type: "Merge", Input: []tf.Input{ - input, + tf.OutputList(inputs), }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Deprecated. Use TensorArrayCloseV3 +// QueueCloseV2Attr is an optional argument to QueueCloseV2. +type QueueCloseV2Attr func(optionalAttr) + +// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. // -// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 +// value: If true, all pending enqueue requests that are +// blocked on the given queue will be canceled. +// If not specified, defaults to false +func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { + return func(m optionalAttr) { + m["cancel_pending_enqueues"] = value + } +} + +// Closes the given queue. +// +// This operation signals that no more elements will be enqueued in the +// given queue. Subsequent Enqueue(Many) operations will fail. +// Subsequent Dequeue(Many) operations will continue to succeed if +// sufficient elements remain in the queue. Subsequent Dequeue(Many) +// operations that would block will fail immediately. +// +// Arguments: +// handle: The handle to a queue. // // Returns the created operation. -func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { +func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArrayCloseV2", + Type: "QueueCloseV2", Input: []tf.Input{ handle, }, + Attrs: attrs, } return scope.AddOperation(opspec) } @@ -24203,147 +24385,6 @@ func MapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output return scope.AddOperation(opspec) } -// MapPeekAttr is an optional argument to MapPeek. -type MapPeekAttr func(optionalAttr) - -// MapPeekCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapPeekCapacity(value int64) MapPeekAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapPeekMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapPeekMemoryLimit(value int64) MapPeekAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapPeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapPeekContainer(value string) MapPeekAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapPeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapPeekSharedName(value string) MapPeekAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op peeks at the values at the specified key. If the -// -// underlying container does not contain this key -// this op will block until it does. -func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MapPeek", - Input: []tf.Input{ - key, indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapPeek", err) - return - } - return values -} - -// QueueCloseV2Attr is an optional argument to QueueCloseV2. -type QueueCloseV2Attr func(optionalAttr) - -// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. -// -// value: If true, all pending enqueue requests that are -// blocked on the given queue will be canceled. -// If not specified, defaults to false -func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { - return func(m optionalAttr) { - m["cancel_pending_enqueues"] = value - } -} - -// Closes the given queue. -// -// This operation signals that no more elements will be enqueued in the -// given queue. Subsequent Enqueue(Many) operations will fail. -// Subsequent Dequeue(Many) operations will continue to succeed if -// sufficient elements remain in the queue. Subsequent Dequeue(Many) -// operations that would block will fail immediately. -// -// Arguments: -// handle: The handle to a queue. -// -// Returns the created operation. -func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QueueCloseV2", - Input: []tf.Input{ - handle, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Forwards the value of an available tensor from `inputs` to `output`. -// -// `Merge` waits for at least one of the tensors in `inputs` to become available. -// It is usually combined with `Switch` to implement branching. -// -// `Merge` forwards the first tensor to become available to `output`, and sets -// `value_index` to its index in `inputs`. -// -// Arguments: -// inputs: The input tensors, exactly one of which will become available. -// -// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. -func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Merge", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - // MapUnstageAttr is an optional argument to MapUnstage. type MapUnstageAttr func(optionalAttr) @@ -28297,61 +28338,3 @@ func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max op := scope.AddOperation(opspec) return op.Output(0) } - -// FakeQuantWithMinMaxVarsPerChannelGradientAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannelGradient. -type FakeQuantWithMinMaxVarsPerChannelGradientAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsPerChannelGradientNumBits sets the optional num_bits attribute to value. -// -// value: The bitwidth of the quantization; between 2 and 8, inclusive. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsPerChannelGradientNumBits(value int64) FakeQuantWithMinMaxVarsPerChannelGradientAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange sets the optional narrow_range attribute to value. -// -// value: Whether to quantize into 2^num_bits - 1 distinct values. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsPerChannelGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelGradientAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation. -// -// Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation, -// shape one of: `[d]`, `[b, d]`, `[b, h, w, d]`. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape -// same as `gradients`. -// min, max: Quantization interval, floats of shape `[d]`. -// -// -// -// Returns Backpropagated gradients w.r.t. inputs, shape same as -// `inputs`: -// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter, shape `[d]`: -// `sum_per_d(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter, shape `[d]`: -// `sum_per_d(gradients * (inputs > max))`. -func FakeQuantWithMinMaxVarsPerChannelGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVarsPerChannelGradient", - Input: []tf.Input{ - gradients, inputs, min, max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index 6285ee0483d9171d6cdb9b4dbf2675bafb953038..d35bb4111271c11839a160517dc9695ead5b46e9 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.5.0-rc1 + 1.6.0-rc1 ../ libtensorflow diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index b0e5c44fecc9bf3a95ac3d4e36d9f98d74d3b2bb..d9ba1bbbfb91170257f64a56f47c6c980e8a9570 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.5.0-rc1 + 1.6.0-rc1 ../ libtensorflow_jni diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index 02c5dca13f4d292718afca7e99bac82710e1949f..f6f532c2c10d0a4dad9fc2d7750ea708652000b1 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.5.0-rc1 + 1.6.0-rc1 ../ libtensorflow_jni_gpu diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index 949597ca7f1e7a05cf6c0e5a15cb5307b00859a1..0a6b3d23d7d37515cf275e6a46842e32ada4fee1 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ 4.0.0 org.tensorflow parentpom - 1.5.0-rc1 + 1.6.0-rc1 pom https://www.tensorflow.org diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index 9f0ebcf84c9c8e01662a93034a4407c6b58a6d7e..1d8e8723731f959c8142f0648fc805593d7beac8 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.5.0-rc1 + 1.6.0-rc1 ../ proto diff --git a/tensorflow/java/maven/tensorflow-android/update.py b/tensorflow/java/maven/tensorflow-android/update.py index 7c250718347f5fdd65aaf8003aad75a87a19c96a..4ae666e4e5351f1bdaf79d1b5cfdb63b0f811e2b 100644 --- a/tensorflow/java/maven/tensorflow-android/update.py +++ b/tensorflow/java/maven/tensorflow-android/update.py @@ -95,7 +95,7 @@ def main(): release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow' info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version) aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version) - build_type = 'release-matrix-android' + build_type = 'release-matrix-android2' # Retrieve build information build_info = get_json(info_url) diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 88d897362ad6c8f84d93cbc9bcf3c30905b345be..5c1b55085c5df1ec473a3f4e0bf750b236cfc264 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.5.0-rc1 + 1.6.0-rc1 ../ tensorflow diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java index 499757e8cf4d6166e425d801ce20335bd8ad83e8..cf773e1686dea97f62f432be43f2c10b69fa8e24 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java +++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java @@ -88,7 +88,7 @@ final class NativeLibrary { // Deletions are in the reverse order of requests, so we need to request that the directory be // deleted first, so that it is empty when the request is fulfilled. tempPath.deleteOnExit(); - final String tempDirectory = tempPath.toString(); + final String tempDirectory = tempPath.getCanonicalPath(); if (frameworkResource != null) { extractResource(frameworkResource, frameworkLibName, tempDirectory); } else { diff --git a/tensorflow/java/src/main/java/org/tensorflow/package-info.java b/tensorflow/java/src/main/java/org/tensorflow/package-info.java index dd4859e1b14045e4123e7f15fbaff98e14d0b377..521c5c610c1f775cf9174664f5b786786ce1181d 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/package-info.java +++ b/tensorflow/java/src/main/java/org/tensorflow/package-info.java @@ -35,5 +35,9 @@ limitations under the License. *
  • Graph execution: Using a Session to execute the graphs and find the best label for an * image. * + * + *

    Additional examples can be found in the tensorflow/models + * GitHub repository. */ package org.tensorflow; diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 01b3e92d2d9edc12afc6c98da44a4442796592e9..cee7c47e00d5673cd2abf2b1e526523ad61bbafd 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1,5 +1,8 @@ # Description: # Python support for TensorFlow. +# +# Public targets: +# ":platform" - Low-level and platform-specific Python code. package( default_visibility = [ @@ -76,6 +79,7 @@ py_library( ":layers", ":lib", ":list_ops", + ":manip_ops", ":math_ops", ":metrics", ":nn", @@ -131,6 +135,7 @@ py_library( ], ) + ["platform/build_info.py"], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":lib", ":pywrap_tensorflow", @@ -298,6 +303,7 @@ cc_library( ":safe_ptr", "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -575,6 +581,7 @@ py_library( ":pywrap_tensorflow", ":random_seed", ":sparse_tensor", + ":tensor_spec", ":tensor_util", ":util", "//tensorflow/python/eager:context", @@ -779,6 +786,18 @@ py_library( ], ) +py_library( + name = "tensor_spec", + srcs = ["framework/tensor_spec.py"], + srcs_version = "PY2AND3", + deps = [ + ":common_shapes", + ":dtypes", + ":tensor_shape", + "//third_party/py/numpy", + ], +) + py_library( name = "tensor_util", srcs = ["framework/tensor_util.py"], @@ -1147,6 +1166,21 @@ py_test( ], ) +py_test( + name = "framework_tensor_spec_test", + size = "small", + srcs = ["framework/tensor_spec_test.py"], + main = "framework/tensor_spec_test.py", + srcs_version = "PY2AND3", + deps = [ + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":platform_test", + ":tensor_spec", + "//third_party/py/numpy", + ], +) + py_test( name = "framework_sparse_tensor_test", size = "small", @@ -1394,6 +1428,14 @@ tf_gen_op_wrapper_private_py( ], ) +tf_gen_op_wrapper_private_py( + name = "manip_ops_gen", + visibility = [ + "//learning/brain/python/ops:__pkg__", + "//tensorflow/python/kernel_tests:__pkg__", + ], +) + tf_gen_op_wrapper_private_py( name = "math_ops_gen", visibility = [ @@ -1726,6 +1768,8 @@ py_library( ":linalg_grad", ":linalg_ops", ":logging_ops", + ":manip_grad", + ":manip_ops", ":math_grad", ":math_ops", ":platform", @@ -1848,6 +1892,29 @@ py_library( ], ) +py_library( + name = "manip_grad", + srcs = ["ops/manip_grad.py"], + srcs_version = "PY2AND3", + deps = [ + ":control_flow_ops", + ":framework_for_generated_wrappers", + ":manip_ops", + ], +) + +py_library( + name = "manip_ops", + srcs = ["ops/manip_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":dtypes", + ":framework_ops", + ":manip_ops_gen", + "//third_party/py/numpy", + ], +) + py_library( name = "logging_ops", srcs = ["ops/logging_ops.py"], @@ -2310,6 +2377,8 @@ py_library( ":linalg_ops", ":logging_ops", ":lookup_ops", + ":manip_grad", + ":manip_ops", ":math_grad", ":math_ops", ":numerics", @@ -2449,6 +2518,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":array_ops", + ":checkpointable", ":control_flow_ops", ":dtypes", ":framework_ops", @@ -2668,6 +2738,7 @@ cuda_py_test( ":nn_ops_gen", "//third_party/py/numpy", ], + shard_count = 4, tags = ["no_windows"], ) @@ -2781,6 +2852,30 @@ py_library( ], ) +py_library( + name = "checkpointable", + srcs = ["training/checkpointable.py"], + srcs_version = "PY2AND3", + deps = [ + ":dtypes", + ":io_ops_gen", + ":ops", + ":pywrap_tensorflow", + ":util", + "//tensorflow/python/eager:context", + ], +) + +py_test( + name = "checkpointable_test", + srcs = ["training/checkpointable_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":checkpointable", + ":client_testlib", + ], +) + py_test( name = "evaluation_test", size = "small", @@ -4228,12 +4323,6 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) -filegroup( - name = "hidden_ops", - srcs = ["ops/hidden_ops.txt"], - visibility = ["//tensorflow:__subpackages__"], -) - cuda_py_test( name = "accumulate_n_benchmark", size = "large", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index bc9ddec2a54a784027120828e9b15a2bf500414e..02ed5517ca895ab070a89f8810f77dadcff9212b 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -60,7 +60,7 @@ from tensorflow.core.protobuf.tensorflow_server_pb2 import * from tensorflow.core.util.event_pb2 import * # Framework -from tensorflow.python.framework.framework_lib import * +from tensorflow.python.framework.framework_lib import * # pylint: disable=redefined-builtin from tensorflow.python.framework.versions import * from tensorflow.python.framework import errors from tensorflow.python.framework import graph_util @@ -84,6 +84,7 @@ from tensorflow.python.feature_column import feature_column_lib as feature_colum from tensorflow.python.layers import layers from tensorflow.python.ops import bitwise_ops as bitwise from tensorflow.python.ops import image_ops as image +from tensorflow.python.ops import manip_ops as manip from tensorflow.python.ops import metrics from tensorflow.python.ops import nn from tensorflow.python.ops import sets @@ -115,6 +116,7 @@ from tensorflow.python.platform import test from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.all_util import make_all +from tensorflow.python.util.tf_export import tf_export # Import modules whose docstrings contribute, for use by remove_undocumented # below. @@ -166,6 +168,31 @@ _allowed_symbols = [ 'TensorInfo', # Used for tf.saved_model functionality. ] +# Export protos +# pylint: disable=undefined-variable +tf_export('AttrValue')(AttrValue) +tf_export('ConfigProto')(ConfigProto) +tf_export('Event', 'summary.Event')(Event) +tf_export('GPUOptions')(GPUOptions) +tf_export('GraphDef')(GraphDef) +tf_export('GraphOptions')(GraphOptions) +tf_export('HistogramProto')(HistogramProto) +tf_export('LogMessage')(LogMessage) +tf_export('MetaGraphDef')(MetaGraphDef) +tf_export('NameAttrList')(NameAttrList) +tf_export('NodeDef')(NodeDef) +tf_export('OptimizerOptions')(OptimizerOptions) +tf_export('RunMetadata')(RunMetadata) +tf_export('RunOptions')(RunOptions) +tf_export('SessionLog', 'summary.SessionLog')(SessionLog) +tf_export('Summary', 'summary.Summary')(Summary) +tf_export('summary.SummaryDescription')(SummaryDescription) +tf_export('SummaryMetadata')(SummaryMetadata) +tf_export('summary.TaggedRunMetadata')(TaggedRunMetadata) +tf_export('TensorInfo')(TensorInfo) +# pylint: enable=undefined-variable + + # The following symbols are kept for compatibility. It is our plan # to remove them in the future. _allowed_symbols.extend([ @@ -241,6 +268,7 @@ _allowed_symbols.extend([ 'linalg', 'logging', 'losses', + 'manip', 'metrics', 'newaxis', 'nn', diff --git a/tensorflow/python/build_defs.bzl b/tensorflow/python/build_defs.bzl index 7f29adc06fcc5922114b7cd2bde8a8df5b1e0665..b9056f86e6d0465a8521f054a459c06eb5aeb37c 100644 --- a/tensorflow/python/build_defs.bzl +++ b/tensorflow/python/build_defs.bzl @@ -22,7 +22,6 @@ def tf_gen_op_wrapper_private_py(name, out=None, deps=[], bare_op_name = name[:-4] # Strip off the _gen tf_gen_op_wrapper_py(name=bare_op_name, out=out, - hidden_file="ops/hidden_ops.txt", visibility=visibility, deps=deps, require_shape_functions=require_shape_functions, diff --git a/tensorflow/python/client/device_lib_test.py b/tensorflow/python/client/device_lib_test.py index 7bba10efacfbc7fbde402c665b3d55d852e36eae..aaf41626ab0078489026036d2b838f33a893a540 100644 --- a/tensorflow/python/client/device_lib_test.py +++ b/tensorflow/python/client/device_lib_test.py @@ -34,7 +34,8 @@ class DeviceLibTest(test_util.TensorFlowTestCase): # GPU test if test.is_gpu_available(): self.assertGreater(len(devices), 1) - self.assertTrue("GPU" in [d.device_type for d in devices] or "SYCL" in [d.device_type for d in devices]) + self.assertTrue("GPU" in [d.device_type for d in devices] or + "SYCL" in [d.device_type for d in devices]) if __name__ == "__main__": diff --git a/tensorflow/python/client/events_writer.i b/tensorflow/python/client/events_writer.i index de030fcb4282912475ed8853bae9d41cde2c085d..c72b76b8fa4a05588841466a836bc189bb64d154 100644 --- a/tensorflow/python/client/events_writer.i +++ b/tensorflow/python/client/events_writer.i @@ -23,6 +23,9 @@ limitations under the License. %nodefaultctor EventsWriter; +%ignore tensorflow::Status::operator=; +%include "tensorflow/core/lib/core/status.h" + %ignoreall %unignore tensorflow; %unignore tensorflow::EventsWriter; diff --git a/tensorflow/python/client/notebook.py b/tensorflow/python/client/notebook.py index 8babe35b3230e7b46c0c9484ccddae4e5e22a335..4b6a0f71ae65aa28b70dd22ce6cffa82e9bc5973 100644 --- a/tensorflow/python/client/notebook.py +++ b/tensorflow/python/client/notebook.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Notebook front-end to TensorFlow. When you run this binary, you'll see something like below, which indicates @@ -43,10 +42,8 @@ from tensorflow.python.platform import app os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2" - FLAGS = None - ORIG_ARGV = sys.argv # Main notebook process calls itself with argv[1]="kernel" to start kernel # subprocesses. @@ -73,8 +70,8 @@ def main(unused_argv): notebookapp.ip = "0.0.0.0" notebookapp.password = passwd(FLAGS.password) else: - print ("\nNo password specified; Notebook server will only be available" - " on the local machine.\n") + print("\nNo password specified; Notebook server will only be available" + " on the local machine.\n") notebookapp.initialize(argv=["--notebook-dir", FLAGS.notebook_dir]) if notebookapp.ip == "0.0.0.0": @@ -125,8 +122,8 @@ if __name__ == "__main__": # kernel app. if IS_KERNEL: # Drop everything except --flagfile. - sys.argv = ([sys.argv[0]] + - [x for x in sys.argv[1:] if x.startswith("--flagfile")]) + sys.argv = ( + [sys.argv[0]] + [x for x in sys.argv[1:] if x.startswith("--flagfile")]) FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 1481a4d035cbc63aa655be6c4d441e6f6741e118..f3c4fecdc0fde0436bea76cc774edaabe1bc07dd 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """A client interface for TensorFlow.""" from __future__ import absolute_import @@ -36,6 +35,7 @@ from tensorflow.python.ops import session_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export class SessionInterface(object): @@ -71,8 +71,9 @@ def _get_indexed_slices_value_from_fetches(fetched_vals): def _get_feeds_for_indexed_slices(feed, feed_val): - return list(zip([feed.values, feed.indices] if feed.dense_shape is None else - [feed.values, feed.indices, feed.dense_shape], feed_val)) + return list( + zip([feed.values, feed.indices] if feed.dense_shape is None else + [feed.values, feed.indices, feed.dense_shape], feed_val)) # List of extensions supported to convert run arguments into actual fetches and @@ -124,6 +125,7 @@ _REGISTERED_EXPANSIONS = [ lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), lambda feed, feed_val: [(feed, feed_val)], lambda feed: [feed])] + # pylint: enable=g-long-lambda @@ -132,8 +134,11 @@ def _convert_to_numpy_obj(numpy_dtype, obj): return numpy_dtype(obj) if numpy_dtype is not object else str(obj) -def register_session_run_conversion_functions(tensor_type, fetch_function, - feed_function=None, feed_function_for_partial_run=None): +def register_session_run_conversion_functions( + tensor_type, + fetch_function, + feed_function=None, + feed_function_for_partial_run=None): """Register fetch and feed conversion functions for `tf.Session.run()`. This function registers a triple of conversion functions for fetching and/or @@ -174,11 +179,11 @@ def register_session_run_conversion_functions(tensor_type, fetch_function, """ for conversion_function in _REGISTERED_EXPANSIONS: if issubclass(conversion_function[0], tensor_type): - raise ValueError( - '%s has already been registered so ignore it.', tensor_type) + raise ValueError('%s has already been registered so ignore it.', + tensor_type) return - _REGISTERED_EXPANSIONS.insert(0, - (tensor_type, fetch_function, feed_function, feed_function_for_partial_run)) + _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function, + feed_function_for_partial_run)) class _FetchMapper(object): @@ -233,8 +238,8 @@ class _FetchMapper(object): An instance of a subclass of `_FetchMapper` that handles the shape. """ if fetch is None: - raise TypeError('Fetch argument %r has invalid type %r' % - (fetch, type(fetch))) + raise TypeError('Fetch argument %r has invalid type %r' % (fetch, + type(fetch))) elif isinstance(fetch, (list, tuple)): # NOTE(touts): This is also the code path for namedtuples. return _ListFetchMapper(fetch) @@ -247,8 +252,8 @@ class _FetchMapper(object): fetches, contraction_fn = fetch_fn(fetch) return _ElementFetchMapper(fetches, contraction_fn) # Did not find anything. - raise TypeError('Fetch argument %r has invalid type %r' % - (fetch, type(fetch))) + raise TypeError('Fetch argument %r has invalid type %r' % (fetch, + type(fetch))) class _ElementFetchMapper(_FetchMapper): @@ -277,8 +282,8 @@ class _ElementFetchMapper(_FetchMapper): fetch, allow_tensor=True, allow_operation=True)) except TypeError as e: raise TypeError('Fetch argument %r has invalid type %r, ' - 'must be a string or Tensor. (%s)' - % (fetch, type(fetch), str(e))) + 'must be a string or Tensor. (%s)' % + (fetch, type(fetch), str(e))) except ValueError as e: raise ValueError('Fetch argument %r cannot be interpreted as a ' 'Tensor. (%s)' % (fetch, str(e))) @@ -376,8 +381,9 @@ class _DictFetchMapper(_FetchMapper): """ self._fetch_type = type(fetches) self._keys = fetches.keys() - self._mappers = [_FetchMapper.for_fetch(fetch) - for fetch in fetches.values()] + self._mappers = [ + _FetchMapper.for_fetch(fetch) for fetch in fetches.values() + ] self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) def unique_fetches(self): @@ -401,6 +407,7 @@ class _FetchHandler(object): result structure matching the user-provided structure for fetches, but containing the corresponding results. """ + # TODO(touts): Make this class also take care of destructuring the feed # dict instead of doing it in the callers. @@ -551,8 +558,11 @@ class _DeviceAttributes(object): return self._memory_limit_bytes def __repr__(self): - return '_DeviceAttributes(%s, %s, %d)' % (self.name, self.device_type, - self.memory_limit_bytes,) + return '_DeviceAttributes(%s, %s, %d)' % ( + self.name, + self.device_type, + self.memory_limit_bytes, + ) class BaseSession(SessionInterface): @@ -601,8 +611,8 @@ class BaseSession(SessionInterface): if config is not None: if not isinstance(config, config_pb2.ConfigProto): - raise TypeError('config must be a tf.ConfigProto, but got %s' - % type(config)) + raise TypeError( + 'config must be a tf.ConfigProto, but got %s' % type(config)) self._config = config self._add_shapes = config.graph_options.infer_shapes else: @@ -976,8 +986,8 @@ class BaseSession(SessionInterface): for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS: if isinstance(feed, tensor_type): return feed_fn(feed) - raise TypeError('Feed argument %r has invalid type %r' - % (feed, type(feed))) + raise TypeError('Feed argument %r has invalid type %r' % (feed, + type(feed))) # Check session. if self._closed: @@ -998,8 +1008,8 @@ class BaseSession(SessionInterface): for feed in feeds: for subfeed in _feed_fn(feed): try: - subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True, - allow_operation=False) + subfeed_t = self.graph.as_graph_element( + subfeed, allow_tensor=True, allow_operation=False) if self._created_with_new_api: # pylint: disable=protected-access feed_list.append(subfeed_t._as_tf_output()) @@ -1007,8 +1017,7 @@ class BaseSession(SessionInterface): else: feed_list.append(compat.as_bytes(subfeed_t.name)) except Exception as e: - e.message = ('Cannot interpret feed_list key as Tensor: ' - + e.message) + e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message) e.args = (e.message,) raise e @@ -1041,12 +1050,13 @@ class BaseSession(SessionInterface): def _run(self, handle, fetches, feed_dict, options, run_metadata): """Perform either run or partial_run, depending the presence of `handle`.""" + def _feed_fn(feed, feed_val): for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS: if isinstance(feed, tensor_type): return feed_fn(feed, feed_val) - raise TypeError('Feed argument %r has invalid type %r' - % (feed, type(feed))) + raise TypeError('Feed argument %r has invalid type %r' % (feed, + type(feed))) # Check session. if self._closed: @@ -1066,11 +1076,11 @@ class BaseSession(SessionInterface): for feed, feed_val in feed_dict.items(): for subfeed, subfeed_val in _feed_fn(feed, feed_val): try: - subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True, - allow_operation=False) + subfeed_t = self.graph.as_graph_element( + subfeed, allow_tensor=True, allow_operation=False) except Exception as e: - raise TypeError('Cannot interpret feed_dict key as Tensor: ' - + e.args[0]) + raise TypeError( + 'Cannot interpret feed_dict key as Tensor: ' + e.args[0]) if isinstance(subfeed_val, ops.Tensor): raise TypeError('The value of a feed cannot be a tf.Tensor object. ' @@ -1081,10 +1091,9 @@ class BaseSession(SessionInterface): if isinstance(subfeed_val, int) and _convert_to_numpy_obj( subfeed_dtype, subfeed_val) != subfeed_val: raise TypeError( - 'Type of feed value ' + str(subfeed_val) + ' with type ' + - str(type(subfeed_val)) + - ' is not compatible with Tensor type ' + - str(subfeed_dtype) + + 'Type of feed value ' + str(subfeed_val) + ' with type ' + str( + type(subfeed_val)) + + ' is not compatible with Tensor type ' + str(subfeed_dtype) + '. Try explicitly setting the type of the feed tensor' ' to a larger type (e.g. int64).') @@ -1098,10 +1107,10 @@ class BaseSession(SessionInterface): if (not is_tensor_handle_feed and not subfeed_t.get_shape().is_compatible_with(np_val.shape)): - raise ValueError( - 'Cannot feed value of shape %r for Tensor %r, ' - 'which has shape %r' - % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) + raise ValueError('Cannot feed value of shape %r for Tensor %r, ' + 'which has shape %r' % + (np_val.shape, subfeed_t.name, + str(subfeed_t.get_shape()))) if not self.graph.is_feedable(subfeed_t): raise ValueError('Tensor %s may not be fed.' % subfeed_t) @@ -1130,10 +1139,7 @@ class BaseSession(SessionInterface): results = [] return fetch_handler.build_results(self, results) - def make_callable(self, - fetches, - feed_list=None, - accept_options=False): + def make_callable(self, fetches, feed_list=None, accept_options=False): """Returns a Python callable that runs a particular step. The returned callable will take `len(feed_list)` arguments whose types @@ -1176,9 +1182,12 @@ class BaseSession(SessionInterface): # `Session._run()` so that we can convert the feeds to a list of # strings here. def _generic_run(*feed_args, **kwargs): - feed_dict = {feed: feed_val - for feed, feed_val in zip(feed_list, feed_args)} + feed_dict = { + feed: feed_val + for feed, feed_val in zip(feed_list, feed_args) + } return self.run(fetches, feed_dict=feed_dict, **kwargs) + return _generic_run # Ensure any changes to the graph are reflected in the runtime. @@ -1198,12 +1207,11 @@ class BaseSession(SessionInterface): fetch_list = _name_list(fetch_handler.fetches()) target_list = _name_list(fetch_handler.targets()) - def _callable_template_with_options_and_metadata( - fetch_list, - target_list, - fetch_handler, - options=None, - run_metadata=None): + def _callable_template_with_options_and_metadata(fetch_list, + target_list, + fetch_handler, + options=None, + run_metadata=None): """Template callable that accepts RunOptions and RunMetadata.""" options_ptr = tf_session.TF_NewBufferFromString( compat.as_bytes(options.SerializeToString())) if options else None @@ -1215,9 +1223,9 @@ class BaseSession(SessionInterface): self._session, options_ptr, {}, fetch_list, target_list, run_metadata_ptr, status) else: - results = tf_session.TF_Run( - self._session, options_ptr, {}, fetch_list, target_list, status, - run_metadata_ptr) + results = tf_session.TF_Run(self._session, options_ptr, {}, + fetch_list, target_list, status, + run_metadata_ptr) if fetch_handler: results = fetch_handler.build_results(self, results) else: @@ -1233,37 +1241,40 @@ class BaseSession(SessionInterface): return results if accept_options: - return functools.partial( - _callable_template_with_options_and_metadata, fetch_list, - target_list, fetch_handler) + return functools.partial(_callable_template_with_options_and_metadata, + fetch_list, target_list, fetch_handler) elif isinstance(fetches, ops.Operation): # Special case for fetching a single operation, because the # function will have no return value. assert not fetch_list assert len(target_list) == 1 + def _single_operation_run(): with errors.raise_exception_on_not_ok_status() as status: if self._created_with_new_api: - tf_session.TF_SessionRun_wrapper( - self._session, None, {}, [], target_list, None, status) + tf_session.TF_SessionRun_wrapper(self._session, None, {}, [], + target_list, None, status) else: - tf_session.TF_Run( - self._session, None, {}, [], target_list, status, None) + tf_session.TF_Run(self._session, None, {}, [], target_list, status, + None) + return _single_operation_run elif isinstance(fetches, ops.Tensor): # Special case for fetching a single tensor, because the # function can return the result of `TF_Run()` directly. assert len(fetch_list) == 1 assert not target_list + def _single_tensor_run(): with errors.raise_exception_on_not_ok_status() as status: if self._created_with_new_api: results = tf_session.TF_SessionRun_wrapper( self._session, None, {}, fetch_list, [], None, status) else: - results = tf_session.TF_Run( - self._session, None, {}, fetch_list, [], status, None) + results = tf_session.TF_Run(self._session, None, {}, fetch_list, [], + status, None) return results[0] + return _single_tensor_run else: # In all other cases, we must use `fetch_handler` to build the @@ -1274,16 +1285,17 @@ class BaseSession(SessionInterface): results = tf_session.TF_SessionRun_wrapper( self._session, None, {}, fetch_list, target_list, None, status) else: - results = tf_session.TF_Run( - self._session, None, {}, fetch_list, target_list, status, None) + results = tf_session.TF_Run(self._session, None, {}, fetch_list, + target_list, status, None) return fetch_handler.build_results(self, results) + return _fetch_handler_run # Captures the name of a node in an error status. _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =') - def _do_run(self, handle, target_list, fetch_list, feed_dict, - options, run_metadata): + def _do_run(self, handle, target_list, fetch_list, feed_dict, options, + run_metadata): """Runs a step based on the given fetches and feeds. Args: @@ -1320,13 +1332,12 @@ class BaseSession(SessionInterface): self._extend_graph() with errors.raise_exception_on_not_ok_status() as status: if self._created_with_new_api: - return tf_session.TF_SessionRun_wrapper( - session, options, feed_dict, fetch_list, target_list, - run_metadata, status) + return tf_session.TF_SessionRun_wrapper(session, options, feed_dict, + fetch_list, target_list, + run_metadata, status) else: - return tf_session.TF_Run(session, options, - feed_dict, fetch_list, target_list, - status, run_metadata) + return tf_session.TF_Run(session, options, feed_dict, fetch_list, + target_list, status, run_metadata) def _prun_fn(session, handle, feed_dict, fetch_list): if target_list: @@ -1365,20 +1376,20 @@ class BaseSession(SessionInterface): def _extend_graph(self): # Nothing to do if we're using the new session interface # TODO(skyewm): remove this function altogether eventually - if self._created_with_new_api: return + if self._created_with_new_api: + return # Ensure any changes to the graph are reflected in the runtime. with self._extend_lock: if self._graph.version > self._current_version: # pylint: disable=protected-access graph_def, self._current_version = self._graph._as_graph_def( - from_version=self._current_version, - add_shapes=self._add_shapes) + from_version=self._current_version, add_shapes=self._add_shapes) # pylint: enable=protected-access with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_ExtendGraph( - self._session, graph_def.SerializeToString(), status) + tf_session.TF_ExtendGraph(self._session, + graph_def.SerializeToString(), status) self._opened = True # The threshold to run garbage collection to delete dead tensors. @@ -1398,9 +1409,8 @@ class BaseSession(SessionInterface): feeds = {} fetches = [] for deleter_key, tensor_handle in enumerate(tensors_to_delete): - holder, deleter = session_ops._get_handle_deleter(self.graph, - deleter_key, - tensor_handle) + holder, deleter = session_ops._get_handle_deleter( + self.graph, deleter_key, tensor_handle) feeds[holder] = tensor_handle fetches.append(deleter) self.run(fetches, feed_dict=feeds) @@ -1432,6 +1442,7 @@ class BaseSession(SessionInterface): return handles +@tf_export('Session') class Session(BaseSession): """A class for running TensorFlow operations. @@ -1471,7 +1482,8 @@ class Session(BaseSession): sess.run(...) ``` - The [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) + The + [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) protocol buffer exposes various configuration options for a session. For example, to create a session that uses soft constraints for device placement, and log the resulting placement decisions, @@ -1502,7 +1514,8 @@ class Session(BaseSession): @{$distributed$Distributed TensorFlow} for more examples. graph: (Optional.) The `Graph` to be launched (described above). - config: (Optional.) A [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) + config: (Optional.) A + [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) protocol buffer with configuration options for the session. """ @@ -1526,8 +1539,22 @@ class Session(BaseSession): def __exit__(self, exec_type, exec_value, exec_tb): if exec_type is errors.OpError: logging.error('Session closing due to OpError: %s', (exec_value,)) - self._default_session_context_manager.__exit__( - exec_type, exec_value, exec_tb) + try: + self._default_session_context_manager.__exit__(exec_type, exec_value, + exec_tb) + except RuntimeError as error: + if error == exec_value: + # NOTE(skyewm): for some reason, in Python3, + # _default_session_context_manager.__exit__ will re-raise the "not + # re-entrant" exception raised in __enter__ above (note that if we're + # here, we're in the outer session context manager, since __exit__ is + # not called when __enter__ raises an exception). We still want to + # continue cleaning up this context manager before the exception is + # further propagated, so we ignore it here (note that it'll continue + # being propagated after this method completes). + pass + else: + raise self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb) self._default_session_context_manager = None @@ -1570,6 +1597,7 @@ class Session(BaseSession): tf_session.TF_Reset(target, containers, config) +@tf_export('InteractiveSession') class InteractiveSession(BaseSession): """A TensorFlow `Session` for use in interactive contexts, such as a shell. diff --git a/tensorflow/python/client/session_benchmark.py b/tensorflow/python/client/session_benchmark.py index 721bca91b71aa00479c27fad102d5888d58d35b1..da74855193dbfe3019f23c542d86c5e493e9ac7a 100644 --- a/tensorflow/python/client/session_benchmark.py +++ b/tensorflow/python/client/session_benchmark.py @@ -22,6 +22,7 @@ import time import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index c579fba33951c4624e02de1e20a9aa5bad11cd73..490572254b0be6a110ef06cea15d20d780f732cf 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Tests for tensorflow.python.client.session.Session.""" from __future__ import absolute_import from __future__ import division @@ -32,7 +31,6 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op @@ -47,8 +45,8 @@ from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops # Import resource_variable_ops for the variables-to-tensor implicit conversion. from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import from tensorflow.python.ops import state_ops @@ -57,7 +55,6 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import server_lib from tensorflow.python.util import compat - # NOTE(mrry): Dummy shape registration for ops used in the tests, since they # don't have C++ op registrations on which to attach C++ shape fns. ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) @@ -95,14 +92,18 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertAllEqual(arr, copy_val) # Test without feed. copy_val = copy.eval() - self.assertAllEqual(np.asarray([[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], - dtype=np.float32), copy_val) + self.assertAllEqual( + np.asarray( + [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32), + copy_val) def testManyCPUs(self): # TODO(keveman): Implement ListDevices and test for the number of # devices returned by ListDevices. with session.Session( - config=config_pb2.ConfigProto(device_count={'CPU': 2})): + config=config_pb2.ConfigProto(device_count={ + 'CPU': 2 + })): inp = constant_op.constant(10.0, name='W1') self.assertAllEqual(inp.eval(), 10.0) @@ -161,20 +162,23 @@ class SessionTest(test_util.TensorFlowTestCase): def exc_predicate(e): return (e.op is None and e.node_def is None and e.error_code == error_codes_pb2.INVALID_ARGUMENT) + with self.assertRaisesOpError(exc_predicate): # Run with a bogus handle. s.partial_run('foo', r1, feed_dict={a: 1, b: 2}) def testOpConstructionErrorPayload(self): - if ops._USE_C_API: return # No shape registration for 'ConstructionFails' + if ops._USE_C_API: + return # No shape registration for 'ConstructionFails' with session.Session(): failing_op = ops.get_default_graph().create_op( 'ConstructionFails', [], [], name='f') def exc_predicate(e): - return (e.op == failing_op - and e.error_code == error_codes_pb2.INVALID_ARGUMENT) + return (e.op == failing_op and + e.error_code == error_codes_pb2.INVALID_ARGUMENT) + with self.assertRaisesOpError(exc_predicate): failing_op.run() @@ -191,9 +195,9 @@ class SessionTest(test_util.TensorFlowTestCase): # pylint: enable=protected-access def exc_predicate(e): - return (e.op == c.op - and e.op._original_op == b.op - and e.op._original_op._original_op == a.op) + return (e.op == c.op and e.op._original_op == b.op and + e.op._original_op._original_op == a.op) + with self.assertRaisesOpError(exc_predicate): c.eval() @@ -341,8 +345,12 @@ class SessionTest(test_util.TensorFlowTestCase): b = control_flow_ops.no_op() # An op, not a tensor. c = constant_op.constant(c_val) # List of lists, tuples, namedtuple, and dict - res = sess.run([[a, b, c], (a, b, c), ABC(a=a, b=b, c=c), - {'a': a.name, 'c': c, 'b': b}]) + res = sess.run([[a, b, c], (a, b, c), + ABC(a=a, b=b, c=c), { + 'a': a.name, + 'c': c, + 'b': b + }]) self.assertTrue(isinstance(res, list)) self.assertEqual(4, len(res)) self.assertTrue(isinstance(res[0], list)) @@ -365,8 +373,11 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(b_val, res[3]['b']) self.assertEqual(c_val, res[3]['c']) # Tuple of lists, tuples, namedtuple, and dict - res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), - {'a': a, 'c': c, 'b': b})) + res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), { + 'a': a, + 'c': c, + 'b': b + })) self.assertTrue(isinstance(res, tuple)) self.assertEqual(4, len(res)) self.assertTrue(isinstance(res[0], list)) @@ -389,10 +400,16 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(b_val, res[3]['b']) self.assertEqual(c_val, res[3]['c']) # Namedtuple of lists, tuples, namedtuples, and dict - res = sess.run(DEFG(d=[a, b, c], - e=(a, b, c), - f=ABC(a=a.name, b=b, c=c), - g={'a': a, 'c': c, 'b': b})) + res = sess.run( + DEFG( + d=[a, b, c], + e=(a, b, c), + f=ABC(a=a.name, b=b, c=c), + g={ + 'a': a, + 'c': c, + 'b': b + })) self.assertTrue(isinstance(res, DEFG)) self.assertTrue(isinstance(res.d, list)) self.assertEqual(3, len(res.d)) @@ -414,10 +431,16 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(b_val, res.g['b']) self.assertEqual(c_val, res.g['c']) # Dict of lists, tuples, namedtuples, and dict - res = sess.run({'d': [a, b, c], - 'e': (a, b, c), - 'f': ABC(a=a, b=b, c=c), - 'g': {'a': a.name, 'c': c, 'b': b}}) + res = sess.run({ + 'd': [a, b, c], + 'e': (a, b, c), + 'f': ABC(a=a, b=b, c=c), + 'g': { + 'a': a.name, + 'c': c, + 'b': b + } + }) self.assertTrue(isinstance(res, dict)) self.assertEqual(4, len(res)) self.assertTrue(isinstance(res['d'], list)) @@ -516,8 +539,7 @@ class SessionTest(test_util.TensorFlowTestCase): values = np.array([1.0, 2.0]).astype(np.float32) shape = np.array([7, 9, 2]).astype(np.int64) sp = sparse_tensor.SparseTensor( - constant_op.constant(indices), - constant_op.constant(values), + constant_op.constant(indices), constant_op.constant(values), constant_op.constant(shape)) # Single fetch, use as tuple sp_out = s.run(sp) @@ -587,14 +609,17 @@ class SessionTest(test_util.TensorFlowTestCase): sp = sparse_tensor.SparseTensor( array_ops.placeholder(dtype=np.int64, shape=(2, 3)), array_ops.placeholder(dtype=np.float32, shape=(2,)), - array_ops.placeholder(dtype=np.int64, shape=(3,)),) + array_ops.placeholder(dtype=np.int64, shape=(3,)), + ) sp_indices = array_ops.identity(sp.indices) sp_values = array_ops.identity(sp.values) sp_shape = array_ops.identity(sp.dense_shape) sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) # Feed with tuple indices_out, values_out, shape_out = s.run( - [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)}) + [sp_indices, sp_values, sp_shape], { + sp: (indices, values, shape) + }) self.assertAllEqual(indices_out, indices) self.assertAllEqual(values_out, values) self.assertAllEqual(shape_out, shape) @@ -605,20 +630,23 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertAllEqual(sp_out.dense_shape, shape) # Feed with SparseTensorValue indices_out, values_out, shape_out = s.run( - [sp_indices, sp_values, sp_shape], - {sp: sparse_tensor.SparseTensorValue(indices, values, shape)}) + [sp_indices, sp_values, sp_shape], { + sp: sparse_tensor.SparseTensorValue(indices, values, shape) + }) self.assertAllEqual(indices_out, indices) self.assertAllEqual(values_out, values) self.assertAllEqual(shape_out, shape) # Feed with SparseTensorValue, fetch SparseTensorValue - sp2_out = s.run( - sp2, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)}) + sp2_out = s.run(sp2, { + sp: sparse_tensor.SparseTensorValue(indices, values, shape) + }) self.assertAllEqual(sp2_out.indices, indices) self.assertAllEqual(sp2_out.values, values) self.assertAllEqual(sp2_out.dense_shape, shape) # Feed SparseTensorValue and fetch sp directly. - sp_out = s.run( - sp, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)}) + sp_out = s.run(sp, { + sp: sparse_tensor.SparseTensorValue(indices, values, shape) + }) self.assertAllEqual(sp_out.indices, indices) self.assertAllEqual(sp_out.values, values) self.assertAllEqual(sp_out.dense_shape, shape) @@ -635,20 +663,24 @@ class SessionTest(test_util.TensorFlowTestCase): sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) # Feed with tuple indices_out, values_out, shape_out = s.run( - [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)}) + [sp_indices, sp_values, sp_shape], { + sp: (indices, values, shape) + }) self.assertAllEqual(indices_out, indices) self.assertAllEqual(values_out, values) self.assertAllEqual(shape_out, shape) # Feed with SparseTensorValue indices_out, values_out, shape_out = s.run( - [sp_indices, sp_values, sp_shape], - {sp: sparse_tensor.SparseTensorValue(indices, values, shape)}) + [sp_indices, sp_values, sp_shape], { + sp: sparse_tensor.SparseTensorValue(indices, values, shape) + }) self.assertAllEqual(indices_out, indices) self.assertAllEqual(values_out, values) self.assertAllEqual(shape_out, shape) # Feed with SparseTensorValue, fetch SparseTensorValue - sp2_out = s.run( - sp2, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)}) + sp2_out = s.run(sp2, { + sp: sparse_tensor.SparseTensorValue(indices, values, shape) + }) self.assertAllEqual(sp2_out.indices, indices) self.assertAllEqual(sp2_out.values, values) self.assertAllEqual(sp2_out.dense_shape, shape) @@ -666,20 +698,24 @@ class SessionTest(test_util.TensorFlowTestCase): sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) # Feed with tuple indices_out, values_out, shape_out = s.run( - [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)}) + [sp_indices, sp_values, sp_shape], { + sp: (indices, values, shape) + }) self.assertAllEqual(indices_out, indices) self.assertAllEqual(values_out, values) self.assertAllEqual(shape_out, shape) # Feed with SparseTensorValue indices_out, values_out, shape_out = s.run( - [sp_indices, sp_values, sp_shape], - {sp: sparse_tensor.SparseTensorValue(indices, values, shape)}) + [sp_indices, sp_values, sp_shape], { + sp: sparse_tensor.SparseTensorValue(indices, values, shape) + }) self.assertAllEqual(indices_out, indices) self.assertAllEqual(values_out, values) self.assertAllEqual(shape_out, shape) # Feed with SparseTensorValue, fetch SparseTensorValue - sp2_out = s.run( - sp2, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)}) + sp2_out = s.run(sp2, { + sp: sparse_tensor.SparseTensorValue(indices, values, shape) + }) self.assertAllEqual(sp2_out.indices, indices) self.assertAllEqual(sp2_out.values, values) self.assertAllEqual(sp2_out.dense_shape, shape) @@ -689,9 +725,8 @@ class SessionTest(test_util.TensorFlowTestCase): indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) values = np.array([1.0, 2.0]).astype(np.float32) shape = np.array([7, 9, 2]).astype(np.int64) - sp = array_ops.sparse_placeholder(dtype=np.float32, - shape=shape, - name='placeholder1') + sp = array_ops.sparse_placeholder( + dtype=np.float32, shape=shape, name='placeholder1') self.assertAllEqual(sp.dense_shape.eval(session=s), shape) self.assertAllEqual(tensor_util.constant_value(sp.dense_shape), shape) sp_indices = array_ops.identity(sp.indices) @@ -699,7 +734,9 @@ class SessionTest(test_util.TensorFlowTestCase): sp_shape = array_ops.identity(sp.dense_shape) # Feed with tuple indices_out, values_out, shape_out = s.run( - [sp_indices, sp_values, sp_shape], {sp: (indices, values)}) + [sp_indices, sp_values, sp_shape], { + sp: (indices, values) + }) self.assertAllEqual(indices_out, indices) self.assertAllEqual(values_out, values) self.assertAllEqual(shape_out, shape) @@ -745,33 +782,34 @@ class SessionTest(test_util.TensorFlowTestCase): indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) dense_shape = np.array([7, 9, 2]).astype(np.int64) ind = ops.IndexedSlices( - array_ops.placeholder(dtype=np.float32, - shape=(2,)), - array_ops.placeholder(dtype=np.int64, - shape=(2, 3)), - array_ops.placeholder(dtype=np.int64, - shape=(3,)),) + array_ops.placeholder(dtype=np.float32, shape=(2,)), + array_ops.placeholder(dtype=np.int64, shape=(2, 3)), + array_ops.placeholder(dtype=np.int64, shape=(3,)), + ) ind_values = array_ops.identity(ind.values) ind_indices = array_ops.identity(ind.indices) ind_dense_shape = array_ops.identity(ind.dense_shape) ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape) # Feed with tuple values_out, indices_out, dense_shape_out = s.run( - [ind_values, ind_indices, ind_dense_shape], - {ind: (values, indices, dense_shape)}) + [ind_values, ind_indices, ind_dense_shape], { + ind: (values, indices, dense_shape) + }) self.assertAllEqual(values_out, values) self.assertAllEqual(indices_out, indices) self.assertAllEqual(dense_shape_out, dense_shape) # Feed with IndexedSlicesValue values_out, indices_out, dense_shape_out = s.run( - [ind_values, ind_indices, ind_dense_shape], - {ind: ops.IndexedSlicesValue(values, indices, dense_shape)}) + [ind_values, ind_indices, ind_dense_shape], { + ind: ops.IndexedSlicesValue(values, indices, dense_shape) + }) self.assertAllEqual(values_out, values) self.assertAllEqual(indices_out, indices) self.assertAllEqual(dense_shape_out, dense_shape) # Feed with IndexedSlicesValue, fetch IndexedSlicesValue - ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices, - dense_shape)}) + ind2_out = s.run(ind2, { + ind: ops.IndexedSlicesValue(values, indices, dense_shape) + }) self.assertAllEqual(ind2_out.values, values) self.assertAllEqual(ind2_out.indices, indices) self.assertAllEqual(ind2_out.dense_shape, dense_shape) @@ -816,28 +854,27 @@ class SessionTest(test_util.TensorFlowTestCase): indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) dense_shape = None ind = ops.IndexedSlices( - array_ops.placeholder(dtype=np.float32, - shape=(2,)), - array_ops.placeholder(dtype=np.int64, - shape=(2, 3)), - None) + array_ops.placeholder(dtype=np.float32, shape=(2,)), + array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None) ind_values = array_ops.identity(ind.values) ind_indices = array_ops.identity(ind.indices) ind2 = ops.IndexedSlices(ind_values, ind_indices) # Feed with tuple - values_out, indices_out = s.run( - [ind_values, ind_indices], {ind: (values, indices)}) + values_out, indices_out = s.run([ind_values, ind_indices], { + ind: (values, indices) + }) self.assertAllEqual(values_out, values) self.assertAllEqual(indices_out, indices) # Feed with IndexedSlicesValue - values_out, indices_out = s.run( - [ind_values, ind_indices], - {ind: ops.IndexedSlicesValue(values, indices, dense_shape)}) + values_out, indices_out = s.run([ind_values, ind_indices], { + ind: ops.IndexedSlicesValue(values, indices, dense_shape) + }) self.assertAllEqual(values_out, values) self.assertAllEqual(indices_out, indices) # Feed with IndexedSlicesValue, fetch IndexedSlicesValue - ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices, - dense_shape)}) + ind2_out = s.run(ind2, { + ind: ops.IndexedSlicesValue(values, indices, dense_shape) + }) self.assertAllEqual(ind2_out.values, values) self.assertAllEqual(ind2_out.indices, indices) self.assertAllEqual(ind2_out.dense_shape, dense_shape) @@ -986,8 +1023,9 @@ class SessionTest(test_util.TensorFlowTestCase): constructed_events = [threading.Event() for _ in range(10)] continue_event = threading.Event() for i, constructed_event in enumerate(constructed_events): - t = self.checkedThread(target=self._testDefaultGraphInThread, - args=(constructed_event, continue_event, i)) + t = self.checkedThread( + target=self._testDefaultGraphInThread, + args=(constructed_event, continue_event, i)) threads.append(t) for t in threads: t.start() @@ -1006,6 +1044,7 @@ class SessionTest(test_util.TensorFlowTestCase): ev.wait() val = c.eval(session=sess) self.assertEqual(val, 5.0) + threads = [self.checkedThread(target=run_step) for _ in range(100)] for t in threads: t.start() @@ -1038,11 +1077,10 @@ class SessionTest(test_util.TensorFlowTestCase): def testGraphDef(self): with session.Session() as sess: - self.assertProtoEquals( - 'versions { producer: %d min_consumer: %d }' % ( - versions.GRAPH_DEF_VERSION, - versions.GRAPH_DEF_VERSION_MIN_CONSUMER), - sess.graph_def) + self.assertProtoEquals('versions { producer: %d min_consumer: %d }' % + (versions.GRAPH_DEF_VERSION, + versions.GRAPH_DEF_VERSION_MIN_CONSUMER), + sess.graph_def) c = constant_op.constant(5.0, name='c') self.assertEquals(len(sess.graph_def.node), 1) d = constant_op.constant(6.0, name='d') @@ -1072,6 +1110,7 @@ class SessionTest(test_util.TensorFlowTestCase): lambda e: 'Attempted to use a closed Session.' in str(e)): while True: sess.run(c) + t = threading.Thread(target=update_thread) t.start() time.sleep(0.1) @@ -1177,17 +1216,11 @@ class SessionTest(test_util.TensorFlowTestCase): def testFeedAndFetch(self): with session.Session() as sess: - for dtype in [dtypes.float16, - dtypes.float32, - dtypes.float64, - dtypes.int32, - dtypes.uint8, - dtypes.int16, - dtypes.int8, - dtypes.int64, - dtypes.bool, - dtypes.complex64, - dtypes.complex128]: + for dtype in [ + dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool, + dtypes.complex64, dtypes.complex128 + ]: for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: np_dtype = dtype.as_numpy_dtype @@ -1206,13 +1239,19 @@ class SessionTest(test_util.TensorFlowTestCase): np_array = np_array.astype(np_dtype) self.assertAllEqual(np_array, - sess.run(out_t, feed_dict={feed_t: np_array})) + sess.run(out_t, feed_dict={ + feed_t: np_array + })) # Check that we can also get the feed back. self.assertAllEqual(np_array, - sess.run(feed_t, feed_dict={feed_t: np_array})) + sess.run(feed_t, feed_dict={ + feed_t: np_array + })) # Also check that we can get both back. - out_v, feed_v = sess.run([out_t, feed_t], - feed_dict={feed_t: np_array}) + out_v, feed_v = sess.run( + [out_t, feed_t], feed_dict={ + feed_t: np_array + }) self.assertAllEqual(np_array, out_v) self.assertAllEqual(np_array, feed_v) @@ -1257,9 +1296,11 @@ class SessionTest(test_util.TensorFlowTestCase): trace_level=config_pb2.RunOptions.FULL_TRACE) run_metadata = config_pb2.RunMetadata() self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) - self.assertAllClose( - 42.0, - tensor_runner(41.0, options=run_options, run_metadata=run_metadata)) + self.assertAllClose(42.0, + tensor_runner( + 41.0, + options=run_options, + run_metadata=run_metadata)) self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) def testFeedError(self): @@ -1296,8 +1337,9 @@ class SessionTest(test_util.TensorFlowTestCase): size = 1 for s in shape: size *= s - c_list = np.array([compat.as_bytes(str(i)) for i in xrange(size)], - dtype=np.object).reshape(shape) if size > 0 else [] + c_list = np.array( + [compat.as_bytes(str(i)) for i in xrange(size)], + dtype=np.object).reshape(shape) if size > 0 else [] c = constant_op.constant(c_list) self.assertAllEqual(c.eval(), c_list) @@ -1307,13 +1349,16 @@ class SessionTest(test_util.TensorFlowTestCase): size = 1 for s in shape: size *= s - c_list = np.array([compat.as_bytes(str(i)) for i in xrange(size)], - dtype=np.object).reshape(shape) + c_list = np.array( + [compat.as_bytes(str(i)) for i in xrange(size)], + dtype=np.object).reshape(shape) feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape) c = array_ops.identity(feed_t) self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list) - self.assertAllEqual(sess.run(feed_t, feed_dict={feed_t: c_list}), - c_list) + self.assertAllEqual( + sess.run(feed_t, feed_dict={ + feed_t: c_list + }), c_list) c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list}) self.assertAllEqual(c_v, c_list) self.assertAllEqual(feed_v, c_list) @@ -1329,8 +1374,10 @@ class SessionTest(test_util.TensorFlowTestCase): def testStringFeedWithUnicode(self): with session.Session(): - c_list = [u'\n\x01\x00', u'\n\x00\x01', - u'\u26a3 unicode', u'\U0001f60e deal with it'] + c_list = [ + u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode', + u'\U0001f60e deal with it' + ] feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)]) c = array_ops.identity(feed_t) @@ -1423,9 +1470,10 @@ class SessionTest(test_util.TensorFlowTestCase): sess.run(constant_op.constant(1.0), run_metadata=run_metadata) self.assertTrue(not run_metadata.HasField('step_stats')) - sess.run(constant_op.constant(1.0), - options=run_options, - run_metadata=run_metadata) + sess.run( + constant_op.constant(1.0), + options=run_options, + run_metadata=run_metadata) self.assertTrue(run_metadata.HasField('step_stats')) self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) @@ -1439,23 +1487,26 @@ class SessionTest(test_util.TensorFlowTestCase): with session.Session() as sess: # all combinations are valid sess.run(constant_op.constant(1.0), options=None, run_metadata=None) - sess.run(constant_op.constant(1.0), options=None, - run_metadata=run_metadata) + sess.run( + constant_op.constant(1.0), options=None, run_metadata=run_metadata) self.assertTrue(not run_metadata.HasField('step_stats')) - sess.run(constant_op.constant(1.0), options=run_options, - run_metadata=None) + sess.run( + constant_op.constant(1.0), options=run_options, run_metadata=None) self.assertTrue(not run_metadata.HasField('step_stats')) - sess.run(constant_op.constant(1.0), options=run_options, - run_metadata=run_metadata) + sess.run( + constant_op.constant(1.0), + options=run_options, + run_metadata=run_metadata) self.assertTrue(run_metadata.HasField('step_stats')) self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) def testFeedShapeCompatibility(self): # TODO(nolivia): C API doesn't yet handle marking nodes as not feedable. - if ops._USE_C_API: return + if ops._USE_C_API: + return with session.Session() as sess: some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) @@ -1499,8 +1550,11 @@ class SessionTest(test_util.TensorFlowTestCase): d = math_ops.multiply(c, c) for step in xrange(120): run_metadata = config_pb2.RunMetadata() - sess.run(d, feed_dict={a: 1.0}, - options=run_options, run_metadata=run_metadata) + sess.run( + d, + feed_dict={a: 1.0}, + options=run_options, + run_metadata=run_metadata) if step == 99: self.assertTrue(run_metadata.HasField('cost_graph')) else: @@ -1569,8 +1623,7 @@ class SessionTest(test_util.TensorFlowTestCase): def testTimeoutWithShortOperations(self): num_epochs = 5 - q = data_flow_ops.FIFOQueue( - capacity=50, dtypes=[dtypes.int32], shapes=[()]) + q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()]) enqueue_op = q.enqueue_many(constant_op.constant([1, 2])) # Use a 10-second timeout, which should be longer than any @@ -1582,7 +1635,9 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(sess.run(q.size()), num_epochs * 2) def testRegisterFetchAndFeedConversionFunctions(self): + class SquaredTensor(object): + def __init__(self, tensor): self.sq = math_ops.square(tensor) @@ -1591,24 +1646,27 @@ class SessionTest(test_util.TensorFlowTestCase): feed_fn2 = lambda feed: [feed.sq] session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, - feed_fn1, feed_fn2) + feed_fn1, feed_fn2) with self.assertRaises(ValueError): - session.register_session_run_conversion_functions(SquaredTensor, - fetch_fn, feed_fn1, feed_fn2) + session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, + feed_fn1, feed_fn2) with self.test_session() as sess: np1 = np.array([1.0, 1.5, 2.0, 2.5]) np2 = np.array([3.0, 3.5, 4.0, 4.5]) squared_tensor = SquaredTensor(np2) squared_eval = sess.run(squared_tensor) self.assertAllClose(np2 * np2, squared_eval) - squared_eval = sess.run(squared_tensor, feed_dict={ - squared_tensor : np1 * np1}) + squared_eval = sess.run( + squared_tensor, feed_dict={ + squared_tensor: np1 * np1 + }) self.assertAllClose(np1 * np1, squared_eval) partial_run = sess.partial_run_setup([squared_tensor], []) squared_eval = sess.partial_run(partial_run, squared_tensor) self.assertAllClose(np2 * np2, squared_eval) def testDefaultLogDevicePlacement(self): + class CaptureStderr(str): """Class to capture stderr from C++ shared library.""" @@ -1686,8 +1744,10 @@ class SessionTest(test_util.TensorFlowTestCase): def runTestBuildGraphError(self, sess): # Ensure that errors from building the graph get propagated. data = array_ops.placeholder(dtypes.float32, shape=[]) - enter_1 = control_flow_ops.enter(data, 'foo_1', False) - enter_2 = control_flow_ops.enter(data, 'foo_2', False) + # pylint: disable=protected-access + enter_1 = gen_control_flow_ops._enter(data, 'foo_1', False) + enter_2 = gen_control_flow_ops._enter(data, 'foo_2', False) + # pylint: enable=protected-access res = math_ops.add(enter_1, enter_2) with self.assertRaisesOpError('has inputs from different frames'): sess.run(res, feed_dict={data: 1.0}) @@ -1719,6 +1779,7 @@ class SessionTest(test_util.TensorFlowTestCase): def runTestAddFunctionToSession(self, target=''): """Add a function to a session after the graph has already been run.""" + @function.Defun(dtypes.float32) def foo(x): return x + 1 @@ -1753,6 +1814,7 @@ class SessionTest(test_util.TensorFlowTestCase): TypeError, 'Type of feed value 1 with type <(\w+) \'int\'> is not'): sess.run(a, feed_dict={a: 1}) + class GraphMutationTest(test_util.TensorFlowTestCase): def setUp(self): @@ -1803,8 +1865,7 @@ class GraphMutationTest(test_util.TensorFlowTestCase): with session.Session(graph=g) as sess: self.assertAllEqual(1.0, sess.run(b)) - b.op._set_attr('DstT', - attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT)) + b.op._set_attr('DstT', attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT)) with self.assertRaisesRegexp( errors.FailedPreconditionError, 'Cast.*was changed by setting attribute after it was run'): diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 5fb389cf92818c7a464cf4a4479d86377185d5cf..8b8adefa65a5c54d40bc28d8f50953513cfd3605 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -59,7 +59,7 @@ tf_py_test( tf_py_test( name = "dataset_from_generator_op_test", - size = "small", + size = "medium", srcs = ["dataset_from_generator_op_test.py"], additional_deps = [ "//third_party/py/numpy", @@ -357,6 +357,9 @@ tf_py_test( "//tensorflow/python:session", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:string_ops", + "//tensorflow/python:lookup_ops", ], grpc_enabled = True, tags = [ diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py index b71652c980f233ce116ea89544fcb38ad1d816d1..02720a2e985914d3a6774dc6f64d1316890c46bf 100644 --- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py @@ -28,6 +28,7 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -202,44 +203,45 @@ class FilesystemCacheDatasetTest(test.TestCase): class MemoryCacheDatasetTest(test.TestCase): def testCacheDatasetPassthrough(self): - repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64)) - dataset = dataset_ops.Dataset.range(3).flat_map( - lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count)) + with ops.device("cpu:0"): + repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64)) + dataset = dataset_ops.Dataset.range(3).flat_map( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count)) - cached_dataset = dataset.cache().repeat(2) - uncached_dataset = dataset.repeat(2) + cached_dataset = dataset.cache().repeat(2) + uncached_dataset = dataset.repeat(2) - # Needs to be initializable to capture the variable. - cached_iterator = cached_dataset.make_initializable_iterator() - cached_next = cached_iterator.get_next() - uncached_iterator = uncached_dataset.make_initializable_iterator() - uncached_next = uncached_iterator.get_next() + # Needs to be initializable to capture the variable. + cached_iterator = cached_dataset.make_initializable_iterator() + cached_next = cached_iterator.get_next() + uncached_iterator = uncached_dataset.make_initializable_iterator() + uncached_next = uncached_iterator.get_next() - with self.test_session() as sess: + with self.test_session() as sess: - sess.run(repeat_count.initializer) - sess.run(cached_iterator.initializer) - sess.run(uncached_iterator.initializer) + sess.run(repeat_count.initializer) + sess.run(cached_iterator.initializer) + sess.run(uncached_iterator.initializer) - for i in range(3): - for _ in range(10): - self.assertEqual(sess.run(cached_next), i) - self.assertEqual(sess.run(uncached_next), i) + for i in range(3): + for _ in range(10): + self.assertEqual(sess.run(cached_next), i) + self.assertEqual(sess.run(uncached_next), i) - sess.run(repeat_count.assign(0)) + sess.run(repeat_count.assign(0)) - # The uncached iterator should now be empty. - with self.assertRaises(errors.OutOfRangeError): - sess.run(uncached_next) + # The uncached iterator should now be empty. + with self.assertRaises(errors.OutOfRangeError): + sess.run(uncached_next) - # The cached iterator replays from cache. - for i in range(3): - for _ in range(10): - self.assertEqual(sess.run(cached_next), i) + # The cached iterator replays from cache. + for i in range(3): + for _ in range(10): + self.assertEqual(sess.run(cached_next), i) - # The cached iterator should now be empty. - with self.assertRaises(errors.OutOfRangeError): - sess.run(cached_next) + # The cached iterator should now be empty. + with self.assertRaises(errors.OutOfRangeError): + sess.run(cached_next) def testEmptyCacheReading(self): components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py index 28cb50c00208f95e64bb11ae80656383b1f41e1e..7dbf7268d74a2a18af551de64ced03daab264799 100644 --- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py @@ -201,6 +201,20 @@ class InterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testEmptyInput(self): + iterator = ( + dataset_ops.Dataset.from_tensor_slices([]) + .repeat(None) + .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py index 45dfa13720b09c7bba979b72a339c13dcd2d827b..25c91b42dc65f849a680e65fc7fc2548c1cea8ea 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py @@ -17,10 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function @@ -28,6 +31,9 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import test @@ -103,6 +109,67 @@ class IteratorClusterTest(test.TestCase): "/job:worker/replica:0/task:1/cpu:0", workers[0].target) + def testCaptureHashTableInSharedIterator(self): + worker, _ = test_util.create_local_cluster(1, 1) + + # NOTE(mrry): We must use the V2 variants of `HashTable` + # etc. because these produce a `tf.resource`-typed output that is + # compatible with the in-graph function implementation. + default_val = -1 + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup_ops.HashTable( + lookup_ops.KeyValueTensorInitializer(keys, values), + default_val, + shared_name="shared_table") + + input_sentences = dataset_ops.Dataset.from_tensor_slices( + ["brain brain tank salad surgery", "surgery brain"]) + + iterator = ( + input_sentences.map(lambda x: string_ops.string_split([x]).values).map( + table.lookup) + .make_initializable_iterator(shared_name="shared_iterator")) + init_op = iterator.initializer + get_next = iterator.get_next() + + with session.Session(worker[0].target) as sess: + sess.run(table.init) + sess.run(init_op) + self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next)) + + with session.Session(worker[0].target) as sess: + self.assertAllEqual([2, 0], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testImplicitDisposeParallelMapDataset(self): + # Tests whether a parallel map dataset will be cleaned up correctly when + # the pipeline does not run it until exhaustion. + # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> + # RepeatDataset(None) -> PrefetchDataset(100). + worker, _ = test_util.create_local_cluster(1, 1) + + components = (np.arange(1000), + np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], + np.array(37.0) * np.arange(1000)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(None).prefetch(10000)) + + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with session.Session(worker[0].target) as sess: + sess.run(init_op) + for _ in range(3): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py index ae08032e191487c38d73876374b24e8f6eefbc80..1d27b036eb804aa301b916b7ed0b7884f75e1a0f 100644 --- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py @@ -201,9 +201,7 @@ class SequenceDatasetTest(test.TestCase): with self.test_session() as sess: sess.run(init_op) - with self.assertRaisesRegexp( - errors.OutOfRangeError, - "Attempted to repeat an empty dataset infinitely."): + with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index c1ba67e4744c6282f0fd3d9a388aabc1ed51267b..b665443b7acb9eb266b6fcf36a002cfce54875f1 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -41,8 +41,10 @@ from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export +@tf_export("data.Dataset") class Dataset(object): """Represents a potentially large set of elements. @@ -556,6 +558,8 @@ class Dataset(object): - /path/to/dir/b.py - /path/to/dir/c.py + NOTE: The order of the file names returned can be non-deterministic. + Args: file_pattern: A string or scalar string `tf.Tensor`, representing the filename pattern that will be matched. @@ -769,7 +773,7 @@ class Dataset(object): return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values) def map(self, map_func, num_parallel_calls=None): - """Maps `map_func` across this datset. + """Maps `map_func` across this dataset. Args: map_func: A function mapping a nested structure of tensors (having @@ -899,10 +903,11 @@ class Dataset(object): Args: transformation_func: A function that takes one `Dataset` argument and - returns a `Dataset`. + returns a `Dataset`. Returns: - Dataset: The `Dataset` returned by applying `transformation_func` to this dataset. + Dataset: The `Dataset` returned by applying `transformation_func` to this + dataset. """ dataset = transformation_func(self) if not isinstance(dataset, Dataset): @@ -1454,6 +1459,19 @@ def _padding_value_to_tensor(value, output_type): return value +def _default_padding(input_dataset): + + def make_zero(t): + if t.base_dtype == dtypes.string: + return "" + elif t.base_dtype == dtypes.variant: + raise TypeError("Unable to create padding for field of type 'variant'") + else: + return np.zeros_like(t.as_numpy_dtype()) + + return nest.map_structure(make_zero, input_dataset.output_types) + + class PaddedBatchDataset(Dataset): """A `Dataset` that batches and pads contiguous elements from its input.""" @@ -1469,23 +1487,13 @@ class PaddedBatchDataset(Dataset): batch_size, dtype=dtypes.int64, name="batch_size") padding_values = ( padding_values - if padding_values is not None else self._default_padding(input_dataset)) + if padding_values is not None else _default_padding(input_dataset)) self._padded_shapes = nest.map_structure_up_to( input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes) self._padding_values = nest.map_structure_up_to( input_dataset.output_shapes, _padding_value_to_tensor, padding_values, input_dataset.output_types) - def _default_padding(self, input_dataset): - - def make_zero(t): - if t.base_dtype == dtypes.string: - return "" - else: - return np.zeros_like(t.as_numpy_dtype()) - - return nest.map_structure(make_zero, input_dataset.output_types) - def _as_variant_tensor(self): return gen_dataset_ops.padded_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 0cbdb3ab19d8f1b966a867dfcf709c1a4a49b871..4756ec74820bace5bea4e1f41ebe214420fe5c3d 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util.tf_export import tf_export # NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple @@ -43,10 +44,12 @@ GET_NEXT_CALL_WARNING_MESSAGE = ( "This often indicates that `Iterator.get_next()` is being called inside " "a training loop, which will cause gradual slowdown and eventual resource " "exhaustion. If this is the case, restructure your code to call " - "`next_element = iterator.get_next() once outside the loop, and use " - "`next_element` inside the loop.") + "`next_element = iterator.get_next()` once outside the loop, and use " + "`next_element` as the input to some computation that is invoked inside " + "the loop.") +@tf_export("data.Iterator") class Iterator(object): """Represents the state of iterating through a `Dataset`.""" @@ -165,8 +168,10 @@ class Iterator(object): iterator_resource = gen_dataset_ops.iterator( container="", shared_name=shared_name, - output_types=nest.flatten(output_types), - output_shapes=nest.flatten(output_shapes)) + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes) @@ -232,8 +237,10 @@ class Iterator(object): string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) iterator_resource = gen_dataset_ops.iterator_from_string_handle( string_handle, - output_types=nest.flatten(output_types), - output_shapes=nest.flatten(output_shapes)) + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes) @@ -297,7 +304,42 @@ class Iterator(object): dataset._as_variant_tensor(), self._iterator_resource, name=name) # pylint: disable=protected-access def get_next(self, name=None): - """Returns a nested structure of `tf.Tensor`s containing the next element. + """Returns a nested structure of `tf.Tensor`s representing the next element. + + In graph mode, you should typically call this method *once* and use its + result as the input to another computation. A typical loop will then call + @{tf.Session.run} on the result of that computation. The loop will terminate + when the `Iterator.get_next()` operation raises + @{tf.errors.OutOfRangeError}. The following skeleton shows how to use + this method when building a training loop: + + ```python + dataset = ... # A `tf.data.Dataset` object. + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + # Build a TensorFlow graph that does something with each element. + loss = model_function(next_element) + optimizer = ... # A `tf.train.Optimizer` object. + train_op = optimizer.minimize(loss) + + with tf.Session() as sess: + try: + while True: + sess.run(train_op) + except tf.errors.OutOfRangeError: + pass + ``` + + NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g. + when you are distributing different elements to multiple devices in a single + step. However, a common pitfall arises when users call `Iterator.get_next()` + in each iteration of their training loop. `Iterator.get_next()` adds ops to + the graph, and executing each op allocates resources (including threads); as + a consequence, invoking it in every iteration of a training loop causes + slowdown and eventual resource exhaustion. To guard against this outcome, we + log a warning when the number of uses crosses a fixed threshold of + suspiciousness. Args: name: (Optional.) A name for the created operation. diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py index 830dc5cec4a54469d001f0ba57d1adc7bc5efd11..fa7601741b11f018e9b53ed3b77a7561be50d3f4 100644 --- a/tensorflow/python/data/ops/readers.py +++ b/tensorflow/python/data/ops/readers.py @@ -23,12 +23,14 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util.tf_export import tf_export # TODO(b/64974358): Increase default buffer size to 256 MB. _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB +@tf_export("data.TextLineDataset") class TextLineDataset(Dataset): """A `Dataset` comprising lines from one or more text files.""" @@ -71,6 +73,7 @@ class TextLineDataset(Dataset): return dtypes.string +@tf_export("data.TFRecordDataset") class TFRecordDataset(Dataset): """A `Dataset` comprising records from one or more TFRecord files.""" @@ -115,6 +118,7 @@ class TFRecordDataset(Dataset): return dtypes.string +@tf_export("data.FixedLengthRecordDataset") class FixedLengthRecordDataset(Dataset): """A `Dataset` of fixed-length records from one or more binary files.""" diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index 2455395635c4c8fa5d157a38d4e7a118f554fd9f..e90ce3fb40af68fb68d6ee8bac6892848d8c5a79 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -266,7 +266,7 @@ def map_structure(func, *structure, **check_types_dict): and the return value will contain the results in the same structure. Args: - func: A callable that acceps as many arguments are there are structures. + func: A callable that accepts as many arguments are there are structures. *structure: scalar, or tuple or list of constructed scalars and/or other tuples/lists, or scalars. Note: numpy arrays are considered scalars. **check_types_dict: only valid keyword argument is `check_types`. If set to @@ -383,8 +383,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): "structure has keys %s, while shallow structure has keys %s." % (list(_six.iterkeys(input_tree)), list(_six.iterkeys(shallow_tree)))) - input_tree = list(_six.iteritems(input_tree)) - shallow_tree = list(_six.iteritems(shallow_tree)) + input_tree = list(sorted(_six.iteritems(input_tree))) + shallow_tree = list(sorted(_six.iteritems(shallow_tree))) for shallow_branch, input_branch in zip(shallow_tree, input_tree): assert_shallow_structure(shallow_branch, input_branch, @@ -479,8 +479,8 @@ def map_structure_up_to(shallow_tree, func, *inputs): The `inputs`, can be thought of as having the same structure as `shallow_tree`, but with leaf nodes that are themselves tree structures. - This function therefore will return something with the same base structure as - `shallow_tree`. + This function, therefore, will return something with the same base structure + as `shallow_tree`. Examples: diff --git a/tensorflow/python/data/util/nest_test.py b/tensorflow/python/data/util/nest_test.py index 90dd7dfe7775b2f10611e5579784fbda63fc9669..ff380815a4a32192de621888199e66355f9b4635 100644 --- a/tensorflow/python/data/util/nest_test.py +++ b/tensorflow/python/data/util/nest_test.py @@ -277,6 +277,10 @@ class NestTest(test.TestCase): with self.assertRaisesRegexp(ValueError, expected_message): nest.assert_shallow_structure(inp_ab2, inp_ab1) + inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) + inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) + nest.assert_shallow_structure(inp_ab, inp_ba) + def testFlattenUpTo(self): input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5))) shallow_tree = ((True, True), (False, True)) diff --git a/tensorflow/python/data/util/sparse.py b/tensorflow/python/data/util/sparse.py index 5ebcb4ea81b23b60dc46bae78bfa792f4a8ab6d8..5e6d22470978d97c5e73640e86d3f8b82cbc1b60 100644 --- a/tensorflow/python/data/util/sparse.py +++ b/tensorflow/python/data/util/sparse.py @@ -141,7 +141,7 @@ def serialize_sparse_tensors(tensors): tensors: a tensor structure to serialize. Returns: - `tensors` with any sparse tensors replaced by the their serialized version. + `tensors` with any sparse tensors replaced by their serialized version. """ ret = nest.pack_sequence_as(tensors, [ diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py index a0fe6066acd1462a94e93d6091db237d01cfede3..dea019fef58015fbd7982a81319dcabe4e5f4930 100644 --- a/tensorflow/python/debug/cli/cli_shared.py +++ b/tensorflow/python/debug/cli/cli_shared.py @@ -175,7 +175,7 @@ def format_tensor(tensor, include_numeric_summary: Whether a text summary of the numeric values (if applicable) will be included. write_path: A path to save the tensor value (after any slicing) to - (optinal). `numpy.save()` is used to save the value. + (optional). `numpy.save()` is used to save the value. Returns: An instance of `debugger_cli_common.RichTextLines` representing the diff --git a/tensorflow/python/debug/cli/tensor_format.py b/tensorflow/python/debug/cli/tensor_format.py index d4aea76d652e7606939f3d8a89ff0378da0774d2..9ba84e3f2261de277361d503e9189583494a5084 100644 --- a/tensorflow/python/debug/cli/tensor_format.py +++ b/tensorflow/python/debug/cli/tensor_format.py @@ -134,7 +134,7 @@ def format_tensor(tensor, if include_metadata: lines.append(" dtype: %s" % str(tensor.dtype)) - lines.append(" shape: %s" % str(tensor.shape)) + lines.append(" shape: %s" % str(tensor.shape).replace("L", "")) if lines: lines.append("") @@ -535,7 +535,7 @@ def numeric_summary(tensor): if not isinstance(tensor, np.ndarray) or not np.size(tensor): return debugger_cli_common.RichTextLines([ "No numeric summary available due to empty tensor."]) - elif (np.issubdtype(tensor.dtype, np.float) or + elif (np.issubdtype(tensor.dtype, np.floating) or np.issubdtype(tensor.dtype, np.complex) or np.issubdtype(tensor.dtype, np.integer)): counts = [ diff --git a/tensorflow/python/debug/examples/debug_fibonacci.py b/tensorflow/python/debug/examples/debug_fibonacci.py index 704dbda357d1208d0663da41eb7aef4b299dedb8..3821b393ec6847db71b7c4b7396b1ed448ae9538 100644 --- a/tensorflow/python/debug/examples/debug_fibonacci.py +++ b/tensorflow/python/debug/examples/debug_fibonacci.py @@ -44,6 +44,10 @@ def main(_): sess.run(tf.global_variables_initializer()) # Wrap the TensorFlow Session object for debugging. + if FLAGS.debug and FLAGS.tensorboard_debug_address: + raise ValueError( + "The --debug and --tensorboard_debug_address flags are mutually " + "exclusive.") if FLAGS.debug: sess = tf_debug.LocalCLIDebugWrapperSession(sess) @@ -52,6 +56,9 @@ def main(_): sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) sess.add_tensor_filter("has_negative", has_negative) + elif FLAGS.tensorboard_debug_address: + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, FLAGS.tensorboard_debug_address) print("Fibonacci number at position %d:\n%s" % (FLAGS.length, sess.run(n1))) @@ -82,7 +89,15 @@ if __name__ == "__main__": "--debug", dest="debug", action="store_true", - help="Use TensorFlow Debugger (tfdbg).") + help="Use TensorFlow Debugger (tfdbg). Mutually exclusive with the " + "--tensorboard_debug_address flag.") + parser.add_argument( + "--tensorboard_debug_address", + type=str, + default=None, + help="Connect to the TensorBoard Debugger Plugin backend specified by " + "the gRPC address (e.g., localhost:1234). Mutually exclusive with the " + "--debug flag.") FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/examples/debug_mnist.py b/tensorflow/python/debug/examples/debug_mnist.py index 0a6dbf311d8e7a0377363d74b57ef2b1d7d00e1d..ab1c90371cd18bbaf278b72248bcc7e9e9c34b06 100644 --- a/tensorflow/python/debug/examples/debug_mnist.py +++ b/tensorflow/python/debug/examples/debug_mnist.py @@ -120,8 +120,15 @@ def main(_): sess.run(tf.global_variables_initializer()) + if FLAGS.debug and FLAGS.tensorboard_debug_address: + raise ValueError( + "The --debug and --tensorboard_debug_address flags are mutually " + "exclusive.") if FLAGS.debug: sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type) + elif FLAGS.tensorboard_debug_address: + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, FLAGS.tensorboard_debug_address) # Add this point, sess is a debug wrapper around the actual Session if # FLAGS.debug is true. In that case, calling run() will launch the CLI. @@ -173,6 +180,14 @@ if __name__ == "__main__": nargs="?", const=True, default=False, - help="Use debugger to track down bad values during training") + help="Use debugger to track down bad values during training. " + "Mutually exclusive with the --tensorboard_debug_address flag.") + parser.add_argument( + "--tensorboard_debug_address", + type=str, + default=None, + help="Connect to the TensorBoard Debugger Plugin backend specified by " + "the gRPC address (e.g., localhost:1234). Mutually exclusive with the " + "--debug flag.") FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py index 92314d8dd9f64f48ffe0bc921f99a4661c4c0e93..4f4666ee4fa51ef085d31ee8396dffaf9e38f49e 100644 --- a/tensorflow/python/debug/examples/debug_tflearn_iris.py +++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py @@ -110,10 +110,16 @@ def main(_): model_dir=model_dir) hooks = None + if FLAGS.debug and FLAGS.tensorboard_debug_address: + raise ValueError( + "The --debug and --tensorboard_debug_address flags are mutually " + "exclusive.") if FLAGS.debug: debug_hook = tf_debug.LocalCLIDebugHook(ui_type=FLAGS.ui_type, dump_root=FLAGS.dump_root) - hooks = [debug_hook] + elif FLAGS.tensorboard_debug_address: + debug_hook = tf_debug.TensorBoardDebugHook(FLAGS.tensorboard_debug_address) + hooks = [debug_hook] if not FLAGS.use_experiment: # Fit model. @@ -185,11 +191,19 @@ if __name__ == "__main__": nargs="?", const=True, default=False, - help="Use debugger to track down bad values during training") + help="Use debugger to track down bad values during training. " + "Mutually exclusive with the --tensorboard_debug_address flag.") parser.add_argument( "--dump_root", type=str, default="", help="Optional custom root directory for temporary debug dump data") + parser.add_argument( + "--tensorboard_debug_address", + type=str, + default=None, + help="Connect to the TensorBoard Debugger Plugin backend specified by " + "the gRPC address (e.g., localhost:1234). Mutually exclusive with the " + "--debug flag.") FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index c4b13a1045dac4966b0e841155a2932216881d34..8d355aa27f6fa10a1889420a9087800be12a81ce 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -222,7 +222,7 @@ def has_inf_or_nan(datum, tensor): # Also return False for data types that cannot be represented as numpy # arrays. return False - elif (np.issubdtype(tensor.dtype, np.float) or + elif (np.issubdtype(tensor.dtype, np.floating) or np.issubdtype(tensor.dtype, np.complex) or np.issubdtype(tensor.dtype, np.integer)): return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor)) diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py index b6c7280a415b367751c4900a302e5af61f260cb0..01867fc69d0782b34edb1e8eb873b19f5dfc8529 100644 --- a/tensorflow/python/debug/lib/debug_gradients_test.py +++ b/tensorflow/python/debug/lib/debug_gradients_test.py @@ -22,6 +22,7 @@ import shutil import tempfile from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_gradients @@ -38,7 +39,12 @@ from tensorflow.python.training import gradient_descent class IdentifyGradientTest(test_util.TensorFlowTestCase): def setUp(self): - self.sess = session.Session() + rewriter_config = rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) + config = config_pb2.ConfigProto(graph_options=graph_options) + self.sess = session.Session(config=config) with self.sess.as_default(): self.u = variables.Variable(2.0, name="u") self.v = variables.Variable(3.0, name="v") @@ -112,8 +118,8 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): def testCallingIdentifyGradientTwiceWithTheSameGradientsDebuggerErrors(self): grad_debugger = debug_gradients.GradientsDebugger() grad_debugger.identify_gradient(self.w) - with self.assertRaisesRegexp( - ValueError, "The graph already contains an op named .*"): + with self.assertRaisesRegexp(ValueError, + "The graph already contains an op named .*"): grad_debugger.identify_gradient(self.w) def testIdentifyGradientWorksOnMultipleLosses(self): @@ -139,10 +145,10 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): self.assertIsNot(dz1_dy, dz2_dy) self.sess.run(variables.global_variables_initializer()) - self.assertAllClose(5.0 ** 2, self.sess.run(z1)) - self.assertAllClose(5.0 ** 0.5, self.sess.run(z2)) + self.assertAllClose(5.0**2, self.sess.run(z1)) + self.assertAllClose(5.0**0.5, self.sess.run(z2)) self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy)) - self.assertAllClose(0.5 * (5.0 ** -0.5), self.sess.run(dz2_dy)) + self.assertAllClose(0.5 * (5.0**-0.5), self.sess.run(dz2_dy)) def testIdentifyGradientRaisesLookupErrorForUnknownXTensor(self): grad_debugger_1 = debug_gradients.GradientsDebugger() @@ -254,8 +260,8 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): self.sess.run(variables.global_variables_initializer()) self.assertAllClose(3.0, self.sess.run(u_grad)) self.assertAllClose(2.0, self.sess.run(v_grad)) - self.assertAllClose( - 3.0, self.sess.run(grad_debugger.gradient_tensor("u:0"))) + self.assertAllClose(3.0, self.sess.run( + grad_debugger.gradient_tensor("u:0"))) def testWatchGradientsWorksOnMultipleTensors(self): y = math_ops.add(self.w, -1.0, name="y") @@ -272,10 +278,10 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): self.assertIsInstance(grad_debugger.gradient_tensor("w:0"), ops.Tensor) self.sess.run(variables.global_variables_initializer()) - self.assertAllClose( - 1.0, self.sess.run(grad_debugger.gradient_tensor("w:0"))) - self.assertAllClose( - 3.0, self.sess.run(grad_debugger.gradient_tensor("u:0"))) + self.assertAllClose(1.0, self.sess.run( + grad_debugger.gradient_tensor("w:0"))) + self.assertAllClose(3.0, self.sess.run( + grad_debugger.gradient_tensor("u:0"))) def testWatchGradientsByXTensorsWorks(self): y = math_ops.add(self.w, -1.0, name="foo/y") @@ -285,8 +291,8 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): # But we can still get the gradient tensors by using # watch_gradients_by_x_tensors(). grad_debugger = debug_gradients.GradientsDebugger() - with grad_debugger.watch_gradients_by_tensors( - self.sess.graph, [self.w, self.u, y]): + with grad_debugger.watch_gradients_by_tensors(self.sess.graph, + [self.w, self.u, y]): gradient_descent.GradientDescentOptimizer(0.1).minimize(z) self.assertEqual(3, len(grad_debugger.gradient_tensors())) @@ -319,18 +325,18 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): self.assertIsNot(dz1_dy, dz2_dy) self.sess.run(variables.global_variables_initializer()) - self.assertAllClose(5.0 ** 2, self.sess.run(z1)) - self.assertAllClose(5.0 ** 0.5, self.sess.run(z2)) + self.assertAllClose(5.0**2, self.sess.run(z1)) + self.assertAllClose(5.0**0.5, self.sess.run(z2)) self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy)) - self.assertAllClose(0.5 * (5.0 ** -0.5), self.sess.run(dz2_dy)) + self.assertAllClose(0.5 * (5.0**-0.5), self.sess.run(dz2_dy)) def testGradientsValuesFromDumpWorks(self): y = math_ops.add(self.w, -1.0, name="y") z = math_ops.square(y, name="z") grad_debugger = debug_gradients.GradientsDebugger() - with grad_debugger.watch_gradients_by_tensors( - self.sess.graph, [self.w, self.u, y]): + with grad_debugger.watch_gradients_by_tensors(self.sess.graph, + [self.w, self.u, y]): train_op = gradient_descent.GradientDescentOptimizer(0.1).minimize(z) self.sess.run(variables.global_variables_initializer()) @@ -338,10 +344,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): run_options = config_pb2.RunOptions(output_partition_graphs=True) dump_dir = tempfile.mkdtemp() debug_url = "file://" + dump_dir - debug_utils.watch_graph( - run_options, - self.sess.graph, - debug_urls=debug_url) + debug_utils.watch_graph(run_options, self.sess.graph, debug_urls=debug_url) run_metadata = config_pb2.RunMetadata() self.assertAllClose(2.0, self.sess.run(self.u)) self.sess.run(train_op, options=run_options, run_metadata=run_metadata) diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py index 367b3535450ac4bd17d4c5dba0eaf149aa4b68b3..b623ee31c5dc59894373ec7952e53acd0f6e1126 100644 --- a/tensorflow/python/debug/lib/session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py @@ -54,7 +54,8 @@ from tensorflow.python.training import monitored_session def no_rewrite_session_config(): rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, - arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py index 20a40018bf9c67c5b743963489c8fc5616efa2db..f4fac1401918ccacd38aae5ad2ef8d686c9204b9 100644 --- a/tensorflow/python/debug/lib/session_debug_testlib.py +++ b/tensorflow/python/debug/lib/session_debug_testlib.py @@ -988,7 +988,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase): def testWatchingVariableUpdateOpsSeesUpdatedValues(self): """Watch output slots on Variable-updating ops, with no emitted edges.""" - with session.Session() as sess: + with session.Session(config=no_rewrite_session_config()) as sess: u_init = constant_op.constant(10.0) u = variables.Variable(u_init, name="gdo/u") v_init = constant_op.constant(20.0) diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py index acea9433e22203d56f4ceb6cd92b681e35876a09..254201c39371e2034b08fad927e98418c8086ea5 100644 --- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py @@ -389,6 +389,11 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase): r"mode\."): sess.invoke_node_stepper(node_stepper) + def testDumpingWrapperWithEmptyFetchWorks(self): + sess = dumping_wrapper.DumpingDebugWrapperSession( + self.sess, session_root=self.session_root, log_usage=False) + sess.run([]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index 909150eb6aa21b45af39f7cbfd6248c701ae1fb5..c530204bbf6959f56a72c6e67add91f1e575f067 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -121,7 +121,9 @@ from tensorflow.python.debug.lib import debug_utils from tensorflow.python.debug.lib import stepper from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.platform import tf_logging from tensorflow.python.training import monitored_session +from tensorflow.python.util import nest # Helper function. @@ -439,7 +441,12 @@ class BaseDebugWrapperSession(session.SessionInterface): "callable_runner and fetches/feed_dict are mutually exclusive, but " "are used simultaneously.") - if self._is_disabled_thread(): + empty_fetches = not nest.flatten(fetches) + if empty_fetches: + tf_logging.info( + "Due to empty fetches, tfdbg Session wrapper is letting a " + "Session.run pass through without any debugging actions.") + if self._is_disabled_thread() or empty_fetches: if callable_runner: return callable_runner(*callable_runner_args) else: diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py index 74d7c2b9e242f947a33c0bdb6508847808d69c0b..fb9494f57636e46e54ef230cf4803dbb6ccad0c7 100644 --- a/tensorflow/python/debug/wrappers/grpc_wrapper.py +++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import signal +import sys import traceback # Google-internal import(s). @@ -137,6 +139,29 @@ class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): if not address.startswith(common.GRPC_URL_PREFIX) else address) +def _signal_handler(unused_signal, unused_frame): + try: + input_func = raw_input + except NameError: + # Python 3 does not have raw_input. + input_func = input + + while True: + response = input_func("\nSIGINT received. Quit program? (Y/n): ").strip() + if response in ("", "Y", "y"): + sys.exit(0) + elif response in ("N", "n"): + break + + +def register_signal_handler(): + try: + signal.signal(signal.SIGINT, _signal_handler) + except ValueError: + # This can happen if we are not in the MainThread. + pass + + class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): """A tfdbg Session wrapper that can be used with TensorBoard Debugger Plugin. @@ -185,6 +210,8 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): # sent to the debug servers. self._sent_graph_version = -1 + register_signal_handler() + def run(self, fetches, feed_dict=None, diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py index 989ad801e53615f7bd26b8b4fb850b8a56cd193c..6705cd31e291d2eab7aa8179e9b2b829f8970c18 100644 --- a/tensorflow/python/debug/wrappers/hooks.py +++ b/tensorflow/python/debug/wrappers/hooks.py @@ -35,10 +35,7 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook): `tf.contrib.learn`'s `Estimator`s and `Experiment`s. """ - def __init__(self, - ui_type="curses", - dump_root=None, - thread_name_filter=None): + def __init__(self, ui_type="curses", dump_root=None, thread_name_filter=None): """Create a local debugger command-line interface (CLI) hook. Args: @@ -62,7 +59,8 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook): """Add a tensor filter. See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details. - Override default behavior to accommodate the possibility of this method being + Override default behavior to accommodate the possibility of this method + being called prior to the initialization of the underlying `LocalCLIDebugWrapperSession` object. @@ -137,9 +135,7 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook): # pylint: enable=protected-access with stepper.NodeStepper( - run_context.session, - run_context.original_args. - fetches, + run_context.session, run_context.original_args.fetches, run_context.original_args.feed_dict) as node_stepper: self._session_wrapper.invoke_node_stepper( node_stepper, restore_variable_values_on_exit=True) @@ -149,8 +145,8 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook): def after_run(self, run_context, run_values): # Adapt run_context and run_values to OnRunEndRequest and invoke superclass # on_run_end() - on_run_end_request = framework.OnRunEndRequest( - self._performed_action, run_values.run_metadata) + on_run_end_request = framework.OnRunEndRequest(self._performed_action, + run_values.run_metadata) self._session_wrapper.on_run_end(on_run_end_request) @@ -260,8 +256,8 @@ class GrpcDebugHook(session_run_hook.SessionRunHook): self._thread_name_filter = thread_name_filter self._grpc_debug_server_addresses = ( grpc_debug_server_addresses - if isinstance(grpc_debug_server_addresses, list) - else [grpc_debug_server_addresses]) + if isinstance(grpc_debug_server_addresses, list) else + [grpc_debug_server_addresses]) self._watch_fn = watch_fn self._log_usage = log_usage @@ -334,6 +330,7 @@ class TensorBoardDebugHook(GrpcDebugHook): log_usage: Whether the usage of this class is to be logged (if applicable). """ + def _gated_grpc_watch_fn(fetches, feeds): del fetches, feeds # Unused. return framework.WatchOptions( @@ -348,6 +345,7 @@ class TensorBoardDebugHook(GrpcDebugHook): self._grpc_debug_server_addresses = grpc_debug_server_addresses self._send_traceback_and_source_code = send_traceback_and_source_code self._sent_graph_version = -1 + grpc_wrapper.register_signal_handler() def before_run(self, run_context): if self._send_traceback_and_source_code: diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py index 770a496aa9d2f4bb8bee0f51526ba8c3d4278b81..490812c96d83791cdc20c56f16c968f1a1851af8 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py @@ -664,6 +664,20 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): [["run"], ["run"]], monitored_sess) self.assertFalse(wrapped_monitored_sess.should_stop()) + def testRunsWithEmptyFetchWorks(self): + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]], self.sess, dump_root="") + + run_output = wrapped_sess.run([]) + self.assertEqual([], run_output) + + def testRunsWithEmptyNestedFetchWorks(self): + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]], self.sess, dump_root="") + + run_output = wrapped_sess.run({"foo": {"baz": []}, "bar": ()}) + self.assertEqual({"foo": {"baz": []}, "bar": ()}, run_output) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 9e3382d4f301529cd2b476bc76efe7dfd2be9298..ab81d40148476735492890f608315b19eaa0a33f 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -206,29 +206,6 @@ cc_library( ], ) -cc_library( - name = "python_eager_op_gen_main", - srcs = [ - "python_eager_op_gen_main.cc", - ], - visibility = ["//visibility:public"], - deps = [ - ":python_eager_op_gen", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:op_gen_lib", - "//tensorflow/core:protos_all_cc", - ], -) - -tf_cc_binary( - name = "python_eager_op_gen_demo", - deps = [ - ":python_eager_op_gen_main", - "//tensorflow/core:ops", - ], -) - py_library( name = "custom_gradient", srcs = ["custom_gradient.py"], diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index a2a3e230bbb4232fe916c658a6b0ac8d6d33658d..5c235382652811ff83ec800c0a28a3beccd45f0f 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import functools import operator import threading @@ -42,6 +43,26 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect +class _TensorCache(object): + """Simple cache which evicts items based on length in a FIFO manner.""" + + def __init__(self, max_items=256): + self._data = collections.OrderedDict() + self._max_items = max_items if max_items else 256 + + def put(self, key, value): + self._data[key] = value + + if len(self._data) > self._max_items: + self._data.popitem(last=False) + + def get(self, key): + return self._data.get(key, None) + + def flush(self): + self._data = {} + + _op_attr_type_cache = {} @@ -157,6 +178,8 @@ _ops_which_dont_need_outputs = set([ "SegmentMax", "UnsortedSegmentSum", "UnsortedSegmentMax", + "UnsortedSegmentMin", + "UnsortedSegmentProd", "Abs", "Neg", "ReciprocalGrad", @@ -734,8 +757,7 @@ def _num_elements(grad): raise ValueError("`grad` not a Tensor or IndexedSlices.") -_last_shape_dtype = [None, None] -_last_zero = [None] +_zeros_cache = _TensorCache() def _fast_fill(value, shape, dtype): @@ -744,17 +766,22 @@ def _fast_fill(value, shape, dtype): def _zeros(shape, dtype): """Wraps array_ops.zeros to cache last zero for a given shape and dtype.""" + device = context.context().device_name if dtype == dtypes.variant: # TODO(apassos): need to save enough information about variant tensors to do # a zeros return None - if [shape, dtype] != _last_shape_dtype: - _last_shape_dtype[:] = [shape, dtype] - _last_zero[0] = _fast_fill(0, shape, dtype) - return _last_zero[0] + cache_key = shape, dtype, device + cached = _zeros_cache.get(cache_key) + if cached is None: + cached = _fast_fill(0, shape, dtype) + _zeros_cache.put(cache_key, cached) + return cached def _ones(shape, dtype): + if shape == (): # pylint: disable=g-explicit-bool-comparison + return constant_op.constant(1, dtype=dtype) return _fast_fill(1, shape, dtype) @@ -856,13 +883,18 @@ class GradientTape(object): t = t.handle tape.watch(t) - def gradient(self, target, sources): + def watched_variables(self): + return self._tape.watched_variables() + + def gradient(self, target, sources, output_gradients=None): """Computes the gradient using information traced by the tape. Args: target: the tensor to be differentiated. sources: a list of Tensors or Variables, the target will be differentiated with respect to the sources. + output_gradients: a list of gradients, one for each element of + target. Defaults to None. Returns: a list of Tensors (or IndexedSlices, or None), one for each element in @@ -880,7 +912,8 @@ class GradientTape(object): else x for x in sources] grad = imperative_grad.imperative_grad( - _default_vspace, self._tape, [target], sources) + _default_vspace, self._tape, [target], sources, + output_gradients=output_gradients) if not self._persistent: self._tape = None return grad diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 75526ba9c139e78dfe9e3de271f1316924539371..b56cbe80a7ab6b90d715187b0f0a44847038fc37 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -28,11 +28,14 @@ from __future__ import print_function import time import numpy as np +import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop # pylint: disable=unused-import from tensorflow.python.eager import context +from tensorflow.python.eager import core +from tensorflow.python.eager import execute from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import dtypes @@ -41,24 +44,31 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops CPU = "/device:CPU:0" GPU = "/device:GPU:0" -def record_gradient_callback(inputs, attrs, results): - return backprop._record_gradient("MatMul", inputs, attrs, results, None) - - -def c_tfe_py_fastpath_execute(a, b, transpose_a=False, transpose_b=False): +def c_tfe_py_fastpath_execute(a, + b, + transpose_a=False, + transpose_b=False, + name=None): ctx = context.context() assert not ctx.in_graph_mode( ), "The prototype doesn't contain C code for graph construction" - ctx_handle = ctx._handle # pylint: disable=protected-access - - return pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx_handle, None, "MatMul", record_gradient_callback, a, b, - "transpose_a", transpose_a, "transpose_b", transpose_b)[0] + try: + return pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, name, + ctx._post_execution_callbacks, a, b, "transpose_a", transpose_a, + "transpose_b", transpose_b) + except core._NotOkStatusException as e: + if name is not None: + message = e.message + " name: " + name + else: + message = e.message + six.raise_from(core._status_to_exception(e.code, message), None) class MicroBenchmarks(test.Benchmark): @@ -262,6 +272,14 @@ class MicroBenchmarks(test.Benchmark): func = lambda: f(m, m, transpose_b) self._run(func, num_iters) + def _benchmark_read_variable(self, m, num_iters): + self._run(m.value, num_iters) + + def _benchmark_read_variable_with_tape(self, m, num_iters): + with backprop.GradientTape() as tape: + tape.watch(m) + self._run(m.value, num_iters) + # Benchmarks for A^2, A of dimension 2 by 2. def benchmark_np_matmul_2_by_2(self): self._benchmark_np_matmul( @@ -398,6 +416,32 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_defun_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) + def benchmark_read_variable_op_2_by_2_CPU(self): + with context.device(CPU): + m = resource_variable_ops.ResourceVariable(self._m_2_by_2) + self._benchmark_read_variable(m, num_iters=self._num_iters_2_by_2) + + def benchmark_read_variable_op_2_by_2_GPU(self): + if not context.num_gpus(): + return + with context.device(GPU): + m = resource_variable_ops.ResourceVariable(self._m_2_by_2.gpu()) + self._benchmark_read_variable(m, num_iters=self._num_iters_2_by_2) + + def benchmark_read_variable_op_with_tape_2_by_2_CPU(self): + with context.device(CPU): + m = resource_variable_ops.ResourceVariable(self._m_2_by_2) + self._benchmark_read_variable_with_tape( + m, num_iters=self._num_iters_2_by_2) + + def benchmark_read_variable_op_with_tape_2_by_2_GPU(self): + if not context.num_gpus(): + return + with context.device(GPU): + m = resource_variable_ops.ResourceVariable(self._m_2_by_2.gpu()) + self._benchmark_read_variable_with_tape( + m, num_iters=self._num_iters_2_by_2) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index b6c7d823237231a138f6a25bb9d03954b69d58d9..07652d3e02b6364e23b6579a64dcadf02dc5eb99 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors from tensorflow.python.util import compat +from tensorflow.python.util import is_in_graph_mode from tensorflow.python.util import tf_contextlib GRAPH_MODE = 0 @@ -59,7 +60,8 @@ class _EagerContext(threading.local): def __init__(self): super(_EagerContext, self).__init__() - self.device_spec = pydev.DeviceSpec.from_string("") + self.device_spec = pydev.DeviceSpec.from_string( + "/job:localhost/replica:0/task:0/device:CPU:0") self.device_name = self.device_spec.to_string() self.mode = _default_mode self.scope_name = "" @@ -599,3 +601,10 @@ def export_run_metadata(): A RunMetadata protocol buffer. """ return context().export_run_metadata() + + +# Not every user creates a Context via context.context() +# (for example, enable_eager_execution in python/framework/ops.py), +# but they do all import this file. Note that IS_IN_GRAPH_MODE and +# in_graph_mode are both parameterless functions. +is_in_graph_mode.IS_IN_GRAPH_MODE = in_graph_mode diff --git a/tensorflow/python/eager/core.py b/tensorflow/python/eager/core.py index 483b7172107838a0069831f2347b0c644c05c000..8fb69300209d74a164c38654d737432cdfb7884a 100644 --- a/tensorflow/python/eager/core.py +++ b/tensorflow/python/eager/core.py @@ -47,3 +47,17 @@ class _NotOkStatusException(Exception): pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException) + + +class _FallbackException(Exception): + """Exception class to handle fallback from the fastpath. + + The fastpath that we refer to here is the one implemented to reduce per-op + overheads (TFE_Py_FastPathExecute_C). If the conditions for executing the op + on the fastpath are not met, we fallback to a safer (and more complete) + slowpath, and this Exception is raised to signal that transition. + """ + pass + + +pywrap_tensorflow.TFE_Py_RegisterFallbackExceptionClass(_FallbackException) diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index ee3c10633e1cb849e319f2f5490e5beb5dd15c80..c68e2f422eb81a915d4f941ffb920f221d9be250 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -65,7 +65,8 @@ class TFETest(test_util.TensorFlowTestCase): ctx.summary_writer_resource = 'mock' self.assertEqual('mock', ctx.summary_writer_resource) - self.assertEqual('', ctx.device_name) + self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0', + ctx.device_name) self.assertEqual(ctx.device_name, ctx.device_spec.to_string()) with ctx.device('GPU:0'): self.assertEqual('/job:localhost/replica:0/task:0/device:GPU:0', diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index 306cf07aabe1c214d02da5f077a57043cc1f4089..2ff5b8d8f489731c14d8abb81652a17026ed4935 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -72,7 +72,7 @@ def execute_with_callbacks(op_name, num_outputs, inputs, attrs, ctx, name=None): """Monkey-patch to execute to enable execution callbacks.""" tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) for callback in ctx.post_execution_callbacks: - callback(op_name, name, attrs, inputs, tensors) + callback(op_name, inputs, attrs, tensors, name) return tensors diff --git a/tensorflow/python/eager/execution_callbacks.py b/tensorflow/python/eager/execution_callbacks.py index 2f1654dda499583fe4766cbe2e330399defc96fd..535361498a9dd33003d0479051e97d7ff2553067 100644 --- a/tensorflow/python/eager/execution_callbacks.py +++ b/tensorflow/python/eager/execution_callbacks.py @@ -104,10 +104,10 @@ class InfOrNanError(Exception): def inf_nan_callback(op_type, - op_name, - attrs, inputs, + attrs, outputs, + op_name, check_inf=True, check_nan=True, action=_DEFAULT_CALLBACK_ACTION): @@ -121,14 +121,14 @@ def inf_nan_callback(op_type, Args: op_type: Name of the TFE operation type (e.g., `MatMul`). - op_name: Name of the TFE operation. This name is set by client and can be - `None` if it unset. - attrs: Attributes of the TFE operation, as a tuple of alternating attribute - names and attribute values. inputs: The `list` of input tensors to the operation, currently unused by this callback. + attrs: Attributes of the TFE operation, as a tuple of alternating attribute + names and attribute values. outputs: The `list` of output tensors from the operation, checked by this callback for `inf` and `nan` values. + op_name: Name of the TFE operation. This name is set by client and can be + `None` if it unset. check_inf: (`bool`) Whether this callback should check for `inf` values in the output tensor values. check_nan: (`bool`) Whether this callback should check for `nan` values in @@ -153,7 +153,7 @@ def inf_nan_callback(op_type, continue numpy_dtype = output.dtype.as_numpy_dtype - if (np.issubdtype(numpy_dtype, np.float) or + if (np.issubdtype(numpy_dtype, np.floating) or np.issubdtype(numpy_dtype, np.complex) or np.issubdtype(numpy_dtype, np.integer)): try: @@ -187,26 +187,38 @@ def inf_nan_callback(op_type, def inf_callback(op_type, - op_name, - attrs, inputs, + attrs, outputs, + op_name, action=_DEFAULT_CALLBACK_ACTION): """A specialization of `inf_nan_callback` that checks for `inf`s only.""" inf_nan_callback( - op_type, op_name, attrs, inputs, outputs, check_inf=True, check_nan=False, + op_type, + inputs, + attrs, + outputs, + op_name, + check_inf=True, + check_nan=False, action=action) def nan_callback(op_type, - op_name, - attrs, inputs, + attrs, outputs, + op_name, action=_DEFAULT_CALLBACK_ACTION): """A specialization of `inf_nan_callback` that checks for `nan`s only.""" inf_nan_callback( - op_type, op_name, attrs, inputs, outputs, check_inf=False, check_nan=True, + op_type, + inputs, + attrs, + outputs, + op_name, + check_inf=False, + check_nan=True, action=action) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 81b1f6f12a1899ddccb711a81122905bfd363748..28f5289ffc0ace6f9b6cad7cdd1160a184f882c7 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.util import compat from tensorflow.python.util import nest @@ -292,6 +293,22 @@ def _map_sequence_obj_to_idx(sequence): return {id(x): i for i, x in enumerate(sequence)} +def _flatten(sequence): + """A wrapper around `nest.flatten` that also unpacks `IndexedSlices`.""" + # TODO(akshayka): Support `SparseTensor` in a similar fashion. + flat_sequence = nest.flatten(sequence) + outputs = [] + for item in flat_sequence: + if isinstance(item, ops.IndexedSlices): + if item.dense_shape is not None: + outputs.extend([item.values, item.indices, item.dense_shape]) + else: + outputs.extend([item.values, item.indices]) + else: + outputs.append(item) + return outputs + + class GraphModeFunction(object): """Callable object representing a graph-mode function. @@ -333,14 +350,14 @@ class GraphModeFunction(object): self._input_placeholders = input_placeholders self._extra_inputs = list(extra_inputs) self._graph = graph - self._has_backprop = False + self._backward_function = None self._func_name = name self._function_def = defined_function self._num_outputs = len(defined_function.signature.output_arg) self._ops = operations self._func_outputs = func_outputs self._returns = [func_outputs] if isinstance( - func_outputs, (ops.Tensor, type(None))) else list(func_outputs) + func_outputs, (ops.Tensor, type(None))) else _flatten(func_outputs) self._output_shapes = output_shapes self._variables = variables if variables is not None else [] @@ -348,9 +365,8 @@ class GraphModeFunction(object): def variables(self): return self._variables - def _compute_backprop(self): - """Computes the backprop function object for this function.""" - self._has_backprop = True + def _construct_backprop_function(self): + """Constructs the backprop function object for this function.""" with self._graph.as_default(), context.graph_mode(): c = _CapturingContext() with c: @@ -361,13 +377,16 @@ class GraphModeFunction(object): filtered_outputs, self._input_placeholders, grad_ys=self._out_grad_placeholders) - shapes = tuple(x.shape for x in in_gradients if x is not None) + + backward_outputs = tuple( + grad for grad in _flatten(in_gradients) if grad is not None) + output_shapes = tuple(grad.shape for grad in backward_outputs) + captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) forward_name = _forward_name(self._func_name) self._forward_fdef = _EagerDefinedFunction( forward_name, self._graph, self._ops, self._input_placeholders, filtered_outputs + captures) - backward_outputs = tuple(x for x in in_gradients if x is not None) all_inputs = self._out_grad_placeholders + captures # Excluding input ops from the body as we do not intend to execute these # operations when the function is executed. @@ -381,7 +400,7 @@ class GraphModeFunction(object): bname = _backward_name(self._func_name) self._backward_function = GraphModeFunction( bname, all_inputs, [], self._graph, function_def_ops, - backward_outputs, in_gradients, shapes) + backward_outputs, in_gradients, output_shapes) def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" @@ -426,9 +445,24 @@ class GraphModeFunction(object): @property def output_shapes(self): + """The function's output shapes.""" # TODO(ebrevdo): Should we only keep the output shapes associated # with len(self._returns) outputs? - return nest.pack_sequence_as(self._func_outputs, self._output_shapes) + outputs_list = nest.flatten(self._func_outputs) + j = 0 + for i, o in enumerate(outputs_list): + if o is not None: + if isinstance(o, ops.IndexedSlices): + # Extract the shape of the `IndexedSlices` object's `values` field. + outputs_list[i] = self._output_shapes[j] # the `values` shape + if o.dense_shape is not None: + j += 3 # skip over shapes for `values`, `indices`, `dense_shape` + else: + j += 2 # skip over shapes for `values`, `indices` + else: + outputs_list[i] = self._output_shapes[j] + j += 1 + return nest.pack_sequence_as(self._func_outputs, outputs_list) @property def output_dtypes(self): @@ -457,12 +491,11 @@ class GraphModeFunction(object): if v._trainable: # pylint: disable=protected-access tape.watch_variable(v) - tensor_inputs = [x for x in nest.flatten(args) - if isinstance(x, ops.Tensor)] + tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] if tape.should_record(tensor_inputs) or tape.should_record( self._extra_inputs): - if not self._has_backprop: - self._compute_backprop() + if self._backward_function is None: + self._construct_backprop_function() return self._backprop_call(tensor_inputs) ctx = context.context() @@ -503,13 +536,30 @@ class GraphModeFunction(object): """ if self._func_outputs is None: return None + # Use `nest.flatten` instead of `_flatten` in order to preserve any + # IndexedSlices in `self._func_outputs`. outputs_list = nest.flatten(self._func_outputs) j = 0 for i, o in enumerate(outputs_list): if o is not None: - outputs_list[i] = result[j] - j += 1 - return nest.pack_sequence_as(self._func_outputs, outputs_list) + if isinstance(o, ops.IndexedSlices): + # Repack Tensors for IndexedSlices. + if o.dense_shape is not None: + outputs_list[i] = ops.IndexedSlices( + values=result[j], + indices=result[j + 1], + dense_shape=result[j + 2]) + j += 3 + else: + outputs_list[i] = ops.IndexedSlices( + values=result[j], + indices=result[j + 1]) + j += 2 + else: + outputs_list[i] = result[j] + j += 1 + ret = nest.pack_sequence_as(self._func_outputs, outputs_list) + return ret def _get_defun_inputs(args): @@ -526,15 +576,13 @@ def _get_defun_inputs(args): def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" - container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access + graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with context.graph_mode(): captures = {} tmp_graph = CapturingGraph(captures) - # Inherit the container prefix, since this is used for error checking when - # isolating eager execution (the container prefix at creation must match the - # container prefix when used, and variables accessed in the defun will be - # used in the outside context). - tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access + # Inherit the graph key, since this is used for matching variables in + # optimizers. + tmp_graph._graph_key = graph_key # pylint: disable=protected-access # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. @@ -545,17 +593,23 @@ def _defun_internal(name, func, args, kwds): with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) + def convert(x): + if x is None: + return None + return ops.convert_to_tensor_or_indexed_slices(x) + with capture_tensors(captures): this_tape = tape.push_new_tape() try: func_outputs = func(*func_inputs, **kwds) + func_outputs = nest.map_structure(convert, func_outputs) finally: tape.pop_tape(this_tape) variables = this_tape.watched_variables() # Returning a closed-over tensor as an output does not trigger a # call to convert_to_tensor, so we manually capture all such tensors. - outputs_list = nest.flatten(func_outputs) + outputs_list = _flatten(func_outputs) func_def_outputs = [ _convert_to_graph_tensor(x) for x in outputs_list if x is not None ] @@ -600,6 +654,18 @@ def _cache_key(x): """Cache key for tfe functions.""" if isinstance(x, ops.Tensor): return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access + if isinstance(x, ops.IndexedSlices): + if x.dense_shape is not None: + return tuple([ + _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access + _TensorDtype(x.indices.dtype, x.indices._shape_tuple()), # pylint: disable=protected-access + _TensorDtype(x.dense_shape.dtype, x.dense_shape._shape_tuple()) # pylint: disable=protected-access + ]) + else: + return tuple([ + _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access + _TensorDtype(x.indices.dtype, x.indices._shape_tuple()) # pylint: disable=protected-access + ]) if isinstance(x, np.ndarray): return ("array", x.shape, tuple(x.reshape(-1))) if isinstance(x, (list, tuple)): @@ -697,7 +763,11 @@ def defun(func): or more Tensor objects). """ # TODO(apassos): deal with captured global state. Deal with control flow. - return tf_decorator.make_decorator(func, named_defun(func, func.__name__)) + try: + name = func.__name__ + except AttributeError: + name = "function" + return tf_decorator.make_decorator(func, named_defun(func, name)) def make_defun_op(func, *args, **kwds): @@ -750,3 +820,208 @@ def make_defun_op(func, *args, **kwds): if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): raise ValueError("Tensor keyword arguments are not supported.") return _defun_internal(name, func, args, kwds) + + +class AutomaticControlDependencies(object): + """Context manager to automatically add control dependencies. + + Code under this context manager will act as if a sensible set of control + dependencies were present. More specifically: + 1. All stateful ops in the scope will execute + 2. Stateful ops which modify the same resource will execute in program order + + Note: creating variables in an automatic control dependencies context is not + supported (the value of the variables will never change as they will keep + getting reinitialized). + + NOT THREAD SAFE + """ + + def __init__(self): + self._returned_tensors = set() + + def mark_as_return(self, tensor): + self._returned_tensors.add(tensor) + + def __enter__(self): + if context.in_eager_mode(): + return self + # This code assumes no other thread is adding ops to the graph while + # we're adding ops to the graph. + # TODO(apassos): Fix this by locking the graph or using a temporary + # graph (but that would mess up devices and collections at least, + # probably other things as well). + self._graph = ops.get_default_graph() + self._n_operations = len(self._graph.get_operations()) + return self + + def _process_switch(self, switch_op, ops_which_must_run, + last_op_using_resource_tensor, merge_for_resource): + """Processes a switch node for a resource input. + + When tensorflow creates a cond, it creates a control flow context for each + branch of the cond. Each external tensor accessed by that branch is routed + through a switch op, which gets created in the graph _after_ the op which + uses that tensor get created. + + If the resource comes from another switch op we process that one first. + + _process_switch creates a corresponding merge node for the switch node. This + merge node is added to the outer control flow context of the switch + node. We also ensure that: + + 1. The switch node executes after the previous op which used the resource + tensor + + 2. Any op which uses a resource output of the switch node executes before + the merge for the switch node. + + 3. The next op which uses the input resource to the switch node (which + might be another switch node for the other branch of the conditional) + will execute after the merge node is done. + + 4. The merge node is marked as must_run so it will run even if no + subsequent operation uses the resource. + + Args: + switch_op: the switch op to be processed + ops_which_must_run: the set of ops which must run + last_op_using_resource_tensor: map from resource tensor to last op using + it + merge_for_resource: map from resource tensor to merge which must follow + all usages of it. + """ + inp = switch_op.inputs[0] + if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": + self._process_switch(inp.op, ops_which_must_run, + last_op_using_resource_tensor, merge_for_resource) + if switch_op.outputs[0] in merge_for_resource: + return + new_merge = control_flow_ops.merge(switch_op.outputs, + name="artificial_merge") + new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access + switch_op._control_flow_context.outer_context) # pylint: disable=protected-access + # Ensures the merge always runs + ops_which_must_run.add(new_merge[0].op) + if inp in last_op_using_resource_tensor: + # Ensures the switch exectutes after the previous op using the resource. + switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access + # Ensure the next op outside the cond happens after the merge. + last_op_using_resource_tensor[inp] = new_merge[0].op + if inp in merge_for_resource: + merge_for_resource[inp]._add_control_input(new_merge[0].op) # pylint: disable=protected-access + for o in switch_op.outputs: + # Ensures the merge will execute after all ops inside the cond + merge_for_resource[o] = new_merge[0].op + + def __exit__(self, unused_type, unused_value, unused_traceback): + if context.in_eager_mode(): + return + + if self._graph is not ops.get_default_graph(): + raise RuntimeError( + "Graph changed while trying to add control dependencies.") + + # map from resource tensor to the last op which used it + last_op_using_resource_tensor = {} + # set of conditional and loop exits + ops_which_must_run = set() + # merge which must depend on ops which use this resource + merge_for_resource = {} + + new_operations = self._graph.get_operations()[self._n_operations:] + + # Ensures that uses of resource tensors get serialized properly and all + # execute. This is done by keeping a map from resource tensor to the last op + # in graph-construction order which used it (last_op_using_resource_tensor). + # + # Conditionals are written in TensorFlow such that every external tensor + # accessed in the conditional goes through a switch op and every return + # tensor (it's guaranteed that there will be at least one) goes through a + # merge op. + # + # To handle conditionals, switches are handled in a special way (see + # comments for _process_switch). Merge nodes created by TF's conditional + # logic (as opposed to by _process_switch) are forced to run and also get a + # control dependency added to them to ensure all stateful ops inside their + # control flow context run. + # + # We also ensure that if an op is using a resource output by a switch node + # (that is, a resource tensor for which there's a value in + # merge_for_resource) this op will run before the merge for that resource. + # + # We try to add control inputs to nodes respecting their control flow + # contexts to avoid dead nodes propagating everywhere and leading to + # "retval[0] doesn't have value" errors. If a node gets a control dependency + # on a dead node (i.e. a note from an untaken control flow branch) that node + # will be marked as dead unless it's a merge node. + # + # TODO(apassos): serialize non-resource-taking stateful ops as well, and + # test that it works. Support while loops. Support init_scope escaping from + # this. + for op in new_operations: + control_inputs = set() + # Ensure stateful ops run + if self._graph._registered_ops[op.type].is_stateful: # pylint: disable=protected-access + ops_which_must_run.add(op) + # Ignore switches (they're handled separately) + if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: + continue + # Make merges trigger all other computation which must run + if op.type == "Merge": + for o in ops_which_must_run: + op._add_control_input(o) # pylint: disable=protected-access + for inp in o.inputs: + if inp in last_op_using_resource_tensor: + last_op_using_resource_tensor[inp] = op + ops_which_must_run = set([op]) + continue + for inp in op.inputs: + if inp.dtype == dtypes_module.resource: + # Deal with switches, finally. + if inp.op.type == "Switch": + self._process_switch(inp.op, ops_which_must_run, + last_op_using_resource_tensor, + merge_for_resource) + # Ensure uses of resources are serialized + if inp in last_op_using_resource_tensor: + if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access + is op._control_flow_context): # pylint: disable=protected-access + control_inputs.add(last_op_using_resource_tensor[inp]) + # Ensure merges happen after the closing of a cond block + if inp in merge_for_resource: + merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access + last_op_using_resource_tensor[inp] = op + control_inputs = [c for c in control_inputs + if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access + op._add_control_inputs(control_inputs) # pylint: disable=protected-access + + # Ensure all ops which must run do run + for r in self._returned_tensors: + r.op._add_control_inputs( # pylint: disable=protected-access + [o for o in ops_which_must_run + if o._control_flow_context is r.op._control_flow_context]) # pylint: disable=protected-access + + +def automatic_control_dependencies(f): + """Wraps f to automatically insert control dependencies. + + The inserted dependencies ensure that: + 1. All stateful ops in f run when the result of f runs + 2. Updates to the same resources happen in order. + + Args: + f: the function to be wrapped. + + Returns: + The wrapped function. + """ + + def wrapper(*args, **kwds): + with AutomaticControlDependencies() as a: + result = f(*args, **kwds) + for t in nest.flatten(result): + a.mark_as_return(t) + return result + + return tf_decorator.make_decorator(f, wrapper) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 0babc29f17b21ee663cdd5bd170875247353e70b..431d9388c0ee97eda197142ec97b9448d985b04b 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -374,6 +375,78 @@ class FunctionTest(test.TestCase): self.assertAllEqual(f(constant_op.constant(1.0)), 2.0) + def testGradientOfGatherWithDefun(self): + + v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) + + def sum_gather(): + return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) + + grad_fn = backprop.implicit_grad(sum_gather) + gradient = grad_fn() + defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather)) + defun_gradient = defun_grad_fn() + self.assertEqual(len(gradient), len(defun_gradient)) + + gradient = gradient[0][0] + defun_gradient = defun_gradient[0][0] + self.assertAllEqual(gradient.values, defun_gradient.values) + self.assertAllEqual(gradient.indices, defun_gradient.indices) + self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape) + + def testReturningIndexedSlicesWithDefun(self): + + def validate(indexed_slice): + def f(): + return indexed_slice + + output = function.defun(f)() + self.assertTrue(isinstance(output, ops.IndexedSlices)) + self.assertAllEqual(indexed_slice.values, output.values) + self.assertAllEqual(indexed_slice.indices, output.indices) + self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape) + + self.assertEqual( + function.make_defun_op(f).output_shapes, indexed_slice.values.shape) + + arg = ops.IndexedSlices( + values=constant_op.constant([1, 2]), + indices=constant_op.constant([0, 1]), + dense_shape=constant_op.constant([2])) + validate(arg) + + arg = ops.IndexedSlices( + values=constant_op.constant([1, 2]), + indices=constant_op.constant([0, 1]), + dense_shape=None) + validate(arg) + + def testIndexedSliceAsArgumentWithDefun(self): + + @function.defun + def f(indexed_slice): + return indexed_slice + + def validate(arg): + output = f(arg) + self.assertTrue(isinstance(output, ops.IndexedSlices)) + self.assertAllEqual(arg.values, output.values) + self.assertAllEqual(arg.indices, output.indices) + self.assertAllEqual(arg.dense_shape, output.dense_shape) + + indexed_slice = ops.IndexedSlices( + values=constant_op.constant([1]), + indices=constant_op.constant([0]), + dense_shape=constant_op.constant([1])) + validate(indexed_slice) + + # Test that `f` works even when `dense_shape` is None. + indexed_slice = ops.IndexedSlices( + values=constant_op.constant([1]), + indices=constant_op.constant([0]), + dense_shape=None) + validate(indexed_slice) + def testFunctionOnDevice(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') @@ -504,6 +577,191 @@ class FunctionTest(test.TestCase): self.assertAllEqual(ret[0][2], 10) self.assertAllEqual(ret[1], 15) + def testVariableNamesRespectNameScopesWithDefun(self): + @function.defun + def create_variable(): + with ops.name_scope('foo'): + v = resource_variable_ops.ResourceVariable(0.0, name='bar') + self.assertEqual(v.name, 'foo/bar:0') + create_variable() + + def testVariableNamesRespectNameScopesWithDefunInGraph(self): + with context.graph_mode(): + @function.defun + def create_variable(): + with ops.name_scope('foo'): + v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar') + self.assertEqual(v.name, 'foo/bar:0') + with ops.get_default_graph().as_default(): + create_variable() + + +class AutomaticControlDependenciesTest(test.TestCase): + + def testBasic(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + with function.AutomaticControlDependencies() as c: + v.assign(v + 1) + v.assign(2 * v) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(), 4.0) + + def testCondMustRun(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1) + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0) + + def testCondMustRunSeparateRead(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1) + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + one = constant_op.constant(1.0) + c.mark_as_return(one) + one.eval(feed_dict={p: False}) + self.assertAllEqual(v.read_value().eval(), 5.0) + one.eval(feed_dict={p: True}) + self.assertAllEqual(v.read_value().eval(), 6.0) + + def testCondNested(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + q = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1, name='true') + return 1.0 + + def false_fn(): + + def inner_true_fn(): + v.assign(v * 2, name='false_true') + return 2.0 + + def inner_false_fn(): + v.assign(v * 3, name='false_false') + return 3.0 + + control_flow_ops.cond(q, inner_true_fn, inner_false_fn) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + with ops.name_scope('final'): + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0) + self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0) + self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0) + self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0) + + def testCondOneBranch(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0) + + def testCondOneBranchUpdateBefore(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + v.assign(v * 2) + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 6.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0) + + def testCondOneBranchUpdateAfter(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + v.assign(v * 2) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0) + + def testDecorator(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + + @function.automatic_control_dependencies + def f(): + v.assign(v + 1) + v.assign(2 * v) + return v.read_value() + + self.assertAllEqual(f().eval(), 4.0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/gen_op.bzl b/tensorflow/python/eager/gen_op.bzl deleted file mode 100644 index 8bc1d6c10a60b89a026cb34dbf6fd98d29e909c2..0000000000000000000000000000000000000000 --- a/tensorflow/python/eager/gen_op.bzl +++ /dev/null @@ -1,65 +0,0 @@ -"""For eager-mode Python.""" - -load("//tensorflow:tensorflow.bzl", - "clean_dep", - "tf_binary_additional_srcs", - "tf_copts", - "tf_cc_binary") - -def tfe_gen_op_wrapper_py(name, - out=None, - visibility=None, - deps=[], - generated_target_name=None, - # ApiDefs will be loaded in the order specified in this list. - api_def_srcs=[]): - """Generate an eager-mode Python op wrapper for an op library.""" - # Construct a cc_binary containing the specified ops. - tool_name = "gen_" + name + "_py_wrappers_cc" - if not deps: - deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))] - tf_cc_binary( - name=tool_name, - linkopts=["-lm"], - copts=tf_copts(), - linkstatic=1, - deps=([ - clean_dep("//tensorflow/python/eager:python_eager_op_gen_main") - ] + deps), - visibility=[clean_dep("//visibility:public")],) - - # Invoke the previous cc_binary to generate a python file. - if not out: - out = "gen_" + name + ".py" - - if not api_def_srcs: - api_def_args_str = "," - else: - api_def_args = [] - for api_def_src in api_def_srcs: - # Add directory of the first ApiDef source to args. - # We are assuming all ApiDefs in a single api_def_src are in the - # same directory. - api_def_args.append( - "$$(dirname $$(echo $(locations " + api_def_src + - ") | cut -d\" \" -f1))") - api_def_args_str = ",".join(api_def_args) - - native.genrule( - name=name + "_pygenrule", - outs=[out], - srcs=api_def_srcs, - tools=[tool_name] + tf_binary_additional_srcs(), - cmd=("$(location " + tool_name + ") " + api_def_args_str + " > $@")) - - # Make a py_library out of the generated python file. - if not generated_target_name: - generated_target_name = name - native.py_library( - name=generated_target_name, - srcs=[out], - srcs_version="PY2AND3", - visibility=visibility, - deps=[ - clean_dep("//tensorflow/python/eager:framework_for_generated_wrappers"), - ],) diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 5c13ea89081a7d060c0ed1201f0169b739a204c2..62106bf0e2809e3c056e4a357f3d05251b7dca68 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -252,21 +252,17 @@ def _graph_callable_internal(func, shape_and_dtypes): Callable graph object. """ container = tf_ops.get_default_graph()._container # pylint: disable=protected-access - container_prefix = tf_ops.get_default_graph()._container_prefix # pylint: disable=protected-access + graph_key = tf_ops.get_default_graph()._graph_key # pylint: disable=protected-access with context.graph_mode(): # This graph will store both the initialization and the call version of the # wrapped function. It will later be used by the backprop code to build the # backprop graph, if necessary. captures = {} tmp_graph = function.CapturingGraph(captures) - # Inherit the container from the original graph to create resources at user - # expected containers. Also inherits the container prefix, since this is - # used for error checking when isolating Eager execution (the container - # prefix at creation must match the container prefix when used, and - # variables returned from the graph callable will be used in the outside - # context). + # Inherit the graph key from the original graph to ensure optimizers don't + # misbehave. tmp_graph._container = container # pylint: disable=protected-access - tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access + tmp_graph._graph_key = graph_key # pylint: disable=protected-access with tmp_graph.as_default(): # Placeholders for the non-variable inputs. func_inputs = _get_graph_callable_inputs(shape_and_dtypes) diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index 90a8779ff845b2fd63d1ba1019e8601fef257e42..e6d03297e0b85856ff165af310149c79e494ab36 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -42,6 +42,8 @@ namespace { const int kRightMargin = 78; +constexpr char kEagerFallbackSuffix[] = "_eager_fallback"; + string AttrVarName(const string& attr_name, std::unordered_map* attr_expressions) { const string var = strings::StrCat("_attr_", attr_name); @@ -49,11 +51,12 @@ string AttrVarName(const string& attr_name, return var; } -void AddInferredAttr(const string& attr_name, const string& value_expression, - string* result, +void AddInferredAttr(const string& indentation, const string& attr_name, + const string& value_expression, string* result, std::unordered_map* attr_expressions) { - strings::StrAppend(result, " ", AttrVarName(attr_name, attr_expressions), - " = ", value_expression, "\n"); + strings::StrAppend(result, indentation, + AttrVarName(attr_name, attr_expressions), " = ", + value_expression, "\n"); } string VectorToTuple(const std::vector& l) { @@ -121,11 +124,33 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { string Code() override; protected: - void ExpectListArg(const string& arg_name); - void AddEagerInferredAttrs(); - void AddEagerInputCasts(); - void AddEagerAttrs(); - void AddEagerExecute(const string& num_outputs_expr); + void HandleGraphMode(const string& function_setup); + + string GetEagerNotAllowedError(); + void ExpectListArg(const string& indentation, const string& arg_name, + string* output); + bool GetEagerFunctionSetup(const string& indentation, string* function_setup); + void GetOutputSizesAndNumOutputsExpr(std::vector* output_sizes, + string* num_outputs_expr); + + void AddEagerFunctionTeardown(const string& indentation, + const std::vector& output_sizes, + bool execute_record_gradient); + + bool AddEagerFastPathAndGraphCode(const string& parameters, + const std::vector& output_sizes, + const string& eager_not_allowed_error); + bool AddEagerFallbackCode(const string& parameters, + const std::vector& output_sizes, + const string& num_outputs_expr, + const string& eager_not_allowed_error); + void AddEagerFastPathExecute(); + + void AddEagerInferredAttrs(const string& indentation); + void AddEagerInputCasts(const string& indentation); + void AddEagerAttrs(const string& indentation); + void AddEagerExecute(const string& indentation, + const string& num_outputs_expr); void AddAttrForArg(const string& attr, int arg_index) { gtl::InsertIfNotPresent(&inferred_attrs_, attr, @@ -148,6 +173,13 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { typedef std::unordered_map> AttrToArgMap; AttrToArgMap attr_to_args_; std::unordered_map attr_expressions_; + // This has all the input args followed by those attrs that don't have + // defaults. + std::vector params_no_default_; + // The parameters with defaults (these have to be listed after those without). + // No input args are included, just attrs. + std::vector> + params_with_default_; }; string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, @@ -207,18 +239,12 @@ string GenEagerPythonOp::Code() { if (api_def_.visibility() == ApiDef::SKIP) { return ""; } - // This has all the input args followed by those attrs that don't have - // defaults. - std::vector params_no_default; - // The parameters with defaults (these have to be listed after those without). - // No input args are included, just attrs. - std::vector> - params_with_default; for (int i = 0; i < api_def_.arg_order_size(); ++i) { const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); - params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to()); + params_no_default_.emplace_back(api_def_arg.name(), + api_def_arg.rename_to()); if (!arg.type_attr().empty()) { AddAttrForArg(arg.type_attr(), i); } else if (!arg.type_list_attr().empty()) { @@ -235,7 +261,7 @@ string GenEagerPythonOp::Code() { if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { if (api_def_attr.has_default_value()) { if (attr.type() == "tensor") { - params_with_default.emplace_back( + params_with_default_.emplace_back( python_op_gen_internal::ParamNames(api_def_attr.name(), api_def_attr.rename_to()), strings::StrCat( @@ -247,22 +273,22 @@ string GenEagerPythonOp::Code() { for (const auto& pb : api_def_attr.default_value().list().tensor()) { pbtxt.emplace_back(TensorPBString(pb)); } - params_with_default.emplace_back( + params_with_default_.emplace_back( python_op_gen_internal::ParamNames(api_def_attr.name(), api_def_attr.rename_to()), strings::StrCat("[_execute.make_tensor(_pb, \"", api_def_attr.rename_to(), "\") for _pb in ", VectorToTuple(pbtxt), "]")); } else { - params_with_default.emplace_back( + params_with_default_.emplace_back( python_op_gen_internal::ParamNames(api_def_attr.name(), api_def_attr.rename_to()), python_op_gen_internal::AttrValueToPython( attr.type(), api_def_attr.default_value(), "_dtypes.")); } } else { - params_no_default.emplace_back(api_def_attr.name(), - api_def_attr.rename_to()); + params_no_default_.emplace_back(api_def_attr.name(), + api_def_attr.rename_to()); } } } @@ -270,29 +296,29 @@ string GenEagerPythonOp::Code() { // Save the list of attr parameters (attrs that won't be inferred), // those with defaults go at the end. // Get the attrs in the order we want by taking the attrs without defaults - // from the end of params_no_default, and adding params_no_default. - attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() + - params_with_default.size()); - for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) { - attrs_.push_back(params_no_default[i].GetName()); + // from the end of params_no_default_, and adding params_no_default_. + attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() + + params_with_default_.size()); + for (int i = op_def_.input_arg_size(); i < params_no_default_.size(); ++i) { + attrs_.push_back(params_no_default_[i].GetName()); } - for (const auto& p : params_with_default) { + for (const auto& p : params_with_default_) { attrs_.push_back(p.first.GetName()); } - param_names_.reserve(params_no_default.size() + params_with_default.size()); - param_names_.insert(param_names_.begin(), params_no_default.begin(), - params_no_default.end()); - for (const auto& param_and_default : params_with_default) { + param_names_.reserve(params_no_default_.size() + params_with_default_.size()); + param_names_.insert(param_names_.begin(), params_no_default_.begin(), + params_no_default_.end()); + for (const auto& param_and_default : params_with_default_) { param_names_.push_back(param_and_default.first); } string parameters; - for (const auto& param : params_no_default) { + for (const auto& param : params_no_default_) { if (!parameters.empty()) strings::StrAppend(¶meters, ", "); strings::StrAppend(¶meters, param.GetRenameTo()); } - for (const auto& param_and_default : params_with_default) { + for (const auto& param_and_default : params_with_default_) { if (!parameters.empty()) strings::StrAppend(¶meters, ", "); strings::StrAppend(¶meters, param_and_default.first.GetRenameTo(), "=", param_and_default.second); @@ -300,19 +326,125 @@ string GenEagerPythonOp::Code() { if (!parameters.empty()) strings::StrAppend(¶meters, ", "); strings::StrAppend(¶meters, "name=None"); - AddExport(); - AddDefLine(parameters); - AddDocStringDescription(); - AddDocStringArgs(); - AddDocStringInputs(); - AddDocStringAttrs(); - AddDocStringNameArg(); - AddOutputGlobals(); - AddDocStringOutputs(); - strings::StrAppend(&result_, " \"\"\"\n"); + // Add attr_expressions_ for attrs that are params. + for (int i = 0; i < attrs_.size(); ++i) { + const string& attr_name = attrs_[i]; + const string& attr_api_name = + param_names_[i + op_def_.input_arg_size()].GetRenameTo(); + attr_expressions_[attr_name] = attr_api_name; + } + // Add attr_expressions_ for attrs that are inferred. + for (int i = 0; i < op_def_.attr_size(); ++i) { + const auto& attr(op_def_.attr(i)); + if (attr.type() == "int") { + auto arg_list = attr_to_args_.find(attr.name()); + if (arg_list != attr_to_args_.end()) { + AttrVarName(attr.name(), &attr_expressions_); + } + } + } + + string num_outputs_expr; + std::vector output_sizes(num_outs_); + GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr); + + string eager_not_allowed_error = GetEagerNotAllowedError(); - // Function body. + if (!AddEagerFastPathAndGraphCode(parameters, output_sizes, + eager_not_allowed_error)) { + return result_; + } + if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr, + eager_not_allowed_error)) { + return result_; + } + + return prelude_ + result_; +} + +void GenEagerPythonOp::HandleGraphMode(const string& function_setup) { + // Handle graph-mode case + strings::StrAppend(&result_, + " _ctx = _context.context()\n" + " if _ctx.in_graph_mode():\n", + function_setup, + " _, _, _op = _op_def_lib._apply_op_helper(\n"); + AddBodyNoReturn(" "); + if (num_outs_ > 0) { + strings::StrAppend(&result_, " _result = _op.outputs[:]\n"); + // Special case handling for stateful op with single list output + // that might be empty. + if (num_outs_ == 1 && op_def_.is_stateful() && + (!op_def_.output_arg(0).number_attr().empty() || + !op_def_.output_arg(0).type_list_attr().empty())) { + // TODO(josh11b): Can skip this if the number_attr/type_list_attr has + // a constraint indicating that this can never be empty. + strings::StrAppend(&result_, + " if not _result:\n" + " return _op\n"); + } + strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n"); + + // Compute graph-mode attrs. + if (op_def_.attr_size() > 0) { + string attr_values; + for (int i = 0; i < op_def_.attr_size(); ++i) { + if (i > 0) strings::StrAppend(&attr_values, ", "); + const auto& attr_name(op_def_.attr(i).name()); + strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"", + attr_name, "\")"); + } + strings::StrAppend(&attr_values, ")"); + strings::StrAppend(&result_, + WordWrap(" _attrs = (", attr_values, kRightMargin), + "\n"); + } else { + strings::StrAppend(&result_, " _attrs = None\n"); + } + } else { + strings::StrAppend(&result_, " return _op\n"); + } +} + +string GenEagerPythonOp::GetEagerNotAllowedError() { + bool eager_allowed = true; + string ref_arg; + for (int i = 0; i < op_def_.input_arg_size(); ++i) { + const auto& arg = op_def_.input_arg(i); + if (arg.is_ref()) { + eager_allowed = false; + DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name()); + ref_arg = api_def_.in_arg(i).rename_to(); + } + } + for (int i = 0; i < op_def_.output_arg_size(); ++i) { + const auto& arg = op_def_.output_arg(i); + if (arg.is_ref()) { + eager_allowed = false; + DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name()); + ref_arg = api_def_.out_arg(i).rename_to(); + } + } + + if (eager_allowed) return ""; + + return strings::StrCat("raise RuntimeError(\"", op_name_, + " op does not support eager execution. ", "Arg '", + ref_arg, "' is a ref.\")\n"); +} + +void GenEagerPythonOp::ExpectListArg(const string& indentation, + const string& arg_name, string* output) { + strings::StrAppend(output, indentation, "if not isinstance(", arg_name, + ", (list, tuple)):\n", indentation, " raise TypeError(\n", + indentation, " \"Expected list for '", arg_name, + "' argument to \"\n", indentation, " \"'", op_name_, + "' Op, not %r.\" % ", arg_name, ")\n"); +} + +bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation, + string* function_setup) { // Validate list inputs, infer length attrs. for (int i = 0; i < op_def_.attr_size(); ++i) { const auto& attr(op_def_.attr(i)); @@ -324,32 +456,27 @@ string GenEagerPythonOp::Code() { for (auto iter = arg_list->second.begin(); iter != arg_list->second.end(); ++iter) { const string& arg_api_name = param_names_[*iter].GetRenameTo(); - ExpectListArg(arg_api_name); + ExpectListArg(indentation, arg_api_name, function_setup); if (iter == arg_list->second.begin()) { - AddInferredAttr(attr.name(), + AddInferredAttr(indentation, attr.name(), strings::StrCat("len(", arg_api_name, ")"), - &result_, &attr_expressions_); + function_setup, &attr_expressions_); } else { const auto& attr_var = attr_expressions_[attr.name()]; - strings::StrAppend(&result_, " if len(", arg_api_name, - ") != ", attr_var, - ":\n" - " raise ValueError(\n" - " \"List argument '", - arg_api_name, "' to '", op_name_, - "' Op with length %d \"\n" - " \"must match length %d of argument '", - inferred_attrs_[attr.name()], - "'.\" %\n" - " (len(", - arg_api_name, "), ", attr_var, "))\n"); + strings::StrAppend( + function_setup, indentation, "if len(", arg_api_name, + ") != ", attr_var, ":\n", indentation, " raise ValueError(\n", + indentation, " \"List argument '", arg_api_name, "' to '", + op_name_, "' Op with length %d \"\n", indentation, + " \"must match length %d of argument '", + inferred_attrs_[attr.name()], "'.\" %\n", indentation, + " (len(", arg_api_name, "), ", attr_var, "))\n"); } } } } } - // Values for non-inferred attrs. for (int i = 0; i < attrs_.size(); ++i) { const string& attr_name = attrs_[i]; const auto& param = param_names_[i + op_def_.input_arg_size()]; @@ -357,241 +484,300 @@ string GenEagerPythonOp::Code() { const string& attr_api_name = param.GetRenameTo(); StringPiece attr_type = attr.type(); attr_expressions_[attr_name] = attr_api_name; - const int default_index = i - (attrs_.size() - params_with_default.size()); + const int default_index = i - (attrs_.size() - params_with_default_.size()); if (default_index >= 0) { - const string& default_value = params_with_default[default_index].second; - strings::StrAppend(&result_, " if ", attr_api_name, " is None:\n"); - strings::StrAppend(&result_, " ", attr_api_name, " = ", default_value, - "\n"); + const string& default_value = params_with_default_[default_index].second; + strings::StrAppend(function_setup, indentation, "if ", attr_api_name, + " is None:\n"); + strings::StrAppend(function_setup, indentation, " ", attr_api_name, + " = ", default_value, "\n"); } if (attr_type.starts_with("list(")) { - ExpectListArg(attr_api_name); + ExpectListArg(indentation, attr_api_name, function_setup); } if (attr_type == "string") { - strings::StrAppend(&result_, " ", attr_api_name, " = _execute.make_str(", - attr_api_name, ", \"", attr_api_name, "\")\n"); + strings::StrAppend(function_setup, indentation, attr_api_name, + " = _execute.make_str(", attr_api_name, ", \"", + attr_api_name, "\")\n"); } else if (attr_type == "list(string)") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = [_execute.make_str(_s, \"", attr_api_name, "\") for _s in ", attr_api_name, "]\n"); } else if (attr_type == "int") { - strings::StrAppend(&result_, " ", attr_api_name, " = _execute.make_int(", - attr_api_name, ", \"", attr_api_name, "\")\n"); + strings::StrAppend(function_setup, indentation, attr_api_name, + " = _execute.make_int(", attr_api_name, ", \"", + attr_api_name, "\")\n"); } else if (attr_type == "list(int)") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = [_execute.make_int(_i, \"", attr_api_name, "\") for _i in ", attr_api_name, "]\n"); } else if (attr_type == "float") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = _execute.make_float(", attr_api_name, ", \"", attr_api_name, "\")\n"); } else if (attr_type == "list(float)") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = [_execute.make_float(_f, \"", attr_api_name, "\") for _f in ", attr_api_name, "]\n"); } else if (attr_type == "bool") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = _execute.make_bool(", attr_api_name, ", \"", attr_api_name, "\")\n"); } else if (attr_type == "list(bool)") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = [_execute.make_bool(_b, \"", attr_api_name, "\") for _b in ", attr_api_name, "]\n"); } else if (attr_type == "type") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = _execute.make_type(", attr_api_name, ", \"", attr_api_name, "\")\n"); } else if (attr_type == "list(type)") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = [_execute.make_type(_t, \"", attr_api_name, "\") for _t in ", attr_api_name, "]\n"); } else if (attr_type == "shape") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = _execute.make_shape(", attr_api_name, ", \"", attr_api_name, "\")\n"); } else if (attr_type == "list(shape)") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = [_execute.make_shape(_s, \"", attr_api_name, "\") for _s in ", attr_api_name, "]\n"); } else if (attr_type == "tensor") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = _execute.make_tensor(", attr_api_name, ", \"", attr_api_name, "\")\n"); } else if (attr_type == "list(tensor)") { - strings::StrAppend(&result_, " ", attr_api_name, + strings::StrAppend(function_setup, indentation, attr_api_name, " = [_execute.make_tensor(_t, \"", attr_api_name, "\") for _t in ", attr_api_name, "]\n"); } else if (attr_type != "func") { - return strings::StrCat("# No definition for ", function_name_, - " since we don't support attrs with type\n" - "# '", - attr_type, "' right now.\n\n"); + *function_setup = + strings::StrCat("# No definition for ", function_name_, + " since we don't support attrs with type\n" + "# '", + attr_type, "' right now.\n\n"); + return false; } } + return true; +} - // Figure out the list of inputs. - const string inputs = FlattenInputs(nullptr, nullptr); - - // Handle graph-mode case - strings::StrAppend(&result_, - " _ctx = _context.context()\n" - - " if _ctx.in_graph_mode():\n" - " _, _, _op = _op_def_lib._apply_op_helper(\n"); - AddBodyNoReturn(" "); - if (num_outs_ > 0) { - strings::StrAppend(&result_, " _result = _op.outputs[:]\n"); - // Special case handling for stateful op with single list output - // that might be empty. - if (num_outs_ == 1 && op_def_.is_stateful() && - (!op_def_.output_arg(0).number_attr().empty() || - !op_def_.output_arg(0).type_list_attr().empty())) { - // TODO(josh11b): Can skip this if the number_attr/type_list_attr has - // a constraint indicating that this can never be empty. - strings::StrAppend(&result_, - " if not _result:\n" - " return _op\n"); - } - strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n"); - - // Compute graph-mode attrs. - if (op_def_.attr_size() > 0) { - string attr_values; - for (int i = 0; i < op_def_.attr_size(); ++i) { - if (i > 0) strings::StrAppend(&attr_values, ", "); - const auto& attr_name(op_def_.attr(i).name()); - strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"", - attr_name, "\")"); - } - strings::StrAppend(&attr_values, ")"); - strings::StrAppend(&result_, - WordWrap(" _attrs = (", attr_values, kRightMargin), - "\n"); - } else { - strings::StrAppend(&result_, " _attrs = None\n"); - } - } else { - strings::StrAppend(&result_, " return _op\n"); - } - - // Handle eager-mode case - strings::StrAppend(&result_, " else:\n"); - +// If output i is list output, output_sizes[i] will be set to a +// string with the python expression that will evaluate to its +// length. output_sizes[i] is empty for non-list outputs. +void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr( + std::vector* output_sizes, string* num_outputs_expr) { // Expression representing the number of outputs. int num_fixed_outputs = 0; - string num_outputs_expr; - // If output i is list output, output_sizes[i] will be set to a - // string with the python expression that will evaluate to its - // length. output_sizes[i] is empty for non-list outputs. - std::vector output_sizes(num_outs_); for (int i = 0; i < num_outs_; ++i) { const auto& arg(op_def_.output_arg(i)); if (!arg.number_attr().empty()) { - if (!num_outputs_expr.empty()) { - strings::StrAppend(&num_outputs_expr, " + "); + if (!num_outputs_expr->empty()) { + strings::StrAppend(num_outputs_expr, " + "); } - output_sizes[i] = attr_expressions_[arg.number_attr()]; - strings::StrAppend(&num_outputs_expr, output_sizes[i]); + (*output_sizes)[i] = attr_expressions_[arg.number_attr()]; + strings::StrAppend(num_outputs_expr, (*output_sizes)[i]); } else if (!arg.type_list_attr().empty()) { - if (!num_outputs_expr.empty()) { - strings::StrAppend(&num_outputs_expr, " + "); + if (!num_outputs_expr->empty()) { + strings::StrAppend(num_outputs_expr, " + "); } // Have to be careful to use an expression that works in both // graph and eager paths here. const auto iter = inferred_attrs_.find(arg.type_list_attr()); if (iter == inferred_attrs_.end()) { - output_sizes[i] = strings::StrCat( + (*output_sizes)[i] = strings::StrCat( "len(", attr_expressions_[arg.type_list_attr()], ")"); } else { - output_sizes[i] = strings::StrCat("len(", iter->second, ")"); + (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")"); } - strings::StrAppend(&num_outputs_expr, output_sizes[i]); + strings::StrAppend(num_outputs_expr, (*output_sizes)[i]); } else { ++num_fixed_outputs; } } if (num_fixed_outputs > 0) { - if (!num_outputs_expr.empty()) { - strings::StrAppend(&num_outputs_expr, " + "); - } - strings::StrAppend(&num_outputs_expr, num_fixed_outputs); - } else if (num_outputs_expr.empty()) { - num_outputs_expr = "0"; - } - - bool eager_allowed = true; - string ref_arg; - for (int i = 0; i < op_def_.input_arg_size(); ++i) { - const auto& arg = op_def_.input_arg(i); - if (arg.is_ref()) { - eager_allowed = false; - DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name()); - ref_arg = api_def_.in_arg(i).rename_to(); - } - } - for (int i = 0; i < op_def_.output_arg_size(); ++i) { - const auto& arg = op_def_.output_arg(i); - if (arg.is_ref()) { - eager_allowed = false; - DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name()); - ref_arg = api_def_.out_arg(i).rename_to(); + if (!num_outputs_expr->empty()) { + strings::StrAppend(num_outputs_expr, " + "); } + strings::StrAppend(num_outputs_expr, num_fixed_outputs); + } else if (num_outputs_expr->empty()) { + *num_outputs_expr = "0"; } +} - if (eager_allowed) { - AddEagerInferredAttrs(); - AddEagerInputCasts(); - strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n"); - AddEagerAttrs(); - AddEagerExecute(num_outputs_expr); - } else { - strings::StrAppend(&result_, - " raise RuntimeError(\n" - " \"", - op_name_, " op does not support eager execution. ", - "Arg '", ref_arg, "'' is a ref.\")\n"); - } - +void GenEagerPythonOp::AddEagerFunctionTeardown( + const string& indentation, const std::vector& output_sizes, + bool execute_record_gradient) { if (num_outs_ > 0) { - strings::StrAppend(&result_, " _execute.record_gradient(\n", " \"", - op_def_.name(), - "\", _inputs_flat, _attrs, _result, name)\n"); + if (execute_record_gradient) { + strings::StrAppend(&result_, indentation, "_execute.record_gradient(\n", + " \"", op_def_.name(), + "\", _inputs_flat, _attrs, _result, name)\n"); + } if (num_outs_ == 1 && !output_sizes[0].empty()) { // Single list result. } else if (num_outs_ == 1) { // Execute returns a single-element list which we need to destructure. - strings::StrAppend(&result_, " _result, = _result\n"); + strings::StrAppend(&result_, indentation, "_result, = _result\n"); } else { // Have multiple outputs, so we will need to reformat the return // value of execute() to be a list with one entry per op output // (that entry will be a list of tensors if that output is of list // type). // For list outputs, convert the right subrange of _result into a list. - Unflatten(" ", output_sizes, "_result", &result_); + Unflatten(indentation, output_sizes, "_result", &result_); // Convert to a named tuple. - strings::StrAppend(&result_, " _result = _", op_def_.name(), + strings::StrAppend(&result_, indentation, "_result = _", op_def_.name(), "Output._make(_result)\n"); } } else { - strings::StrAppend(&result_, " _result = None\n"); + strings::StrAppend(&result_, indentation, "_result = None\n"); } - strings::StrAppend(&result_, " return _result\n\n"); - return prelude_ + result_; + strings::StrAppend(&result_, indentation, "return _result\n\n"); } -void GenEagerPythonOp::ExpectListArg(const string& arg_name) { - strings::StrAppend(&result_, " if not isinstance(", arg_name, - ", (list, tuple)):\n" - " raise TypeError(\n" - " \"Expected list for '", - arg_name, - "' argument to \"\n" - " \"'", - op_name_, "' Op, not %r.\" % ", arg_name, ")\n"); +bool GenEagerPythonOp::AddEagerFastPathAndGraphCode( + const string& parameters, const std::vector& output_sizes, + const string& eager_not_allowed_error) { + AddExport(); + AddDefLine(function_name_, parameters); + AddDocStringDescription(); + AddDocStringArgs(); + AddDocStringInputs(); + AddDocStringAttrs(); + AddDocStringNameArg(); + AddOutputGlobals(); // Added to prelude_ + AddDocStringOutputs(); + strings::StrAppend(&result_, " \"\"\"\n"); + + // Handle graph-mode case + string function_setup; + if (!GetEagerFunctionSetup(" ", &function_setup)) { + result_ = function_setup; + return false; + } + HandleGraphMode(function_setup); + AddEagerFunctionTeardown(" ", output_sizes, + true /* execute_record_gradient */); + + // Handle eager-mode case + strings::StrAppend(&result_, " else:\n"); + + if (eager_not_allowed_error.empty()) { + AddEagerFastPathExecute(); + } else { + strings::StrAppend(&result_, " ", eager_not_allowed_error); + } + + strings::StrAppend(&result_, "\n\n"); + return true; } -void GenEagerPythonOp::AddEagerInferredAttrs() { +bool GenEagerPythonOp::AddEagerFallbackCode( + const string& parameters, const std::vector& output_sizes, + const string& num_outputs_expr, const string& eager_not_allowed_error) { + if (!eager_not_allowed_error.empty()) { + strings::StrAppend(&result_, " ", eager_not_allowed_error); + return true; + } + + AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix), parameters); + strings::StrAppend( + &result_, " r\"\"\"This is the slowpath function for Eager mode.\n"); + strings::StrAppend(&result_, " This is for function ", function_name_, + "\n \"\"\"\n"); + + strings::StrAppend(&result_, " _ctx = _context.context()\n"); + + string function_setup; + if (!GetEagerFunctionSetup(" ", &function_setup)) { + result_ = function_setup; + return false; + } + strings::StrAppend(&result_, function_setup); + + AddEagerInferredAttrs(" "); + AddEagerInputCasts(" "); + strings::StrAppend( + &result_, " _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n"); + AddEagerAttrs(" "); + AddEagerExecute(" ", num_outputs_expr); + + AddEagerFunctionTeardown(" ", output_sizes, + true /* execute_record_gradient */); + + return true; +} + +void GenEagerPythonOp::AddEagerFastPathExecute() { + string fastpath_execute_params = strings::StrCat( + "_ctx._handle, _ctx.device_name, \"", op_def_.name(), "\", ", + "_execute.record_gradient, name, _ctx._post_execution_callbacks"); + string fallback_params; + + for (int i = 0; i < api_def_.in_arg_size(); i++) { + const string param_name = param_names_[i].GetRenameTo(); + strings::StrAppend(&fastpath_execute_params, ", ", param_name); + if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); + strings::StrAppend(&fallback_params, param_name); + } + + for (const auto& attr : api_def_.attr()) { + if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { + strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ", + attr.rename_to()); + + if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); + strings::StrAppend(&fallback_params, attr.rename_to(), "=", + attr.rename_to()); + } + } + + if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); + strings::StrAppend(&fallback_params, "name=name"); + + strings::StrAppend(&result_, " try:\n"); + strings::StrAppend( + &result_, " ", + "_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n", + WordWrap(strings::StrCat(" "), + strings::StrCat(fastpath_execute_params, ")"), kRightMargin), + "\n"); + + if (op_def_.output_arg_size() > 1) { + const string output_tuple_name = + strings::StrCat("_", op_def_.name(), "Output"); + strings::StrAppend(&result_, " ", "_result = ", output_tuple_name, + "._make(_result)\n"); + } + strings::StrAppend(&result_, " ", "return _result\n"); + + // Handle fallback. + strings::StrAppend(&result_, " ", "except _core._FallbackException:\n"); + strings::StrAppend( + &result_, " ", "return ", function_name_, kEagerFallbackSuffix, + "(\n", + WordWrap(strings::StrCat(" "), + strings::StrCat(fallback_params, ")"), kRightMargin), + "\n"); + + // Any errors thrown from execute need to be unwrapped from + // _NotOkStatusException. + strings::StrAppend(&result_, " ", + "except _core._NotOkStatusException as e:\n"); + strings::StrAppend(&result_, " ", "if name is not None:\n"); + strings::StrAppend(&result_, " ", + "message = e.message + \" name: \" + name\n"); + strings::StrAppend(&result_, " ", "else:\n"); + strings::StrAppend(&result_, " ", "message = e.message\n"); + strings::StrAppend( + &result_, " ", + "_six.raise_from(_core._status_to_exception(e.code, message), None)\n"); +} + +void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) { // Figure out values for inferred attrs, and cast to eager tensors. for (int i = 0; i < op_def_.attr_size(); ++i) { const auto& attr(op_def_.attr(i)); @@ -618,24 +804,24 @@ void GenEagerPythonOp::AddEagerInferredAttrs() { const string inputs_var = param_names_[arg_list->second.front()].GetRenameTo(); if (output_sizes.front().empty()) { - strings::StrAppend(&result_, " ", var_name, ", (", inputs_var, - ",) = ", conversion, "\n"); + strings::StrAppend(&result_, indentation, var_name, ", (", + inputs_var, ",) = ", conversion, "\n"); } else { - strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, - " = ", conversion, "\n"); + strings::StrAppend(&result_, indentation, var_name, ", ", + inputs_var, " = ", conversion, "\n"); } } else { const string inputs_var = strings::StrCat("_inputs_", attr.name()); - strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, + strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var, " = ", conversion, "\n"); // Convert from a flat list of eager tensors back to the // parameter variables. - Unflatten(" ", output_sizes, inputs_var, &result_); + Unflatten(indentation, output_sizes, inputs_var, &result_); std::vector p; for (int j : arg_list->second) { p.emplace_back(param_names_[j].GetRenameTo()); } - strings::StrAppend(&result_, " ", VectorToTuple(p), " = ", + strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ", inputs_var, "\n"); } } else if (attr.type() == "list(type)") { @@ -662,14 +848,14 @@ void GenEagerPythonOp::AddEagerInferredAttrs() { inputs_var = param_names_[arg_list->second.front()].GetRenameTo(); conversion = "_execute.convert_to_mixed_eager_tensors"; } - strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, " = ", - conversion, "(", inputs_var, ", _ctx)\n"); + strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var, + " = ", conversion, "(", inputs_var, ", _ctx)\n"); } } } } -void GenEagerPythonOp::AddEagerInputCasts() { +void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) { // Cast remaining args to eager tensors for (int i = 0; i < op_def_.input_arg_size(); ++i) { const auto& arg(op_def_.input_arg(i)); @@ -678,12 +864,12 @@ void GenEagerPythonOp::AddEagerInputCasts() { const string fn = arg.number_attr().empty() ? "" : "n_"; const string dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes."); - strings::StrAppend(&result_, " ", param, " = _ops.convert_", fn, + strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn, "to_tensor(", param, ", ", dtype, ")\n"); } } -void GenEagerPythonOp::AddEagerAttrs() { +void GenEagerPythonOp::AddEagerAttrs(const string& indentation) { // Compute eager attrs if (op_def_.attr_size() > 0) { string attr_values; @@ -695,14 +881,19 @@ void GenEagerPythonOp::AddEagerAttrs() { } strings::StrAppend(&attr_values, ")"); strings::StrAppend( - &result_, WordWrap(" _attrs = (", attr_values, kRightMargin), "\n"); + &result_, + WordWrap(indentation, strings::StrCat("_attrs = (", attr_values), + kRightMargin), + "\n"); } else { - strings::StrAppend(&result_, " _attrs = None\n"); + strings::StrAppend(&result_, indentation, "_attrs = None\n"); } } -void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) { - const string return_prefix = " _result = _execute.execute("; +void GenEagerPythonOp::AddEagerExecute(const string& indentation, + const string& num_outputs_expr) { + const string return_prefix = + strings::StrCat(indentation, "_result = _execute.execute("); const string return_args = strings::StrCat( "b\"", op_def_.name(), "\", ", num_outputs_expr, ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)"); @@ -723,8 +914,8 @@ string GetEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, This file is MACHINE GENERATED! Do not edit. )"); - // Mention the original source file so someone tracing back through generated - // Python code will know where to look next. + // Mention the original source file so someone tracing back through + // generated Python code will know where to look next. if (!source_file_name.empty()) { strings::StrAppend(&result, "Original C++ source file: "); strings::StrAppend(&result, source_file_name); @@ -734,11 +925,14 @@ This file is MACHINE GENERATED! Do not edit. strings::StrAppend(&result, R"(""" import collections as _collections +import six as _six -from tensorflow.python.eager import execute as _execute +from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow from tensorflow.python.eager import context as _context from tensorflow.python.eager import core as _core +from tensorflow.python.eager import execute as _execute from tensorflow.python.framework import dtypes as _dtypes +from tensorflow.python.framework import errors as _errors from tensorflow.python.framework import tensor_shape as _tensor_shape from tensorflow.core.framework import op_def_pb2 as _op_def_pb2 @@ -756,11 +950,21 @@ from tensorflow.python.util.tf_export import tf_export auto out = cleaned_ops.mutable_op(); out->Reserve(ops.op_size()); for (const auto& op_def : ops.op()) { - bool is_hidden = false; - for (const string& hidden : hidden_ops) { - if (op_def.name() == hidden) { - is_hidden = true; - break; + const auto* api_def = api_defs.GetApiDef(op_def.name()); + + if (api_def->visibility() == ApiDef::SKIP) { + continue; + } + + // An op is hidden if either its ApiDef visibility is HIDDEN + // or it is in the hidden_ops list. + bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; + if (!is_hidden) { + for (const string& hidden : hidden_ops) { + if (op_def.name() == hidden) { + is_hidden = true; + break; + } } } @@ -777,7 +981,6 @@ from tensorflow.python.util.tf_export import tf_export continue; } - const auto* api_def = api_defs.GetApiDef(op_def.name()); strings::StrAppend(&result, GetEagerPythonOp(op_def, *api_def, function_name)); diff --git a/tensorflow/python/eager/python_eager_op_gen_main.cc b/tensorflow/python/eager/python_eager_op_gen_main.cc deleted file mode 100644 index 05351bd8b115ae07482b82166974e86758bc7712..0000000000000000000000000000000000000000 --- a/tensorflow/python/eager/python_eager_op_gen_main.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/python/eager/python_eager_op_gen.h" - -#include -#include -#include - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/framework/op_gen_lib.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" - -namespace tensorflow { -namespace { - -void PrintAllPythonOps(const std::vector& hidden_ops, - const std::vector& api_def_dirs) { - OpList ops; - OpRegistry::Global()->Export(false, &ops); - - ApiDefMap api_def_map(ops); - if (!api_def_dirs.empty()) { - Env* env = Env::Default(); - - for (const auto& api_def_dir : api_def_dirs) { - std::vector api_files; - TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"), - &api_files)); - TF_CHECK_OK(api_def_map.LoadFileList(env, api_files)); - } - api_def_map.UpdateDocs(); - } - - PrintEagerPythonOps(ops, api_def_map, hidden_ops, true /* require_shapes */); -} - -} // namespace -} // namespace tensorflow - -int main(int argc, char* argv[]) { - tensorflow::port::InitMain(argv[0], &argc, &argv); - - // Usage: - // python_eager_op_gen_main api_def_dir1,api_def_dir2,... - if (argc == 1) { - tensorflow::PrintAllPythonOps({}, {}); - } else if (argc == 2) { - const std::vector api_def_dirs = - tensorflow::str_util::Split(argv[1], ",", - tensorflow::str_util::SkipEmpty()); - tensorflow::PrintAllPythonOps({}, api_def_dirs); - } else { - return -1; - } - return 0; -} diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 4aea134fa9df845fe2a84f32d56a17a8766bde9b..16b7d1a119a409d1d0a77b220d5d0945b280b638 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -47,8 +47,18 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, // Registers e as the Exception class for handling not ok Status. Returns // Py_None if registration succeeds, else throws a TypeError and returns NULL. +// +// This function is not thread-safe. PyObject* TFE_Py_RegisterExceptionClass(PyObject* e); +// Registers e as the Exception to be raised when the conditions of +// TFE_Py_FastPathExecute_C have not been met. When this exception is set, it +// is a signal to the calling code that it should fall back to the safer (and +// more complete) code path. +// +// This function is not thread-safe. +PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e); + // Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using // `exception` if not nullptr, else using the class registered via // TFE_Py_RegisterExceptionClass), and returns -1. @@ -142,10 +152,12 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, // or NULL for automatic selection. // Item 3: op_name: Name of the TensorFlow op to execute. // Item 4: record_gradient_callback: Callback that records the gradient of the -// result. -// The callback takes (inputs, attrs, result) - all sequences and -// records the gradient. -// Item 5 onwards: inputs - This is a list of inputs followed by a list of +// result. The callback takes (op_name, inputs, attrs, result, name) +// - all sequences and records the gradient. +// Item 5: name: An optional name for the operation. +// Item 6: List representing all callbacks to execute after successful +// op execute. +// Item 7 onwards: inputs - This is a list of inputs followed by a list of // attrs. It is not necessary for type attrs to be present. // // This is named _C since there doesn't seem to be any way to make it visible diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 6162644036998bfaa97ac4a37680b661d844ff7a..cabbcc48fd56563a50591cc6adabc3af75918401 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -23,9 +23,11 @@ limitations under the License. #include "tensorflow/c/eager/tape.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/compactptrset.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/python/eager/pywrap_tensor.h" @@ -56,10 +58,22 @@ PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong) #else PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong) PARSE_VALUE(ParseInt64Value, int64_t, PyInt_Check, PyInt_AsLong) +PARSE_VALUE(ParseInt64LongValue, int64_t, PyLong_Check, PyLong_AsLong) #endif PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble) #undef PARSE_VALUE +Py_ssize_t TensorShapeNumDims(PyObject* value) { + const auto size = PySequence_Size(value); + if (size == -1) { + // TensorShape.__len__ raises an error in the scenario where the shape is an + // unknown, which needs to be cleared. + // TODO(nareshmodi): ensure that this is actually a TensorShape. + PyErr_Clear(); + } + return size; +} + bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, const char** value) { if (PyBytes_Check(py_value)) { @@ -86,32 +100,40 @@ bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status, return true; } -const char* ParseProtoValue(const string& key, const char* proto_name, - PyObject* py_value, size_t* size, - TF_Status* status) { - char* output = nullptr; - Py_ssize_t py_size; - if (PyBytes_Check(py_value) && - PyBytes_AsStringAndSize(py_value, &output, &py_size) >= 0) { - *size = static_cast(py_size); - return output; - } +bool IsInteger(PyObject* py_value) { #if PY_MAJOR_VERSION >= 3 - if (PyUnicode_Check(py_value) && - (output = PyUnicode_AsUTF8AndSize(py_value, &py_size)) != nullptr) { - *size = static_cast(py_size); - return output; - } + return PyLong_Check(py_value); +#else + return PyInt_Check(py_value); #endif - TF_SetStatus(status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat("Expecting a string (serialized ", - proto_name, ") value for attr ", key) - .c_str()); - return nullptr; } -bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list, - TF_AttrType type, TF_Status* status) { +// The passed in py_value is expected to be an object of the python type +// dtypes.DType or an int. +bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status, + int* value) { + if (IsInteger(py_value)) { + return ParseIntValue(key, py_value, status, value); + } + + PyObject* py_type_enum = PyObject_GetAttrString(py_value, "_type_enum"); + if (py_type_enum == nullptr) { + return false; + } + + if (!ParseIntValue(key, py_type_enum, status, value)) { + Py_DECREF(py_type_enum); + return false; + } + + Py_DECREF(py_type_enum); + return true; +} + +bool SetOpAttrList( + TFE_Op* op, const char* key, PyObject* py_list, TF_AttrType type, + tensorflow::gtl::FlatMap* attr_list_sizes, + TF_Status* status) { if (!PySequence_Check(py_list)) { TF_SetStatus( status, TF_INVALID_ARGUMENT, @@ -121,6 +143,7 @@ bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list, return false; } const int num_values = PySequence_Size(py_list); + if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values; #define PARSE_LIST(c_type, parse_fn) \ std::unique_ptr values(new c_type[num_values]); \ @@ -142,7 +165,7 @@ bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list, PARSE_LIST(unsigned char, ParseBoolValue); TFE_OpSetAttrBoolList(op, key, values.get(), num_values); } else if (type == TF_ATTR_TYPE) { - PARSE_LIST(int, ParseIntValue); + PARSE_LIST(int, ParseTypeValue); TFE_OpSetAttrTypeList(op, key, reinterpret_cast(values.get()), num_values); @@ -162,8 +185,10 @@ bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list, .c_str()); return false; } - const auto size = PySequence_Size(py_value); - total_dims += size; + const auto size = TensorShapeNumDims(py_value); + if (size >= 0) { + total_dims += size; + } } } // Allocate a buffer that can fit all of the dims together. @@ -179,7 +204,12 @@ bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list, dims[i] = nullptr; num_dims[i] = -1; } else { - const auto size = PySequence_Size(py_value); + const auto size = TensorShapeNumDims(py_value); + if (size == -1) { + dims[i] = nullptr; + num_dims[i] = -1; + continue; + } dims[i] = offset; num_dims[i] = size; for (int j = 0; j < size; ++j) { @@ -207,8 +237,123 @@ bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list, return true; } -bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key, - PyObject* py_value, TF_AttrType type, TF_Status* status) { +// This is only declared here since GetFunc makes a recursive call to +// SetOpAttrScalarDefault. +void SetOpAttrScalarDefault( + TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, + const char* attr_name, + tensorflow::gtl::FlatMap* attr_list_sizes, + TF_Status* status); + +TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, + TF_Status* status) { + TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); + for (const auto& attr : func.attr()) { + if (TF_GetCode(status) != TF_OK) return nullptr; + SetOpAttrScalarDefault(ctx, func_op, attr.second, attr.first.data(), + nullptr, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + return func_op; +} + +void SetOpAttrListDefault( + TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, + const char* key, TF_AttrType type, + tensorflow::gtl::FlatMap* attr_list_sizes, + TF_Status* status) { + if (type == TF_ATTR_STRING) { + int num_values = attr.default_value().list().s_size(); + std::unique_ptr values(new const char*[num_values]); + (*attr_list_sizes)[key] = num_values; + for (int i = 0; i < num_values; i++) { + values[i] = attr.default_value().list().s(i).data(); + } + TFE_OpSetAttrStringList(op, key, values.get(), num_values); + } else if (type == TF_ATTR_INT) { + int num_values = attr.default_value().list().i_size(); + std::unique_ptr values(new int64_t[num_values]); + (*attr_list_sizes)[key] = num_values; + for (int i = 0; i < num_values; i++) { + values[i] = attr.default_value().list().i(i); + } + TFE_OpSetAttrIntList(op, key, values.get(), num_values); + } else if (type == TF_ATTR_FLOAT) { + int num_values = attr.default_value().list().f_size(); + std::unique_ptr values(new float[num_values]); + (*attr_list_sizes)[key] = num_values; + for (int i = 0; i < num_values; i++) { + values[i] = attr.default_value().list().f(i); + } + TFE_OpSetAttrFloatList(op, key, values.get(), num_values); + } else if (type == TF_ATTR_BOOL) { + int num_values = attr.default_value().list().b_size(); + std::unique_ptr values(new unsigned char[num_values]); + (*attr_list_sizes)[key] = num_values; + for (int i = 0; i < num_values; i++) { + values[i] = attr.default_value().list().b(i); + } + TFE_OpSetAttrBoolList(op, key, values.get(), num_values); + } else if (type == TF_ATTR_TYPE) { + int num_values = attr.default_value().list().type_size(); + std::unique_ptr values(new int[num_values]); + (*attr_list_sizes)[key] = num_values; + for (int i = 0; i < num_values; i++) { + values[i] = attr.default_value().list().type(i); + } + TFE_OpSetAttrTypeList(op, key, + reinterpret_cast(values.get()), + attr.default_value().list().type_size()); + } else if (type == TF_ATTR_SHAPE) { + int num_values = attr.default_value().list().shape_size(); + (*attr_list_sizes)[key] = num_values; + int total_dims = 0; + for (int i = 0; i < num_values; ++i) { + if (!attr.default_value().list().shape(i).unknown_rank()) { + total_dims += attr.default_value().list().shape(i).dim_size(); + } + } + // Allocate a buffer that can fit all of the dims together. + std::unique_ptr buffer(new int64_t[total_dims]); + // Copy the input dims into the buffer and set dims to point to + // the start of each list's dims. + std::unique_ptr dims(new const int64_t*[num_values]); + std::unique_ptr num_dims(new int[num_values]); + int64_t* offset = buffer.get(); + for (int i = 0; i < num_values; ++i) { + const auto& shape = attr.default_value().list().shape(i); + if (shape.unknown_rank()) { + dims[i] = nullptr; + num_dims[i] = -1; + } else { + for (int j = 0; j < shape.dim_size(); j++) { + *offset = shape.dim(j).size(); + ++offset; + } + } + } + TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, + status); + } else if (type == TF_ATTR_FUNC) { + int num_values = attr.default_value().list().func_size(); + (*attr_list_sizes)[key] = num_values; + std::unique_ptr funcs(new const TFE_Op*[num_values]); + for (int i = 0; i < num_values; i++) { + funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status); + } + TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); + } else { + TF_SetStatus(status, TF_UNIMPLEMENTED, + "Lists of tensors are not yet implemented for default valued " + "attributes for an operation."); + } +} + +bool SetOpAttrScalar( + TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value, + TF_AttrType type, + tensorflow::gtl::FlatMap* attr_list_sizes, + TF_Status* status) { if (type == TF_ATTR_STRING) { const char* value; if (!ParseStringValue(key, py_value, status, &value)) return false; @@ -217,6 +362,10 @@ bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key, int64_t value; if (!ParseInt64Value(key, py_value, status, &value)) return false; TFE_OpSetAttrInt(op, key, value); + // attr_list_sizes is set for all int attributes (since at this point we are + // not aware if that attribute might be used to calculate the size of an + // output list or not). + if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value; } else if (type == TF_ATTR_FLOAT) { float value; if (!ParseFloatValue(key, py_value, status, &value)) return false; @@ -227,7 +376,7 @@ bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key, TFE_OpSetAttrBool(op, key, value); } else if (type == TF_ATTR_TYPE) { int value; - if (!ParseIntValue(key, py_value, status, &value)) return false; + if (!ParseTypeValue(key, py_value, status, &value)) return false; TFE_OpSetAttrType(op, key, static_cast(value)); } else if (type == TF_ATTR_SHAPE) { if (py_value == Py_None) { @@ -241,7 +390,11 @@ bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key, .c_str()); return false; } - const auto num_dims = PySequence_Size(py_value); + const auto num_dims = TensorShapeNumDims(py_value); + if (num_dims == -1) { + TFE_OpSetAttrShape(op, key, nullptr, -1, status); + return true; + } std::unique_ptr dims(new int64_t[num_dims]); for (int i = 0; i < num_dims; ++i) { auto inner_py_value = PySequence_ITEM(py_value, i); @@ -293,6 +446,65 @@ bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key, return true; } +void SetOpAttrScalarDefault( + TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, + const char* attr_name, + tensorflow::gtl::FlatMap* attr_list_sizes, + TF_Status* status) { + switch (default_value.value_case()) { + case tensorflow::AttrValue::kS: + TFE_OpSetAttrString(op, attr_name, default_value.s().data()); + break; + case tensorflow::AttrValue::kI: + TFE_OpSetAttrInt(op, attr_name, static_cast(default_value.i())); + (*attr_list_sizes)[attr_name] = default_value.i(); + break; + case tensorflow::AttrValue::kF: + TFE_OpSetAttrFloat(op, attr_name, default_value.f()); + break; + case tensorflow::AttrValue::kB: + TFE_OpSetAttrBool(op, attr_name, default_value.b()); + break; + case tensorflow::AttrValue::kType: + TFE_OpSetAttrType(op, attr_name, + static_cast(default_value.type())); + break; + case tensorflow::AttrValue::kShape: { + const auto& tensor_shape = default_value.shape(); + if (tensor_shape.unknown_rank()) { + TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status); + } else { + const auto num_dims = tensor_shape.dim_size(); + std::unique_ptr dims(new int64_t[num_dims]); + for (int i = 0; i < num_dims; ++i) { + dims[i] = tensor_shape.dim(i).size(); + } + TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status); + } + } break; + case tensorflow::AttrValue::kFunc: { + const auto func_op = GetFunc(ctx, default_value.func(), status); + if (TF_GetCode(status) != TF_OK) return; + // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList + // require TFE_Op* and just convert it internally a NameAttrValue, so + // consider adding an overload to the C API to make this case easier. + TFE_OpSetAttrFunction(op, attr_name, func_op); + } break; + case tensorflow::AttrValue::kList: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kTensor: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kPlaceholder: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::VALUE_NOT_SET: + TF_SetStatus( + status, TF_UNIMPLEMENTED, + tensorflow::strings::StrCat("Unable to get setfor default value: ", + default_value.DebugString()) + .data()); + } +} + // start_index is the index at which the Tuple/List attrs will start getting // processed. void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index, @@ -318,9 +530,40 @@ void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index, const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status); if (TF_GetCode(out_status) != TF_OK) return; if (is_list != 0) { - if (!SetOpAttrList(op, key, py_value, type, out_status)) return; + if (!SetOpAttrList(op, key, py_value, type, nullptr, out_status)) return; + } else { + if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status)) + return; + } + } +} + +// This function will set the op attrs required. If an attr has the value of +// None, then it will read the AttrDef to get the default value and set that +// instead. Any failure in this function will simply fall back to the slow +// path. +void SetOpAttrWithDefaults( + TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, + const char* attr_name, PyObject* attr_value, + tensorflow::gtl::FlatMap* attr_list_sizes, + TF_Status* status) { + unsigned char is_list = 0; + const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status); + if (TF_GetCode(status) != TF_OK) return; + if (attr_value == Py_None) { + if (is_list != 0) { + SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes, + status); + } else { + SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name, + attr_list_sizes, status); + } + } else { + if (is_list != 0) { + SetOpAttrList(op, attr_name, attr_value, type, attr_list_sizes, status); } else { - if (!SetOpAttrScalar(ctx, op, key, py_value, type, out_status)) return; + SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, + status); } } } @@ -329,8 +572,12 @@ void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index, tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED); PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr; -static tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED); -static tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0; +// Python subclass of Exception that is created to signal fallback. +PyObject* fallback_exception_class = nullptr; + +tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED); +tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0; + } // namespace void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, @@ -383,6 +630,37 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) { } } +PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { + if (fallback_exception_class != nullptr) { + Py_DECREF(fallback_exception_class); + } + if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { + fallback_exception_class = nullptr; + PyErr_SetString(PyExc_TypeError, + "TFE_Py_RegisterFallbackExceptionClass: " + "Registered class should be subclass of Exception."); + return nullptr; + } else { + Py_INCREF(e); + fallback_exception_class = e; + Py_RETURN_NONE; + } +} + +void RaiseFallbackException(const char* message) { + if (fallback_exception_class != nullptr) { + PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message)); + return; + } + + PyErr_SetString( + PyExc_RuntimeError, + tensorflow::strings::StrCat( + "Fallback exception type not set, attempting to fallback due to ", + message) + .data()); +} + int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) { if (TF_GetCode(status) == TF_OK) return 0; const char* msg = TF_Message(status); @@ -551,6 +829,34 @@ tensorflow::gtl::CompactPointerSet* GetTapeSet() { return tape_set; } +// A safe copy of the current tapeset. Does not get affected by other python +// threads changing the set of active tapes. +class SafeTapeSet { + public: + SafeTapeSet() : tape_set_(*GetTapeSet()) { + for (auto* tape : tape_set_) { + Py_INCREF(tape); + } + } + + ~SafeTapeSet() { + for (auto* tape : tape_set_) { + Py_DECREF(tape); + } + } + + tensorflow::gtl::CompactPointerSet::const_iterator begin() { + return tape_set_.begin(); + } + + tensorflow::gtl::CompactPointerSet::const_iterator end() { + return tape_set_.end(); + } + + private: + tensorflow::gtl::CompactPointerSet tape_set_; +}; + // xcode 7 doesn't define thread_local, so for compatibility we implement our // own. TODO(apassos) remove once we can deprecate xcode 7. #ifndef __APPLE__ @@ -741,10 +1047,7 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) { if (*ThreadTapeIsStopped()) { return; } - // Note: making a copy because watching a variable can trigger a change to the - // set of tapes by allowing python's garbage collector to run. - auto tape_set = *GetTapeSet(); - for (TFE_Py_Tape* tape : tape_set) { + for (TFE_Py_Tape* tape : SafeTapeSet()) { tape->tape->WatchVariable(variable); } } @@ -766,6 +1069,9 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, return; } std::vector input_ids = MakeTensorIDList(input_tensors); + if (PyErr_Occurred()) { + return; + } std::vector output_info; PyObject* seq = PySequence_Fast(output_tensors, "expected a sequence of integer tensor ids"); @@ -797,8 +1103,7 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, return; } - auto set = *GetTapeSet(); - for (TFE_Py_Tape* tape : set) { + for (TFE_Py_Tape* tape : SafeTapeSet()) { Py_INCREF(backward_function); tape->tape->RecordOperation( op_type_str, output_info, input_ids, backward_function, @@ -807,10 +1112,7 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, } void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { - // Note: making a copy because deleting the trace can trigger a change to the - // set of tapes by allowing python's garbage collector to run. - auto tape_set = *GetTapeSet(); - for (TFE_Py_Tape* tape : tape_set) { + for (TFE_Py_Tape* tape : SafeTapeSet()) { tape->tape->DeleteTrace(tensor_id); } } @@ -1030,22 +1332,66 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, } return py_result; } - Py_INCREF(Py_None); - return Py_None; + return PyList_New(0); } namespace { -static const int kFastPathExecuteInputStartIndex = 4; +static const int kFastPathExecuteInputStartIndex = 6; -bool CheckEagerTensors(PyObject* seq, int start_index, int num_to_check) { - for (int i = start_index; i < start_index + num_to_check; i++) { - PyObject* item = PyTuple_GET_ITEM(seq, i); - if (!EagerTensor_CheckExact(item)) return false; +PyObject* GetPythonObjectFromString(const char* s) { +#if PY_MAJOR_VERSION >= 3 + return PyUnicode_FromString(s); +#else + return PyBytes_FromString(s); +#endif +} + +bool CheckEagerTensors(PyObject* seq, int start_index, + const tensorflow::OpDef& op_def) { + for (int i = 0; i < op_def.input_arg_size(); i++) { + PyObject* item = PyTuple_GET_ITEM(seq, i + start_index); + if (!op_def.input_arg(i).number_attr().empty() || + !op_def.input_arg(i).type_list_attr().empty()) { + // This item should be a list input. + if (!PyList_Check(item)) return false; + for (Py_ssize_t j = 0; j < PyList_Size(item); j++) { + if (!EagerTensor_CheckExact(PyList_GET_ITEM(item, j))) return false; + } + } else if (!EagerTensor_CheckExact(item)) { + return false; + } } return true; } +// Adds input and type attr to the op, and to the list of flattened +// inputs/attrs. +bool AddInputToOp(PyObject* input, const tensorflow::OpDef::ArgDef* input_arg, + std::vector* flattened_attrs, + std::vector* flattened_inputs, TFE_Op* op, + TF_Status* status) { + TFE_TensorHandle* input_handle = EagerTensor_Handle(input); + if (input_arg != nullptr && !input_arg->type_attr().empty()) { + auto dtype = TFE_TensorHandleDataType(input_handle); + TFE_OpSetAttrType(op, input_arg->type_attr().data(), dtype); + if (flattened_attrs != nullptr) { + flattened_attrs->push_back( + GetPythonObjectFromString(input_arg->type_attr().data())); + flattened_attrs->push_back(PyLong_FromLong(dtype)); + } + } + + if (flattened_inputs != nullptr) { + flattened_inputs->push_back(input); + } + TFE_OpAddInput(op, input_handle, status); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + return false; + } + return true; +} + const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) { const char* op_name = TFE_GetPythonString(py_op_name); if (op_name == nullptr) { @@ -1072,59 +1418,103 @@ const char* GetDeviceName(PyObject* py_device_name) { return nullptr; } -bool MaybeRunRecordGradientCallback(const tensorflow::OpDef* op_def, - PyObject* args, PyObject* result, - PyObject* record_gradient_callback) { - if (*ThreadTapeIsStopped() || GetTapeSet()->empty() || - record_gradient_callback == Py_None) { - return true; - } - if (!PyCallable_Check(record_gradient_callback)) { - PyErr_SetString( - PyExc_TypeError, - Printf( - "expected a function for record_gradient_callback, got %s instead", - record_gradient_callback->ob_type->tp_name) - .c_str()); +bool RaiseIfNotPyList(PyObject* list, const string& attr_name) { + if (!PyList_Check(list)) { + PyErr_SetString(PyExc_TypeError, + Printf("expected a list for attr %s, got %s instead", + attr_name.data(), list->ob_type->tp_name) + .data()); + return false; } + return true; +} - PyObject* inputs = PyTuple_New(op_def->input_arg_size()); - for (int i = 0; i < op_def->input_arg_size(); i++) { - auto* input = PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i); +bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks, + const tensorflow::OpDef* op_def, PyObject* args, + const std::vector& flattened_inputs, + const std::vector& flattened_attrs, + PyObject* flattened_result, PyObject* op_name, PyObject* name, + PyObject* record_gradient_callback, PyObject* callbacks) { + PyObject* inputs = PyTuple_New(flattened_inputs.size()); + for (int i = 0; i < flattened_inputs.size(); i++) { + PyObject* input = flattened_inputs[i]; Py_INCREF(input); PyTuple_SET_ITEM(inputs, i, input); } - int args_size = PyTuple_GET_SIZE(args); - int num_attrs = - args_size - op_def->input_arg_size() - kFastPathExecuteInputStartIndex; + int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - + op_def->input_arg_size() - + kFastPathExecuteInputStartIndex; + int num_attrs = flattened_attrs.size() + num_non_inferred_attrs; PyObject* attrs = PyTuple_New(num_attrs); - for (int i = 0; i < num_attrs; i++) { + + for (int i = 0; i < num_non_inferred_attrs; i++) { auto* attr = PyTuple_GET_ITEM( args, kFastPathExecuteInputStartIndex + op_def->input_arg_size() + i); Py_INCREF(attr); PyTuple_SET_ITEM(attrs, i, attr); } + for (int i = num_non_inferred_attrs; i < num_attrs; i++) { + // Not INCREFing anything in flattened_attrs as each of those is a new + // reference, so allow the attrs tuple to steal the reference. + PyTuple_SET_ITEM(attrs, i, flattened_attrs.at(i - num_non_inferred_attrs)); + } - PyObject* callback_args = Py_BuildValue("OOO", inputs, attrs, result); - PyObject_CallObject(record_gradient_callback, callback_args); + PyObject* callback_args = + Py_BuildValue("OOOOO", op_name, inputs, attrs, flattened_result, name); + + auto cleaner = tensorflow::gtl::MakeCleanup([inputs, attrs, callback_args] { + Py_DECREF(inputs); + Py_DECREF(attrs); + Py_DECREF(callback_args); + }); + + if (run_gradient_callback) { + if (!PyCallable_Check(record_gradient_callback)) { + PyErr_SetString(PyExc_TypeError, + Printf("expected a function for " + "record_gradient_callback, got %s instead", + record_gradient_callback->ob_type->tp_name) + .c_str()); + return false; + } + + PyObject* callback_result = + PyObject_CallObject(record_gradient_callback, callback_args); + if (!callback_result) { + return false; + } + Py_DECREF(callback_result); + } + + if (run_post_exec_callbacks) { + for (Py_ssize_t i = 0; i < PyList_Size(callbacks); i++) { + PyObject* callback_fn = PyList_GET_ITEM(callbacks, i); + if (!PyCallable_Check(callback_fn)) { + PyErr_SetString( + PyExc_TypeError, + Printf("expected a function for " + "post execution callback in index %ld, got %s instead", + i, callback_fn->ob_type->tp_name) + .c_str()); + return false; + } + PyObject* callback_result = + PyObject_CallObject(callback_fn, callback_args); + if (!callback_result) { + return false; + } + Py_DECREF(callback_result); + } + } - Py_DECREF(inputs); - Py_DECREF(callback_args); - Py_DECREF(attrs); return true; } + } // namespace PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { - TFE_Context* ctx = reinterpret_cast( - PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr)); - const tensorflow::OpDef* op_def = GetOpDef(PyTuple_GET_ITEM(args, 2)); - if (op_def == nullptr) return nullptr; - const char* device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1)); - PyObject* record_gradient_callback = PyTuple_GET_ITEM(args, 3); - Py_ssize_t args_size = PyTuple_GET_SIZE(args); if (args_size < kFastPathExecuteInputStartIndex) { PyErr_SetString( @@ -1135,6 +1525,16 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { return nullptr; } + TFE_Context* ctx = reinterpret_cast( + PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr)); + const char* device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1)); + PyObject* op_name = PyTuple_GET_ITEM(args, 2); + const tensorflow::OpDef* op_def = GetOpDef(op_name); + if (op_def == nullptr) return nullptr; + PyObject* record_gradient_callback = PyTuple_GET_ITEM(args, 3); + PyObject* name = PyTuple_GET_ITEM(args, 4); + PyObject* callbacks = PyTuple_GET_ITEM(args, 5); + if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) { PyErr_SetString( PyExc_ValueError, @@ -1146,13 +1546,10 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { return nullptr; } - if (!CheckEagerTensors(args, kFastPathExecuteInputStartIndex, - op_def->input_arg_size())) { - // TODO(nareshmodi): Maybe some other way of signalling that this should - // fall back? - PyErr_SetString(PyExc_NotImplementedError, - "This function does not handle the case of the path where " - "all inputs are not already EagerTensors."); + if (!CheckEagerTensors(args, kFastPathExecuteInputStartIndex, *op_def)) { + RaiseFallbackException( + "This function does not handle the case of the path where " + "all inputs are not already EagerTensors."); return nullptr; } @@ -1166,62 +1563,236 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { return nullptr; } - TFE_OpSetDevice(op, device_name, status); - if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { - return nullptr; + // Mapping of attr name to size - used to calculate the number of values + // to be expected by the TFE_Execute run. + tensorflow::gtl::FlatMap attr_list_sizes; + + // Set non-inferred attrs, including setting defaults if the attr is passed in + // as None. + for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size(); + i < args_size; i += 2) { + PyObject* py_attr_name = PyTuple_GET_ITEM(args, i); + const tensorflow::StringPiece attr_name(TFE_GetPythonString(py_attr_name)); + PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1); + + // Not creating an index since most of the time there are not more than a + // few attrs. + // TODO(nareshmodi): Maybe include the index as part of the + // OpRegistrationData. + for (const auto& attr : op_def->attr()) { + if (attr_name == attr.name()) { + SetOpAttrWithDefaults(ctx, op, attr, attr_name.data(), py_attr_value, + &attr_list_sizes, status); + + if (TF_GetCode(status) != TF_OK) { + RaiseFallbackException(TF_Message(status)); + return nullptr; + } + + break; + } + } } - // Add non-type attrs. - SetOpAttrs(ctx, op, args, - kFastPathExecuteInputStartIndex + op_def->input_arg_size(), - status); + TFE_OpSetDevice(op, device_name, status); if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { return nullptr; } - // Add type attrs and inputs. + // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks + // (similar to benchmark_tf_gradient_function_*). Also consider using an + // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks + // point out problems with heap allocs. + bool run_gradient_callback = !*ThreadTapeIsStopped() && + !GetTapeSet()->empty() && + record_gradient_callback != Py_None; + bool run_post_exec_callbacks = + callbacks != Py_None && PyList_Size(callbacks) > 0; + bool run_callbacks = run_gradient_callback || run_post_exec_callbacks; + // Flat attrs and inputs as required by the record_gradient call. The attrs + // here only contain inferred attrs (non-inferred attrs are added directly + // from the input args). + // All items in flattened_attrs contain new references. + // All items in flattened_inputs contain borrowed references. + // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work + // directly. + std::unique_ptr> flattened_attrs = nullptr; + std::unique_ptr> flattened_inputs = nullptr; + + if (run_callbacks) { + flattened_attrs.reset(new std::vector); + flattened_inputs.reset(new std::vector); + } + + // Add inferred attrs and inputs. + // The following code might set duplicate type attrs. This will result in + // the CacheKey for the generated AttrBuilder possibly differing from + // those where the type attrs are correctly set. Inconsistent CacheKeys + // for ops means that there might be unnecessarily duplicated kernels. + // TODO(nareshmodi): Fix this. for (int i = 0; i < op_def->input_arg_size(); i++) { const auto& input_arg = op_def->input_arg(i); PyObject* input = PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i); - TFE_TensorHandle* input_handle = EagerTensor_Handle(input); - - // The following code might set duplicate type attrs. This will result in - // the CacheKey for the generated AttrBuilder possibly differing from those - // where the type attrs are correctly set. Inconsistent CacheKeys for ops - // means that there might be unnecessarily duplicated kernels. - // TODO(nareshmodi): Fix this. - if (!input_arg.type_attr().empty()) { - TFE_OpSetAttrType(op, input_arg.type_attr().data(), - TFE_TensorHandleDataType(input_handle)); + if (!input_arg.number_attr().empty()) { + // The item is a homogeneous list. + if (!RaiseIfNotPyList(input, input_arg.number_attr())) return nullptr; + Py_ssize_t len = PyList_Size(input); + + TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len); + if (run_callbacks) { + flattened_attrs->push_back( + GetPythonObjectFromString(input_arg.number_attr().data())); + flattened_attrs->push_back(PyLong_FromLong(len)); + } + attr_list_sizes[input_arg.number_attr()] = len; + + if (len > 0) { + // First item adds the type attr. + if (!AddInputToOp(PyList_GET_ITEM(input, 0), &input_arg, + flattened_attrs.get(), flattened_inputs.get(), op, + status)) { + return nullptr; + } + + for (Py_ssize_t j = 1; j < len; j++) { + // Since the list is homogeneous, we don't need to re-add the attr. + if (!AddInputToOp(PyList_GET_ITEM(input, j), nullptr /* input_arg */, + nullptr /* flattened_attrs */, + flattened_inputs.get(), op, status)) { + return nullptr; + } + } + } + } else if (!input_arg.type_list_attr().empty()) { + // The item is a heterogeneous list. + if (!RaiseIfNotPyList(input, input_arg.type_list_attr())) return nullptr; + const string& attr_name = input_arg.type_list_attr(); + Py_ssize_t len = PyList_Size(input); + tensorflow::gtl::InlinedVector attr_value(len); + PyObject* py_attr_value = nullptr; + if (run_callbacks) { + py_attr_value = PyTuple_New(len); + } + for (Py_ssize_t j = 0; j < len; j++) { + PyObject* py_input = PyList_GET_ITEM(input, j); + TFE_TensorHandle* input_handle = EagerTensor_Handle(py_input); + attr_value[j] = TFE_TensorHandleDataType(input_handle); + + TFE_OpAddInput(op, input_handle, status); + if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + return nullptr; + } + + if (run_callbacks) { + flattened_inputs->push_back(py_input); + + PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j])); + } + } + if (run_callbacks) { + flattened_attrs->push_back(GetPythonObjectFromString(attr_name.data())); + flattened_attrs->push_back(py_attr_value); + } + TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(), + attr_value.size()); + attr_list_sizes[attr_name] = len; + } else { + // The item is a single item. + if (!AddInputToOp(input, &input_arg, flattened_attrs.get(), + flattened_inputs.get(), op, status)) { + return nullptr; + } } + } - TFE_OpAddInput(op, input_handle, status); - if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { - return nullptr; + int num_retvals = 0; + for (int i = 0; i < op_def->output_arg_size(); i++) { + const auto& output_arg = op_def->output_arg(i); + if (!output_arg.number_attr().empty()) { + num_retvals += attr_list_sizes[output_arg.number_attr()]; + } else if (!output_arg.type_list_attr().empty()) { + num_retvals += attr_list_sizes[output_arg.type_list_attr()]; + } else { + num_retvals++; } } - int num_retvals = op_def->output_arg_size(); tensorflow::gtl::InlinedVector retvals(num_retvals); Py_BEGIN_ALLOW_THREADS; TFE_Execute(op, retvals.data(), &num_retvals, status); Py_END_ALLOW_THREADS; - if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { + if (TF_GetCode(status) != TF_OK) { + // Augment the status with the op_name for easier debugging similar to + // TFE_Py_Execute. + TF_SetStatus(status, TF_GetCode(status), + tensorflow::strings::StrCat(TF_Message(status), " [Op:", + TFE_GetPythonString(op_name), "]") + .c_str()); + + MaybeRaiseExceptionFromTFStatus(status, nullptr); return nullptr; } - PyObject* result = PyTuple_New(num_retvals); + PyObject* flat_result = PyList_New(num_retvals); for (int i = 0; i < num_retvals; ++i) { - PyTuple_SET_ITEM(result, i, EagerTensorFromHandle(retvals[i])); + PyList_SET_ITEM(flat_result, i, EagerTensorFromHandle(retvals[i])); } - if (!MaybeRunRecordGradientCallback(op_def, args, result, - record_gradient_callback)) { + if (run_callbacks && + !RunCallbacks(run_gradient_callback, run_post_exec_callbacks, op_def, + args, *flattened_inputs, *flattened_attrs, flat_result, + op_name, name, record_gradient_callback, callbacks)) { return nullptr; } + // Unflatten results. + if (op_def->output_arg_size() == 0) { + Py_RETURN_NONE; + } + + if (op_def->output_arg_size() == 1) { + if (!op_def->output_arg(0).number_attr().empty() || + !op_def->output_arg(0).type_list_attr().empty()) { + return flat_result; + } else { + auto* result = PyList_GET_ITEM(flat_result, 0); + Py_INCREF(result); + Py_DECREF(flat_result); + return result; + } + } + + // Correctly output the results that are made into a namedtuple. + PyObject* result = PyList_New(op_def->output_arg_size()); + int flat_result_index = 0; + for (int i = 0; i < op_def->output_arg_size(); i++) { + if (!op_def->output_arg(i).number_attr().empty()) { + int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()]; + PyObject* inner_list = PyList_New(list_length); + for (int j = 0; j < list_length; j++) { + PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++); + Py_INCREF(obj); + PyList_SET_ITEM(inner_list, j, obj); + } + PyList_SET_ITEM(result, i, inner_list); + } else if (!op_def->output_arg(i).type_list_attr().empty()) { + int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()]; + PyObject* inner_list = PyList_New(list_length); + for (int j = 0; j < list_length; j++) { + PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++); + Py_INCREF(obj); + PyList_SET_ITEM(inner_list, j, obj); + } + PyList_SET_ITEM(result, i, inner_list); + } else { + PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++); + Py_INCREF(obj); + PyList_SET_ITEM(result, i, obj); + } + } + Py_DECREF(flat_result); return result; } diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index d4f4ed592fb99e475af4652a33e5364d9abeea1a..49323e6640e664ef5f98b227964f9dd4e248ca39 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -21,28 +21,15 @@ from __future__ import print_function from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import execute from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -def record_gradient_callback(inputs, attrs, results): - return backprop._record_gradient("MatMul", inputs, attrs, results, None) - - -def c_tfe_py_fastpath_execute(a, b, transpose_a=False, transpose_b=False): - ctx = context.context() - assert not ctx.in_graph_mode( - ), "The prototype doesn't contain C code for graph construction" - ctx_handle = ctx._handle # pylint: disable=protected-access - - return pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx_handle, ctx.device_name, "MatMul", record_gradient_callback, a, b, - "transpose_a", transpose_a, "transpose_b", transpose_b)[0] - - class Tests(test.TestCase): @test_util.assert_no_new_tensors @@ -54,31 +41,100 @@ class Tests(test.TestCase): a_100_by_784 = random_ops.random_uniform((100, 784)) b_100_by_784 = random_ops.random_uniform((100, 784)) + ctx = context.context() + self.assertAllClose( math_ops.matmul(a_2_by_2, b_2_by_2), - c_tfe_py_fastpath_execute(a_2_by_2, b_2_by_2)) + pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, + None, None, a_2_by_2, b_2_by_2, "transpose_a", False, "transpose_b", + False)) self.assertAllClose( math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True), - c_tfe_py_fastpath_execute(a_100_by_784, b_100_by_784, transpose_b=True)) + pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, + None, None, a_100_by_784, b_100_by_784, "transpose_a", False, + "transpose_b", True)) @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created def testFastpathExecute_TapeWrite(self): + ctx = context.context() with backprop.GradientTape(persistent=True) as tape: a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) tape.watch(a_2_by_2) - z = c_tfe_py_fastpath_execute(a_2_by_2, a_2_by_2) + z = pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, None, + None, a_2_by_2, a_2_by_2, "transpose_a", False, "transpose_b", False) dz_dy = tape.gradient(z, [a_2_by_2])[0] self.assertAllEqual(dz_dy.numpy(), constant_op.constant(4.0, shape=[2, 2]).numpy()) + # Tests homogeneous list op @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created - def testFastpathExecute_MatMulSlowPath(self): - a_2_by_2 = random_ops.random_uniform((2, 2)).cpu().numpy() + def testFastpathExecute_AddNCorrectResponse(self): + ctx = context.context() + a_2_by_2 = random_ops.random_uniform((2, 2)) + b_2_by_2 = random_ops.random_uniform((2, 2)) + + self.assertAllClose( + math_ops.add_n([a_2_by_2, b_2_by_2]), + pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "AddN", execute.record_gradient, None, + None, [a_2_by_2, b_2_by_2])) + + # Tests homogeneous list op + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_AddNTapeWrite(self): + ctx = context.context() + a_2_by_2 = random_ops.random_uniform((2, 2)) + b_2_by_2 = random_ops.random_uniform((2, 2)) - with self.assertRaises(NotImplementedError): - c_tfe_py_fastpath_execute(a_2_by_2, a_2_by_2) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(a_2_by_2) + tape.watch(b_2_by_2) + z1 = pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "AddN", execute.record_gradient, None, + None, [a_2_by_2, b_2_by_2]) + z2 = math_ops.add_n([a_2_by_2, b_2_by_2]) + dz1_dy = tape.gradient(z1, [a_2_by_2])[0] + dz2_dy = tape.gradient(z2, [a_2_by_2])[0] + self.assertAllEqual(dz1_dy.numpy(), dz2_dy.numpy()) + + # Tests heterogeneous list op + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_IdentityNCorrectResponse(self): + ctx = context.context() + a_2_by_2 = random_ops.random_uniform((2, 2)) + b_2_by_2 = random_ops.random_uniform((2, 2)) + + self.assertAllClose( + array_ops.identity_n([a_2_by_2, b_2_by_2]), + pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "IdentityN", execute.record_gradient, + None, None, [a_2_by_2, b_2_by_2])) + + # Tests heterogeneous list op + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_IdentityNTapeWrite(self): + ctx = context.context() + a_2_by_2 = random_ops.random_uniform((2, 2)) + b_2_by_2 = random_ops.random_uniform((2, 2)) + + with backprop.GradientTape(persistent=True) as tape: + tape.watch(a_2_by_2) + tape.watch(b_2_by_2) + z1 = pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "IdentityN", execute.record_gradient, + None, None, [a_2_by_2, b_2_by_2]) + z2 = array_ops.identity_n([a_2_by_2, b_2_by_2]) + dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0] + dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0] + self.assertAllEqual(dz1_dy.numpy(), dz2_dy.numpy()) @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created @@ -89,20 +145,24 @@ class Tests(test.TestCase): ), "The prototype doesn't contain C code for graph construction" ctx_handle = ctx._handle # pylint: disable=protected-access + # Not enough base params with self.assertRaisesRegexp(ValueError, - "at least 4 items in the input tuple"): + "at least 6 items in the input tuple"): pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name, "Identity") + # Not enough inputs with self.assertRaisesRegexp(ValueError, - "Expected to be at least 5, was 4"): + "Expected to be at least 7, was 6"): pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx_handle, ctx_handle, "Identity", record_gradient_callback) + ctx_handle, ctx_handle, "Identity", backprop._record_gradient, None, + []) + # Bad type with self.assertRaisesRegexp(TypeError, "expected a string for op_name"): pywrap_tensorflow.TFE_Py_FastPathExecute( - ctx_handle, ctx.device_name, ctx_handle, record_gradient_callback, - a_2_by_2) + ctx_handle, ctx.device_name, ctx_handle, backprop._record_gradient, + None, [], a_2_by_2) if __name__ == "__main__": diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 41f55b12af893e3207ad1ffa45098d12b1c4fff6..c519fd557a9319d6ef5522b26198e5b4202917fc 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -604,6 +604,7 @@ py_library( ":metric_keys", ":model_fn", ":prediction_keys", + ":util", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py index 96e4ecd29fbcd4f4335077e9f81c5704ae2b9bec..3e92a77543e3d2162497e9f995f3adc2a01cb4dd 100644 --- a/tensorflow/python/estimator/canned/baseline.py +++ b/tensorflow/python/estimator/canned/baseline.py @@ -57,7 +57,9 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.training import training_util +from tensorflow.python.util.tf_export import tf_export # The default learning rate of 0.3 is a historical artifact of the initial # implementation, but seems a reasonable choice. @@ -172,6 +174,7 @@ def _baseline_model_fn(features, labels, mode, head, optimizer, train_op_fn=train_op_fn) +@tf_export('estimator.BaselineClassifier') class BaselineClassifier(estimator.Estimator): """A classifier that can establish a simple baseline. @@ -220,7 +223,8 @@ class BaselineClassifier(estimator.Estimator): weight_column=None, label_vocabulary=None, optimizer='Ftrl', - config=None): + config=None, + loss_reduction=losses.Reduction.SUM): """Initializes a BaselineClassifier instance. Args: @@ -240,6 +244,8 @@ class BaselineClassifier(estimator.Estimator): optimizer to use for training. If not specified, will use `FtrlOptimizer` with a default learning rate of 0.3. config: `RunConfig` object to configure the runtime settings. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. Returns: A `BaselineClassifier` estimator. @@ -249,11 +255,13 @@ class BaselineClassifier(estimator.Estimator): if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) else: head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access n_classes, weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) def _model_fn(features, labels, mode, config): return _baseline_model_fn( features=features, @@ -269,6 +277,7 @@ class BaselineClassifier(estimator.Estimator): config=config) +@tf_export('estimator.BaselineRegressor') class BaselineRegressor(estimator.Estimator): """A regressor that can establish a simple baseline. @@ -311,7 +320,8 @@ class BaselineRegressor(estimator.Estimator): label_dimension=1, weight_column=None, optimizer='Ftrl', - config=None): + config=None, + loss_reduction=losses.Reduction.SUM): """Initializes a BaselineRegressor instance. Args: @@ -328,13 +338,16 @@ class BaselineRegressor(estimator.Estimator): optimizer to use for training. If not specified, will use `FtrlOptimizer` with a default learning rate of 0.3. config: `RunConfig` object to configure the runtime settings. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. Returns: A `BaselineRegressor` estimator. """ head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access label_dimension=label_dimension, - weight_column=weight_column) + weight_column=weight_column, + loss_reduction=loss_reduction) def _model_fn(features, labels, mode, config): return _baseline_model_fn( features=features, diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py index 96639e88ea4a07e14121049d78f07e03fcb22156..18c955f5a0e998de983b31fc4cc595895e6bbcbd 100644 --- a/tensorflow/python/estimator/canned/baseline_test.py +++ b/tensorflow/python/estimator/canned/baseline_test.py @@ -1075,7 +1075,7 @@ class BaselineClassifierEvaluationTest(test.TestCase): metric_keys.MetricKeys.LABEL_MEAN: 1., metric_keys.MetricKeys.ACCURACY_BASELINE: 1, metric_keys.MetricKeys.AUC: 0., - metric_keys.MetricKeys.AUC_PR: 1., + metric_keys.MetricKeys.AUC_PR: 0.5, } else: # Multi classes: loss = 1 * -log ( softmax(logits)[label] ) @@ -1136,7 +1136,7 @@ class BaselineClassifierEvaluationTest(test.TestCase): metric_keys.MetricKeys.LABEL_MEAN: 0.5, metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, metric_keys.MetricKeys.AUC: 0.5, - metric_keys.MetricKeys.AUC_PR: 0.75, + metric_keys.MetricKeys.AUC_PR: 0.25, } else: # Expand logits since batch_size=2 @@ -1212,7 +1212,7 @@ class BaselineClassifierEvaluationTest(test.TestCase): metric_keys.MetricKeys.ACCURACY_BASELINE: ( max(label_mean, 1-label_mean)), metric_keys.MetricKeys.AUC: 0.5, - metric_keys.MetricKeys.AUC_PR: 2. / (1. + 2.), + metric_keys.MetricKeys.AUC_PR: 0.16666645, } else: # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] ) diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 0f274a23c03426fc431c15ac0a14617a4a65bb79..7043da8de036e5be27d223271c37e065d9ffbcdd 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import training_util +from tensorflow.python.util.tf_export import tf_export # The default learning rate of 0.05 is a historical artifact of the initial # implementation, but seems a reasonable choice. @@ -149,9 +150,7 @@ def _dnn_model_fn(features, config: `RunConfig` object to configure the runtime settings. Returns: - predictions: A dict of `Tensor` objects. - loss: A scalar containing the loss of the step. - train_op: The op for training. + An `EstimatorSpec` instance. Raises: ValueError: If features has the wrong type. @@ -198,6 +197,7 @@ def _dnn_model_fn(features, logits=logits) +@tf_export('estimator.DNNClassifier') class DNNClassifier(estimator.Estimator): """A classifier for TensorFlow DNN models. @@ -358,6 +358,7 @@ class DNNClassifier(estimator.Estimator): warm_start_from=warm_start_from) +@tf_export('estimator.DNNRegressor') class DNNRegressor(estimator.Estimator): """A regressor for TensorFlow DNN models. diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 1a0f4c5c3931a6b41026470f30e7bdd381e5b37a..6d0fb96057ee93964ee3571bae3b878faad88882 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -37,6 +37,7 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training_util +from tensorflow.python.util.tf_export import tf_export # The default learning rates are a historical artifact of the initial # implementation. @@ -116,7 +117,7 @@ def _dnn_linear_combined_model_fn(features, config: `RunConfig` object to configure the runtime settings. Returns: - `ModelFnOps` + An `EstimatorSpec` instance. Raises: ValueError: If both `linear_feature_columns` and `dnn_features_columns` @@ -225,6 +226,7 @@ def _dnn_linear_combined_model_fn(features, logits=logits) +@tf_export('estimator.DNNLinearCombinedClassifier') class DNNLinearCombinedClassifier(estimator.Estimator): """An estimator for TensorFlow Linear and DNN joined classification models. @@ -405,6 +407,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator): warm_start_from=warm_start_from) +@tf_export('estimator.DNNLinearCombinedRegressor') class DNNLinearCombinedRegressor(estimator.Estimator): """An estimator for TensorFlow Linear and DNN joined models for regression. diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py index 706575985ff9e0fef94f110825ec11af33031ea3..cbae43e4f7fef0271de20a4ec54449989455d4bd 100644 --- a/tensorflow/python/estimator/canned/dnn_testing_utils.py +++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py @@ -1041,7 +1041,7 @@ class BaseDNNClassifierEvaluateTest(object): # There is no good way to calculate AUC for only two data points. But # that is what the algorithm returns. metric_keys.MetricKeys.AUC: 0.5, - metric_keys.MetricKeys.AUC_PR: 0.75, + metric_keys.MetricKeys.AUC_PR: 0.25, ops.GraphKeys.GLOBAL_STEP: global_step }, dnn_classifier.evaluate(input_fn=_input_fn, steps=1)) diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 94a5d3a342dd7bad49d5fb4b91166c67a2705ff3..cb9e3fc6ca116ac0f48a37cea92fa4119754f324 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -24,6 +24,7 @@ import collections import six from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export_output @@ -371,6 +372,64 @@ def _check_logits_final_dim(logits, expected_logits_dimension): return array_ops.identity(logits, name=scope) +def _validate_loss_fn_args(loss_fn): + """Validates loss_fn arguments. + + Required arguments: labels, logits. + Optional arguments: features. + + Args: + loss_fn: The loss function. + Raises: + ValueError: If the signature is unexpected. + """ + loss_fn_args = util.fn_args(loss_fn) + for required_arg in ['labels', 'logits']: + if required_arg not in loss_fn_args: + raise ValueError( + 'loss_fn must contain argument: {}. ' + 'Given arguments: {}'.format(required_arg, loss_fn_args)) + invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features'])) + if invalid_args: + raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args)) + + +def _call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1): + """Calls loss_fn and checks the returned shape. + + Args: + loss_fn: The loss function. + labels: Processed labels Tensor. + logits: Logits Tensor of shape [D0, D1, ... DN, logits_dimension]. + features: Features dict. + expected_loss_dim: The expected last dimension of loss Tensor. + Returns: + Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim]. + """ + loss_fn_args = util.fn_args(loss_fn) + kwargs = {} + if 'features' in loss_fn_args: + kwargs['features'] = features + with ops.name_scope( + None, 'call_loss_fn', + values=[labels, logits] + list(six.itervalues(features))): + unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs) + logits_shape = array_ops.shape(logits, name='logits_shape') + expected_loss_shape = array_ops.concat( + [logits_shape[:-1], [expected_loss_dim]], axis=0, + name='expected_loss_shape') + loss_shape = array_ops.shape(unweighted_loss, name='loss_shape') + check_loss_shape_op = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(loss_shape, expected_loss_shape)), + data=[ + 'loss_fn must return Tensor of shape ' + '[D0, D1, ... DN, {}]. '.format(expected_loss_dim), + 'logits_shape: ', logits_shape, 'loss_shape: ', loss_shape], + name='check_loss_shape') + with ops.control_dependencies([check_loss_shape_op]): + return array_ops.identity(unweighted_loss) + + def _indicator_labels_mean(labels, weights=None, name=None): with ops.name_scope(name, 'labels_mean', (labels, weights)) as scope: labels = math_ops.to_float(labels, name='labels') @@ -467,6 +526,7 @@ def _multi_class_head_with_softmax_cross_entropy_loss( weight_column=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM, + loss_fn=None, name=None): """Creates a '_Head' for multi class classification. @@ -485,6 +545,12 @@ def _multi_class_head_with_softmax_cross_entropy_loss( labels have shape `[batch_size, 1]`, the loss is the weighted sum over `batch_size`. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with + shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to + the input labels before passing them to `loss_fn`. + Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`). @@ -499,6 +565,7 @@ def _multi_class_head_with_softmax_cross_entropy_loss( `label_vocabulary` is not provided but labels are strings. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -517,11 +584,14 @@ def _multi_class_head_with_softmax_cross_entropy_loss( if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) + if loss_fn: + _validate_loss_fn_args(loss_fn) return _MultiClassHeadWithSoftmaxCrossEntropyLoss( n_classes=n_classes, weight_column=weight_column, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) @@ -533,6 +603,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): weight_column=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM, + loss_fn=None, name=None): if (n_classes is None) or (n_classes <= 2): raise ValueError('n_classes must be > 2: %s.' % n_classes) @@ -540,6 +611,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): self._weight_column = weight_column self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction + self._loss_fn = loss_fn self._name = name @property @@ -602,10 +674,15 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): labels = _check_dense_labels_match_logits_and_reshape( labels=labels, logits=logits, expected_labels_dimension=1) label_ids = self._label_ids(labels) - unweighted_loss = losses.sparse_softmax_cross_entropy( - labels=label_ids, logits=logits, reduction=losses.Reduction.NONE) - # Restore the squeezed dim, so unweighted_loss matches the weights shape. - unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1) + if self._loss_fn: + unweighted_loss = _call_loss_fn( + loss_fn=self._loss_fn, labels=label_ids, logits=logits, + features=features, expected_loss_dim=1) + else: + unweighted_loss = losses.sparse_softmax_cross_entropy( + labels=label_ids, logits=logits, reduction=losses.Reduction.NONE) + # Restore the squeezed dim, so unweighted_loss matches the weights shape. + unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1) weights = _get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) training_loss = losses.compute_weighted_loss( @@ -734,8 +811,12 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): def _binary_logistic_head_with_sigmoid_cross_entropy_loss( - weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, name=None): + weight_column=None, + thresholds=None, + label_vocabulary=None, + loss_reduction=losses.Reduction.SUM, + loss_fn=None, + name=None): """Creates a `_Head` for single label binary classification. This head uses `sigmoid_cross_entropy_with_logits` loss. @@ -755,6 +836,12 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( labels have shape `[batch_size, 1]`, the loss is the weighted sum over `batch_size`. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with + shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to + the input labels before passing them to `loss_fn`. + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -772,6 +859,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( is not provided but labels are strings. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -795,11 +883,14 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) + if loss_fn: + _validate_loss_fn_args(loss_fn) return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss( weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) @@ -811,11 +902,13 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): thresholds=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM, + loss_fn=None, name=None): self._weight_column = weight_column self._thresholds = thresholds self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction + self._loss_fn = loss_fn self._name = name @property @@ -916,8 +1009,13 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): name='class_id_lookup').lookup(labels) labels = math_ops.to_float(labels) labels = _assert_range(labels, 2) - unweighted_loss = nn.sigmoid_cross_entropy_with_logits( - labels=labels, logits=logits) + if self._loss_fn: + unweighted_loss = _call_loss_fn( + loss_fn=self._loss_fn, labels=labels, logits=logits, + features=features, expected_loss_dim=1) + else: + unweighted_loss = nn.sigmoid_cross_entropy_with_logits( + labels=labels, logits=logits) weights = _get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) training_loss = losses.compute_weighted_loss( @@ -1057,6 +1155,7 @@ def _regression_head_with_mean_squared_error_loss( weight_column=None, label_dimension=1, loss_reduction=losses.Reduction.SUM, + loss_fn=None, name=None): """Creates a `_Head` for regression using the `mean_squared_error` loss. @@ -1075,6 +1174,10 @@ def _regression_head_with_mean_squared_error_loss( `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN, label_dimension]`. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[D0, D1, ... DN, label_dimension]`. + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -1085,6 +1188,7 @@ def _regression_head_with_mean_squared_error_loss( `[batch_size, label_dimension]`). loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -1097,10 +1201,13 @@ def _regression_head_with_mean_squared_error_loss( if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) + if loss_fn: + _validate_loss_fn_args(loss_fn) return _RegressionHeadWithMeanSquaredErrorLoss( weight_column=weight_column, label_dimension=label_dimension, loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) @@ -1112,6 +1219,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): label_dimension, weight_column=None, loss_reduction=losses.Reduction.SUM, + loss_fn=None, name=None): """`Head` for regression.""" if label_dimension < 1: @@ -1119,6 +1227,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): self._logits_dimension = label_dimension self._weight_column = weight_column self._loss_reduction = loss_reduction + self._loss_fn = loss_fn self._name = name @property @@ -1137,8 +1246,13 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): labels=labels, logits=logits, expected_labels_dimension=self._logits_dimension) labels = math_ops.to_float(labels) - unweighted_loss = losses.mean_squared_error( - labels=labels, predictions=logits, reduction=losses.Reduction.NONE) + if self._loss_fn: + unweighted_loss = _call_loss_fn( + loss_fn=self._loss_fn, labels=labels, logits=logits, + features=features, expected_loss_dim=self._logits_dimension) + else: + unweighted_loss = losses.mean_squared_error( + labels=labels, predictions=logits, reduction=losses.Reduction.NONE) weights = _get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits, allow_per_logit_weights=True) diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 4e871e8f375f346bfd1b0be2cade97c34871f31c..c09f88262af3cdbb952a2ebadf2b2bdaf2a651cb 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -111,6 +111,41 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): head_lib._multi_class_head_with_softmax_cross_entropy_loss( n_classes=3, loss_reduction=losses.Reduction.NONE) + def test_loss_fn_arg_labels_missing(self): + def _loss_fn(logits): + del logits # Unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn must contain argument: labels\. ' + r'Given arguments: \(\'logits\',\)'): + head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_fn=_loss_fn) + + def test_loss_fn_arg_logits_missing(self): + def _loss_fn(labels): + del labels # unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn must contain argument: logits\. ' + r'Given arguments: \(\'labels\',\)'): + head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_fn=_loss_fn) + + def test_loss_fn_arg_features_ok(self): + def _loss_fn(labels, logits, features): + del labels, logits, features # Unused + head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_fn=_loss_fn) + + def test_loss_fn_arg_invalid(self): + def _loss_fn(labels, logits, name=None): + del labels, logits, name # Unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn has unexpected args: \[\'name\'\]'): + head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_fn=_loss_fn) + def test_invalid_logits_shape(self): n_classes = 3 head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes) @@ -406,6 +441,56 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): self.assertAllClose( expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2) + def test_eval_create_loss_loss_fn(self): + """Tests head.create_loss for eval mode and custom loss_fn.""" + loss = np.array([[1.], [2.]], dtype=np.float32) + logits_input = np.array([[-10., 10., 0.], [-15., 10., 0]], dtype=np.float32) + labels_input = np.array([[1], [2]], dtype=np.int64) + def _loss_fn(labels, logits): + check_labels = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(labels, labels_input)), + data=[labels]) + check_logits = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(logits, logits_input)), + data=[logits]) + with ops.control_dependencies([check_labels, check_logits]): + return constant_op.constant(loss) + head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_fn=_loss_fn) + + actual_training_loss = head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits_input, + labels=labels_input)[0] + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + self.assertAllClose(np.sum(loss), actual_training_loss.eval()) + + def test_eval_create_loss_loss_fn_wrong_shape(self): + """Tests custom loss_fn that returns Tensor of unexpected shape.""" + loss = np.array([1., 2.], dtype=np.float32) + def _loss_fn(labels, logits): + del labels, logits # Unused + return constant_op.constant(loss) + head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_fn=_loss_fn) + + logits = np.array([[-10., 10., 0.], [-15., 10., 0.]], dtype=np.float32) + labels = np.array([[1], [2]], dtype=np.int64) + actual_training_loss = head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels)[0] + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] ' + r'\[logits_shape: \] \[2 3\] \[loss_shape: \] \[2\]'): + actual_training_loss.eval() + def test_eval_labels_none(self): """Tests that error is raised when labels is None.""" head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( @@ -1204,6 +1289,41 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( loss_reduction=losses.Reduction.NONE) + def test_loss_fn_arg_labels_missing(self): + def _loss_fn(logits): + del logits # Unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn must contain argument: labels\. ' + r'Given arguments: \(\'logits\',\)'): + head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_fn=_loss_fn) + + def test_loss_fn_arg_logits_missing(self): + def _loss_fn(labels): + del labels # unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn must contain argument: logits\. ' + r'Given arguments: \(\'labels\',\)'): + head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_fn=_loss_fn) + + def test_loss_fn_arg_features_ok(self): + def _loss_fn(labels, logits, features): + del labels, logits, features # Unused + head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_fn=_loss_fn) + + def test_loss_fn_arg_invalid(self): + def _loss_fn(labels, logits, name=None): + del labels, logits, name # Unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn has unexpected args: \[\'name\'\]'): + head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_fn=_loss_fn) + def test_invalid_logits_shape(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() self.assertEqual(1, head.logits_dimension) @@ -1438,7 +1558,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.LABEL_MEAN: 2./2, keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., - keys.AUC_PR: 1., + keys.AUC_PR: 0.74999905, } # Assert spec contains expected tensors. @@ -1516,7 +1636,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.LABEL_MEAN: 2./2, keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., - keys.AUC_PR: 1., + keys.AUC_PR: 0.75, } # Assert predictions, loss, and metrics. @@ -1621,7 +1741,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.LABEL_MEAN: 2./2, keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., - keys.AUC_PR: 1., + keys.AUC_PR: 0.74999905, keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 1., keys.PRECISION_AT_THRESHOLD % thresholds[0]: 1., keys.RECALL_AT_THRESHOLD % thresholds[0]: 1., @@ -1699,6 +1819,56 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval()) self.assertAllClose(expected_weights, actual_weights) + def test_eval_create_loss_loss_fn(self): + """Tests head.create_loss for eval mode and custom loss_fn.""" + loss = np.array([[1.], [2.]], dtype=np.float32) + logits_input = np.array([[-10.], [10.]], dtype=np.float32) + labels_input = np.array([[1], [0]], dtype=np.int64) + def _loss_fn(labels, logits): + check_labels = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(labels, labels_input)), + data=[labels]) + check_logits = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(logits, logits_input)), + data=[logits]) + with ops.control_dependencies([check_labels, check_logits]): + return constant_op.constant(loss) + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_fn=_loss_fn) + + actual_training_loss = head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits_input, + labels=labels_input)[0] + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + self.assertAllClose(np.sum(loss), actual_training_loss.eval()) + + def test_eval_create_loss_loss_fn_wrong_shape(self): + """Tests custom loss_fn that returns Tensor of unexpected shape.""" + loss = np.array([1., 2.], dtype=np.float32) + def _loss_fn(labels, logits): + del labels, logits # Unused + return constant_op.constant(loss) + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_fn=_loss_fn) + + logits = np.array([[-10.], [10.]], dtype=np.float32) + labels = np.array([[1], [0]], dtype=np.int64) + actual_training_loss = head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels)[0] + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] ' + r'\[logits_shape: \] \[2 1\] \[loss_shape: \] \[2\]'): + actual_training_loss.eval() + def test_train_labels_none(self): """Tests that error is raised when labels is None.""" head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() @@ -2018,7 +2188,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.LABEL_MEAN: expected_label_mean, keys.ACCURACY_BASELINE: 1 - expected_label_mean, keys.AUC: .45454565, - keys.AUC_PR: .6737757325172424, + keys.AUC_PR: .21923049, } # Assert spec contains expected tensors. @@ -2317,7 +2487,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): # We cannot reliably calculate AUC with only 4 data points, but the # values should not change because of backwards-compatibility. keys.AUC: 0.5222, - keys.AUC_PR: 0.7341, + keys.AUC_PR: 0.5119, } tol = 1e-2 @@ -2355,6 +2525,37 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): head_lib._regression_head_with_mean_squared_error_loss( loss_reduction=losses.Reduction.NONE) + def test_loss_fn_arg_labels_missing(self): + def _loss_fn(logits): + del logits # Unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn must contain argument: labels\. ' + r'Given arguments: \(\'logits\',\)'): + head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn) + + def test_loss_fn_arg_logits_missing(self): + def _loss_fn(labels): + del labels # unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn must contain argument: logits\. ' + r'Given arguments: \(\'labels\',\)'): + head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn) + + def test_loss_fn_arg_features_ok(self): + def _loss_fn(labels, logits, features): + del labels, logits, features # Unused + head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn) + + def test_loss_fn_arg_invalid(self): + def _loss_fn(labels, logits, name=None): + del labels, logits, name # Unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn has unexpected args: \[\'name\'\]'): + head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn) + def test_invalid_logits(self): head = head_lib._regression_head_with_mean_squared_error_loss( label_dimension=3) @@ -2530,6 +2731,56 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): # loss = [(43-45)^2, (44-41)] = [4, 9] self.assertAllClose(13., training_loss.eval()) + def test_eval_create_loss_loss_fn(self): + """Tests head.create_loss for eval mode and custom loss_fn.""" + loss = np.array([[0., 1.], [2., 3.]], dtype=np.float32) + logits_input = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32) + labels_input = np.array([[1., 0.], [2., -1.]], dtype=np.float32) + def _loss_fn(labels, logits): + check_labels = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(labels, labels_input)), + data=[labels]) + check_logits = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(logits, logits_input)), + data=[logits]) + with ops.control_dependencies([check_labels, check_logits]): + return constant_op.constant(loss) + head = head_lib._regression_head_with_mean_squared_error_loss( + label_dimension=2, loss_fn=_loss_fn) + + actual_training_loss = head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits_input, + labels=labels_input)[0] + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + self.assertAllClose(np.sum(loss), actual_training_loss.eval()) + + def test_eval_create_loss_loss_fn_wrong_shape(self): + """Tests custom loss_fn that returns Tensor of unexpected shape.""" + loss = np.array([[1.], [2.]], dtype=np.float32) + def _loss_fn(labels, logits): + del labels, logits # Unused + return constant_op.constant(loss) + head = head_lib._regression_head_with_mean_squared_error_loss( + label_dimension=2, loss_fn=_loss_fn) + + logits = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32) + labels = np.array([[1., 0.], [2., -1.]], dtype=np.float32) + actual_training_loss = head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels)[0] + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 2\]\. \] ' + r'\[logits_shape: \] \[2 2\] \[loss_shape: \] \[2 1\]'): + actual_training_loss.eval() + def test_eval_labels_none(self): """Tests that error is raised when labels is None.""" head = head_lib._regression_head_with_mean_squared_error_loss() diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index a5b1172e729240a2ea02fa1d4330420786c2686c..a2f24ef27044680fe93b176b5207593165d0d109 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -34,6 +34,7 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import ftrl from tensorflow.python.training import training_util +from tensorflow.python.util.tf_export import tf_export # The default learning rate of 0.2 is a historical artifact of the initial @@ -170,6 +171,7 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, logits=logits) +@tf_export('estimator.LinearClassifier') class LinearClassifier(estimator.Estimator): """Linear classifier model. @@ -322,6 +324,7 @@ class LinearClassifier(estimator.Estimator): warm_start_from=warm_start_from) +@tf_export('estimator.LinearRegressor') class LinearRegressor(estimator.Estimator): """An estimator for TensorFlow Linear regression problems. diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py index 3e9183cf1b633757074377472e9b4cac953e04a1..e88fcbbd2e0e3617dde428662e58b1d86c4eddd0 100644 --- a/tensorflow/python/estimator/canned/linear_testing_utils.py +++ b/tensorflow/python/estimator/canned/linear_testing_utils.py @@ -1342,7 +1342,7 @@ class BaseLinearClassifierEvaluationTest(object): metric_keys.MetricKeys.LABEL_MEAN: 1., metric_keys.MetricKeys.ACCURACY_BASELINE: 1, metric_keys.MetricKeys.AUC: 0., - metric_keys.MetricKeys.AUC_PR: 1., + metric_keys.MetricKeys.AUC_PR: 0.5, } else: # Multi classes: loss = 1 * -log ( soft_max(logits)[label] ) diff --git a/tensorflow/python/estimator/canned/parsing_utils.py b/tensorflow/python/estimator/canned/parsing_utils.py index f153272947ca427b25b00e6df4741d7ada5790df..74e5e5a1bed80229c68daa3ff33ee7af4004bf47 100644 --- a/tensorflow/python/estimator/canned/parsing_utils.py +++ b/tensorflow/python/estimator/canned/parsing_utils.py @@ -23,8 +23,10 @@ import six from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import dtypes from tensorflow.python.ops import parsing_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export('estimator.classifier_parse_example_spec') def classifier_parse_example_spec(feature_columns, label_key, label_dtype=dtypes.int64, @@ -164,6 +166,7 @@ def classifier_parse_example_spec(feature_columns, return parsing_spec +@tf_export('estimator.regressor_parse_example_spec') def regressor_parse_example_spec(feature_columns, label_key, label_dtype=dtypes.float32, diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 96555b5e03c7a291480b3c30fe1f2c641c5c75e1..1167b3834eb6a79abf670f629ec2cbc37957d191 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -57,12 +57,14 @@ from tensorflow.python.training import training_util from tensorflow.python.util import compat from tensorflow.python.util import compat_internal from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export _VALID_MODEL_FN_ARGS = set( ['features', 'labels', 'mode', 'params', 'self', 'config']) +@tf_export('estimator.Estimator') class Estimator(object): """Estimator class to train and evaluate TensorFlow models. @@ -425,7 +427,8 @@ class Estimator(object): input_fn, predict_keys=None, hooks=None, - checkpoint_path=None): + checkpoint_path=None, + yield_single_examples=True): """Yields predictions for given features. Args: @@ -451,13 +454,18 @@ class Estimator(object): inside the prediction call. checkpoint_path: Path of a specific checkpoint to predict. If `None`, the latest checkpoint in `model_dir` is used. + yield_single_examples: If False, yield the whole batch as returned by the + model_fn instead of decomposing the batch into individual elements. This + is useful if model_fn return some tensor with first dimension not + equal to the batch size Yields: Evaluated values of `predictions` tensors. Raises: ValueError: Could not find a trained model in model_dir. - ValueError: if batch length of predictions are not same. + ValueError: if batch length of predictions are not same and + yield_single_examples is True. ValueError: If there is a conflict between `predict_keys` and `predictions`. For example if `predict_keys` is not `None` but `EstimatorSpec.predictions` is not a `dict`. @@ -478,16 +486,21 @@ class Estimator(object): estimator_spec = self._call_model_fn( features, None, model_fn_lib.ModeKeys.PREDICT, self.config) predictions = self._extract_keys(estimator_spec.predictions, predict_keys) + all_hooks = list(input_hooks) + all_hooks.extend(hooks) + all_hooks.extend(list(estimator_spec.prediction_hooks or [])) with training.MonitoredSession( session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, master=self._config.master, scaffold=estimator_spec.scaffold, config=self._session_config), - hooks=input_hooks + hooks) as mon_sess: + hooks=all_hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions) - if not isinstance(predictions, dict): + if not yield_single_examples: + yield preds_evaluated + elif not isinstance(predictions, dict): for pred in preds_evaluated: yield pred else: @@ -499,9 +512,11 @@ class Estimator(object): def _assert_members_are_not_overridden(self): """Asserts members of `Estimator` are not overridden.""" - allowed_overrides = set(['_call_input_fn', '_create_global_step', - '_convert_train_steps_to_hooks', - '_convert_eval_steps_to_hooks']) + allowed_overrides = set([ + '_call_input_fn', '_create_global_step', + '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks', + '_tf_api_names' + ]) estimator_members = set([m for m in Estimator.__dict__.keys() if not m.startswith('__')]) subclass_members = set(self.__class__.__dict__.keys()) @@ -610,7 +625,6 @@ class Estimator(object): sharded=True) saver_for_restore.restore(session, checkpoint_path) - # TODO(b/36111876): replace legacy_init_op with main_op mechanism # pylint: disable=protected-access local_init_op = ( estimator_spec.scaffold.local_init_op or @@ -997,7 +1011,7 @@ def _get_replica_device_setter(config): 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', 'MutableHashTableV2', 'MutableHashTableOfTensors', 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable', - 'MutableDenseHashTableV2' + 'MutableDenseHashTableV2', 'VarHandleOp' ] if config.task_type: @@ -1100,7 +1114,7 @@ def _write_dict_to_summary(output_dir, isinstance(dictionary[key], np.int32) or isinstance(dictionary[key], int)): summary_proto.value.add(tag=key, simple_value=int(dictionary[key])) - elif isinstance(dictionary[key], six.string_types): + elif isinstance(dictionary[key], six.binary_type): try: summ = summary_pb2.Summary.FromString(dictionary[key]) for i, _ in enumerate(summ.value): diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 833f3dcac3b97962c967cba9ac7ab53a3b9c61f1..b0a7752ec74913c959fc176e3eb9001f7418b4a2 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -80,18 +80,18 @@ def dummy_model_fn(features, labels, params): _, _, _ = features, labels, params -def check_eventfile_for_keyword(keyword, est): +def check_eventfile_for_keyword(keyword, dir_): """Checks event files for the keyword.""" writer_cache.FileWriterCache.clear() # Get last Event written. - event_paths = glob.glob(os.path.join(est.model_dir, 'events*')) + event_paths = glob.glob(os.path.join(dir_, 'events*')) last_event = None for last_event in summary_iterator.summary_iterator(event_paths[-1]): if last_event.summary is not None: - if last_event.summary.value: - if keyword in last_event.summary.value[0].tag: + for value in last_event.summary.value: + if keyword in value.tag: return True return False @@ -610,7 +610,7 @@ class EstimatorTrainTest(test.TestCase): # Make sure nothing is stuck in limbo. writer_cache.FileWriterCache.clear() - if check_eventfile_for_keyword('loss', est): + if check_eventfile_for_keyword('loss', est.model_dir): return self.fail('{} should be part of reported summaries.'.format('loss')) @@ -1290,8 +1290,9 @@ class EstimatorEvaluateTest(test.TestCase): # Make sure nothing is stuck in limbo. writer_cache.FileWriterCache.clear() - # Get last Event written. - if check_eventfile_for_keyword('image', est): + # Get last evaluation Event written. + if check_eventfile_for_keyword('image', + os.path.join(est.model_dir, 'eval')): return self.fail('{} should be part of reported summaries.'.format('image')) @@ -1355,6 +1356,25 @@ class EstimatorPredictTest(test.TestCase): est.train(dummy_input_fn, steps=1) self.assertEqual(10., next(est.predict(dummy_input_fn))) + def test_predictionhooks_are_used(self): + hook = test.mock.MagicMock( + wraps=training.SessionRunHook(), spec=training.SessionRunHook) + + def _model_fn_hooks(features, labels, mode): + _, _ = features, labels + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + predictions=constant_op.constant([[10.]]), + prediction_hooks=[hook]) + + est = estimator.Estimator(model_fn=_model_fn_hooks) + est.train(dummy_input_fn, steps=1) + self.assertFalse(hook.begin.called) + next(est.predict(dummy_input_fn)) + self.assertTrue(hook.begin.called) + def test_warn_if_no_queue_runner(self): def _model_fn(features, labels, mode): @@ -1453,6 +1473,27 @@ class EstimatorPredictTest(test.TestCase): 'Batch length of predictions should be same'): next(est.predict(dummy_input_fn)) + def test_iterate_batches(self): + + def _model_fn(features, labels, mode): + _, _ = features, labels + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + predictions={ + # First dim is different but the prediction should still work + 'y1': array_ops.zeros(shape=[3]), + 'y2': array_ops.zeros(shape=[5, 3]) + }) + + est = estimator.Estimator(model_fn=_model_fn) + est.train(dummy_input_fn, steps=1) + + predictions = next(est.predict(dummy_input_fn, yield_single_examples=False)) + self.assertAllEqual(predictions['y1'].shape, [3]) + self.assertAllEqual(predictions['y2'].shape, [5, 3]) + def test_predict_keys_defined_for_tensor(self): def _model_fn(features, labels, mode): diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 51075731ddc52a55799958c3bfa6140f77404541..83251c79fc561e16ebddb638668b92b3c69b8af4 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -36,12 +36,14 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' +@tf_export('estimator.export.ServingInputReceiver') class ServingInputReceiver(collections.namedtuple( 'ServingInputReceiver', ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])): @@ -118,6 +120,7 @@ class ServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives=receiver_tensors_alternatives) +@tf_export('estimator.export.build_parsing_serving_input_receiver_fn') def build_parsing_serving_input_receiver_fn(feature_spec, default_batch_size=None): """Build a serving_input_receiver_fn expecting fed tf.Examples. @@ -146,6 +149,7 @@ def build_parsing_serving_input_receiver_fn(feature_spec, return serving_input_receiver_fn +@tf_export('estimator.export.build_raw_serving_input_receiver_fn') def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """Build a serving_input_receiver_fn expecting feature Tensors. diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 863af6d41d985043542b03375372fe564c283b82..87b964be37197dac99b8ce4398cbdaf3b4989c7f 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -26,8 +26,10 @@ import six from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.util.tf_export import tf_export +@tf_export('estimator.export.ExportOutput') class ExportOutput(object): """Represents an output of a model that can be served. @@ -50,6 +52,7 @@ class ExportOutput(object): pass +@tf_export('estimator.export.ClassificationOutput') class ClassificationOutput(ExportOutput): """Represents the output of a classification head. @@ -118,6 +121,7 @@ class ClassificationOutput(ExportOutput): examples, self.classes, self.scores) +@tf_export('estimator.export.RegressionOutput') class RegressionOutput(ExportOutput): """Represents the output of a regression head.""" @@ -153,6 +157,7 @@ class RegressionOutput(ExportOutput): _SINGLE_OUTPUT_DEFAULT_NAME = 'output' +@tf_export('estimator.export.PredictOutput') class PredictOutput(ExportOutput): """Represents the output of a generic prediction head. diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py index ba522f396d0eda1bb3d13b21acfddcc3d593e21b..a3f04626d1e5ed7ca7fb09a5dcc2457a0cf5ab82 100644 --- a/tensorflow/python/estimator/exporter.py +++ b/tensorflow/python/estimator/exporter.py @@ -25,8 +25,10 @@ from tensorflow.python.estimator import gc from tensorflow.python.framework import errors_impl from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging +from tensorflow.python.util.tf_export import tf_export +@tf_export('estimator.Exporter') class Exporter(object): """A class representing a type of model export.""" @@ -123,6 +125,7 @@ class _SavedModelExporter(Exporter): return export_result +@tf_export('estimator.FinalExporter') class FinalExporter(Exporter): """This class exports the serving graph and checkpoints in the end. @@ -174,6 +177,7 @@ class FinalExporter(Exporter): is_the_final_export) +@tf_export('estimator.LatestExporter') class LatestExporter(Exporter): """This class regularly exports the serving graph and checkpoints. diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py index c4c2e30e8771c5cb1e492fed751c71583dcf477b..a6f471291008e3c27dea1aeea5865e334f76e5c8 100644 --- a/tensorflow/python/estimator/inputs/numpy_io.py +++ b/tensorflow/python/estimator/inputs/numpy_io.py @@ -24,6 +24,7 @@ import numpy as np from six import string_types from tensorflow.python.estimator.inputs.queues import feeding_functions +from tensorflow.python.util.tf_export import tf_export # Key name to pack the target into dict of `features`. See # `_get_unique_target_key` for details. @@ -86,6 +87,7 @@ def _validate_and_convert_features(x): return ordered_dict_data +@tf_export('estimator.inputs.numpy_input_fn') def numpy_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py index 90d6145377d8f931b94793f8a912f77f1620f16e..bd06843021f47f81fc0c22d0fcee43530dc10098 100644 --- a/tensorflow/python/estimator/inputs/pandas_io.py +++ b/tensorflow/python/estimator/inputs/pandas_io.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.estimator.inputs.queues import feeding_functions +from tensorflow.python.util.tf_export import tf_export try: # pylint: disable=g-import-not-at-top @@ -34,6 +35,7 @@ except ImportError: HAS_PANDAS = False +@tf_export('estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py index 75c0e61d47b37110b14aa57f6a185cab822a70bb..8e5d8141a1a15d8cb28aefc0f24c02495337245d 100644 --- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py +++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py @@ -47,10 +47,9 @@ except ImportError: def _fill_array(arr, seq, fillvalue=0): - """ - Recursively fills padded arr with elements from seq. - If length of seq is less than arr padded length, fillvalue used. + """Recursively fills padded arr with elements from seq. + If length of seq is less than arr padded length, fillvalue used. Args: arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len]. seq: Non-padded list of data sampels of shape @@ -84,28 +83,30 @@ def _pad_if_needed(batch_key_item, fillvalue=0): Raises: ValueError if data samples have different shapes (except last padded dim). """ - shapes = [seq.shape[:-1] if len(seq.shape) > 0 else -1 - for seq in batch_key_item] + shapes = [ + seq.shape[:-1] if len(seq.shape) > 0 else -1 for seq in batch_key_item + ] if not all(shapes[0] == x for x in shapes): raise ValueError("Array shapes must match.") - last_length = [seq.shape[-1] if len(seq.shape) > 0 else 0 - for seq in batch_key_item] + last_length = [ + seq.shape[-1] if len(seq.shape) > 0 else 0 for seq in batch_key_item + ] if all([x == last_length[0] for x in last_length]): return batch_key_item batch_size = len(batch_key_item) max_sequence_length = max(last_length) result_batch = np.zeros( - shape=[batch_size] + list(shapes[0]) + [max_sequence_length], - dtype=batch_key_item[0].dtype) + shape=[batch_size] + list(shapes[0]) + [max_sequence_length], + dtype=batch_key_item[0].dtype) _fill_array(result_batch, batch_key_item, fillvalue) return result_batch -def _get_integer_indices_for_next_batch( - batch_indices_start, batch_size, epoch_end, array_length, - current_epoch, total_epochs): +def _get_integer_indices_for_next_batch(batch_indices_start, batch_size, + epoch_end, array_length, current_epoch, + total_epochs): """Returns the integer indices for next batch. If total epochs is not None and current epoch is the final epoch, the end @@ -135,8 +136,9 @@ def _get_integer_indices_for_next_batch( "Already emitted %s epochs." % current_epoch) batch_indices_end = batch_indices_start + batch_size - batch_indices = [j % array_length for j in - range(batch_indices_start, batch_indices_end)] + batch_indices = [ + j % array_length for j in range(batch_indices_start, batch_indices_end) + ] epoch_end_indices = [i for i, x in enumerate(batch_indices) if x == epoch_end] current_epoch += len(epoch_end_indices) @@ -320,16 +322,20 @@ class _GeneratorFeedFn(object): raise KeyError("key mismatch between dicts emitted by GenFun " "Expected {} keys; got {}".format( self._keys, data_row.keys())) - list_dict.setdefault(self._col_placeholders[index], - list()).append(data_row[key]) + list_dict.setdefault(self._col_placeholders[index], list()).append( + data_row[key]) list_dict_size += 1 if self._pad_value is not None: - feed_dict = {key: np.asarray(_pad_if_needed(item, self._pad_value)) - for key, item in list(list_dict.items())} + feed_dict = { + key: np.asarray(_pad_if_needed(item, self._pad_value)) + for key, item in list(list_dict.items()) + } else: - feed_dict = {key: np.asarray(item) - for key, item in list(list_dict.items())} + feed_dict = { + key: np.asarray(item) + for key, item in list(list_dict.items()) + } return feed_dict @@ -382,9 +388,8 @@ def _enqueue_data(data, queue_shapes = [(), data.shape[1:]] get_feed_fn = _ArrayFeedFn elif isinstance(data, collections.OrderedDict): - types = [dtypes.int64] + [ - dtypes.as_dtype(col.dtype) for col in data.values() - ] + types = [dtypes.int64 + ] + [dtypes.as_dtype(col.dtype) for col in data.values()] queue_shapes = [()] + [col.shape[1:] for col in data.values()] get_feed_fn = _OrderedDictNumpyFeedFn elif isinstance(data, tp.FunctionType): @@ -447,11 +452,11 @@ def _enqueue_data(data, seed=seed) elif pad_data: min_after_dequeue = 0 # just for the summary text - queue_shapes = list(map( - lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x, - queue_shapes)) + queue_shapes = list( + map(lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x, + queue_shapes)) queue = data_flow_ops.PaddingFIFOQueue( - capacity, dtypes=types, shapes=queue_shapes) + capacity, dtypes=types, shapes=queue_shapes) else: min_after_dequeue = 0 # just for the summary text queue = data_flow_ops.FIFOQueue( @@ -470,31 +475,35 @@ def _enqueue_data(data, if not pad_data: feed_fns.append( - get_feed_fn( - placeholders, - data, - enqueue_size, - random_start=shuffle, - seed=seed_i, - num_epochs=num_epochs)) + get_feed_fn( + placeholders, + data, + enqueue_size, + random_start=shuffle, + seed=seed_i, + num_epochs=num_epochs)) else: feed_fns.append( - get_feed_fn( - placeholders, - data, - enqueue_size, - random_start=shuffle, - seed=seed_i, - num_epochs=num_epochs, - pad_value=pad_value)) + get_feed_fn( + placeholders, + data, + enqueue_size, + random_start=shuffle, + seed=seed_i, + num_epochs=num_epochs, + pad_value=pad_value)) runner = fqr._FeedingQueueRunner( # pylint: disable=protected-access - queue=queue, enqueue_ops=enqueue_ops, feed_fns=feed_fns) + queue=queue, + enqueue_ops=enqueue_ops, + feed_fns=feed_fns) queue_runner.add_queue_runner(runner) - full = (math_ops.cast( - math_ops.maximum(0, queue.size() - min_after_dequeue), - dtypes.float32) * (1. / (capacity - min_after_dequeue))) + full = ( + math_ops.cast( + math_ops.maximum(0, + queue.size() - min_after_dequeue), dtypes.float32) + * (1. / (capacity - min_after_dequeue))) # Note that name contains a '/' at the end so we intentionally do not place # a '/' after %s below. summary_name = ("queue/%sfraction_over_%d_of_%d_full" % diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index da202408c3680b397994620e221fa4937d7c65e4..8111ab564c017175b3f7bc1020d850db74587958 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -31,8 +31,10 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export +@tf_export('estimator.ModeKeys') class ModeKeys(object): """Standard names for model modes. @@ -52,11 +54,12 @@ LOSS_METRIC_KEY = 'loss' AVERAGE_LOSS_METRIC_KEY = 'average_loss' +@tf_export('estimator.EstimatorSpec') class EstimatorSpec( collections.namedtuple('EstimatorSpec', [ 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops', 'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold', - 'evaluation_hooks' + 'evaluation_hooks', 'prediction_hooks' ])): """Ops and objects returned from a `model_fn` and passed to an `Estimator`. @@ -73,7 +76,8 @@ class EstimatorSpec( training_chief_hooks=None, training_hooks=None, scaffold=None, - evaluation_hooks=None): + evaluation_hooks=None, + prediction_hooks=None): """Creates a validated `EstimatorSpec` instance. Depending on the value of `mode`, different arguments are required. Namely @@ -154,6 +158,8 @@ class EstimatorSpec( initialization, saver, and more to be used in training. evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to run during evaluation. + prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to + run during predictions. Returns: A validated `EstimatorSpec` object. @@ -282,7 +288,10 @@ class EstimatorSpec( training_chief_hooks = tuple(training_chief_hooks or []) training_hooks = tuple(training_hooks or []) evaluation_hooks = tuple(evaluation_hooks or []) - for hook in training_hooks + training_chief_hooks + evaluation_hooks: + prediction_hooks = tuple(prediction_hooks or []) + + for hook in (training_hooks + training_chief_hooks + evaluation_hooks + + prediction_hooks): if not isinstance(hook, session_run_hook.SessionRunHook): raise TypeError( 'All hooks must be SessionRunHook instances, given: {}'.format( @@ -305,7 +314,8 @@ class EstimatorSpec( training_chief_hooks=training_chief_hooks, training_hooks=training_hooks, scaffold=scaffold, - evaluation_hooks=evaluation_hooks) + evaluation_hooks=evaluation_hooks, + prediction_hooks=prediction_hooks) def _replace(self, **kwds): """Return a new EstimatorSpec replacing specified fields with new values.""" diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index d67c4b716161816d941eef94a4b9aeb0643de55e..b7eeeb437cb4a624cdee552be3032364b18a8290 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -72,7 +72,8 @@ class EstimatorSpecTrainTest(test.TestCase): training_chief_hooks=[_FakeHook()], training_hooks=[_FakeHook()], scaffold=monitored_session.Scaffold(), - evaluation_hooks=[_FakeHook()]) + evaluation_hooks=[_FakeHook()], + prediction_hooks=[_FakeHook()]) def testLossNumber(self): """Tests that error is raised when loss is a number (not Tensor).""" @@ -465,7 +466,17 @@ class EstimatorSpecInferTest(test.TestCase): training_chief_hooks=[_FakeHook()], training_hooks=[_FakeHook()], scaffold=monitored_session.Scaffold(), - evaluation_hooks=[_FakeHook()]) + evaluation_hooks=[_FakeHook()], + prediction_hooks=[_FakeHook()]) + + def testPredictionHookInvalid(self): + with ops.Graph().as_default(), self.test_session(): + with self.assertRaisesRegexp( + TypeError, 'All hooks must be SessionRunHook instances'): + model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.PREDICT, + predictions=constant_op.constant(1.), + prediction_hooks=[_InvalidHook()]) def testPredictionsMissing(self): with ops.Graph().as_default(), self.test_session(): diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index e446b3e03a262e0d4abe69df73cb8604b0dab9f9..3e021242c4cc914990c6b38736b8f725213b5b7e 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -27,8 +27,8 @@ import six from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib -from tensorflow.python.util import compat from tensorflow.python.util import compat_internal +from tensorflow.python.util.tf_export import tf_export _USE_DEFAULT = object() @@ -287,6 +287,7 @@ class TaskType(object): EVALUATOR = 'evaluator' +@tf_export('estimator.RunConfig') class RunConfig(object): """This class specifies the configurations for an `Estimator` run.""" diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 52fb1d39ae2e9c84e4269785a72be4f9a495b73c..63328dcfb55646ce2aaf8929d5517c8522c418f2 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Classes and functions related to train_and_evaluate.""" from __future__ import absolute_import @@ -36,7 +35,7 @@ from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import server_lib from tensorflow.python.training import session_run_hook from tensorflow.python.util import compat - +from tensorflow.python.util.tf_export import tf_export _MAX_DELAY_SECS = 60 _DELAY_SECS_PER_WORKER = 5 @@ -50,8 +49,7 @@ _TRAINER_JOBS = (run_config_lib.TaskType.CHIEF, run_config_lib.TaskType.MASTER, def _validate_input_fn(input_fn): """Validates the `input_fn`.""" if not callable(input_fn): - raise TypeError( - '`input_fn` must be callable, given: {}'.format(input_fn)) + raise TypeError('`input_fn` must be callable, given: {}'.format(input_fn)) def _validate_hooks(hooks): @@ -117,6 +115,7 @@ def _is_google_env(): return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE +@tf_export('estimator.TrainSpec') class TrainSpec( collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])): """Configuration for the "train" part for the `train_and_evaluate` call. @@ -125,10 +124,7 @@ class TrainSpec( duration. Optional hooks run at various stages of training. """ - def __new__(cls, - input_fn, - max_steps=None, - hooks=None): + def __new__(cls, input_fn, max_steps=None, hooks=None): """Creates a validated `TrainSpec` instance. Args: @@ -161,16 +157,14 @@ class TrainSpec( hooks = _validate_hooks(hooks) return super(TrainSpec, cls).__new__( - cls, - input_fn=input_fn, - max_steps=max_steps, - hooks=hooks) + cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks) +@tf_export('estimator.EvalSpec') class EvalSpec( collections.namedtuple('EvalSpec', [ - 'input_fn', 'steps', 'name', 'hooks', 'exporters', - 'start_delay_secs', 'throttle_secs' + 'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs', + 'throttle_secs' ])): """Configuration for the "eval" part for the `train_and_evaluate` call. @@ -255,6 +249,7 @@ class EvalSpec( throttle_secs=throttle_secs) +@tf_export('estimator.train_and_evaluate') def train_and_evaluate(estimator, train_spec, eval_spec): """Train and evaluate the `estimator`. @@ -417,8 +412,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec): Raises: ValueError: if environment variable `TF_CONFIG` is incorrectly set. """ - executor = _TrainingExecutor(estimator=estimator, train_spec=train_spec, - eval_spec=eval_spec) + executor = _TrainingExecutor( + estimator=estimator, train_spec=train_spec, eval_spec=eval_spec) config = estimator.config if (config.task_type == run_config_lib.TaskType.EVALUATOR and @@ -561,9 +556,8 @@ class _TrainingExecutor(object): self._timer.update_last_triggered_step(global_step_value) self._evaluator.evaluate_and_export() else: - logging.info( - 'Skip the current checkpoint eval due to throttle secs ' - '({} secs).'.format(self._eval_throttle_secs)) + logging.info('Skip the current checkpoint eval due to throttle secs ' + '({} secs).'.format(self._eval_throttle_secs)) # Final export signal: For any eval result with global_step >= train # max_steps, the evaluator will send the final export signal. There is a @@ -576,8 +570,8 @@ class _TrainingExecutor(object): # # But here, throttle_secs will skip the next intermediate checkpoint and, # so, the double final export chance is very small. - evaluator = _TrainingExecutor._Evaluator( - self._estimator, self._eval_spec, self._train_spec.max_steps) + evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, + self._train_spec.max_steps) # When the underlying `Estimator` object saves a new checkpoint, we would # like this callback to be called so that evaluation and export can trigger. @@ -617,8 +611,7 @@ class _TrainingExecutor(object): raise ValueError('eval_spec.throttle_secs should be positive, given: {}.' 'It is used do determine how long each training ' 'iteration should go when train and evaluate ' - 'locally.'.format( - self._eval_spec.throttle_secs)) + 'locally.'.format(self._eval_spec.throttle_secs)) stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs) train_hooks = ( @@ -663,8 +656,9 @@ class _TrainingExecutor(object): if not config.master: jobs = config.cluster_spec.jobs - if (len(jobs) == 1 and len(config.cluster_spec.job_tasks(jobs[0])) == 1 - and config.task_type in _TRAINER_JOBS): + if (len(jobs) == 1 and + len(config.cluster_spec.job_tasks(jobs[0])) == 1 and + config.task_type in _TRAINER_JOBS): # For distributed training, config.master is empty if and only if it has # a single node in the cluster spec. In this case, we should not start # the server. @@ -679,9 +673,9 @@ class _TrainingExecutor(object): logging.info('Start Tensorflow server.') if config.session_config is None: - session_config=config_pb2.ConfigProto(log_device_placement=False) + session_config = config_pb2.ConfigProto(log_device_placement=False) else: - session_config=config_pb2.ConfigProto( + session_config = config_pb2.ConfigProto( log_device_placement=False, gpu_options=config.session_config.gpu_options) @@ -744,8 +738,7 @@ class _TrainingExecutor(object): global_step >= self._train_spec.max_steps): logging.info( 'Exiting evaluation, global_step=%s >= train max_steps=%s', - global_step, - self._train_spec.max_steps) + global_step, self._train_spec.max_steps) return latest_eval_result, should_early_stop = self._execute_evaluator_once( @@ -781,10 +774,9 @@ class _TrainingExecutor(object): # Throttle if necessary. elapsed_time = time.time() - start - difference = throttle_secs - elapsed_time + difference = throttle_secs - elapsed_time if difference > 0: - logging.info('Waiting %f secs before starting next eval run.', - difference) + logging.info('Waiting %f secs before starting next eval run.', difference) time.sleep(difference) return (eval_result, should_early_stop) @@ -929,8 +921,8 @@ class _EvalResult( if checkpoint_path: raise ValueError( 'checkpoint must be `None` if status is not {}; got status {}, ' - 'checkpoint_path {}'.format( - _EvalStatus.EVALUATED, status, checkpoint_path)) + 'checkpoint_path {}'.format(_EvalStatus.EVALUATED, status, + checkpoint_path)) return super(_EvalResult, cls).__new__(cls, status, metrics, checkpoint_path) diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index b7ba76d8714e6b13551bb3e18083f45e53d2afc3..3ce8eea84b6bf601ce89dfaa7d8e3a5d193468b3 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -21,10 +21,12 @@ from __future__ import print_function import functools +from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect def _is_bounded_method(fn): + _, fn = tf_decorator.unwrap(fn) return tf_inspect.ismethod(fn) and (fn.__self__ is not None) diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py index ad95c71234f82457cb938ca55214b28086b033a2..adb013f5c653c4967a743047fef4e805946e0f59 100644 --- a/tensorflow/python/estimator/warm_starting_util.py +++ b/tensorflow/python/estimator/warm_starting_util.py @@ -30,8 +30,10 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_ops from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver +from tensorflow.python.util.tf_export import tf_export +@tf_export("estimator.VocabInfo") class VocabInfo( collections.namedtuple("VocabInfo", [ "new_vocab", @@ -81,6 +83,7 @@ class VocabInfo( ) +@tf_export("estimator.WarmStartSettings") class WarmStartSettings( collections.namedtuple("WarmStartSettings", [ "ckpt_to_initialize_from", @@ -117,21 +120,13 @@ class WarmStartSettings( ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000") ``` - Warm-start only the embeddings (input layer) and their accumulator variables: + Warm-start only the embeddings (input layer): ``` ws = WarmStartSettings(ckpt_to_initialize_from="/tmp", vars_to_warm_start=".*input_layer.*") ``` - Warm-start everything except the optimizer accumulator variables - (DNN defaults to Adagrad): - - ``` - ws = WarmStartSettings(ckpt_to_initialize_from="/tmp", - vars_to_warm_start="^(?!.*(Adagrad))") - ``` - Warm-start all weights but the embedding parameters corresponding to `sc_vocab_file` have a different vocab from the one used in the current model: @@ -415,14 +410,16 @@ def _warm_start(warm_start_settings): a stronger check for variable configuration than relying on users to examine the logs. """ - logging.info("Warm-starting from: ", - warm_start_settings.ckpt_to_initialize_from) + logging.info("Warm-starting from: %s", + (warm_start_settings.ckpt_to_initialize_from,)) # We have to deal with partitioned variables, since get_collection flattens # out the list. grouped_variables = {} # Both warm_start_settings.vars_to_warm_start = '.*' and # warm_start_settings.vars_to_warm_start = None will match everything here. for v in ops.get_collection( + # TODO(eddz): Allow for different collections here (to support + # warm-starting accumulators). ops.GraphKeys.TRAINABLE_VARIABLES, scope=warm_start_settings.vars_to_warm_start): if not isinstance(v, list): diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 7feb209cc49c4be70387c44168dbdeea6d108d66..c416881c3119c160d28f4b8e37cd2aeb22f239a6 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -157,6 +157,8 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_utils from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import tf_export def _internal_input_layer(features, @@ -209,6 +211,7 @@ def _internal_input_layer(features, return array_ops.concat(output_tensors, 1) +@tf_export('feature_column.input_layer') def input_layer(features, feature_columns, weight_collections=None, @@ -329,6 +332,7 @@ class InputLayer(object): return self._input_layer_template.weights +@tf_export('feature_column.linear_model') def linear_model(features, feature_columns, units=1, @@ -498,6 +502,7 @@ def _transform_features(features, feature_columns): return outputs +@tf_export('feature_column.make_parse_example_spec') def make_parse_example_spec(feature_columns): """Creates parsing spec dictionary from input feature_columns. @@ -507,6 +512,7 @@ def make_parse_example_spec(feature_columns): ```python # Define features and transformations + feature_a = categorical_column_with_vocabulary_file(...) feature_b = numeric_column(...) feature_c_bucketized = bucketized_column(numeric_column("feature_c"), ...) feature_a_x_feature_c = crossed_column( @@ -557,6 +563,7 @@ def make_parse_example_spec(feature_columns): return result +@tf_export('feature_column.embedding_column') def embedding_column( categorical_column, dimension, combiner='mean', initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None, @@ -657,6 +664,7 @@ def embedding_column( trainable=trainable) +@tf_export('feature_column.shared_embedding_columns') def shared_embedding_columns( categorical_columns, dimension, combiner='mean', initializer=None, shared_embedding_collection_name=None, ckpt_to_load_from=None, @@ -807,6 +815,7 @@ def shared_embedding_columns( return result +@tf_export('feature_column.numeric_column') def numeric_column(key, shape=(1,), default_value=None, @@ -881,6 +890,7 @@ def numeric_column(key, normalizer_fn=normalizer_fn) +@tf_export('feature_column.bucketized_column') def bucketized_column(source_column, boundaries): """Represents discretized dense input. @@ -970,6 +980,7 @@ def _assert_string_or_int(dtype, prefix): '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) +@tf_export('feature_column.categorical_column_with_hash_bucket') def categorical_column_with_hash_bucket(key, hash_bucket_size, dtype=dtypes.string): @@ -1026,6 +1037,7 @@ def categorical_column_with_hash_bucket(key, return _HashedCategoricalColumn(key, hash_bucket_size, dtype) +@tf_export('feature_column.categorical_column_with_vocabulary_file') def categorical_column_with_vocabulary_file(key, vocabulary_file, vocabulary_size=None, @@ -1145,6 +1157,7 @@ def categorical_column_with_vocabulary_file(key, dtype=dtype) +@tf_export('feature_column.categorical_column_with_vocabulary_list') def categorical_column_with_vocabulary_list( key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0): """A `_CategoricalColumn` with in-memory vocabulary. @@ -1255,6 +1268,7 @@ def categorical_column_with_vocabulary_list( default_value=default_value, num_oov_buckets=num_oov_buckets) +@tf_export('feature_column.categorical_column_with_identity') def categorical_column_with_identity(key, num_buckets, default_value=None): """A `_CategoricalColumn` that returns identity values. @@ -1322,6 +1336,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None): key=key, num_buckets=num_buckets, default_value=default_value) +@tf_export('feature_column.indicator_column') def indicator_column(categorical_column): """Represents multi-hot representation of given categorical column. @@ -1350,6 +1365,7 @@ def indicator_column(categorical_column): return _IndicatorColumn(categorical_column) +@tf_export('feature_column.weighted_categorical_column') def weighted_categorical_column( categorical_column, weight_feature_key, dtype=dtypes.float32): """Applies weight values to a `_CategoricalColumn`. @@ -1424,6 +1440,7 @@ def weighted_categorical_column( dtype=dtype) +@tf_export('feature_column.crossed_column') def crossed_column(keys, hash_bucket_size, hash_key=None): """Returns a column for performing crosses of categorical features. diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py index 3b1092f923112dbd9a081942d40162ae384bf167..3c5aebbce8af117aa1e216f1ef07ded181c997ea 100644 --- a/tensorflow/python/framework/common_shapes.py +++ b/tensorflow/python/framework/common_shapes.py @@ -34,7 +34,7 @@ def scalar_shape(unused_op): def unchanged_shape(op): - """Shape function for ops that output an tensor like their first input.""" + """Shape function for ops that output a tensor like their first input.""" return [op.inputs[0].get_shape()] diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 67ccf990d6a0e59c965ff76c2ba601be2a64060a..99ae8b24f11c4955379ae532ba7b921ebec63385 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Library of dtypes (Tensor element types).""" from __future__ import absolute_import from __future__ import division from __future__ import print_function - import numpy as np from tensorflow.core.framework import types_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.util.tf_export import tf_export - _np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type() @@ -83,8 +80,8 @@ class DType(object): # TODO(mrry): Make the necessary changes (using __new__) to ensure # that calling this returns one of the interned values. type_enum = int(type_enum) - if (type_enum not in types_pb2.DataType.values() - or type_enum == types_pb2.DT_INVALID): + if (type_enum not in types_pb2.DataType.values() or + type_enum == types_pb2.DT_INVALID): raise TypeError( "type_enum is not a valid types_pb2.DataType: %s" % type_enum) self._type_enum = type_enum @@ -123,10 +120,10 @@ class DType(object): @property def is_numpy_compatible(self): - numpy_incompatible = [types_pb2.DT_VARIANT, - types_pb2.DT_VARIANT_REF, - types_pb2.DT_RESOURCE, - types_pb2.DT_RESOURCE_REF] + numpy_incompatible = [ + types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE, + types_pb2.DT_RESOURCE_REF + ] return self._type_enum not in numpy_incompatible @property @@ -153,9 +150,9 @@ class DType(object): @property def is_floating(self): """Returns whether this is a (non-quantized, real) floating point type.""" - return ((self.is_numpy_compatible and np.issubdtype(self.as_numpy_dtype, - np.floating)) - or self.base_dtype == bfloat16) + return ((self.is_numpy_compatible and + np.issubdtype(self.as_numpy_dtype, np.floating)) or + self.base_dtype == bfloat16) @property def is_complex(self): @@ -190,8 +187,8 @@ class DType(object): TypeError: if this is a non-numeric, unordered, or quantized type. """ - if (self.is_quantized or self.base_dtype in - (bool, string, complex64, complex128)): + if (self.is_quantized or + self.base_dtype in (bool, string, complex64, complex128)): raise TypeError("Cannot find minimum value of %s." % self) # there is no simple way to get the min value of a dtype, we have to check @@ -214,8 +211,8 @@ class DType(object): TypeError: if this is a non-numeric, unordered, or quantized type. """ - if (self.is_quantized or self.base_dtype in - (bool, string, complex64, complex128)): + if (self.is_quantized or + self.base_dtype in (bool, string, complex64, complex128)): raise TypeError("Cannot find maximum value of %s." % self) # there is no simple way to get the max value of a dtype, we have to check @@ -241,9 +238,9 @@ class DType(object): min, max : tuple Lower and upper intensity limits. """ - min, max = dtype_range[self.as_numpy_dtype] + min, max = dtype_range[self.as_numpy_dtype] # pylint: disable=redefined-builtin if clip_negative: - min = 0 + min = 0 # pylint: disable=redefined-builtin return min, max def is_compatible_with(self, other): @@ -266,8 +263,8 @@ class DType(object): this `DType`. """ other = as_dtype(other) - return self._type_enum in ( - other.as_datatype_enum, other.base_dtype.as_datatype_enum) + return self._type_enum in (other.as_datatype_enum, + other.base_dtype.as_datatype_enum) def __eq__(self, other): """Returns True iff this DType refers to the same type as `other`.""" @@ -307,19 +304,22 @@ class DType(object): return 1 return np.dtype(self.as_numpy_dtype).itemsize + # Define data type range of numpy dtype -dtype_range = {np.bool_: (False, True), - np.bool8: (False, True), - np.uint8: (0, 255), - np.uint16: (0, 65535), - np.int8: (-128, 127), - np.int16: (-32768, 32767), - np.int64: (-2**63, 2**63 - 1), - np.uint64: (0, 2**64 - 1), - np.int32: (-2**31, 2**31 - 1), - np.uint32: (0, 2**32 - 1), - np.float32: (-1, 1), - np.float64: (-1, 1)} +dtype_range = { + np.bool_: (False, True), + np.bool8: (False, True), + np.uint8: (0, 255), + np.uint16: (0, 65535), + np.int8: (-128, 127), + np.int16: (-32768, 32767), + np.int64: (-2**63, 2**63 - 1), + np.uint64: (0, 2**64 - 1), + np.int32: (-2**31, 2**31 - 1), + np.uint32: (0, 2**32 - 1), + np.float32: (-1, 1), + np.float64: (-1, 1) +} # Define standard wrappers for the types_pb2.DataType enum. resource = DType(types_pb2.DT_RESOURCE) @@ -356,7 +356,7 @@ complex128 = DType(types_pb2.DT_COMPLEX128) tf_export("complex128").export_constant(__name__, "complex128") int64 = DType(types_pb2.DT_INT64) tf_export("int64").export_constant(__name__, "int64") -bool = DType(types_pb2.DT_BOOL) +bool = DType(types_pb2.DT_BOOL) # pylint: disable=redefined-builtin tf_export("bool").export_constant(__name__, "bool") qint8 = DType(types_pb2.DT_QINT8) tf_export("qint8").export_constant(__name__, "qint8") @@ -396,7 +396,6 @@ quint16_ref = DType(types_pb2.DT_QUINT16_REF) qint32_ref = DType(types_pb2.DT_QINT32_REF) bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF) - # Maintain an intern table so that we don't have to create a large # number of small objects. _INTERN_TABLE = { @@ -448,7 +447,6 @@ _INTERN_TABLE = { types_pb2.DT_VARIANT_REF: variant_ref, } - # Standard mappings between types_pb2.DataType values and string names. _TYPE_TO_STRING = { types_pb2.DT_HALF: "float16", @@ -498,8 +496,10 @@ _TYPE_TO_STRING = { types_pb2.DT_RESOURCE_REF: "resource_ref", types_pb2.DT_VARIANT_REF: "variant_ref", } -_STRING_TO_TF = {value: _INTERN_TABLE[key] - for key, value in _TYPE_TO_STRING.items()} +_STRING_TO_TF = { + value: _INTERN_TABLE[key] + for key, value in _TYPE_TO_STRING.items() +} # Add non-canonical aliases. _STRING_TO_TF["half"] = float16 _STRING_TO_TF["half_ref"] = float16_ref @@ -508,7 +508,6 @@ _STRING_TO_TF["float_ref"] = float32_ref _STRING_TO_TF["double"] = float64 _STRING_TO_TF["double_ref"] = float64_ref - # Numpy representation for quantized dtypes. # # These are magic strings that are used in the swig wrapper to identify @@ -551,58 +550,100 @@ _NP_TO_TF = frozenset([ (_np_bfloat16, bfloat16), ]) _TF_TO_NP = { - types_pb2.DT_HALF: np.float16, - types_pb2.DT_FLOAT: np.float32, - types_pb2.DT_DOUBLE: np.float64, - types_pb2.DT_INT32: np.int32, - types_pb2.DT_UINT8: np.uint8, - types_pb2.DT_UINT16: np.uint16, - types_pb2.DT_UINT32: np.uint32, - types_pb2.DT_UINT64: np.uint64, - types_pb2.DT_INT16: np.int16, - types_pb2.DT_INT8: np.int8, + types_pb2.DT_HALF: + np.float16, + types_pb2.DT_FLOAT: + np.float32, + types_pb2.DT_DOUBLE: + np.float64, + types_pb2.DT_INT32: + np.int32, + types_pb2.DT_UINT8: + np.uint8, + types_pb2.DT_UINT16: + np.uint16, + types_pb2.DT_UINT32: + np.uint32, + types_pb2.DT_UINT64: + np.uint64, + types_pb2.DT_INT16: + np.int16, + types_pb2.DT_INT8: + np.int8, # NOTE(touts): For strings we use np.object as it supports variable length # strings. - types_pb2.DT_STRING: np.object, - types_pb2.DT_COMPLEX64: np.complex64, - types_pb2.DT_COMPLEX128: np.complex128, - types_pb2.DT_INT64: np.int64, - types_pb2.DT_BOOL: np.bool, - types_pb2.DT_QINT8: _np_qint8, - types_pb2.DT_QUINT8: _np_quint8, - types_pb2.DT_QINT16: _np_qint16, - types_pb2.DT_QUINT16: _np_quint16, - types_pb2.DT_QINT32: _np_qint32, - types_pb2.DT_BFLOAT16: _np_bfloat16, + types_pb2.DT_STRING: + np.object, + types_pb2.DT_COMPLEX64: + np.complex64, + types_pb2.DT_COMPLEX128: + np.complex128, + types_pb2.DT_INT64: + np.int64, + types_pb2.DT_BOOL: + np.bool, + types_pb2.DT_QINT8: + _np_qint8, + types_pb2.DT_QUINT8: + _np_quint8, + types_pb2.DT_QINT16: + _np_qint16, + types_pb2.DT_QUINT16: + _np_quint16, + types_pb2.DT_QINT32: + _np_qint32, + types_pb2.DT_BFLOAT16: + _np_bfloat16, # Ref types - types_pb2.DT_HALF_REF: np.float16, - types_pb2.DT_FLOAT_REF: np.float32, - types_pb2.DT_DOUBLE_REF: np.float64, - types_pb2.DT_INT32_REF: np.int32, - types_pb2.DT_UINT32_REF: np.uint32, - types_pb2.DT_UINT8_REF: np.uint8, - types_pb2.DT_UINT16_REF: np.uint16, - types_pb2.DT_INT16_REF: np.int16, - types_pb2.DT_INT8_REF: np.int8, - types_pb2.DT_STRING_REF: np.object, - types_pb2.DT_COMPLEX64_REF: np.complex64, - types_pb2.DT_COMPLEX128_REF: np.complex128, - types_pb2.DT_INT64_REF: np.int64, - types_pb2.DT_UINT64_REF: np.uint64, - types_pb2.DT_BOOL_REF: np.bool, - types_pb2.DT_QINT8_REF: _np_qint8, - types_pb2.DT_QUINT8_REF: _np_quint8, - types_pb2.DT_QINT16_REF: _np_qint16, - types_pb2.DT_QUINT16_REF: _np_quint16, - types_pb2.DT_QINT32_REF: _np_qint32, - types_pb2.DT_BFLOAT16_REF: _np_bfloat16, + types_pb2.DT_HALF_REF: + np.float16, + types_pb2.DT_FLOAT_REF: + np.float32, + types_pb2.DT_DOUBLE_REF: + np.float64, + types_pb2.DT_INT32_REF: + np.int32, + types_pb2.DT_UINT32_REF: + np.uint32, + types_pb2.DT_UINT8_REF: + np.uint8, + types_pb2.DT_UINT16_REF: + np.uint16, + types_pb2.DT_INT16_REF: + np.int16, + types_pb2.DT_INT8_REF: + np.int8, + types_pb2.DT_STRING_REF: + np.object, + types_pb2.DT_COMPLEX64_REF: + np.complex64, + types_pb2.DT_COMPLEX128_REF: + np.complex128, + types_pb2.DT_INT64_REF: + np.int64, + types_pb2.DT_UINT64_REF: + np.uint64, + types_pb2.DT_BOOL_REF: + np.bool, + types_pb2.DT_QINT8_REF: + _np_qint8, + types_pb2.DT_QUINT8_REF: + _np_quint8, + types_pb2.DT_QINT16_REF: + _np_qint16, + types_pb2.DT_QUINT16_REF: + _np_quint16, + types_pb2.DT_QINT32_REF: + _np_qint32, + types_pb2.DT_BFLOAT16_REF: + _np_bfloat16, } - -QUANTIZED_DTYPES = frozenset( - [qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref, - quint16_ref, qint32_ref]) +QUANTIZED_DTYPES = frozenset([ + qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref, + quint16_ref, qint32_ref +]) tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES") @@ -613,7 +654,8 @@ def as_dtype(type_value): Args: type_value: A value that can be converted to a `tf.DType` object. This may currently be a `tf.DType` object, a - [`DataType` enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto), + [`DataType` + enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto), a string type name, or a `numpy.dtype`. Returns: @@ -650,5 +692,4 @@ def as_dtype(type_value): except TypeError as e: raise TypeError("Cannot convert {} to a dtype. {}".format(type_value, e)) - raise TypeError( - "Cannot convert value %r to a TensorFlow DType." % type_value) + raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value) diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py index d16fe979e6ef9a41063c3a2b3e8a3e18de2aa9d7..3172f3c2c3d259d2c3f2b340b101aef043d0fc33 100644 --- a/tensorflow/python/framework/framework_lib.py +++ b/tensorflow/python/framework/framework_lib.py @@ -118,7 +118,7 @@ from tensorflow.python.framework.ops import register_tensor_conversion_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.python.framework.dtypes import * +from tensorflow.python.framework.dtypes import * # pylint: disable=redefined-builtin # Load a TensorFlow plugin from tensorflow.python.framework.load_library import * diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index a4ca3f9a89bd4cce2240d90895c43dda1acb849b..301a7f682dde8dbeccd1e81675b0059433990a09 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function import re -import time import sys +import time import numpy as np @@ -86,6 +86,21 @@ class FunctionTest(test.TestCase): with session.Session() as sess: self.assertAllEqual([18.0], sess.run(call)) + def testIdentityImplicitDeref(self): + + @function.Defun(dtypes.float32, func_name="MyIdentity") + def MyIdentityFunc(a): + return a + + with ops.Graph().as_default(): + var = variables.Variable([18.0]) + call = MyIdentityFunc(var._ref()) # pylint: disable=protected-access + self.assertEqual("MyIdentity", call.op.name) + for cfg in _OptimizerOptions(): + with session.Session(config=cfg) as sess: + sess.run(var.initializer) + self.assertAllEqual([18.0], sess.run(call)) + def testIdentityOutputName(self): @function.Defun( @@ -771,7 +786,7 @@ class FunctionTest(test.TestCase): # We added more randomness to function names in C API. # TODO(iga): Remove this if statement when we switch to C API. if ops._USE_C_API: # pylint: disable=protected-access - if sys.byteorder == 'big': + if sys.byteorder == "big": self.assertEqual("Foo_kEdkAG8SJvg", Foo.instantiate([dtypes.float32] * 3).name) else: @@ -1443,7 +1458,7 @@ class FunctionInlineControlTest(test.TestCase): def Cell(v): # If v is a vector [n, 1], x is a big square matrix. x = math_ops.tanh(v + array_ops.transpose(v, [1, 0])) - return math_ops.reduce_sum(x, 1, keep_dims=True) + return math_ops.reduce_sum(x, 1, keepdims=True) @function.Defun(dtype) def Forward(x): diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 00fff8d040d6facfc81359061f6cf9a1cf6d3d3c..6ecc1a40ae14760dd39242aaf595b32a9decdc9f 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """A utility function for importing TensorFlow graphs.""" from __future__ import absolute_import from __future__ import division @@ -43,8 +42,8 @@ from tensorflow.python.util.tf_export import tf_export # the logic here. def _GetNodeAttr(node_def, attr_name): if attr_name not in node_def.attr: - raise ValueError('Expected one attr with name %r in %s.' - % (attr_name, str(node_def))) + raise ValueError('Expected one attr with name %r in %s.' % (attr_name, + str(node_def))) return node_def.attr[attr_name] @@ -151,7 +150,7 @@ def _MaybeDevice(device): yield -def _ProcessGraphDefParam(graph_def): +def _ProcessGraphDefParam(graph_def, op_dict): """Type-checks and possibly canonicalizes `graph_def`.""" if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed @@ -162,6 +161,22 @@ def _ProcessGraphDefParam(graph_def): graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') + else: + # If we're using the graph_def provided by the caller, modify graph_def + # in-place to add attr defaults to the NodeDefs (this is visible to the + # caller). + # NOTE(skyewm): this is undocumented behavior that at least meta_graph.py + # depends on. It might make sense to move this to meta_graph.py and have + # import_graph_def not modify the graph_def argument (we'd have to make sure + # this doesn't break anything else.) + for node in graph_def.node: + if node.op not in op_dict: + # Assume unrecognized ops are functions for now. TF_ImportGraphDef will + # report an error if the op is actually missing. + continue + op_def = op_dict[node.op] + _SetDefaultAttrValues(node, op_def) + return graph_def @@ -170,9 +185,8 @@ def _ProcessInputMapParam(input_map): if input_map is None: input_map = {} else: - if not (isinstance(input_map, dict) - and all(isinstance(k, compat.bytes_or_text_types) - for k in input_map.keys())): + if not (isinstance(input_map, dict) and all( + isinstance(k, compat.bytes_or_text_types) for k in input_map.keys())): raise TypeError('input_map must be a dictionary mapping strings to ' 'Tensor objects.') return input_map @@ -180,9 +194,10 @@ def _ProcessInputMapParam(input_map): def _ProcessReturnElementsParam(return_elements): """Type-checks and possibly canonicalizes `return_elements`.""" - if return_elements is None: return None - if not all(isinstance(x, compat.bytes_or_text_types) - for x in return_elements): + if return_elements is None: + return None + if not all( + isinstance(x, compat.bytes_or_text_types) for x in return_elements): raise TypeError('return_elements must be a list of strings.') return tuple(compat.as_str(x) for x in return_elements) @@ -255,21 +270,20 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map, """Populates the TF_ImportGraphDefOptions `options`.""" c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix) c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True) - c_api.TF_ImportGraphDefOptionsSetUniquifyPrefix(options, True) for input_src, input_dst in input_map.items(): input_src = compat.as_str(input_src) if input_src.startswith('^'): src_name = compat.as_bytes(input_src[1:]) dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access - c_api.TF_ImportGraphDefOptionsRemapControlDependency(options, src_name, - dst_op) + c_api.TF_ImportGraphDefOptionsRemapControlDependency( + options, src_name, dst_op) else: src_name, src_idx = _ParseTensorName(input_src) src_name = compat.as_str(src_name) dst_output = input_dst._as_tf_output() # pylint: disable=protected-access - c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, - src_idx, dst_output) + c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_idx, + dst_output) for name in return_elements or []: if ':' in name: op_name, index = _ParseTensorName(name) @@ -315,8 +329,8 @@ def _ProcessNewOps(graph): coloc_op = graph._get_operation_by_name_unsafe(coloc_op_name) # pylint: disable=protected-access except KeyError: raise ValueError('Specified colocation to an op that ' - 'does not exist during import: %s in %s' % ( - coloc_op_name, op.name)) + 'does not exist during import: %s in %s' % + (coloc_op_name, op.name)) if coloc_op.device: coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) break @@ -370,13 +384,27 @@ def _GatherReturnElements(requested_return_elements, graph, results): return combined_return_elements +def _SetDefaultAttrValues(node_def, op_def): + """Set any default attr values in `node_def` that aren't present.""" + assert node_def.op == op_def.name + for attr_def in op_def.attr: + key = attr_def.name + if attr_def.HasField('default_value'): + value = node_def.attr[key] + if value is None or value.WhichOneof('value') is None: + node_def.attr[key].CopyFrom(attr_def.default_value) + + @tf_export('import_graph_def') @deprecated_args(None, 'Please file an issue at ' 'https://github.com/tensorflow/tensorflow/issues if you depend' - ' on this feature.', - 'op_dict') -def import_graph_def(graph_def, input_map=None, return_elements=None, - name=None, op_dict=None, producer_op_list=None): + ' on this feature.', 'op_dict') +def import_graph_def(graph_def, + input_map=None, + return_elements=None, + name=None, + op_dict=None, + producer_op_list=None): """Imports the graph from `graph_def` into the current default `Graph`. This function provides a way to import a serialized TensorFlow @@ -418,12 +446,12 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ - graph_def = _ProcessGraphDefParam(graph_def) + op_dict = op_def_registry.get_registered_ops() + + graph_def = _ProcessGraphDefParam(graph_def, op_dict) input_map = _ProcessInputMapParam(input_map) return_elements = _ProcessReturnElementsParam(return_elements) - op_dict = op_def_registry.get_registered_ops() - if producer_op_list is not None: # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) @@ -480,11 +508,12 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( results)) if missing_unused_input_keys: - missing_unused_input_keys = [compat.as_str(s) - for s in missing_unused_input_keys] + missing_unused_input_keys = [ + compat.as_str(s) for s in missing_unused_input_keys + ] raise ValueError( - 'Attempted to map inputs that were not found in graph_def: [%s]' - % ', '.join(missing_unused_input_keys)) + 'Attempted to map inputs that were not found in graph_def: [%s]' % + ', '.join(missing_unused_input_keys)) if return_elements is None: return None @@ -532,16 +561,9 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, # Check to see if this op's name matches a previously seen op if node.name in name_to_op: raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name) - # Set any default attr values that aren't present. if node.op not in op_dict: raise ValueError('No op named %s in defined operations.' % node.op) op_def = op_dict[node.op] - for attr_def in op_def.attr: - key = attr_def.name - if attr_def.HasField('default_value'): - value = node.attr[key] - if value is None or value.WhichOneof('value') is None: - node.attr[key].CopyFrom(attr_def.default_value) output_types = _OutputTypes(node, op_dict) name_to_op[node.name] = g.create_op( diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index acaec37f810cb00daa9bae17ffbcb675648b9fe1..bf5d9fe0936882c242198bdc7118f9f3a4e79260 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -154,6 +154,25 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(b3.name, "A_3/B") self.assertEqual(list(b3.inputs), [a3.outputs[0]]) + # Import with an already-used name but with a '/' to indicate an + # "absolute" name scope (see the Graph.name_scope docstring). + a_a, a_b = importer.import_graph_def( + graph_def, + return_elements=["A", "B"], + name="A/") + self.assertEqual(a_a.name, "A/A") + self.assertEqual(a_b.name, "A/B") + self.assertEqual(list(a_b.inputs), [a_a.outputs[0]]) + + # Repeat the same import. + a_a1, a_b1 = importer.import_graph_def( + graph_def, + return_elements=["A", "B"], + name="A/") + self.assertEqual(a_a1.name, "A/A_1") + self.assertEqual(a_b1.name, "A/B_1") + self.assertEqual(list(a_b1.inputs), [a_a1.outputs[0]]) + # Import with existing de-duped node names a1_1, b1_1 = importer.import_graph_def( self._MakeGraphDef(""" diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py index c997ead829855f33efdb3efe947c3f59b5dbe76c..1f2aa264c110930b318f30e3a24010a96ebce47e 100644 --- a/tensorflow/python/framework/load_library.py +++ b/tensorflow/python/framework/load_library.py @@ -21,10 +21,10 @@ from __future__ import print_function import hashlib import imp import sys -import threading +import threading # pylint: disable=unused-import from tensorflow.core.framework import op_def_pb2 -from tensorflow.core.lib.core import error_codes_pb2 +from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import from tensorflow.python import pywrap_tensorflow as py_tf from tensorflow.python.framework import errors_impl from tensorflow.python.util import compat diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index fc1a82361ba59cddc02a65a96da98283d871fd2c..4c1bd736d727e974375ad9008a579361137fb9d6 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -87,6 +87,10 @@ def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): compat.as_str(s).split("@")[1].startswith(export_scope)] node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=new_s))) + elif node_def.op in ("Enter", "RefEnter") and k == "frame_name": + if not export_scope or compat.as_str(v.s).startswith(export_scope): + new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope)) + node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s)) else: node_def.attr[k].CopyFrom(v) @@ -737,6 +741,7 @@ def import_scoped_meta_graph(meta_graph_or_file, producer_op_list=producer_op_list) # Restores all the other collections. + variable_objects = {} for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: @@ -752,11 +757,23 @@ def import_scoped_meta_graph(meta_graph_or_file, from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) - for value in col_def.bytes_list.value: - proto = proto_type() - proto.ParseFromString(value) - graph.add_to_collection( - key, from_proto(proto, import_scope=scope_to_prepend_to_names)) + if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access + for value in col_def.bytes_list.value: + variable = variable_objects.get(value, None) + if variable is None: + proto = proto_type() + proto.ParseFromString(value) + variable = from_proto( + proto, import_scope=scope_to_prepend_to_names) + variable_objects[value] = variable + graph.add_to_collection(key, variable) + else: + for value in col_def.bytes_list.value: + proto = proto_type() + proto.ParseFromString(value) + graph.add_to_collection( + key, from_proto( + proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: @@ -959,5 +976,3 @@ def copy_scoped_meta_graph(from_scope, to_scope, graph=to_graph, import_scope=to_scope) return var_list - - diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index b5ed1352843eac31b3e34eb96385acd13a5bc7a9..19dcd6a1b34741290b2578d93b79883c103fdb1b 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -25,6 +25,7 @@ import shutil from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -34,6 +35,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import nn_ops @@ -259,6 +261,29 @@ class SimpleMetaGraphTest(test.TestCase): self.assertEqual(node_def.attr["attr_1"].i, 1) self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) + def testVariableObjectsAreSharedAmongCollections(self): + with ops.Graph().as_default() as graph1: + v = variables.Variable(3.0) + # A single instance of Variable is shared among the collections: + global_vars = graph1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + trainable_vars = graph1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertEqual(len(global_vars), 1) + self.assertEqual(len(trainable_vars), 1) + self.assertIs(global_vars[0], trainable_vars[0]) + self.assertIs(v, global_vars[0]) + + orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1) + del graph1 # To avoid accidental references in code involving graph2. + + with ops.Graph().as_default() as graph2: + meta_graph.import_scoped_meta_graph(orig_meta_graph) + global_vars = graph2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + trainable_vars = graph2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertEqual(len(global_vars), 1) + self.assertEqual(len(trainable_vars), 1) + # A single instance of Variable is shared among the collections: + self.assertIs(global_vars[0], trainable_vars[0]) + @test_util.with_c_api class ScopedMetaGraphTest(test.TestCase): @@ -447,6 +472,56 @@ class ScopedMetaGraphTest(test.TestCase): del b.collection_def["unbound_inputs"] test_util.assert_meta_graph_protos_equal(self, a, b) + def testWhileLoopGradients(self): + # Create a simple while loop. + with ops.Graph().as_default(): + with ops.name_scope("export"): + var = variables.Variable(0) + var_name = var.name + _, output = control_flow_ops.while_loop(lambda i, x: i < 5, + lambda i, x: (i + 1, x + i), + [0, var]) + output_name = output.name + + # Generate a MetaGraphDef containing the while loop with an export scope. + meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + export_scope="export") + + # Build and run the gradients of the while loop. We use this below to + # verify that the gradients are correct with the imported MetaGraphDef. + init_op = variables.global_variables_initializer() + grad = gradients_impl.gradients([output], [var]) + with session.Session() as sess: + sess.run(init_op) + expected_grad_value = sess.run(grad) + + # Restore the MetaGraphDef into a new Graph with an import scope. + with ops.Graph().as_default(): + meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope="import") + + # Re-export and make sure we get the same MetaGraphDef. + new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + export_scope="import") + test_util.assert_meta_graph_protos_equal( + self, meta_graph_def, new_meta_graph_def) + + # Make sure we can still build gradients and get the same result. + + def new_name(tensor_name): + base_tensor_name = tensor_name.replace("export/", "") + return "import/" + base_tensor_name + + var = ops.get_default_graph().get_tensor_by_name(new_name(var_name)) + output = ops.get_default_graph().get_tensor_by_name(new_name(output_name)) + grad = gradients_impl.gradients([output], [var]) + + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + actual_grad_value = sess.run(grad) + self.assertEqual(expected_grad_value, actual_grad_value) + def testScopedImportUnderNameScope(self): graph = ops.Graph() with graph.as_default(): @@ -831,21 +906,25 @@ class ExportImportAcrossScopesTest(test.TestCase): graph_fn(use_resource=use_resource) if use_resource: - # Bringing in a collection that contains ResourceVariables adds ops - # to the graph, so mimic the same behavior. + # Bringing in collections that contain ResourceVariables will adds ops + # to the graph the first time a variable is encountered, so mimic the + # same behavior. + seen_variables = set() for collection_key in sorted([ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES, ]): for var in expected_graph.get_collection(collection_key): - var._read_variable_op() + if var not in seen_variables: + var._read_variable_op() + seen_variables.add(var) result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0] expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0] if use_resource: # Clear all shared_name attributes before comparing, since they are - # supposed to be orthogonal to scopes. + # orthogonal to scopes and are not updated on export/import. for meta_graph_def in [result, expected]: for node in meta_graph_def.graph_def.node: shared_name_attr = "shared_name" diff --git a/tensorflow/python/framework/op_def_library_test.py b/tensorflow/python/framework/op_def_library_test.py index 817007ce6c18e11d19038e09d77a8f27bd7eca91..84ca062ade3b32c37212ba2d5b7eb9c64fb1dfa5 100644 --- a/tensorflow/python/framework/op_def_library_test.py +++ b/tensorflow/python/framework/op_def_library_test.py @@ -42,7 +42,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def setUp(self): self._lib = test_ops._op_def_lib - def _add_op(self, ascii): + def _add_op(self, ascii): # pylint: disable=redefined-builtin op_def = op_def_pb2.OpDef() text_format.Merge(ascii, op_def) self._lib.add_op(op_def) @@ -1336,7 +1336,7 @@ class OpDefLibraryGraphTest(test_util.TensorFlowTestCase): def setUp(self): self._lib = test_ops._op_def_lib - def _add_op(self, ascii): + def _add_op(self, ascii): # pylint: disable=redefined-builtin op_def = op_def_pb2.OpDef() text_format.Merge(ascii, op_def) self._lib.add_op(op_def) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index b107670275c87e2ee711c1a10fbe6bacc334ad5f..398b3f67e20660dc23f8fb339774ad0e3b2eff9d 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1618,7 +1618,7 @@ class Operation(object): for i, x in zip(inputs, input_types)): raise TypeError("In op '%s', input types (%s) are not compatible " "with expected types (%s)" % - (self.node_def.name, [i.dtype for i in inputs], + (node_def.name, [i.dtype for i in inputs], input_types)) # Build the list of control inputs. @@ -1657,7 +1657,7 @@ class Operation(object): self._c_op = c_op elif self._graph._c_graph: # pylint: disable=protected-access if op_def is None: - op_def = self._graph._registered_ops[node_def.op] + op_def = self._graph._get_op_def(node_def.op) # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. # Refactor so we don't have to do this here. grouped_inputs = self._reconstruct_sequence_inputs( @@ -2103,6 +2103,10 @@ class Operation(object): logging.warning("Operation._control_inputs is private, use " "Operation.control_inputs instead. " "Operation._control_inputs will eventually be removed.") + # Copy value because it may be self._control_inputs_val (in particular if + # this is called from self._control_inputs += ...), and we don't want to + # clear value below. + value = copy.copy(value) self._remove_all_control_inputs() self._add_control_inputs(value) @@ -2160,16 +2164,7 @@ class Operation(object): """ # pylint: enable=line-too-long if self._c_op: - with c_api_util.tf_buffer() as buf: - with errors.raise_exception_on_not_ok_status() as status: - # pylint: disable=protected-access - c_api.TF_GraphGetOpDef(self._graph._c_graph, - compat.as_bytes(self.type), buf, status) - # pylint: enable=protected-access - data = c_api.TF_GetBuffer(buf) - op_def = op_def_pb2.OpDef() - op_def.ParseFromString(compat.as_bytes(data)) - return op_def + return self._graph._get_op_def(self.type) else: return self._op_def_val @@ -2756,15 +2751,12 @@ class Graph(object): self._handle_movers = {} # A map from tensor handle to its delete op. self._handle_deleters = {} - # Resource container. - if context.in_graph_mode(): - self._container_prefix = "" - else: - # In Eager mode, isolate resources (particularly ResourceVariables) in - # Graphs by default. This prevents unintended variable sharing. Graph mode - # gets this kind of isolation from Sessions. - self._container_prefix = "eager-execution-%d/" % (uid(),) - self._container = self._container_prefix + # Allow optimizers and other objects to pseudo-uniquely key graphs (this key + # will be shared when defining function graphs, for example, so optimizers + # being called inside function definitions behave as if they were seeing the + # actual outside graph). + self._graph_key = "grap-key-%d/" % (uid(),) + self._container = "" self._registered_ops = op_def_registry.get_registered_ops() # TODO(skyewm): fold as much of the above as possible into the C @@ -3376,8 +3368,8 @@ class Graph(object): # (2) "is_stateful" is set in OpDef # (3) "container" attribute is in OpDef # (4) "container" attribute is None - if (self._container and op.type in self._registered_ops and - self._registered_ops[op.type].is_stateful): + # TODO(skyewm): remove op.op_def check when _USE_C_API is removed. + if self._container and op.op_def and op.op_def.is_stateful: try: container_attr = op.get_attr("container") except ValueError: @@ -3657,6 +3649,22 @@ class Graph(object): def _last_id(self): return self._next_id_counter + def _get_op_def(self, type): # pylint: disable=redefined-builtin + """Returns the `OpDef` proto for `type`. `type` is a string.""" + if self._c_graph: + with c_api_util.tf_buffer() as buf: + with errors.raise_exception_on_not_ok_status() as status: + # pylint: disable=protected-access + c_api.TF_GraphGetOpDef(self._c_graph, + compat.as_bytes(type), buf, status) + # pylint: enable=protected-access + data = c_api.TF_GetBuffer(buf) + op_def = op_def_pb2.OpDef() + op_def.ParseFromString(compat.as_bytes(data)) + return op_def + else: + return self._registered_ops[type] + def as_default(self): """Returns a context manager that makes this `Graph` the default graph. @@ -4225,7 +4233,7 @@ class Graph(object): """ original_container = self._container try: - self._container = self._container_prefix + container_name + self._container = container_name yield self._container finally: self._container = original_container @@ -5004,9 +5012,22 @@ def init_scope(): """ # pylint: enable=g-doc-return-or-yield,line-too-long + in_graph_mode = context.in_graph_mode() + # Retrieve the active name scope: entering an `init_scope` preserves + # the name scope of the current context. + if in_graph_mode: + default_graph = get_default_graph() + scope = default_graph.get_name_scope() + else: + scope = context.context().scope_name + if scope and scope[-1] != '/': + # Names that end with trailing slashes are treated by `name_scope` as + # absolute. + scope = scope + '/' + outer_context = None - if context.in_graph_mode() and not _default_graph_stack.stack: - outer_context = get_default_graph().as_default + if in_graph_mode and not _default_graph_stack.stack: + outer_context = default_graph.as_default else: for stack_entry in reversed(context.context_stack.stack): if not stack_entry.is_building_function: @@ -5018,7 +5039,8 @@ def init_scope(): "eager context was previously active.") try: - with outer_context(), control_dependencies(None), tape.stop_recording(): + with outer_context(), name_scope(scope), control_dependencies( + None), tape.stop_recording(): yield finally: pass @@ -5515,6 +5537,9 @@ def get_all_collection_keys(): return get_default_graph().get_all_collection_keys() +name_scope_cache = {} + + # Named like a function for backwards compatibility with the # @tf_contextlib.contextmanager version, which was switched to a class to avoid # some object creation overhead. @@ -5574,7 +5599,11 @@ class name_scope(object): # pylint: disable=invalid-name if not self._name: scope_name = "" else: - if self._name[-1] == "/": + cache_key = self._name, self._old_name, self._default_name + if cache_key in name_scope_cache: + self._ctx.scope_name = name_scope_cache[cache_key] + return self._ctx.scope_name + elif self._name[-1] == "/": # A trailing slash breaks out of nested name scopes, indicating a # fully specified scope name, for compatibility with Graph.name_scope. scope_name = self._name @@ -5583,6 +5612,7 @@ class name_scope(object): # pylint: disable=invalid-name scope_name = ( self._old_name + name_with_trailing_slash if self._old_name else name_with_trailing_slash) + name_scope_cache[cache_key] = scope_name self._ctx.scope_name = scope_name return scope_name else: diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 78519f108ba69a8f3f296debf2e199d6613bf86a..c6deafd89eb1bdc4892a65ba3ab8c7900915390f 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -916,7 +916,6 @@ class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase): op = g.get_operation_by_name("myloop/myop") self.assertIsNotNone(op) - self.assertEqual(len(op.control_inputs), 1) # External control dep is removed and replaced with internal control dep self.assertNotEqual(op.control_inputs[0], c.op) self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) @@ -2072,10 +2071,34 @@ class InitScopeTest(test_util.TensorFlowTestCase): # pylint: disable=protected-access self.assertEqual(len(ops._default_graph_stack.stack), 0) with ops.init_scope(): - self.assertEqual(len(ops._default_graph_stack.stack), 1) + self.assertGreater(len(ops._default_graph_stack.stack), 0) self.assertEqual(len(ops._default_graph_stack.stack), 0) # pylint: enable=protected-access + def testPreservesNameScopeInGraphConstruction(self): + with ops.Graph().as_default(): + function_graph = ops.Graph() + with function_graph.as_default(): + with ops.name_scope("inner"), ops.init_scope(): + self.assertEqual(ops.get_name_scope(), "inner") + self.assertEqual(ops.get_name_scope(), "") + + def testPreservesNameScopeInEagerExecution(self): + with context.eager_mode(): + def foo(): + with ops.name_scope("inner"), ops.init_scope(): + if context.in_graph_mode(): + self.assertEqual(ops.get_name_scope(), "inner") + else: + # A trailing slash is always appended when eager execution is + # enabled. + self.assertEqual(context.context().scope_name, "inner/") + foo() + self.assertEqual(ops.get_name_scope(), "") + foo_compiled = eager_function.defun(foo) + foo_compiled() + self.assertEqual(ops.get_name_scope(), "") + @test_util.with_c_api class GraphTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 65810fa7094409c7429dbaaa6c1e62efb263eafc..c95149d177990e364c3d6b9daeae5dc535cf0070 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -476,9 +476,6 @@ GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def, GenPythonOp::~GenPythonOp() {} string GenPythonOp::Code() { - if (api_def_.visibility() == ApiDef::SKIP) { - return ""; - } // This has all the input args followed by those attrs that don't have // defaults. std::vector params_no_default; @@ -583,8 +580,13 @@ void GenPythonOp::AddExport() { strings::StrAppend(&result_, ")\n"); } +void GenPythonOp::AddDefLine(const string& function_name, + const string& parameters) { + strings::StrAppend(&result_, "def ", function_name, "(", parameters, "):\n"); +} + void GenPythonOp::AddDefLine(const string& parameters) { - strings::StrAppend(&result_, "def ", function_name_, "(", parameters, "):\n"); + AddDefLine(function_name_, parameters); } void GenPythonOp::AddDocStringDescription() { @@ -805,11 +807,21 @@ from tensorflow.python.util.tf_export import tf_export auto out = cleaned_ops.mutable_op(); out->Reserve(ops.op_size()); for (const auto& op_def : ops.op()) { - bool is_hidden = false; - for (const string& hidden : hidden_ops) { - if (op_def.name() == hidden) { - is_hidden = true; - break; + const auto* api_def = api_defs.GetApiDef(op_def.name()); + + if (api_def->visibility() == ApiDef::SKIP) { + continue; + } + + // An op is hidden if either its ApiDef visibility is HIDDEN + // or it is in the hidden_ops list. + bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; + if (!is_hidden) { + for (const string& hidden : hidden_ops) { + if (op_def.name() == hidden) { + is_hidden = true; + break; + } } } @@ -826,7 +838,6 @@ from tensorflow.python.util.tf_export import tf_export continue; } - const auto* api_def = api_defs.GetApiDef(op_def.name()); strings::StrAppend(&result, GetPythonOp(op_def, *api_def, function_name)); if (!require_shapes) { diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h index d09b36a3e8247241420649c6a4a950be6edc3c00..4319e5a7820b33283df8153fdc76e0e567813a17 100644 --- a/tensorflow/python/framework/python_op_gen_internal.h +++ b/tensorflow/python/framework/python_op_gen_internal.h @@ -73,6 +73,7 @@ class GenPythonOp { protected: // Print: def Function(parameters): + void AddDefLine(const string& function_name, const string& parameters); void AddDefLine(const string& parameters); // Format the Op's descriptions so that it can be a Python docstring. diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..a0411bc3d9b4b2b87e5a31e9f201154f28ccf1cc --- /dev/null +++ b/tensorflow/python/framework/tensor_spec.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TensorSpec class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class TensorSpec(object): + """Describes a tf.Tensor. + + A TensorSpec allows an API to describe the Tensors that it accepts or + returns, before that Tensor exists. This allows dynamic and flexible graph + construction and configuration. + """ + + __slots__ = ["_shape", "_dtype", "_name"] + + def __init__(self, shape, dtype, name=None): + """Creates a TensorSpec. + + Args: + shape: Value convertible to `tf.TensorShape`. The shape of the tensor. + dtype: Value convertible to `tf.DType`. The type of the tensor values. + name: Optional name for the Tensor. + + Raises: + TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is + not convertible to a `tf.DType`. + """ + self._shape = tensor_shape.TensorShape(shape) + self._dtype = dtypes.as_dtype(dtype) + self._name = name + + @classmethod + def from_spec(cls, spec, name=None): + return cls(spec.shape, spec.dtype, name or spec.name) + + @classmethod + def from_tensor(cls, tensor, name=None): + if isinstance(tensor, ops.EagerTensor): + return TensorSpec(tensor.shape, tensor.dtype, name) + elif isinstance(tensor, ops.Tensor): + return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) + else: + raise ValueError("`tensor` should be a tf.Tensor") + + @property + def shape(self): + """Returns the `TensorShape` that represents the shape of the tensor.""" + return self._shape + + @property + def dtype(self): + """Returns the `dtype` of elements in the tensor.""" + return self._dtype + + @property + def name(self): + """Returns the name of the described tensor.""" + return self._name + + def is_compatible_with(self, spec_or_tensor): + """True if the shape and dtype of `spec_or_tensor` are compatible.""" + return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and + self._shape.is_compatible_with(spec_or_tensor.shape)) + + def __repr__(self): + return "TensorSpec(shape={}, dtype={}, name={})".format( + self.shape, repr(self.dtype), repr(self.name)) + + def __eq__(self, other): + return self.shape == other.shape and self.dtype == other.dtype + + def __ne__(self, other): + return not self == other + + +class BoundedTensorSpec(TensorSpec): + """A `TensorSpec` that specifies minimum and maximum values. + + Example usage: + ```python + spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5)) + tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype) + tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype) + ``` + + Bounds are meant to be inclusive. This is especially important for + integer types. The following spec will be satisfied by tensors + with values in the set {0, 1, 2}: + ```python + spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2) + ``` + """ + + __slots__ = ("_minimum", "_maximum") + + def __init__(self, shape, dtype, minimum, maximum, name=None): + """Initializes a new `BoundedTensorSpec`. + + Args: + shape: Value convertible to `tf.TensorShape`. The shape of the tensor. + dtype: Value convertible to `tf.DType`. The type of the tensor values. + minimum: Number or sequence specifying the minimum element bounds + (inclusive). Must be broadcastable to `shape`. + maximum: Number or sequence specifying the maximum element bounds + (inclusive). Must be broadcastable to `shape`. + name: Optional string containing a semantic name for the corresponding + array. Defaults to `None`. + + Raises: + ValueError: If `minimum` or `maximum` are not provided or not + broadcastable to `shape`. + TypeError: If the shape is not an iterable or if the `dtype` is an invalid + numpy dtype. + """ + super(BoundedTensorSpec, self).__init__(shape, dtype, name) + + if minimum is None or maximum is None: + raise ValueError("minimum and maximum must be provided; but saw " + "'%s' and '%s'" % (minimum, maximum)) + + try: + minimum_shape = np.shape(minimum) + common_shapes.broadcast_shape( + tensor_shape.TensorShape(minimum_shape), self.shape) + except ValueError as exception: + raise ValueError("minimum is not compatible with shape. " + "Message: {!r}.".format(exception)) + + try: + maximum_shape = np.shape(maximum) + common_shapes.broadcast_shape( + tensor_shape.TensorShape(maximum_shape), self.shape) + except ValueError as exception: + raise ValueError("maximum is not compatible with shape. " + "Message: {!r}.".format(exception)) + + self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype()) + self._minimum.setflags(write=False) + + self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype()) + self._maximum.setflags(write=False) + + @classmethod + def from_spec(cls, spec): + dtype = dtypes.as_dtype(spec.dtype) + if dtype in [dtypes.float64, dtypes.float32]: + # Avoid under/over-flow for `dtype.maximum - dtype.minimum`. + low = dtype.min / 2 + high = dtype.max / 2 + else: + low = dtype.min + high = dtype.max + + minimum = getattr(spec, "minimum", low) + maximum = getattr(spec, "maximum", high) + return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name) + + @property + def minimum(self): + """Returns a NumPy array specifying the minimum bounds (inclusive).""" + return self._minimum + + @property + def maximum(self): + """Returns a NumPy array specifying the maximum bounds (inclusive).""" + return self._maximum + + def __repr__(self): + s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})" + return s.format(self.shape, repr(self.dtype), repr(self.name), + repr(self.minimum), repr(self.maximum)) + + def __eq__(self, other): + tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other) + return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and + np.allclose(self.maximum, other.maximum)) + + diff --git a/tensorflow/python/framework/tensor_spec_test.py b/tensorflow/python/framework/tensor_spec_test.py new file mode 100644 index 0000000000000000000000000000000000000000..54ca4d9a19c2e1c879c05cfb828085951bdd8444 --- /dev/null +++ b/tensorflow/python/framework/tensor_spec_test.py @@ -0,0 +1,227 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensor_spec.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class TensorSpecTest(test_util.TensorFlowTestCase): + + def testAcceptsNumpyDType(self): + desc = tensor_spec.TensorSpec([1], np.float32) + self.assertEqual(desc.dtype, dtypes.float32) + + def testAcceptsTensorShape(self): + desc = tensor_spec.TensorSpec(tensor_shape.TensorShape([1]), dtypes.float32) + self.assertEqual(desc.shape, tensor_shape.TensorShape([1])) + + def testUnknownShape(self): + desc = tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32) + self.assertEqual(desc.shape, tensor_shape.TensorShape(None)) + + def testShapeCompatibility(self): + unknown = array_ops.placeholder(dtypes.int64) + partial = array_ops.placeholder(dtypes.int64, shape=[None, 1]) + full = array_ops.placeholder(dtypes.int64, shape=[2, 3]) + rank3 = array_ops.placeholder(dtypes.int64, shape=[4, 5, 6]) + + desc_unknown = tensor_spec.TensorSpec(None, dtypes.int64) + self.assertTrue(desc_unknown.is_compatible_with(unknown)) + self.assertTrue(desc_unknown.is_compatible_with(partial)) + self.assertTrue(desc_unknown.is_compatible_with(full)) + self.assertTrue(desc_unknown.is_compatible_with(rank3)) + + desc_partial = tensor_spec.TensorSpec([2, None], dtypes.int64) + self.assertTrue(desc_partial.is_compatible_with(unknown)) + self.assertTrue(desc_partial.is_compatible_with(partial)) + self.assertTrue(desc_partial.is_compatible_with(full)) + self.assertFalse(desc_partial.is_compatible_with(rank3)) + + desc_full = tensor_spec.TensorSpec([2, 3], dtypes.int64) + self.assertTrue(desc_full.is_compatible_with(unknown)) + self.assertFalse(desc_full.is_compatible_with(partial)) + self.assertTrue(desc_full.is_compatible_with(full)) + self.assertFalse(desc_full.is_compatible_with(rank3)) + + desc_rank3 = tensor_spec.TensorSpec([4, 5, 6], dtypes.int64) + self.assertTrue(desc_rank3.is_compatible_with(unknown)) + self.assertFalse(desc_rank3.is_compatible_with(partial)) + self.assertFalse(desc_rank3.is_compatible_with(full)) + self.assertTrue(desc_rank3.is_compatible_with(rank3)) + + def testTypeCompatibility(self): + floats = array_ops.placeholder(dtypes.float32, shape=[10, 10]) + ints = array_ops.placeholder(dtypes.int32, shape=[10, 10]) + desc = tensor_spec.TensorSpec(shape=(10, 10), dtype=dtypes.float32) + self.assertTrue(desc.is_compatible_with(floats)) + self.assertFalse(desc.is_compatible_with(ints)) + + def testName(self): + desc = tensor_spec.TensorSpec([1], dtypes.float32, name="beep") + self.assertEqual(desc.name, "beep") + + def testRepr(self): + desc1 = tensor_spec.TensorSpec([1], dtypes.float32, name="beep") + self.assertEqual( + repr(desc1), + "TensorSpec(shape=(1,), dtype=tf.float32, name='beep')") + desc2 = tensor_spec.TensorSpec([1, None], dtypes.int32) + self.assertEqual( + repr(desc2), + "TensorSpec(shape=(1, ?), dtype=tf.int32, name=None)") + + def testFromTensorSpec(self): + spec_1 = tensor_spec.TensorSpec((1, 2), dtypes.int32) + spec_2 = tensor_spec.TensorSpec.from_spec(spec_1) + self.assertEqual(spec_1, spec_2) + + def testFromTensor(self): + zero = constant_op.constant(0) + spec = tensor_spec.TensorSpec.from_tensor(zero) + self.assertEqual(spec.dtype, dtypes.int32) + self.assertEqual(spec.shape, []) + self.assertEqual(spec.name, "Const") + + def testFromPlaceholder(self): + unknown = array_ops.placeholder(dtypes.int64, name="unknown") + partial = array_ops.placeholder(dtypes.float32, + shape=[None, 1], + name="partial") + spec_1 = tensor_spec.TensorSpec.from_tensor(unknown) + self.assertEqual(spec_1.dtype, dtypes.int64) + self.assertEqual(spec_1.shape, None) + self.assertEqual(spec_1.name, "unknown") + spec_2 = tensor_spec.TensorSpec.from_tensor(partial) + self.assertEqual(spec_2.dtype, dtypes.float32) + self.assertEqual(spec_2.shape.as_list(), [None, 1]) + self.assertEqual(spec_2.name, "partial") + + def testFromBoundedTensorSpec(self): + bounded_spec = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, 0, 1) + spec = tensor_spec.TensorSpec.from_spec(bounded_spec) + self.assertEqual(bounded_spec.shape, spec.shape) + self.assertEqual(bounded_spec.dtype, spec.dtype) + self.assertEqual(bounded_spec.name, spec.name) + + +class BoundedTensorSpecTest(test_util.TensorFlowTestCase): + + def testInvalidMinimum(self): + with self.assertRaisesRegexp(ValueError, "not compatible"): + tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, (0, 0, 0), (1, 1)) + + def testInvalidMaximum(self): + with self.assertRaisesRegexp(ValueError, "not compatible"): + tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, 0, (1, 1, 1)) + + def testMinimumMaximumAttributes(self): + spec = tensor_spec.BoundedTensorSpec( + (1, 2, 3), dtypes.float32, 0, (5, 5, 5)) + self.assertEqual(type(spec.minimum), np.ndarray) + self.assertEqual(type(spec.maximum), np.ndarray) + self.assertAllEqual(spec.minimum, np.array(0, dtype=np.float32)) + self.assertAllEqual(spec.maximum, np.array([5, 5, 5], dtype=np.float32)) + + def testNotWriteableNP(self): + spec = tensor_spec.BoundedTensorSpec( + (1, 2, 3), dtypes.float32, 0, (5, 5, 5)) + with self.assertRaisesRegexp(ValueError, "read-only"): + spec.minimum[0] = -1 + with self.assertRaisesRegexp(ValueError, "read-only"): + spec.maximum[0] = 100 + + def testReuseSpec(self): + spec_1 = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, + minimum=0, maximum=1) + spec_2 = tensor_spec.BoundedTensorSpec( + spec_1.shape, spec_1.dtype, spec_1.minimum, spec_1.maximum) + self.assertEqual(spec_1, spec_2) + + def testScalarBounds(self): + spec = tensor_spec.BoundedTensorSpec( + (), dtypes.float32, minimum=0.0, maximum=1.0) + + self.assertIsInstance(spec.minimum, np.ndarray) + self.assertIsInstance(spec.maximum, np.ndarray) + + # Sanity check that numpy compares correctly to a scalar for an empty shape. + self.assertEqual(0.0, spec.minimum) + self.assertEqual(1.0, spec.maximum) + + # Check that the spec doesn't fail its own input validation. + _ = tensor_spec.BoundedTensorSpec( + spec.shape, spec.dtype, spec.minimum, spec.maximum) + + def testFromBoundedTensorSpec(self): + spec_1 = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, + minimum=0, maximum=1) + spec_2 = tensor_spec.BoundedTensorSpec.from_spec(spec_1) + self.assertEqual(spec_1, spec_2) + + def testEquality(self): + spec_1_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + 0, (5, 5, 5)) + spec_1_2 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + 0.00000001, + (5, 5, 5.00000000000000001)) + spec_2_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + 1, (5, 5, 5)) + spec_2_2 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + (1, 1, 1), (5, 5, 5)) + spec_2_3 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + (1, 1, 1), 5) + spec_3_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + (2, 1, 1), (5, 5, 5)) + + self.assertEqual(spec_1_1, spec_1_2) + self.assertEqual(spec_1_2, spec_1_1) + + self.assertNotEqual(spec_1_1, spec_2_2) + self.assertNotEqual(spec_1_1, spec_2_1) + self.assertNotEqual(spec_2_2, spec_1_1) + self.assertNotEqual(spec_2_1, spec_1_1) + + self.assertEqual(spec_2_1, spec_2_2) + self.assertEqual(spec_2_2, spec_2_1) + self.assertEqual(spec_2_2, spec_2_3) + + self.assertNotEqual(spec_1_1, spec_3_1) + self.assertNotEqual(spec_2_1, spec_3_1) + self.assertNotEqual(spec_2_2, spec_3_1) + + def testFromTensorSpec(self): + spec = tensor_spec.TensorSpec((1, 2), dtypes.int32) + bounded_spec = tensor_spec.BoundedTensorSpec.from_spec(spec) + self.assertEqual(spec.shape, bounded_spec.shape) + self.assertEqual(spec.dtype, bounded_spec.dtype) + self.assertEqual(spec.dtype.min, bounded_spec.minimum) + self.assertEqual(spec.dtype.max, bounded_spec.maximum) + self.assertEqual(spec.name, bounded_spec.name) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index d2b8e80305724fd12341bc089d8e0a63c40b6688..27afaa074a6becd5c8b7db94be59e8da1611c13a 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Utilities to create TensorProtos.""" from __future__ import absolute_import from __future__ import division @@ -39,6 +38,7 @@ except ImportError: from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.util.tf_export import tf_export + # pylint: enable=g-import-not-at-top @@ -47,8 +47,8 @@ def ExtractBitsFromFloat16(x): def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.half_val.extend([ - ExtractBitsFromFloat16(x) for x in proto_values]) + tensor_proto.half_val.extend( + [ExtractBitsFromFloat16(x) for x in proto_values]) def ExtractBitsFromBFloat16(x): @@ -57,31 +57,47 @@ def ExtractBitsFromBFloat16(x): def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.half_val.extend([ - ExtractBitsFromBFloat16(x) for x in proto_values]) + tensor_proto.half_val.extend( + [ExtractBitsFromBFloat16(x) for x in proto_values]) if _FAST_TENSOR_UTIL_AVAILABLE: _NP_TO_APPEND_FN = { - dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto, + dtypes.bfloat16.as_numpy_dtype: + SlowAppendBFloat16ArrayToTensorProto, # TODO(sesse): We should have a # fast_tensor_util.AppendFloat16ArrayToTensorProto, # but it seems np.float16_t doesn't exist? - np.float16: SlowAppendFloat16ArrayToTensorProto, - np.float32: fast_tensor_util.AppendFloat32ArrayToTensorProto, - np.float64: fast_tensor_util.AppendFloat64ArrayToTensorProto, - np.int32: fast_tensor_util.AppendInt32ArrayToTensorProto, - np.int64: fast_tensor_util.AppendInt64ArrayToTensorProto, - np.uint8: fast_tensor_util.AppendUInt8ArrayToTensorProto, - np.uint16: fast_tensor_util.AppendUInt16ArrayToTensorProto, - np.uint32: fast_tensor_util.AppendUInt32ArrayToTensorProto, - np.uint64: fast_tensor_util.AppendUInt64ArrayToTensorProto, - np.int8: fast_tensor_util.AppendInt8ArrayToTensorProto, - np.int16: fast_tensor_util.AppendInt16ArrayToTensorProto, - np.complex64: fast_tensor_util.AppendComplex64ArrayToTensorProto, - np.complex128: fast_tensor_util.AppendComplex128ArrayToTensorProto, - np.object: fast_tensor_util.AppendObjectArrayToTensorProto, - np.bool: fast_tensor_util.AppendBoolArrayToTensorProto, + np.float16: + SlowAppendFloat16ArrayToTensorProto, + np.float32: + fast_tensor_util.AppendFloat32ArrayToTensorProto, + np.float64: + fast_tensor_util.AppendFloat64ArrayToTensorProto, + np.int32: + fast_tensor_util.AppendInt32ArrayToTensorProto, + np.int64: + fast_tensor_util.AppendInt64ArrayToTensorProto, + np.uint8: + fast_tensor_util.AppendUInt8ArrayToTensorProto, + np.uint16: + fast_tensor_util.AppendUInt16ArrayToTensorProto, + np.uint32: + fast_tensor_util.AppendUInt32ArrayToTensorProto, + np.uint64: + fast_tensor_util.AppendUInt64ArrayToTensorProto, + np.int8: + fast_tensor_util.AppendInt8ArrayToTensorProto, + np.int16: + fast_tensor_util.AppendInt16ArrayToTensorProto, + np.complex64: + fast_tensor_util.AppendComplex64ArrayToTensorProto, + np.complex128: + fast_tensor_util.AppendComplex128ArrayToTensorProto, + np.object: + fast_tensor_util.AppendObjectArrayToTensorProto, + np.bool: + fast_tensor_util.AppendBoolArrayToTensorProto, dtypes.qint8.as_numpy_dtype: fast_tensor_util.AppendInt8ArrayToTensorProto, dtypes.quint8.as_numpy_dtype: @@ -118,14 +134,12 @@ else: tensor_proto.uint64_val.extend([np.asscalar(x) for x in proto_values]) def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.scomplex_val.extend([np.asscalar(v) - for x in proto_values - for v in [x.real, x.imag]]) + tensor_proto.scomplex_val.extend( + [np.asscalar(v) for x in proto_values for v in [x.real, x.imag]]) def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.dcomplex_val.extend([np.asscalar(v) - for x in proto_values - for v in [x.real, x.imag]]) + tensor_proto.dcomplex_val.extend( + [np.asscalar(v) for x in proto_values for v in [x.real, x.imag]]) def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values): tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values]) @@ -252,15 +266,16 @@ def _FilterTuple(v): return None if isinstance(v, list): if not any(isinstance(x, (list, tuple)) for x in v): - return _FirstNotNone([None if isinstance(x, (list, tuple)) else x for x in v]) + return _FirstNotNone( + [None if isinstance(x, (list, tuple)) else x for x in v]) return _FirstNotNone([_FilterTuple(x) for x in v]) def _FilterInt(v): if isinstance(v, (list, tuple)): return _FirstNotNone([_FilterInt(x) for x in v]) - return None if isinstance(v, (compat.integral_types, - tensor_shape.Dimension)) else _NotNone(v) + return None if isinstance( + v, (compat.integral_types, tensor_shape.Dimension)) else _NotNone(v) def _FilterFloat(v): @@ -380,8 +395,11 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): if dtype: dtype = dtypes.as_dtype(dtype) - is_quantized = (dtype in [dtypes.qint8, dtypes.quint8, dtypes.qint16, - dtypes.quint16, dtypes.qint32]) + is_quantized = ( + dtype in [ + dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16, + dtypes.qint32 + ]) # We first convert value to a numpy array or scalar. if isinstance(values, (np.ndarray, np.generic)): @@ -419,9 +437,9 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): if (list(nparray.shape) != _GetDenseDimensions(values) and not is_quantized): raise ValueError("""Argument must be a dense tensor: %s""" - """ - got shape %s, but wanted %s.""" % ( - values, list(nparray.shape), - _GetDenseDimensions(values))) + """ - got shape %s, but wanted %s.""" % + (values, list(nparray.shape), + _GetDenseDimensions(values))) # python/numpy default float type is float64. We prefer float32 instead. if (nparray.dtype == np.float64) and dtype is None: @@ -446,8 +464,8 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): if dtype is not None and (not hasattr(dtype, "base_dtype") or dtype.base_dtype != numpy_dtype.base_dtype): - raise TypeError("Incompatible types: %s vs. %s. Value is %s" - % (dtype, nparray.dtype, values)) + raise TypeError("Incompatible types: %s vs. %s. Value is %s" % + (dtype, nparray.dtype, values)) # If shape is not given, get the shape from the numpy array. if shape is None: @@ -510,8 +528,8 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): append_fn = GetNumpyAppendFn(proto_values.dtype) if append_fn is None: - raise TypeError("Element type not supported in TensorProto: %s" % - numpy_dtype.name) + raise TypeError( + "Element type not supported in TensorProto: %s" % numpy_dtype.name) append_fn(tensor_proto, proto_values) return tensor_proto @@ -539,7 +557,8 @@ def MakeNdarray(tensor): dtype = tensor_dtype.as_numpy_dtype if tensor.tensor_content: - return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape) + return (np.frombuffer(tensor.tensor_content, dtype=dtype).copy() + .reshape(shape)) elif tensor_dtype == dtypes.float16: # the half_val field of the TensorProto stores the binary representation # of the fp16: we need to reinterpret this as a proper float16 @@ -553,19 +572,23 @@ def MakeNdarray(tensor): return tmp.reshape(shape) elif tensor_dtype == dtypes.float32: if len(tensor.float_val) == 1: - return np.repeat(np.array(tensor.float_val[0], dtype=dtype), - num_elements).reshape(shape) + return np.repeat( + np.array(tensor.float_val[0], dtype=dtype), + num_elements).reshape(shape) else: return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape) elif tensor_dtype == dtypes.float64: if len(tensor.double_val) == 1: - return np.repeat(np.array(tensor.double_val[0], dtype=dtype), - num_elements).reshape(shape) + return np.repeat( + np.array(tensor.double_val[0], dtype=dtype), + num_elements).reshape(shape) else: return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape) - elif tensor_dtype in [dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, - dtypes.int8, dtypes.qint32, dtypes.quint8, dtypes.qint8, - dtypes.qint16, dtypes.quint16, dtypes.bfloat16]: + elif tensor_dtype in [ + dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8, + dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16, + dtypes.bfloat16 + ]: if len(tensor.int_val) == 1: return np.repeat(np.array(tensor.int_val[0], dtype=dtype), num_elements).reshape(shape) @@ -573,35 +596,41 @@ def MakeNdarray(tensor): return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape) elif tensor_dtype == dtypes.int64: if len(tensor.int64_val) == 1: - return np.repeat(np.array(tensor.int64_val[0], dtype=dtype), - num_elements).reshape(shape) + return np.repeat( + np.array(tensor.int64_val[0], dtype=dtype), + num_elements).reshape(shape) else: return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape) elif tensor_dtype == dtypes.string: if len(tensor.string_val) == 1: - return np.repeat(np.array(tensor.string_val[0], dtype=dtype), - num_elements).reshape(shape) + return np.repeat( + np.array(tensor.string_val[0], dtype=dtype), + num_elements).reshape(shape) else: - return np.array([x for x in tensor.string_val], - dtype=dtype).reshape(shape) + return np.array( + [x for x in tensor.string_val], dtype=dtype).reshape(shape) elif tensor_dtype == dtypes.complex64: it = iter(tensor.scomplex_val) if len(tensor.scomplex_val) == 2: - return np.repeat(np.array(complex(tensor.scomplex_val[0], - tensor.scomplex_val[1]), dtype=dtype), - num_elements).reshape(shape) + return np.repeat( + np.array( + complex(tensor.scomplex_val[0], tensor.scomplex_val[1]), + dtype=dtype), num_elements).reshape(shape) else: - return np.array([complex(x[0], x[1]) for x in zip(it, it)], - dtype=dtype).reshape(shape) + return np.array( + [complex(x[0], x[1]) for x in zip(it, it)], + dtype=dtype).reshape(shape) elif tensor_dtype == dtypes.complex128: it = iter(tensor.dcomplex_val) if len(tensor.dcomplex_val) == 2: - return np.repeat(np.array(complex(tensor.dcomplex_val[0], - tensor.dcomplex_val[1]), dtype=dtype), - num_elements).reshape(shape) + return np.repeat( + np.array( + complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]), + dtype=dtype), num_elements).reshape(shape) else: - return np.array([complex(x[0], x[1]) for x in zip(it, it)], - dtype=dtype).reshape(shape) + return np.array( + [complex(x[0], x[1]) for x in zip(it, it)], + dtype=dtype).reshape(shape) elif tensor_dtype == dtypes.bool: if len(tensor.bool_val) == 1: return np.repeat(np.array(tensor.bool_val[0], dtype=dtype), @@ -645,8 +674,9 @@ def _ConstantValue(tensor, partial): elif tensor.op.type == "Shape": input_shape = tensor.op.inputs[0].get_shape() if input_shape.is_fully_defined(): - return np.array([dim.value for dim in input_shape.dims], - dtype=tensor.dtype.as_numpy_dtype) + return np.array( + [dim.value for dim in input_shape.dims], + dtype=tensor.dtype.as_numpy_dtype) else: return None elif tensor.op.type == "Size": @@ -658,8 +688,10 @@ def _ConstantValue(tensor, partial): elif tensor.op.type == "Rank": input_shape = tensor.op.inputs[0].get_shape() if input_shape.ndims is not None: - return np.ndarray(shape=(), buffer=np.array([input_shape.ndims], dtype=np.int32), - dtype=np.int32) + return np.ndarray( + shape=(), + buffer=np.array([input_shape.ndims], dtype=np.int32), + dtype=np.int32) else: return None elif tensor.op.type == "Range": @@ -861,8 +893,8 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name new_axis_mask = tensor.op.get_attr("new_axis_mask") shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask") valid_attributes = (not ellipsis_mask and not new_axis_mask and - not shrink_axis_mask and - (not begin_mask or (begin_mask == 1)) and + not shrink_axis_mask and (not begin_mask or + (begin_mask == 1)) and (not end_mask or (end_mask == 1))) if valid_attributes: # additional inputs not supported prev = constant_value_as_shape(tensor.op.inputs[0]) @@ -878,8 +910,8 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name ret = tensor_shape.unknown_shape(shape[0].value) value = constant_value(tensor) if value is not None: - ret = ret.merge_with(tensor_shape.TensorShape( - [d if d >= 0 else None for d in value])) + ret = ret.merge_with( + tensor_shape.TensorShape([d if d >= 0 else None for d in value])) return ret diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index f2de69e159646b4a085645fa1bfef7782e78cd59..bea0ee34fd4900cc9d4d5d52348ba4512368e81f 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -199,6 +199,25 @@ class TensorUtilTest(test.TestCase): dtype=nptype), a) + def testFloatMutateArray(self): + t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32) + a = tensor_util.MakeNdarray(t) + a[0] = 5.0 + self.assertEquals(np.float32, a.dtype) + self.assertAllClose(np.array([5.0, 20.0, 30.0], dtype=np.float32), a) + if sys.byteorder == "big": + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 3 } } + tensor_content: "A \000\000A\240\000\000A\360\000\000" + """, t) + else: + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 3 } } + tensor_content: "\000\000 A\000\000\240A\000\000\360A" + """, t) + def testHalf(self): t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=np.float16)) self.assertProtoEquals(""" diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc index c6c6c2233c9a81467f57abe2d42f0df9b7ce7106..070b5ac11f563443a97b304ddcdaabd2f4338445 100644 --- a/tensorflow/python/framework/test_ops.cc +++ b/tensorflow/python/framework/test_ops.cc @@ -76,6 +76,11 @@ REGISTER_OP("TestStringOutput") .Output("output2: string") .SetShapeFn(shape_inference::UnknownShape); +REGISTER_OP("TestAttr") + .Output("out: T") + .Attr("T: {float, double}") + .SetShapeFn(shape_inference::UnknownShape); + namespace { enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL }; } // namespace @@ -188,6 +193,20 @@ class ResourceUsingOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("ResourceUsingOp").Device(DEVICE_CPU), ResourceUsingOp); +class TestAttrOp : public OpKernel { + public: + explicit TestAttrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Tensor* output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = 1.0; + } +}; + +REGISTER_KERNEL_BUILDER( + Name("TestAttr").Device(DEVICE_CPU).TypeConstraint("T"), TestAttrOp); + // Various test ops without kernels. These are used to test graph construction. REGISTER_OP("A") diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 0133318456219b35be11bc5ef128406292bc2feb..682b2b366724294de3cf222633c95f69600823ba 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -49,10 +49,11 @@ from tensorflow.python.client import device_lib from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import tape +from tensorflow.python.eager import tape # pylint: disable=unused-import from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import versions @@ -122,11 +123,11 @@ def assert_equal_graph_def(actual, expected, checkpoint_v2=False): TypeError: If either argument is not a `GraphDef`. """ if not isinstance(actual, graph_pb2.GraphDef): - raise TypeError("Expected tf.GraphDef for actual, got %s" % - type(actual).__name__) + raise TypeError( + "Expected tf.GraphDef for actual, got %s" % type(actual).__name__) if not isinstance(expected, graph_pb2.GraphDef): - raise TypeError("Expected tf.GraphDef for expected, got %s" % - type(expected).__name__) + raise TypeError( + "Expected tf.GraphDef for expected, got %s" % type(expected).__name__) if checkpoint_v2: _strip_checkpoint_v2_randomized(actual) @@ -151,11 +152,10 @@ def assert_meta_graph_protos_equal(tester, a, b): a_proto = proto_type() b_proto = proto_type() # Number of entries in the collections is the same - tester.assertEqual(len(a_value.bytes_list.value), - len(b_value.bytes_list.value)) - for (a_value_item, b_value_item) in zip( - a_value.bytes_list.value, - b_value.bytes_list.value): + tester.assertEqual( + len(a_value.bytes_list.value), len(b_value.bytes_list.value)) + for (a_value_item, b_value_item) in zip(a_value.bytes_list.value, + b_value.bytes_list.value): a_proto.ParseFromString(a_value_item) b_proto.ParseFromString(b_value_item) tester.assertProtoEquals(a_proto, b_proto) @@ -219,10 +219,7 @@ def NHWCToNCHW(input_tensor): converted tensor or shape array """ # tensor dim -> new axis order - new_axes = { - 4: [0, 3, 1, 2], - 5: [0, 4, 1, 2, 3] - } + new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]} if isinstance(input_tensor, ops.Tensor): ndims = input_tensor.shape.ndims return array_ops.transpose(input_tensor, new_axes[ndims]) @@ -249,8 +246,9 @@ def NHWCToNCHW_VECT_C(input_shape_or_tensor): """ permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) - temp_shape = (input_shape_or_tensor.shape.as_list() - if is_tensor else input_shape_or_tensor) + temp_shape = ( + input_shape_or_tensor.shape.as_list() + if is_tensor else input_shape_or_tensor) if temp_shape[-1] % 4 != 0: raise ValueError( "Last dimension of input must be evenly divisible by 4 to convert to " @@ -282,8 +280,9 @@ def NCHW_VECT_CToNHWC(input_shape_or_tensor): """ permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) - input_shape = (input_shape_or_tensor.shape.as_list() - if is_tensor else input_shape_or_tensor) + input_shape = ( + input_shape_or_tensor.shape.as_list() + if is_tensor else input_shape_or_tensor) if input_shape[-1] != 4: raise ValueError("Last dimension of NCHW_VECT_C must be 4.") permutation = permutations[len(input_shape)] @@ -306,10 +305,7 @@ def NCHWToNHWC(input_tensor): converted tensor or shape array """ # tensor dim -> new axis order - new_axes = { - 4: [0, 2, 3, 1], - 5: [0, 2, 3, 4, 1] - } + new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]} if isinstance(input_tensor, ops.Tensor): ndims = input_tensor.shape.ndims return array_ops.transpose(input_tensor, new_axes[ndims]) @@ -324,10 +320,17 @@ def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs): prev_value = ops._USE_C_API ops._USE_C_API = use_c_api try: - with ops.Graph().as_default(): - fn(*args, **kwargs) + # Reset the default graph so it has the C API enabled. We call + # reset_default_graph() instead of creating a new default Graph context to + # make this robust to tests that call reset_default_graph(), which requires + # that the current default graph isn't nested. + ops.reset_default_graph() + fn(*args, **kwargs) finally: ops._USE_C_API = prev_value + # Make sure default graph reflects prev_value in case next test doesn't call + # reset_default_graph(). + ops.reset_default_graph() # pylint: disable=protected-access @@ -344,7 +347,9 @@ def skip_if(condition): Returns: The wrapped function """ + def real_skip_if(fn): + def wrapper(*args, **kwargs): if callable(condition): skip = condition() @@ -352,7 +357,9 @@ def skip_if(condition): skip = condition if not skip: fn(*args, **kwargs) + return wrapper + return real_skip_if @@ -369,8 +376,10 @@ def disable_c_api(fn): Returns: The wrapped function """ + def wrapper(*args, **kwargs): _use_c_api_wrapper(fn, False, *args, **kwargs) + return wrapper @@ -387,8 +396,10 @@ def enable_c_api(fn): Returns: The wrapped function """ + def wrapper(*args, **kwargs): _use_c_api_wrapper(fn, True, *args, **kwargs) + return wrapper @@ -414,66 +425,6 @@ def with_c_api(cls): return cls -class IsolateTest(object): - """A context manager which isolates resources in its block. - - Provides an Eager-agnostic abstraction for preventing the sharing of - variables and other resources. - - In graph mode, resource handle ops are only executed in a particular Session, - isolating them from resources with the same name in other Graphs. In Eager, - separate Sessions do not exist, so resources (particularly ResourceVariables) - would be shared implicitly if a resource of the same name were created - anywhere in a Python process. Multiple handles to the same resource would - cause several issues, and so this type of sharing will raise an exception. - - Using resources with the same name in a single Python process may be useful - (especially for unit tests), so this context manager provides an abstraction - for isolating resources. Using a resource created in one Isolation environment - in another is an error. - - Example usage in Eager mode: - - ```python - import tensorflow as tf - # Import subject to change - from tensorflow.contrib.eager.python import tfe - - tfe.enable_eager_execution() - - for hyperparameter in [1, 2, 3]: - with tfe.IsolateTest(): - v = tfe.Variable(name="v", initial_value=hyperparameter) - # train model, test results ... - ``` - - IsolateTest is currently exposed through contrib.eager, but it creates a new - default Graph and provides equivalent safety in graph mode. - """ - - def __init__(self): - if context.in_eager_mode() and tape.could_possibly_record(): - raise ValueError("Cannot isolate Eager execution with an active tape.") - # In Eager, Graphs set a container which isolates resources, and maintain a - # VariableStore which caches ResourceVariable objects created through - # get_variable. So setting the default Graph has the side effect of - # isolating Eager resources. - with context.eager_mode(): - # Create the graph in Eager mode, as this provides stricter semantics - # (i.e. has a unique container prefix). This prevents implicit sharing - # when a Graph-mode graph is created and then Eager mode is enabled (an - # error through enable_eager_execution, but common with context managers - # in unit tests). - self._graph_as_default_context_manager = ops.Graph().as_default() - - def __enter__(self): - self._graph_as_default_context_manager.__enter__() - - def __exit__(self, type_arg, value_arg, traceback_arg): - return self._graph_as_default_context_manager.__exit__( - type_arg, value_arg, traceback_arg) - - def assert_no_new_tensors(f): """Decorator for asserting that no new Tensors persist after a test. @@ -504,17 +455,15 @@ def assert_no_new_tensors(f): return False tensors_before = set(id(obj) for obj in gc.get_objects() if _is_tensor(obj)) - outside_container_prefix = ops.get_default_graph()._container_prefix - with IsolateTest(): + outside_graph_key = ops.get_default_graph()._graph_key + with ops.Graph().as_default(): # Run the test in a new graph so that collections get cleared when it's - # done, but inherit the container prefix so that we can print the values - # of variables which get leaked when executing eagerly. - ops.get_default_graph()._container_prefix = outside_container_prefix + # done, but inherit the graph key so optimizers behave. + ops.get_default_graph()._graph_key = outside_graph_key f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. - backprop._last_zero = [None] - backprop._shape_dtype = [None, None] + backprop._zeros_cache.flush() context.get_default_context().scalar_cache().clear() gc.collect() tensors_after = [ @@ -560,13 +509,17 @@ def assert_no_garbage_created(f): # not hold on to every object in other tests. gc.set_debug(previous_debug_flags) gc.enable() + return decorator -def run_in_graph_and_eager_modes( - __unused__=None, graph=None, config=None, - use_gpu=False, force_gpu=False, - reset_test=True, assert_no_eager_garbage=False): +def run_in_graph_and_eager_modes(__unused__=None, + graph=None, + config=None, + use_gpu=False, + force_gpu=False, + reset_test=True, + assert_no_eager_garbage=False): """Runs the test in both graph and eager modes. Args: @@ -595,6 +548,7 @@ def run_in_graph_and_eager_modes( def decorator(f): """Test method decorator.""" + def decorated(self, **kwargs): """Decorated the test method.""" with context.graph_mode(): @@ -626,10 +580,11 @@ def run_in_graph_and_eager_modes( assert_no_garbage_created(run_eager_mode)) with context.eager_mode(): - with IsolateTest(): + with ops.Graph().as_default(): run_eager_mode(self, **kwargs) return decorated + return decorator @@ -736,7 +691,7 @@ class TensorFlowTestCase(googletest.TestCase): self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir()) return self._tempdir - def _AssertProtoEquals(self, a, b): + def _AssertProtoEquals(self, a, b, msg=None): """Asserts that a and b are the same proto. Uses ProtoEq() first, as it returns correct results @@ -746,11 +701,12 @@ class TensorFlowTestCase(googletest.TestCase): Args: a: a proto. b: another proto. + msg: Optional message to report on failure. """ if not compare.ProtoEq(a, b): - compare.assertProtoEqual(self, a, b, normalize_numbers=True) + compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) - def assertProtoEquals(self, expected_message_maybe_ascii, message): + def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None): """Asserts that message is same as parsed expected_message_ascii. Creates another prototype of message, reads the ascii message into it and @@ -759,29 +715,33 @@ class TensorFlowTestCase(googletest.TestCase): Args: expected_message_maybe_ascii: proto message in original or ascii form. message: the message to validate. + msg: Optional message to report on failure. """ - + msg = msg if msg else "" if isinstance(expected_message_maybe_ascii, type(message)): expected_message = expected_message_maybe_ascii self._AssertProtoEquals(expected_message, message) elif isinstance(expected_message_maybe_ascii, str): expected_message = type(message)() - text_format.Merge(expected_message_maybe_ascii, expected_message, - descriptor_pool=descriptor_pool.Default()) - self._AssertProtoEquals(expected_message, message) + text_format.Merge( + expected_message_maybe_ascii, + expected_message, + descriptor_pool=descriptor_pool.Default()) + self._AssertProtoEquals(expected_message, message, msg=msg) else: - assert False, ("Can't compare protos of type %s and %s" % - (type(expected_message_maybe_ascii), type(message))) + assert False, ("Can't compare protos of type %s and %s. %s" % + (type(expected_message_maybe_ascii), type(message), msg)) def assertProtoEqualsVersion( self, expected, actual, producer=versions.GRAPH_DEF_VERSION, - min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER): + min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER, + msg=None): expected = "versions { producer: %d min_consumer: %d };\n%s" % ( producer, min_consumer, expected) - self.assertProtoEquals(expected, actual) + self.assertProtoEquals(expected, actual, msg=msg) def assertStartsWith(self, actual, expected_start, msg=None): """Assert that actual.startswith(expected_start) is True. @@ -851,7 +811,8 @@ class TensorFlowTestCase(googletest.TestCase): trigger the creation of a new session. Use the `use_gpu` and `force_gpu` options to control where ops are run. If - `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if `use_gpu` + `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if + `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to the CPU. @@ -1050,6 +1011,7 @@ class TensorFlowTestCase(googletest.TestCase): self._threads.append(ret) return ret + # pylint: enable=invalid-name def assertNear(self, f1, f2, err, msg=None): @@ -1069,7 +1031,7 @@ class TensorFlowTestCase(googletest.TestCase): "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg if msg is not None else "")) - def assertArrayNear(self, farray1, farray2, err): + def assertArrayNear(self, farray1, farray2, err, msg=None): """Asserts that two float arrays are near each other. Checks that for all elements of farray1 and farray2 @@ -1079,23 +1041,25 @@ class TensorFlowTestCase(googletest.TestCase): farray1: a list of float values. farray2: a list of float values. err: a float value. + msg: Optional message to report on failure. """ - self.assertEqual(len(farray1), len(farray2)) + self.assertEqual(len(farray1), len(farray2), msg=msg) for f1, f2 in zip(farray1, farray2): - self.assertNear(float(f1), float(f2), err) + self.assertNear(float(f1), float(f2), err, msg=msg) def _NDArrayNear(self, ndarray1, ndarray2, err): return np.linalg.norm(ndarray1 - ndarray2) < err - def assertNDArrayNear(self, ndarray1, ndarray2, err): + def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None): """Asserts that two numpy arrays have near values. Args: ndarray1: a numpy ndarray. ndarray2: a numpy ndarray. err: a float. The maximum absolute difference allowed. + msg: Optional message to report on failure. """ - self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err)) + self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg) def _GetNdArray(self, a): if not isinstance(a, np.ndarray): @@ -1117,7 +1081,8 @@ class TensorFlowTestCase(googletest.TestCase): # the absolute difference between a and b. Here, we want to # print out which elements violate such conditions. cond = np.logical_or( - np.abs(a - b) > atol + rtol * np.abs(b), np.isnan(a) != np.isnan(b)) + np.abs(a - b) > atol + rtol * np.abs(b), + np.isnan(a) != np.isnan(b)) if a.ndim: x = a[np.where(cond)] y = b[np.where(cond)] @@ -1136,9 +1101,16 @@ class TensorFlowTestCase(googletest.TestCase): np.testing.assert_allclose( a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True) - def _assertAllCloseRecursive(self, a, b, rtol=1e-6, atol=1e-6, path=None): + def _assertAllCloseRecursive(self, + a, + b, + rtol=1e-6, + atol=1e-6, + path=None, + msg=None): path = path or [] path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "") + msg = msg if msg else "" # Check if a and/or b are namedtuples. if hasattr(a, "_asdict"): @@ -1147,18 +1119,18 @@ class TensorFlowTestCase(googletest.TestCase): b = b._asdict() a_is_dict = isinstance(a, dict) if a_is_dict != isinstance(b, dict): - raise ValueError("Can't compare dict to non-dict, a%s vs b%s." % - (path_str, path_str)) + raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" % + (path_str, path_str, msg)) if a_is_dict: self.assertItemsEqual( a.keys(), b.keys(), - msg="mismatched keys: a%s has keys %s, but b%s has keys %s" % - (path_str, a.keys(), path_str, b.keys())) + msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" % + (path_str, a.keys(), path_str, b.keys(), msg)) for k in a: path.append(k) self._assertAllCloseRecursive( - a[k], b[k], rtol=rtol, atol=atol, path=path) + a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg) del path[-1] elif isinstance(a, (list, tuple)): # Try to directly compare a, b as ndarrays; if not work, then traverse @@ -1171,29 +1143,35 @@ class TensorFlowTestCase(googletest.TestCase): b_as_ndarray, rtol=rtol, atol=atol, - msg="Mismatched value: a%s is different from b%s." % (path_str, - path_str)) + msg="Mismatched value: a%s is different from b%s. %s" % + (path_str, path_str, msg)) except (ValueError, TypeError) as e: if len(a) != len(b): raise ValueError( - "Mismatched length: a%s has %d items, but b%s has %d items" % - (path_str, len(a), path_str, len(b))) + "Mismatched length: a%s has %d items, but b%s has %d items. %s" % + (path_str, len(a), path_str, len(b), msg)) for idx, (a_ele, b_ele) in enumerate(zip(a, b)): path.append(str(idx)) self._assertAllCloseRecursive( - a_ele, b_ele, rtol=rtol, atol=atol, path=path) + a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg) del path[-1] # a and b are ndarray like objects else: - self._assertArrayLikeAllClose( - a, - b, - rtol=rtol, - atol=atol, - msg="Mismatched value: a%s is different from b%s." % (path_str, - path_str)) - - def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6): + try: + self._assertArrayLikeAllClose( + a, + b, + rtol=rtol, + atol=atol, + msg="Mismatched value: a%s is different from b%s." % (path_str, + path_str)) + except TypeError as e: + msg = "Error: a%s has %s, but b%s has %s" % ( + path_str, type(a), path_str, type(b)) + e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:]) + raise + + def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): """Asserts that two structures of numpy arrays, have near values. `a` and `b` can be arbitrarily nested structures. A layer of a nested @@ -1206,6 +1184,7 @@ class TensorFlowTestCase(googletest.TestCase): numpy `ndarray`, or any arbitrarily nested of structure of these. rtol: relative tolerance. atol: absolute tolerance. + msg: Optional message to report on failure. Raises: ValueError: if only one of `a[p]` and `b[p]` is a dict or @@ -1213,7 +1192,7 @@ class TensorFlowTestCase(googletest.TestCase): to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and `[p] = [1]['d']`, then `a[p] = (6, 7)`. """ - self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol) + self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg) def assertAllCloseAccordingToType(self, a, @@ -1225,7 +1204,8 @@ class TensorFlowTestCase(googletest.TestCase): half_rtol=1e-3, half_atol=1e-3, bfloat16_rtol=1e-2, - bfloat16_atol=1e-2): + bfloat16_atol=1e-2, + msg=None): """Like assertAllClose, but also suitable for comparing fp16 arrays. In particular, the tolerance is reduced to 1e-3 if at least @@ -1242,6 +1222,7 @@ class TensorFlowTestCase(googletest.TestCase): half_atol: absolute tolerance for float16. bfloat16_rtol: relative tolerance for bfloat16. bfloat16_atol: absolute tolerance for bfloat16. + msg: Optional message to report on failure. """ a = self._GetNdArray(a) b = self._GetNdArray(b) @@ -1258,19 +1239,21 @@ class TensorFlowTestCase(googletest.TestCase): rtol = max(rtol, bfloat16_rtol) atol = max(atol, bfloat16_atol) - self.assertAllClose(a, b, rtol=rtol, atol=atol) + self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) - def assertAllEqual(self, a, b): + def assertAllEqual(self, a, b, msg=None): """Asserts that two numpy arrays have the same values. Args: a: the expected numpy ndarray or anything can be converted to one. b: the actual numpy ndarray or anything can be converted to one. + msg: Optional message to report on failure. """ + msg = msg if msg else "" a = self._GetNdArray(a) b = self._GetNdArray(b) - self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." % - (a.shape, b.shape)) + self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." + " %s" % (a.shape, b.shape, msg)) same = (a == b) if a.dtype == np.float32 or a.dtype == np.float64: @@ -1287,7 +1270,7 @@ class TensorFlowTestCase(googletest.TestCase): x, y = a, b print("not equal lhs = ", x) print("not equal rhs = ", y) - np.testing.assert_array_equal(a, b) + np.testing.assert_array_equal(a, b, err_msg=msg) # pylint: disable=g-doc-return-or-yield @contextlib.contextmanager @@ -1337,12 +1320,13 @@ class TensorFlowTestCase(googletest.TestCase): return self.assertRaisesWithPredicateMatch(errors.OpError, expected_err_re_or_predicate) - def assertShapeEqual(self, np_array, tf_tensor): + def assertShapeEqual(self, np_array, tf_tensor, msg=None): """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape. Args: np_array: A Numpy ndarray or Numpy scalar. tf_tensor: A Tensor. + msg: Optional message to report on failure. Raises: TypeError: If the arguments have the wrong type. @@ -1351,19 +1335,22 @@ class TensorFlowTestCase(googletest.TestCase): raise TypeError("np_array must be a Numpy ndarray or Numpy scalar") if not isinstance(tf_tensor, ops.Tensor): raise TypeError("tf_tensor must be a Tensor") - self.assertAllEqual(np_array.shape, tf_tensor.get_shape().as_list()) + self.assertAllEqual( + np_array.shape, tf_tensor.get_shape().as_list(), msg=msg) - def assertDeviceEqual(self, device1, device2): + def assertDeviceEqual(self, device1, device2, msg=None): """Asserts that the two given devices are the same. Args: device1: A string device name or TensorFlow `DeviceSpec` object. device2: A string device name or TensorFlow `DeviceSpec` object. + msg: Optional message to report on failure. """ device1 = pydev.canonical_name(device1) device2 = pydev.canonical_name(device2) self.assertEqual(device1, device2, - "Devices %s and %s are not equal" % (device1, device2)) + "Devices %s and %s are not equal. %s" % + (device1, device2, msg)) # Fix Python 3 compatibility issues if six.PY3: @@ -1379,8 +1366,11 @@ class TensorFlowTestCase(googletest.TestCase): @tf_export("test.create_local_cluster") -def create_local_cluster(num_workers, num_ps, protocol="grpc", - worker_config=None, ps_config=None): +def create_local_cluster(num_workers, + num_ps, + protocol="grpc", + worker_config=None, + ps_config=None): """Create and start local servers and return the associated `Server` objects. Example: @@ -1430,15 +1420,21 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc", workers = [ server_lib.Server( - cs, job_name="worker", protocol=protocol, task_index=ix, - config=worker_config, start=True) - for ix in range(num_workers) + cs, + job_name="worker", + protocol=protocol, + task_index=ix, + config=worker_config, + start=True) for ix in range(num_workers) ] ps_servers = [ server_lib.Server( - cs, job_name="ps", protocol=protocol, task_index=ix, - config=ps_config, start=True) - for ix in range(num_ps) + cs, + job_name="ps", + protocol=protocol, + task_index=ix, + config=ps_config, + start=True) for ix in range(num_ps) ] return workers, ps_servers @@ -1460,3 +1456,14 @@ def get_node_def_from_graph(node_name, graph_def): if node_def.name == node_name: return node_def return None + + +def set_producer_version(graph, producer_version): + """Sets graph.graph_def_versions.producer to `producer_version`.""" + # The C API doesn't expose altering GraphDefVersions. We can indirectly set + # it via import_graph_def though. + graph_def = graph_pb2.GraphDef() + graph_def.versions.producer = producer_version + with graph.as_default(): + importer.import_graph_def(graph_def) + assert graph.graph_def_versions.producer, producer_version diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 3594d125bf616917727bea4958eaabf159d0aee0..a717eb39513ac3369ae133b6090ff82597f12eb7 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -29,7 +29,6 @@ from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors @@ -39,7 +38,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -443,71 +441,5 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): LeakedTensorTest().test_has_no_leak() -@test_util.with_c_api -class IsolationTest(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes() - def test_variable_reuse_exception(self): - with test_util.IsolateTest(), session.Session(): - first_container_variable = resource_variable_ops.ResourceVariable( - name="first_container_variable", - initial_value=1) - if context.in_graph_mode(): - self.evaluate([variables.global_variables_initializer()]) - with test_util.IsolateTest(): - if context.in_graph_mode(): - with self.assertRaises(RuntimeError): - self.evaluate(first_container_variable.read_value()) - else: - with self.assertRaises(ValueError): - first_container_variable.read_value() - - @test_util.run_in_graph_and_eager_modes() - def test_variable_reuse_exception_nested(self): - with test_util.IsolateTest(), session.Session(): - first_container_variable = resource_variable_ops.ResourceVariable( - name="first_container_variable", - initial_value=1) - if context.in_graph_mode(): - self.evaluate([variables.global_variables_initializer()]) - with test_util.IsolateTest(), session.Session(): - if context.in_graph_mode(): - with self.assertRaises(RuntimeError): - self.evaluate(first_container_variable.read_value()) - else: - with self.assertRaises(ValueError): - first_container_variable.read_value() - - @test_util.run_in_graph_and_eager_modes() - def test_no_sharing(self): - with test_util.IsolateTest(), session.Session(): - first_container_variable = resource_variable_ops.ResourceVariable( - name="same_name", - initial_value=1) - if context.in_graph_mode(): - self.evaluate([variables.global_variables_initializer()]) - with test_util.IsolateTest(), session.Session(): - second_container_variable = resource_variable_ops.ResourceVariable( - name="same_name", - initial_value=2) - if context.in_graph_mode(): - self.evaluate([variables.global_variables_initializer()]) - self.assertEqual( - 2, self.evaluate(second_container_variable.read_value())) - self.assertEqual(1, self.evaluate(first_container_variable.read_value())) - - def test_graph_mode_isolation(self): - with context.graph_mode(): - # Even if we've (accidentally) called IsolateTest in Graph mode, it should - # provide Eager isolation. - with test_util.IsolateTest(): - with context.eager_mode(): - first_container_variable = resource_variable_ops.ResourceVariable( - name="first_container_variable", - initial_value=1) - with context.eager_mode(): - with self.assertRaises(ValueError): - first_container_variable.read_value() - if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/framework/versions.py b/tensorflow/python/framework/versions.py index bdcbc15af63c57d712abfac97537f86b3bbe1737..06955b885852a641bc814f88c99838effe03bfd4 100644 --- a/tensorflow/python/framework/versions.py +++ b/tensorflow/python/framework/versions.py @@ -35,7 +35,9 @@ tf_export("GIT_VERSION").export_constant(__name__, "GIT_VERSION") COMPILER_VERSION = __compiler_version__ tf_export("COMPILER_VERSION").export_constant(__name__, "COMPILER_VERSION") CXX11_ABI_FLAG = __cxx11_abi_flag__ +tf_export("CXX11_ABI_FLAG").export_constant(__name__, "CXX11_ABI_FLAG") MONOLITHIC_BUILD = __monolithic_build__ +tf_export("MONOLITHIC_BUILD").export_constant(__name__, "MONOLITHIC_BUILD") GRAPH_DEF_VERSION = pywrap_tensorflow.GRAPH_DEF_VERSION tf_export("GRAPH_DEF_VERSION").export_constant(__name__, "GRAPH_DEF_VERSION") diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i index 0c8d04ff29518d587079a76e3fee3b2e327c6c5c..8079cb307bb1f5904b71bae891d5ef5f1e749e66 100644 --- a/tensorflow/python/grappler/cluster.i +++ b/tensorflow/python/grappler/cluster.i @@ -140,6 +140,7 @@ static GCluster TF_NewCluster(bool allow_soft_placement, timeout_s, num_cpu_cores, num_gpus); cluster_->DisableDetailedStats(disable_detailed_stats); cluster_->AllowSoftPlacement(allow_soft_placement); + cluster_->SetNumWarmupSteps(10); tensorflow::Status status = cluster_->Provision(); tensorflow::Set_TF_Status_from_Status(out_status, status); return GCluster(cluster_); diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py index 2292b2c732b2d5d0d40b44d8ca831f4e72b057c6..caae5b114e1f896fc758613380a6f702853d98a5 100644 --- a/tensorflow/python/grappler/cluster_test.py +++ b/tensorflow/python/grappler/cluster_test.py @@ -45,7 +45,7 @@ class ClusterTest(test.TestCase): op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts( grappler_item) self.assertTrue(run_time > 0) - self.assertEqual(len(op_perfs), 9) + self.assertEqual(len(op_perfs), 8) self.assertTrue(step_stats.dev_stats) def testNoDetailedStats(self): @@ -125,7 +125,7 @@ class ClusterTest(test.TestCase): disable_detailed_stats=False, disable_timeline=False) as gcluster: op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item) self.assertTrue(run_time > 0) - self.assertEqual(len(op_perfs), 9) + self.assertEqual(len(op_perfs), 8) self.assertTrue(step_stats.dev_stats) def testAvailableOps(self): diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py index 146bb4311cb5a44d5739821db19f33a41e6e9ce2..86db87d51530621d71b9b90f145e0f10d0b72443 100644 --- a/tensorflow/python/grappler/cost_analyzer_tool.py +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -22,34 +22,76 @@ import argparse import sys from google.protobuf import text_format - +from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op # pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops from tensorflow.python.grappler import cost_analyzer from tensorflow.python.grappler import tf_optimizer from tensorflow.python.platform import app from tensorflow.python.platform import gfile +from tensorflow.python.training import saver def main(_): - with gfile.GFile(FLAGS.input) as input_file: - metagraph = meta_graph_pb2.MetaGraphDef() - metagraph.ParseFromString(input_file.read()) + if FLAGS.metagraphdef: + with gfile.GFile(FLAGS.metagraphdef) as meta_file: + metagraph = meta_graph_pb2.MetaGraphDef() + if FLAGS.metagraphdef.endswith(".pbtxt"): + text_format.Merge(meta_file.read(), metagraph) + else: + metagraph.ParseFromString(meta_file.read()) + if FLAGS.fetch is not None: + fetch_collection = meta_graph_pb2.CollectionDef() + fetch_collection.node_list.value.append(FLAGS.fetch) + metagraph.collection_def["train_op"].CopyFrom(fetch_collection) + else: + with gfile.GFile(FLAGS.graphdef) as graph_file: + graph_def = graph_pb2.GraphDef() + if FLAGS.graphdef.endswith(".pbtxt"): + text_format.Merge(graph_file.read(), graph_def) + else: + graph_def.ParseFromString(graph_file.read()) + importer.import_graph_def(graph_def, name="") + graph = ops.get_default_graph() + fetch = graph.get_operation_by_name(FLAGS.fetch) + graph.add_to_collection("train_op", fetch) + metagraph = saver.export_meta_graph( + graph_def=graph.as_graph_def(), graph=graph) + rewriter_config = rewriter_config_pb2.RewriterConfig() if FLAGS.rewriter_config is not None: - rewriter_config = rewriter_config_pb2.RewriterConfig() text_format.Merge(FLAGS.rewriter_config, rewriter_config) - optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) - metagraph.graph_def.CopyFrom(optimized_graph) + optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) + metagraph.graph_def.CopyFrom(optimized_graph) report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report) print(report) + report = cost_analyzer.GenerateMemoryReport(metagraph) + print(report) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--input", type=str, default=None, help="Input .meta file path.") + "--metagraphdef", + type=str, + default=None, + help="Input .meta MetaGraphDef file path.") + parser.add_argument( + "--graphdef", + type=str, + default=None, + help="Input .pb GraphDef file path.") + parser.add_argument( + "--fetch", + type=str, + default=None, + help= + "The name of the fetch node." + ) parser.add_argument( "--rewriter_config", type=str, diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 578f86ca5a0c1f2446dbf26ce412e34f3bdbd23a..0f5150174049250e86bbac0a49eb998339058326 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -157,6 +157,7 @@ def _get_config(layout_optimizer=True): graph_options = config_pb2.GraphOptions( rewrite_options=rewrite_options, build_cost_model=1) config = config_pb2.ConfigProto(graph_options=graph_options) + config.graph_options.optimizer_options.opt_level = -1 return config @@ -179,6 +180,8 @@ def _get_cluster(): named_device = device_properties_pb2.NamedDevice() named_device.name = '/GPU:0' named_device.properties.type = 'GPU' + named_device.properties.num_cores = 24 + named_device.properties.frequency = 1000 named_device.properties.environment['architecture'] = '4' cluster = gcluster.Cluster(devices=[named_device]) return cluster @@ -253,7 +256,7 @@ class LayoutOptimizerTest(test.TestCase): x = random_ops.truncated_normal([1, 784], seed=0) output = _two_layer_model(x) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -290,7 +293,7 @@ class LayoutOptimizerTest(test.TestCase): add = bn0[0] + bn1[0] output = array_ops.identity(add) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={dim: 3}) with session.Session(config=_get_config()) as sess: @@ -322,7 +325,7 @@ class LayoutOptimizerTest(test.TestCase): value=conv, size_splits=sizes, axis=dim, num_split=3) output = math_ops.reduce_sum(split[0]) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={dim: 3}) with session.Session(config=_get_config()) as sess: @@ -356,7 +359,7 @@ class LayoutOptimizerTest(test.TestCase): pad = array_ops.pad(conv, paddings) output = array_ops.identity(pad) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -387,7 +390,7 @@ class LayoutOptimizerTest(test.TestCase): reduce_sum = math_ops.reduce_sum(conv) output = array_ops.identity(reduce_sum) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -416,7 +419,7 @@ class LayoutOptimizerTest(test.TestCase): cast = math_ops.cast(conv, dtype='bool') output = array_ops.identity(cast) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -447,7 +450,67 @@ class LayoutOptimizerTest(test.TestCase): squeeze = array_ops.squeeze(reduce_sum) output = array_ops.identity(squeeze) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Three transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 1 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testSqueezeAlongHW(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + reduce_sum = math_ops.reduce_sum(conv, axis=[1, 2], keep_dims=True) + squeeze = array_ops.squeeze(reduce_sum, axis=[1, 2]) + output = array_ops.identity(squeeze) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Three transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 1 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testSqueezeAlongNHW(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + reduce_sum = math_ops.reduce_sum(conv, axis=[0, 1, 2], keep_dims=True) + squeeze = array_ops.squeeze(reduce_sum, axis=[0, 1, 2]) + output = array_ops.identity(squeeze) + + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -476,7 +539,7 @@ class LayoutOptimizerTest(test.TestCase): reduce_sum = math_ops.reduce_sum(conv, axis=[1, 2, 3]) output = array_ops.identity(reduce_sum) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -505,7 +568,7 @@ class LayoutOptimizerTest(test.TestCase): reduce_sum = math_ops.reduce_sum(conv, axis=[0, 1, 2]) output = array_ops.identity(reduce_sum) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -534,7 +597,7 @@ class LayoutOptimizerTest(test.TestCase): reduce_sum = math_ops.reduce_sum(conv, axis=[3]) output = array_ops.identity(reduce_sum) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -555,6 +618,94 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testReduceSumAlongCKeepDims(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + reduce_sum = math_ops.reduce_sum(conv, axis=[3], keep_dims=True) + output = array_ops.identity(reduce_sum) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self._assert_trans_nchw_to_nhwc('Sum-0-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testReduceSumAlongHKeepDims(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + reduce_sum = math_ops.reduce_sum(conv, axis=[2], keep_dims=True) + output = array_ops.identity(reduce_sum) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testReduceSumAlongWCKeepDims(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + reduce_sum = math_ops.reduce_sum(conv, axis=[2, 3], keep_dims=True) + output = array_ops.identity(reduce_sum) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testConcatWithControlDependency(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) @@ -567,7 +718,7 @@ class LayoutOptimizerTest(test.TestCase): concat = array_ops.concat([conv, conv], axis) output = array_ops.identity(concat) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -601,7 +752,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(fill) x_val = [3.4] * 784 - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={x: x_val}) with session.Session(config=_get_config()) as sess: @@ -643,7 +794,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(tile) multiple_val = [2, 3, 4, 1] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={multiple: multiple_val}) with session.Session(config=_get_config()) as sess: @@ -678,7 +829,7 @@ class LayoutOptimizerTest(test.TestCase): reverse = array_ops.reverse(conv, dims) output = array_ops.identity(reverse) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -711,7 +862,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(reverse) dims_val = [2, 3] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={dims: dims_val}) with session.Session(config=_get_config()) as sess: @@ -748,7 +899,7 @@ class LayoutOptimizerTest(test.TestCase): select = gen_math_ops._select(condition, conv, add) output = array_ops.identity(select) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -768,6 +919,37 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_nchw_to_nhwc('Select-0-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testSelectOpConditionUnknownShape(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + add = math_ops.add(conv, conv) + condition = array_ops.placeholder(dtype='bool') + select = gen_math_ops._select(condition, conv, add) + output = array_ops.identity(select) + + condition_val = np.zeros((1, 7, 7, 64)) + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output, feed_dict={condition: condition_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run( + output, run_metadata=metadata, feed_dict={condition: condition_val}) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 3 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testSelectOpScalarCondition(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) @@ -778,7 +960,7 @@ class LayoutOptimizerTest(test.TestCase): select = gen_math_ops._select(condition, conv, add) output = array_ops.identity(select) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -808,7 +990,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(pad) paddings_val = [[1, 2], [3, 4], [5, 6], [7, 8]] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={paddings: paddings_val}) with session.Session(config=_get_config()) as sess: @@ -845,7 +1027,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(max_pool) strides_val = [1, 3, 2, 1] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={strides: strides_val}) with session.Session(config=_get_config()) as sess: @@ -882,7 +1064,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(max_pool_grad) strides_val = [1, 3, 2, 1] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={strides: strides_val}) with session.Session(config=_get_config()) as sess: @@ -917,7 +1099,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(s) size_val = [1, 2, 3, 4] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={size: size_val}) with session.Session(config=_get_config()) as sess: @@ -953,7 +1135,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(s) end_val = [1, 2, 3, 4] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={end: end_val}) with session.Session(config=_get_config()) as sess: @@ -991,7 +1173,7 @@ class LayoutOptimizerTest(test.TestCase): s = conv[:, :, 1:-1, :] output = array_ops.identity(s) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -1026,7 +1208,7 @@ class LayoutOptimizerTest(test.TestCase): s = conv[:, :, :, 1:-1] output = array_ops.identity(s) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -1065,7 +1247,7 @@ class LayoutOptimizerTest(test.TestCase): [1, 2, 3, 1], s) output = array_ops.identity(s_grad) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={end: end_val}) with session.Session(config=_get_config()) as sess: @@ -1101,7 +1283,7 @@ class LayoutOptimizerTest(test.TestCase): output = math_ops.add(shapen[0], shapen[1]) x_val = [1.7] * 784 - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={x: x_val}) with session.Session(config=_get_config()) as sess: @@ -1124,11 +1306,42 @@ class LayoutOptimizerTest(test.TestCase): self._assert_vec_nchw_to_nhwc('ShapeN-0-0', nodes) self.assertAllEqual(output_val_ref, output_val) + def testShapeNFollowedByNotConvertibleNodeReshape(self): + if test.is_gpu_available(cuda_only=True): + x = array_ops.placeholder(dtype='float32') + conv = _two_layer_model(x) + conv_reshape = array_ops.reshape(conv, [1, 1, 1, -1]) + shapen = array_ops.shape_n([conv, conv_reshape]) + shape = array_ops.identity(shapen[1]) + ones = array_ops.ones(shape) + output = math_ops.add_n([conv_reshape, ones]) + + x_val = [1.7] * 784 + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output, feed_dict={x: x_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run( + output, run_metadata=metadata, feed_dict={x: x_val}) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllEqual(output_val_ref, output_val) + def testLoop(self): if test.is_gpu_available(cuda_only=True): output = _loop() - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -1155,7 +1368,7 @@ class LayoutOptimizerTest(test.TestCase): if test.is_gpu_available(cuda_only=True): output = _loop_with_branch() - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -1169,7 +1382,7 @@ class LayoutOptimizerTest(test.TestCase): num_transposes += 1 nodes.append(node.name) - expected_num_transposes = 2 + expected_num_transposes = 3 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes) @@ -1179,7 +1392,7 @@ class LayoutOptimizerTest(test.TestCase): if test.is_gpu_available(cuda_only=True): output = _loop_with_vec_and_4d() - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -1203,7 +1416,7 @@ class LayoutOptimizerTest(test.TestCase): if test.is_gpu_available(cuda_only=True): output = _model_with_second_port() - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i index f0dd4483a635ddf39e7f51ad0008390c1feb2e13..1b657983a4690dd0ddb7f569ce514b08cb10400a 100644 --- a/tensorflow/python/grappler/tf_optimizer.i +++ b/tensorflow/python/grappler/tf_optimizer.i @@ -103,6 +103,11 @@ PyObject* TF_OptimizeGraph( std::unique_ptr grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); + if (!grappler_item) { + TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info."); + return nullptr; + } + tensorflow::DeviceBase* cpu_device = nullptr; tensorflow::GraphDef out_graph; tensorflow::grappler::MetaOptimizer optimizer(cpu_device, rewriter_config); diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 1f20b3ae0eb1ddf981f12f9a12c4e8153711c7f9..1956478f39a6b5d4dea720436d9b87f66ca20426 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -3,6 +3,8 @@ licenses(["notice"]) # Apache 2.0 +exports_files(["LICENSE"]) + package(default_visibility = ["//visibility:public"]) load("//tensorflow:tensorflow.bzl", "py_test") @@ -14,10 +16,12 @@ py_library( "_impl/keras/__init__.py", "_impl/keras/activations.py", "_impl/keras/applications/__init__.py", + "_impl/keras/applications/densenet.py", "_impl/keras/applications/imagenet_utils.py", "_impl/keras/applications/inception_resnet_v2.py", "_impl/keras/applications/inception_v3.py", "_impl/keras/applications/mobilenet.py", + "_impl/keras/applications/nasnet.py", "_impl/keras/applications/resnet50.py", "_impl/keras/applications/vgg16.py", "_impl/keras/applications/vgg19.py", @@ -37,6 +41,7 @@ py_library( "_impl/keras/engine/__init__.py", "_impl/keras/engine/topology.py", "_impl/keras/engine/training.py", + "_impl/keras/engine/training_eager.py", "_impl/keras/estimator.py", "_impl/keras/initializers.py", "_impl/keras/layers/__init__.py", @@ -76,9 +81,11 @@ py_library( "_impl/keras/wrappers/scikit_learn.py", "activations/__init__.py", "applications/__init__.py", + "applications/densenet/__init__.py", "applications/inception_resnet_v2/__init__.py", "applications/inception_v3/__init__.py", "applications/mobilenet/__init__.py", + "applications/nasnet/__init__.py", "applications/resnet50/__init__.py", "applications/vgg16/__init__.py", "applications/vgg19/__init__.py", @@ -256,6 +263,18 @@ py_test( ], ) +py_test( + name = "densenet_test", + size = "large", + srcs = ["_impl/keras/applications/densenet_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "inception_resnet_v2_test", size = "medium", @@ -292,6 +311,18 @@ py_test( ], ) +py_test( + name = "nasnet_test", + size = "large", + srcs = ["_impl/keras/applications/nasnet_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "resnet50_test", size = "small", @@ -364,7 +395,7 @@ py_test( py_test( name = "convolutional_test", - size = "medium", + size = "large", srcs = ["_impl/keras/layers/convolutional_test.py"], srcs_version = "PY2AND3", tags = [ @@ -453,6 +484,7 @@ py_test( size = "small", srcs = ["_impl/keras/layers/normalization_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":keras", "//tensorflow/python:client_testlib", @@ -504,7 +536,7 @@ py_test( py_test( name = "recurrent_test", - size = "small", + size = "medium", srcs = ["_impl/keras/layers/recurrent_test.py"], srcs_version = "PY2AND3", deps = [ @@ -527,7 +559,7 @@ py_test( py_test( name = "wrappers_test", - size = "small", + size = "medium", srcs = ["_impl/keras/layers/wrappers_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], @@ -691,6 +723,32 @@ py_test( ], ) +py_test( + name = "training_eager_test", + size = "medium", + srcs = ["_impl/keras/engine/training_eager_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +py_test( + name = "model_subclassing_test", + size = "medium", + srcs = ["_impl/keras/model_subclassing_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "topology_test", size = "small", diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py index a70250d796b4dd8d08ac65ebdac84b307b917b13..b63907b2e60acfc80ee411b9193b2829f0224c3e 100644 --- a/tensorflow/python/keras/_impl/keras/__init__.py +++ b/tensorflow/python/keras/_impl/keras/__init__.py @@ -40,4 +40,4 @@ from tensorflow.python.keras._impl.keras.layers import Input from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.models import Sequential -__version__ = '2.1.2-tf' +__version__ = '2.1.4-tf' diff --git a/tensorflow/python/keras/_impl/keras/activations.py b/tensorflow/python/keras/_impl/keras/activations.py index f017d2ae85548211070ececf48e977dd7d2f6a25..236e17653e1b762e1e6962f453b714d1bf7bcbf7 100644 --- a/tensorflow/python/keras/_impl/keras/activations.py +++ b/tensorflow/python/keras/_impl/keras/activations.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras built-in activation functions. +"""Built-in activation functions. """ from __future__ import absolute_import from __future__ import division @@ -24,8 +24,10 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.layers.base import Layer from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.activations.softmax') def softmax(x, axis=-1): """Softmax activation function. @@ -50,10 +52,12 @@ def softmax(x, axis=-1): raise ValueError('Cannot apply softmax to a tensor that is 1D') +@tf_export('keras.activations.elu') def elu(x, alpha=1.0): return K.elu(x, alpha) +@tf_export('keras.activations.selu') def selu(x): """Scaled Exponential Linear Unit. (Klambauer et al., 2017). @@ -61,48 +65,59 @@ def selu(x): x: A tensor or variable to compute the activation function for. Returns: - Tensor with the same shape and dtype as `x`. + Tensor with the same shape and dtype as `x`. + + # Note + - To be used together with the initialization "lecun_normal". + - To be used together with the dropout variant "AlphaDropout". - References: - - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) """ alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 return scale * K.elu(x, alpha) +@tf_export('keras.activations.softplus') def softplus(x): return K.softplus(x) +@tf_export('keras.activations.softsign') def softsign(x): return K.softsign(x) +@tf_export('keras.activations.relu') def relu(x, alpha=0., max_value=None): return K.relu(x, alpha=alpha, max_value=max_value) +@tf_export('keras.activations.tanh') def tanh(x): return K.tanh(x) +@tf_export('keras.activations.sigmoid') def sigmoid(x): return K.sigmoid(x) +@tf_export('keras.activations.hard_sigmoid') def hard_sigmoid(x): return K.hard_sigmoid(x) +@tf_export('keras.activations.linear') def linear(x): return x +@tf_export('keras.activations.serialize') def serialize(activation): return activation.__name__ +@tf_export('keras.activations.deserialize') def deserialize(name, custom_objects=None): return deserialize_keras_object( name, @@ -111,6 +126,7 @@ def deserialize(name, custom_objects=None): printable_module_name='activation function') +@tf_export('keras.activations.get') def get(identifier): if identifier is None: return linear diff --git a/tensorflow/python/keras/_impl/keras/applications/__init__.py b/tensorflow/python/keras/_impl/keras/applications/__init__.py index c11c52b71e9bff1cfd595a9dbc0e86dcaa8506c8..206a769b377483c65a78b76fe44055eb50bdc7c4 100644 --- a/tensorflow/python/keras/_impl/keras/applications/__init__.py +++ b/tensorflow/python/keras/_impl/keras/applications/__init__.py @@ -18,9 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121 +from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169 +from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201 from tensorflow.python.keras._impl.keras.applications.inception_resnet_v2 import InceptionResNetV2 from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet +from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetLarge +from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetMobile from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet.py b/tensorflow/python/keras/_impl/keras/applications/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..6521f8410435fd13393b9991d3ee9a6342a912d0 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/applications/densenet.py @@ -0,0 +1,354 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=invalid-name +# pylint: disable=unused-import +"""DenseNet models for Keras. + +# Reference paper + +- [Densely Connected Convolutional Networks] + (https://arxiv.org/abs/1608.06993) (CVPR 2017 Best Paper Award) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras.applications import imagenet_utils +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.layers import Activation +from tensorflow.python.keras._impl.keras.layers import AveragePooling2D +from tensorflow.python.keras._impl.keras.layers import BatchNormalization +from tensorflow.python.keras._impl.keras.layers import Concatenate +from tensorflow.python.keras._impl.keras.layers import Conv2D +from tensorflow.python.keras._impl.keras.layers import Dense +from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras._impl.keras.layers import Input +from tensorflow.python.keras._impl.keras.layers import MaxPooling2D +from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D +from tensorflow.python.keras._impl.keras.models import Model +from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.util.tf_export import tf_export + + +DENSENET121_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet121_weights_tf_dim_ordering_tf_kernels.h5' +DENSENET121_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5' +DENSENET169_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet169_weights_tf_dim_ordering_tf_kernels.h5' +DENSENET169_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5' +DENSENET201_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet201_weights_tf_dim_ordering_tf_kernels.h5' +DENSENET201_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5' + + +def dense_block(x, blocks, name): + """A dense block. + + Arguments: + x: input tensor. + blocks: integer, the number of building blocks. + name: string, block label. + + Returns: + output tensor for the block. + """ + for i in range(blocks): + x = conv_block(x, 32, name=name + '_block' + str(i + 1)) + return x + + +def transition_block(x, reduction, name): + """A transition block. + + Arguments: + x: input tensor. + reduction: float, compression rate at transition layers. + name: string, block label. + + Returns: + output tensor for the block. + """ + bn_axis = 3 if K.image_data_format() == 'channels_last' else 1 + x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + '_bn')(x) + x = Activation('relu', name=name + '_relu')(x) + x = Conv2D( + int(K.int_shape(x)[bn_axis] * reduction), + 1, + use_bias=False, + name=name + '_conv')( + x) + x = AveragePooling2D(2, strides=2, name=name + '_pool')(x) + return x + + +def conv_block(x, growth_rate, name): + """A building block for a dense block. + + Arguments: + x: input tensor. + growth_rate: float, growth rate at dense layers. + name: string, block label. + + Returns: + output tensor for the block. + """ + bn_axis = 3 if K.image_data_format() == 'channels_last' else 1 + x1 = BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')( + x) + x1 = Activation('relu', name=name + '_0_relu')(x1) + x1 = Conv2D(4 * growth_rate, 1, use_bias=False, name=name + '_1_conv')(x1) + x1 = BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')( + x1) + x1 = Activation('relu', name=name + '_1_relu')(x1) + x1 = Conv2D( + growth_rate, 3, padding='same', use_bias=False, name=name + '_2_conv')( + x1) + x = Concatenate(axis=bn_axis, name=name + '_concat')([x, x1]) + return x + + +def DenseNet(blocks, + include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000): + """Instantiates the DenseNet architecture. + + Optionally loads weights pre-trained + on ImageNet. Note that when using TensorFlow, + for best performance you should set + `image_data_format='channels_last'` in your Keras config + at ~/.keras/keras.json. + + The model and the weights are compatible with + TensorFlow, Theano, and CNTK. The data format + convention used by the model is the one + specified in your Keras config file. + + Arguments: + blocks: numbers of building blocks for the four dense layers. + include_top: whether to include the fully-connected + layer at the top of the network. + weights: one of `None` (random initialization), + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified + if `include_top` is False (otherwise the input shape + has to be `(224, 224, 3)` (with `channels_last` data format) + or `(3, 224, 224)` (with `channels_first` data format). + It should have exactly 3 inputs channels. + pooling: optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is True, and + if no `weights` argument is specified. + + Returns: + A Keras model instance. + + Raises: + ValueError: in case of invalid argument for `weights`, + or invalid input shape. + """ + if not (weights in {'imagenet', None} or os.path.exists(weights)): + raise ValueError('The `weights` argument should be either ' + '`None` (random initialization), `imagenet` ' + '(pre-training on ImageNet), ' + 'or the path to the weights file to be loaded.') + + if weights == 'imagenet' and include_top and classes != 1000: + raise ValueError('If using `weights` as imagenet with `include_top`' + ' as true, `classes` should be 1000') + + # Determine proper input shape + input_shape = _obtain_input_shape( + input_shape, + default_size=224, + min_size=221, + data_format=K.image_data_format(), + require_flatten=include_top, + weights=weights) + + if input_tensor is None: + img_input = Input(shape=input_shape) + else: + if not K.is_keras_tensor(input_tensor): + img_input = Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + bn_axis = 3 if K.image_data_format() == 'channels_last' else 1 + + x = ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input) + x = Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x) + x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x) + x = Activation('relu', name='conv1/relu')(x) + x = ZeroPadding2D(padding=((1, 1), (1, 1)))(x) + x = MaxPooling2D(3, strides=2, name='pool1')(x) + + x = dense_block(x, blocks[0], name='conv2') + x = transition_block(x, 0.5, name='pool2') + x = dense_block(x, blocks[1], name='conv3') + x = transition_block(x, 0.5, name='pool3') + x = dense_block(x, blocks[2], name='conv4') + x = transition_block(x, 0.5, name='pool4') + x = dense_block(x, blocks[3], name='conv5') + + x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x) + + if include_top: + x = GlobalAveragePooling2D(name='avg_pool')(x) + x = Dense(classes, activation='softmax', name='fc1000')(x) + else: + if pooling == 'avg': + x = GlobalAveragePooling2D(name='avg_pool')(x) + elif pooling == 'max': + x = GlobalMaxPooling2D(name='max_pool')(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + if blocks == [6, 12, 24, 16]: + model = Model(inputs, x, name='densenet121') + elif blocks == [6, 12, 32, 32]: + model = Model(inputs, x, name='densenet169') + elif blocks == [6, 12, 48, 32]: + model = Model(inputs, x, name='densenet201') + else: + model = Model(inputs, x, name='densenet') + + # Load weights. + if weights == 'imagenet': + if include_top: + if blocks == [6, 12, 24, 16]: + weights_path = get_file( + 'densenet121_weights_tf_dim_ordering_tf_kernels.h5', + DENSENET121_WEIGHT_PATH, + cache_subdir='models', + file_hash='0962ca643bae20f9b6771cb844dca3b0') + elif blocks == [6, 12, 32, 32]: + weights_path = get_file( + 'densenet169_weights_tf_dim_ordering_tf_kernels.h5', + DENSENET169_WEIGHT_PATH, + cache_subdir='models', + file_hash='bcf9965cf5064a5f9eb6d7dc69386f43') + elif blocks == [6, 12, 48, 32]: + weights_path = get_file( + 'densenet201_weights_tf_dim_ordering_tf_kernels.h5', + DENSENET201_WEIGHT_PATH, + cache_subdir='models', + file_hash='7bb75edd58cb43163be7e0005fbe95ef') + else: + if blocks == [6, 12, 24, 16]: + weights_path = get_file( + 'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5', + DENSENET121_WEIGHT_PATH_NO_TOP, + cache_subdir='models', + file_hash='4912a53fbd2a69346e7f2c0b5ec8c6d3') + elif blocks == [6, 12, 32, 32]: + weights_path = get_file( + 'densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5', + DENSENET169_WEIGHT_PATH_NO_TOP, + cache_subdir='models', + file_hash='50662582284e4cf834ce40ab4dfa58c6') + elif blocks == [6, 12, 48, 32]: + weights_path = get_file( + 'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5', + DENSENET201_WEIGHT_PATH_NO_TOP, + cache_subdir='models', + file_hash='1c2de60ee40562448dbac34a0737e798') + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +@tf_export('keras.applications.DenseNet121', + 'keras.applications.densenet.DenseNet121') +def DenseNet121(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000): + return DenseNet([6, 12, 24, 16], include_top, weights, input_tensor, + input_shape, pooling, classes) + + +@tf_export('keras.applications.DenseNet169', + 'keras.applications.densenet.DenseNet169') +def DenseNet169(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000): + return DenseNet([6, 12, 32, 32], include_top, weights, input_tensor, + input_shape, pooling, classes) + + +@tf_export('keras.applications.DenseNet201', + 'keras.applications.densenet.DenseNet201') +def DenseNet201(include_top=True, + weights='imagenet', + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000): + return DenseNet([6, 12, 48, 32], include_top, weights, input_tensor, + input_shape, pooling, classes) + + +@tf_export('keras.applications.densenet.preprocess_input') +def preprocess_input(x, data_format=None): + """Preprocesses a numpy array encoding a batch of images. + + Arguments: + x: a 3D or 4D numpy array consists of RGB values within [0, 255]. + data_format: data format of the image tensor. + + Returns: + Preprocessed array. + """ + return imagenet_utils.preprocess_input(x, data_format, mode='torch') + + +setattr(DenseNet121, '__doc__', DenseNet.__doc__) +setattr(DenseNet169, '__doc__', DenseNet.__doc__) +setattr(DenseNet201, '__doc__', DenseNet.__doc__) diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet_test.py b/tensorflow/python/keras/_impl/keras/applications/densenet_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3b92287a1e77a944c069a6c234e11e4a79ad7d32 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/applications/densenet_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DenseNet application.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test + + +class DenseNet121Test(test.TestCase): + + def test_with_top(self): + model = keras.applications.DenseNet121(weights=None) + self.assertEqual(model.output_shape, (None, 1000)) + + def test_no_top(self): + model = keras.applications.DenseNet121(weights=None, include_top=False) + self.assertEqual(model.output_shape, (None, None, None, 1024)) + + def test_with_pooling(self): + model = keras.applications.DenseNet121(weights=None, + include_top=False, + pooling='avg') + self.assertEqual(model.output_shape, (None, 1024)) + + def test_weight_loading(self): + with self.assertRaises(ValueError): + keras.applications.DenseNet121(weights='unknown', + include_top=False) + with self.assertRaises(ValueError): + keras.applications.DenseNet121(weights='imagenet', + classes=2000) + + +class DenseNet169Test(test.TestCase): + + def test_with_top(self): + model = keras.applications.DenseNet169(weights=None) + self.assertEqual(model.output_shape, (None, 1000)) + + def test_no_top(self): + model = keras.applications.DenseNet169(weights=None, include_top=False) + self.assertEqual(model.output_shape, (None, None, None, 1664)) + + def test_with_pooling(self): + model = keras.applications.DenseNet169(weights=None, + include_top=False, + pooling='max') + self.assertEqual(model.output_shape, (None, 1664)) + + def test_weight_loading(self): + with self.assertRaises(ValueError): + keras.applications.DenseNet169(weights='unknown', + include_top=False) + with self.assertRaises(ValueError): + keras.applications.DenseNet169(weights='imagenet', + classes=2000) + + +class DenseNet201(test.TestCase): + + def test_with_top(self): + model = keras.applications.DenseNet201(weights=None) + self.assertEqual(model.output_shape, (None, 1000)) + + def test_no_top(self): + model = keras.applications.DenseNet201(weights=None, include_top=False) + self.assertEqual(model.output_shape, (None, None, None, 1920)) + + def test_with_pooling(self): + model = keras.applications.DenseNet201(weights=None, + include_top=False, + pooling='avg') + self.assertEqual(model.output_shape, (None, 1920)) + + def test_weight_loading(self): + with self.assertRaises(ValueError): + keras.applications.DenseNet201(weights='unknown', + include_top=False) + with self.assertRaises(ValueError): + keras.applications.DenseNet201(weights='imagenet', + classes=2000) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py index 63ee83cb51e8366f391f192a9408566076cad468..c26a28ed4087e30968585ec8ac0b64b51513bcae 100644 --- a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py +++ b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities used by models pre-trained on ImageNet. +"""Utilities for ImageNet data preprocessing & prediction decoding. """ from __future__ import absolute_import from __future__ import division @@ -25,6 +25,7 @@ import numpy as np from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export CLASS_INDEX = None @@ -35,63 +36,92 @@ _IMAGENET_MEAN = None def _preprocess_numpy_input(x, data_format, mode): - """Preprocesses a image tensor as a Numpy array. + """Preprocesses a Numpy array encoding a batch of images. Arguments: - x: input Numpy, 3D or 4D. - data_format: data format of the image tensor. - mode: One of "caffe", "tf". + x: Input array, 3D or 4D. + data_format: Data format of the image array. + mode: One of "caffe", "tf" or "torch". - caffe: will convert the images from RGB to BGR, then will zero-center each color channel with respect to the ImageNet dataset, without scaling. - tf: will scale pixels between -1 and 1, sample-wise. + - torch: will scale pixels between 0 and 1 and then + will normalize each channel with respect to the + ImageNet dataset. Returns: - Preprocessed array. + Preprocessed Numpy array. """ if mode == 'tf': x /= 127.5 x -= 1. return x + if mode == 'torch': + x /= 255. + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + else: + if data_format == 'channels_first': + # 'RGB'->'BGR' + if x.ndim == 3: + x = x[::-1, ...] + else: + x = x[:, ::-1, ...] + else: + # 'RGB'->'BGR' + x = x[..., ::-1] + mean = [103.939, 116.779, 123.68] + std = None + + # Zero-center by mean pixel if data_format == 'channels_first': if x.ndim == 3: - # 'RGB'->'BGR' - x = x[::-1, ...] - # Zero-center by mean pixel - x[0, :, :] -= 103.939 - x[1, :, :] -= 116.779 - x[2, :, :] -= 123.68 + x[0, :, :] -= mean[0] + x[1, :, :] -= mean[1] + x[2, :, :] -= mean[2] + if std is not None: + x[0, :, :] /= std[0] + x[1, :, :] /= std[1] + x[2, :, :] /= std[2] else: - x = x[:, ::-1, ...] - x[:, 0, :, :] -= 103.939 - x[:, 1, :, :] -= 116.779 - x[:, 2, :, :] -= 123.68 + x[:, 0, :, :] -= mean[0] + x[:, 1, :, :] -= mean[1] + x[:, 2, :, :] -= mean[2] + if std is not None: + x[:, 0, :, :] /= std[0] + x[:, 1, :, :] /= std[1] + x[:, 2, :, :] /= std[2] else: - # 'RGB'->'BGR' - x = x[..., ::-1] - # Zero-center by mean pixel - x[..., 0] -= 103.939 - x[..., 1] -= 116.779 - x[..., 2] -= 123.68 + x[..., 0] -= mean[0] + x[..., 1] -= mean[1] + x[..., 2] -= mean[2] + if std is not None: + x[..., 0] /= std[0] + x[..., 1] /= std[1] + x[..., 2] /= std[2] return x def _preprocess_symbolic_input(x, data_format, mode): - """Preprocesses a symbolic image tensor. + """Preprocesses a tensor encoding a batch of images. Arguments: - x: symoblic tensor, 3D or 4D. - data_format: data format of the image tensor. - mode: One of "caffe", "tf". + x: Input tensor, 3D or 4D. + data_format: Data format of the image tensor. + mode: One of "caffe", "tf" or "torch". - caffe: will convert the images from RGB to BGR, then will zero-center each color channel with respect to the ImageNet dataset, without scaling. - tf: will scale pixels between -1 and 1, sample-wise. + - torch: will scale pixels between 0 and 1 and then + will normalize each channel with respect to the + ImageNet dataset. Returns: Preprocessed tensor. @@ -103,32 +133,45 @@ def _preprocess_symbolic_input(x, data_format, mode): x -= 1. return x - if data_format == 'channels_first': - # 'RGB'->'BGR' - if K.ndim(x) == 3: - x = x[::-1, ...] - else: - x = x[:, ::-1, ...] + if mode == 'torch': + x /= 255. + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] else: - # 'RGB'->'BGR' - x = x[..., ::-1] + if data_format == 'channels_first': + # 'RGB'->'BGR' + if K.ndim(x) == 3: + x = x[::-1, ...] + else: + x = x[:, ::-1, ...] + else: + # 'RGB'->'BGR' + x = x[..., ::-1] + mean = [103.939, 116.779, 123.68] + std = None if _IMAGENET_MEAN is None: - _IMAGENET_MEAN = K.constant(-np.array([103.939, 116.779, 123.68])) + _IMAGENET_MEAN = K.constant(-np.array(mean)) + # Zero-center by mean pixel if K.dtype(x) != K.dtype(_IMAGENET_MEAN): x = K.bias_add(x, K.cast(_IMAGENET_MEAN, K.dtype(x)), data_format) else: x = K.bias_add(x, _IMAGENET_MEAN, data_format) + if std is not None: + x /= std return x +@tf_export('keras.applications.resnet50.preprocess_input', + 'keras.applications.vgg19.preprocess_input', + 'keras.applications.vgg16.preprocess_input') def preprocess_input(x, data_format=None, mode='caffe'): - """Preprocesses a tensor encoding a batch of images. + """Preprocesses a tensor or Numpy array encoding a batch of images. Arguments: - x: input Numpy or symoblic tensor, 3D or 4D. - data_format: data format of the image tensor. + x: Input Numpy or symbolic tensor, 3D or 4D. + data_format: Data format of the image tensor/array. mode: One of "caffe", "tf". - caffe: will convert the images from RGB to BGR, then will zero-center each color channel with @@ -138,10 +181,10 @@ def preprocess_input(x, data_format=None, mode='caffe'): sample-wise. Returns: - Preprocessed tensor. + Preprocessed tensor or Numpy array. Raises: - ValueError: in case of incorrect data_format. + ValueError: In case of unknown `data_format` argument. """ if data_format is None: data_format = K.image_data_format() @@ -154,12 +197,21 @@ def preprocess_input(x, data_format=None, mode='caffe'): return _preprocess_symbolic_input(x, data_format=data_format, mode=mode) +@tf_export('keras.applications.nasnet.decode_predictions', + 'keras.applications.resnet50.decode_predictions', + 'keras.applications.vgg19.decode_predictions', + 'keras.applications.vgg16.decode_predictions', + 'keras.applications.inception_resnet_v2.decode_predictions', + 'keras.applications.inception_v3.decode_predictions', + 'keras.applications.densenet.decode_predictions', + 'keras.applications.mobilenet.decode_predictions', + 'keras.applications.xception.decode_predictions') def decode_predictions(preds, top=5): """Decodes the prediction of an ImageNet model. Arguments: preds: Numpy tensor encoding a batch of predictions. - top: integer, how many top-guesses to return. + top: Integer, how many top-guesses to return. Returns: A list of lists of top class prediction tuples @@ -167,7 +219,7 @@ def decode_predictions(preds, top=5): One list of tuples per sample in batch input. Raises: - ValueError: in case of invalid shape of the `pred` array + ValueError: In case of invalid shape of the `pred` array (must be 2D). """ global CLASS_INDEX @@ -177,11 +229,13 @@ def decode_predictions(preds, top=5): '(i.e. a 2D array of shape (samples, 1000)). ' 'Found array with shape: ' + str(preds.shape)) if CLASS_INDEX is None: - fpath = get_file('imagenet_class_index.json', - CLASS_INDEX_PATH, - cache_subdir='models', - file_hash='c2c37ea517e94d9795004a39431a14cb') - CLASS_INDEX = json.load(open(fpath)) + fpath = get_file( + 'imagenet_class_index.json', + CLASS_INDEX_PATH, + cache_subdir='models', + file_hash='c2c37ea517e94d9795004a39431a14cb') + with open(fpath) as f: + CLASS_INDEX = json.load(f) results = [] for pred in preds: top_indices = pred.argsort()[-top:][::-1] @@ -197,17 +251,17 @@ def _obtain_input_shape(input_shape, data_format, require_flatten, weights=None): - """Internal utility to compute/validate an ImageNet model's input shape. + """Internal utility to compute/validate a model's input shape. Arguments: - input_shape: either None (will return the default network input shape), + input_shape: Either None (will return the default network input shape), or a user-provided shape to be validated. - default_size: default input width/height for the model. - min_size: minimum input width/height accepted by the model. - data_format: image data format to use. - require_flatten: whether the model is expected to + default_size: Default input width/height for the model. + min_size: Minimum input width/height accepted by the model. + data_format: Image data format to use. + require_flatten: Whether the model is expected to be linked to a classifier via a Flatten layer. - weights: one of `None` (random initialization) + weights: One of `None` (random initialization) or 'imagenet' (pre-training on ImageNet). If weights='imagenet' input channels must be equal to 3. @@ -215,7 +269,7 @@ def _obtain_input_shape(input_shape, An integer shape tuple (may include None entries). Raises: - ValueError: in case of invalid argument values. + ValueError: In case of invalid argument values. """ if weights != 'imagenet' and input_shape and len(input_shape) == 3: if data_format == 'channels_first': @@ -252,8 +306,8 @@ def _obtain_input_shape(input_shape, '`input_shape=' + str(input_shape) + '`') if ((input_shape[1] is not None and input_shape[1] < min_size) or (input_shape[2] is not None and input_shape[2] < min_size)): - raise ValueError('Input size must be at least ' + str(min_size) + 'x' - + str(min_size) + '; got ' + raise ValueError('Input size must be at least ' + str(min_size) + + 'x' + str(min_size) + '; got ' '`input_shape=' + str(input_shape) + '`') else: if input_shape is not None: @@ -264,8 +318,8 @@ def _obtain_input_shape(input_shape, '`input_shape=' + str(input_shape) + '`') if ((input_shape[0] is not None and input_shape[0] < min_size) or (input_shape[1] is not None and input_shape[1] < min_size)): - raise ValueError('Input size must be at least ' + str(min_size) + 'x' - + str(min_size) + '; got ' + raise ValueError('Input size must be at least ' + str(min_size) + + 'x' + str(min_size) + '; got ' '`input_shape=' + str(input_shape) + '`') else: if require_flatten: diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py index 2e73cefb6ce32c2a770eb9bde5ffb220be2da92c..bf3901fc54419c2b401bf9c4d6311b39a18f1aba 100644 --- a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py +++ b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=invalid-name +# pylint: disable=unused-import """Inception-ResNet V2 model for Keras. # Reference @@ -28,7 +30,7 @@ import os from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D @@ -43,10 +45,14 @@ from tensorflow.python.keras._impl.keras.layers import Lambda from tensorflow.python.keras._impl.keras.layers import MaxPooling2D from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export + BASE_WEIGHT_URL = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.7/' +@tf_export('keras.applications.inception_resnet_v2.preprocess_input') def preprocess_input(x): """Preprocesses a numpy array encoding a batch of images. @@ -116,7 +122,8 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): scale: scaling factor to scale the residuals (i.e., the output of passing `x` through an inception module) before adding them to the shortcut branch. Let `r` be the output from the residual - branch, the output of this block will be `x + scale * r`. + branch, + the output of this block will be `x + scale * r`. block_type: `'block35'`, `'block17'` or `'block8'`, determines the network structure in the residual branch. block_idx: an `int` used for generating layer names. The Inception-ResNet @@ -128,8 +135,7 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): will have `block_type='block35', block_idx=0`, ane the layer names will have a common prefix `'block35_0'`. - activation: activation function to use at the end of the block - (see [activations](../activations.md)). + activation: activation function to use at the end of the block. When `activation=None`, no activation is applied (i.e., "linear" activation: `a(x) = x`). @@ -178,6 +184,7 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): x = Lambda( lambda inputs, scale: inputs[0] + inputs[1] * scale, + output_shape=K.int_shape(x)[1:], arguments={'scale': scale}, name=block_name)([x, up]) if activation is not None: @@ -185,7 +192,9 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): return x -def InceptionResNetV2(include_top=True, # pylint: disable=invalid-name +@tf_export('keras.applications.InceptionResNetV2', + 'keras.applications.inception_resnet_v2.InceptionResNetV2') +def InceptionResNetV2(include_top=True, weights='imagenet', input_tensor=None, input_shape=None, @@ -211,8 +220,8 @@ def InceptionResNetV2(include_top=True, # pylint: disable=invalid-name include_top: whether to include the fully-connected layer at the top of the network. weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. input_shape: optional shape tuple, only to be specified diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py index 4424b9280413bb8e556ab376b0c0acccf4030c73..e268e97bc663773a218f01b958b08f8e43c74ee2 100644 --- a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py +++ b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name +# pylint: disable=unused-import """Inception V3 model for Keras. Note that the input image format for this model is different than for @@ -35,7 +36,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import layers from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D @@ -48,6 +49,8 @@ from tensorflow.python.keras._impl.keras.layers import Input from tensorflow.python.keras._impl.keras.layers import MaxPooling2D from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.5/inception_v3_weights_tf_dim_ordering_tf_kernels.h5' @@ -92,12 +95,15 @@ def conv2d_bn(x, strides=strides, padding=padding, use_bias=False, - name=conv_name)(x) + name=conv_name)( + x) x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x) x = Activation('relu', name=name)(x) return x +@tf_export('keras.applications.InceptionV3', + 'keras.applications.inception_v3.InceptionV3') def InceptionV3(include_top=True, weights='imagenet', input_tensor=None, @@ -109,7 +115,7 @@ def InceptionV3(include_top=True, Optionally loads weights pre-trained on ImageNet. Note that when using TensorFlow, for best performance you should set - `image_data_format="channels_last"` in your Keras config + `image_data_format='channels_last'` in your Keras config at ~/.keras/keras.json. The model and the weights are compatible with both TensorFlow and Theano. The data format @@ -121,15 +127,15 @@ def InceptionV3(include_top=True, include_top: whether to include the fully-connected layer at the top of the network. weights: one of `None` (random initialization), - "imagenet" (pre-training on ImageNet), - or the path to the weights file to be loaded. + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. input_shape: optional shape tuple, only to be specified if `include_top` is False (otherwise the input shape has to be `(299, 299, 3)` (with `channels_last` data format) or `(3, 299, 299)` (with `channels_first` data format). - It should have exactly 3 input channels, + It should have exactly 3 inputs channels, and width and height should be no smaller than 139. E.g. `(150, 150, 3)` would be one valid value. pooling: Optional pooling mode for feature extraction @@ -176,7 +182,10 @@ def InceptionV3(include_top=True, if input_tensor is None: img_input = Input(shape=input_shape) else: - img_input = Input(tensor=input_tensor, shape=input_shape) + if not K.is_keras_tensor(input_tensor): + img_input = Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor if K.image_data_format() == 'channels_first': channel_axis = 1 @@ -389,9 +398,12 @@ def InceptionV3(include_top=True, model.load_weights(weights_path) elif weights is not None: model.load_weights(weights) + return model +@tf_export('keras.applications.nasnet.preprocess_input', + 'keras.applications.inception_v3.preprocess_input') def preprocess_input(x): """Preprocesses a numpy array encoding a batch of images. diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py index 5f97c138fc038688a009dfa83b48c8f367ee8df2..1bbbedb85e47902b9e6d3dd741e9d52ab9209080 100644 --- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=invalid-name +# pylint: disable=unused-import """MobileNet v1 models for Keras. MobileNet is a general architecture and can be used for multiple use cases. @@ -56,7 +58,7 @@ the 100 % MobileNet on various input sizes: ------------------------------------------------------------------------ The weights for all 16 models are obtained and translated -from Tensorflow checkpoints found at +from TensorFlow checkpoints found at https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md # Reference @@ -75,9 +77,10 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import BatchNormalization from tensorflow.python.keras._impl.keras.layers import Conv2D @@ -90,6 +93,8 @@ from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.keras._impl.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export + BASE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/' @@ -98,6 +103,7 @@ def relu6(x): return K.relu(x, max_value=6) +@tf_export('keras.applications.mobilenet.preprocess_input') def preprocess_input(x): """Preprocesses a numpy array encoding a batch of images. @@ -130,7 +136,7 @@ class DepthwiseConv2D(Conv2D): all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying any `dilation_rate` value != 1. - padding: one of `"valid"` or `"same"` (case-insensitive). + padding: one of `'valid'` or `'same'` (case-insensitive). depth_multiplier: The number of depthwise convolution output channels for each input channel. The total number of depthwise convolution output @@ -144,29 +150,21 @@ class DepthwiseConv2D(Conv2D): `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. - If you never set it, then it will be "channels_last". - activation: Activation function to use - (see [activations](../activations.md)). + If you never set it, then it will be 'channels_last'. + activation: Activation function to use. If you don't specify anything, no activation is applied - (ie. "linear" activation: `a(x) = x`). + (ie. 'linear' activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. - depthwise_initializer: Initializer for the depthwise kernel matrix - (see [initializers](../initializers.md)). - bias_initializer: Initializer for the bias vector - (see [initializers](../initializers.md)). + depthwise_initializer: Initializer for the depthwise kernel matrix. + bias_initializer: Initializer for the bias vector. depthwise_regularizer: Regularizer function applied to - the depthwise kernel matrix - (see [regularizer](../regularizers.md)). - bias_regularizer: Regularizer function applied to the bias vector - (see [regularizer](../regularizers.md)). + the depthwise kernel matrix. + bias_regularizer: Regularizer function applied to the bias vector. activity_regularizer: Regularizer function applied to - the output of the layer (its "activation"). - (see [regularizer](../regularizers.md)). + the output of the layer (its 'activation').. depthwise_constraint: Constraint function applied to - the depthwise kernel matrix - (see [constraints](../constraints.md)). - bias_constraint: Constraint function applied to the bias vector - (see [constraints](../constraints.md)). + the depthwise kernel matrix. + bias_constraint: Constraint function applied to the bias vector. Input shape: 4D tensor with shape: @@ -216,6 +214,7 @@ class DepthwiseConv2D(Conv2D): self.depthwise_constraint = constraints.get(depthwise_constraint) self.bias_initializer = initializers.get(bias_initializer) + @shape_type_conversion def build(self, input_shape): if len(input_shape) < 4: raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. ' @@ -269,6 +268,7 @@ class DepthwiseConv2D(Conv2D): return outputs + @shape_type_conversion def compute_output_shape(self, input_shape): if self.data_format == 'channels_first': rows = input_shape[2] @@ -305,7 +305,9 @@ class DepthwiseConv2D(Conv2D): return config -def MobileNet(input_shape=None, # pylint: disable=invalid-name +@tf_export('keras.applications.MobileNet', + 'keras.applications.mobilenet.MobileNet') +def MobileNet(input_shape=None, alpha=1.0, depth_multiplier=1, dropout=1e-3, @@ -334,7 +336,7 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name if `include_top` is False (otherwise the input shape has to be `(224, 224, 3)` (with `channels_last` data format) or (3, 224, 224) (with `channels_first` data format). - It should have exactly 3 input channels, + It should have exactly 3 inputs channels, and width and height should be no smaller than 32. E.g. `(200, 200, 3)` would be one valid value. alpha: controls the width of the network. @@ -350,8 +352,8 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name include_top: whether to include the fully-connected layer at the top of the network. weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. @@ -380,6 +382,12 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name RuntimeError: If attempting to run this model with a backend that does not support separable convolutions. """ + + if K.backend() != 'tensorflow': + raise RuntimeError('Only TensorFlow backend is currently supported, ' + 'as other backends do not support ' + 'depthwise convolution.') + if not (weights in {'imagenet', None} or os.path.exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' @@ -390,7 +398,7 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name raise ValueError('If using `weights` as ImageNet with `include_top` ' 'as true, `classes` should be 1000') - # Determine proper input shape. + # Determine proper input shape and default size. if input_shape is None: default_size = 224 else: @@ -400,10 +408,12 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name else: rows = input_shape[0] cols = input_shape[1] + if rows == cols and rows in [128, 160, 192, 224]: default_size = rows else: default_size = 224 + input_shape = _obtain_input_shape( input_shape, default_size=default_size, @@ -411,6 +421,7 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name data_format=K.image_data_format(), require_flatten=include_top, weights=weights) + if K.image_data_format() == 'channels_last': row_axis, col_axis = (0, 1) else: @@ -536,8 +547,6 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name if old_data_format: K.set_image_data_format(old_data_format) - elif weights is not None: - model.load_weights(weights) return model @@ -552,7 +561,7 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): and width and height should be no smaller than 32. E.g. `(224, 224, 3)` would be one valid value. filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). alpha: controls the width of the network. - If `alpha` < 1.0, proportionally decreases the number of filters in each layer. @@ -595,7 +604,8 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): padding='same', use_bias=False, strides=strides, - name='conv1')(inputs) + name='conv1')( + inputs) x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x) return Activation(relu6, name='conv1_relu')(x) @@ -617,7 +627,7 @@ def _depthwise_conv_block(inputs, (with `channels_last` data format) or (channels, rows, cols) (with `channels_first` data format). pointwise_conv_filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the pointwise convolution). + (i.e. the number of output filters in the pointwise convolution). alpha: controls the width of the network. - If `alpha` < 1.0, proportionally decreases the number of filters in each layer. @@ -662,7 +672,8 @@ def _depthwise_conv_block(inputs, depth_multiplier=depth_multiplier, strides=strides, use_bias=False, - name='conv_dw_%d' % block_id)(inputs) + name='conv_dw_%d' % block_id)( + inputs) x = BatchNormalization(axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x) x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x) @@ -671,6 +682,7 @@ def _depthwise_conv_block(inputs, padding='same', use_bias=False, strides=(1, 1), - name='conv_pw_%d' % block_id)(x) + name='conv_pw_%d' % block_id)( + x) x = BatchNormalization(axis=channel_axis, name='conv_pw_%d_bn' % block_id)(x) return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x) diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet.py b/tensorflow/python/keras/_impl/keras/applications/nasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..08dae57f006c64021cbca26404770cd89b1ce176 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/applications/nasnet.py @@ -0,0 +1,788 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=line-too-long +# pylint: disable=invalid-name +# pylint: disable=unused-import +"""NASNet-A models for Keras. + +NASNet refers to Neural Architecture Search Network, a family of models +that were designed automatically by learning the model architectures +directly on the dataset of interest. + +Here we consider NASNet-A, the highest performance model that was found +for the CIFAR-10 dataset, and then extended to ImageNet 2012 dataset, +obtaining state of the art performance on CIFAR-10 and ImageNet 2012. +Only the NASNet-A models, and their respective weights, which are suited +for ImageNet 2012 are provided. + +The below table describes the performance on ImageNet 2012: +-------------------------------------------------------------------------------- + Architecture | Top-1 Acc | Top-5 Acc | Multiply-Adds | Params (M) +-------------------------------------------------------------------------------- +| NASNet-A (4 @ 1056) | 74.0 % | 91.6 % | 564 M | 5.3 | +| NASNet-A (6 @ 4032) | 82.7 % | 96.2 % | 23.8 B | 88.9 | +-------------------------------------------------------------------------------- + +References: + - [Learning Transferable Architectures for Scalable Image Recognition] + (https://arxiv.org/abs/1707.07012) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input +from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.layers import Activation +from tensorflow.python.keras._impl.keras.layers import add +from tensorflow.python.keras._impl.keras.layers import AveragePooling2D +from tensorflow.python.keras._impl.keras.layers import BatchNormalization +from tensorflow.python.keras._impl.keras.layers import concatenate +from tensorflow.python.keras._impl.keras.layers import Conv2D +from tensorflow.python.keras._impl.keras.layers import Cropping2D +from tensorflow.python.keras._impl.keras.layers import Dense +from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D +from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D +from tensorflow.python.keras._impl.keras.layers import Input +from tensorflow.python.keras._impl.keras.layers import MaxPooling2D +from tensorflow.python.keras._impl.keras.layers import SeparableConv2D +from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D +from tensorflow.python.keras._impl.keras.models import Model +from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export + + +NASNET_MOBILE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-mobile.h5' +NASNET_MOBILE_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-mobile-no-top.h5' +NASNET_LARGE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-large.h5' +NASNET_LARGE_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-large-no-top.h5' + + +def NASNet(input_shape=None, + penultimate_filters=4032, + num_blocks=6, + stem_block_filters=96, + skip_reduction=True, + filter_multiplier=2, + include_top=True, + weights=None, + input_tensor=None, + pooling=None, + classes=1000, + default_size=None): + """Instantiates a NASNet model. + + Note that only TensorFlow is supported for now, + therefore it only works with the data format + `image_data_format='channels_last'` in your Keras config + at `~/.keras/keras.json`. + + Arguments: + input_shape: Optional shape tuple, only to be specified + if `include_top` is False (otherwise the input shape + has to be `(331, 331, 3)` for NASNetLarge or + `(224, 224, 3)` for NASNetMobile + It should have exactly 3 inputs channels, + and width and height should be no smaller than 32. + E.g. `(224, 224, 3)` would be one valid value. + penultimate_filters: Number of filters in the penultimate layer. + NASNet models use the notation `NASNet (N @ P)`, where: + - N is the number of blocks + - P is the number of penultimate filters + num_blocks: Number of repeated blocks of the NASNet model. + NASNet models use the notation `NASNet (N @ P)`, where: + - N is the number of blocks + - P is the number of penultimate filters + stem_block_filters: Number of filters in the initial stem block + skip_reduction: Whether to skip the reduction step at the tail + end of the network. Set to `False` for CIFAR models. + filter_multiplier: Controls the width of the network. + - If `filter_multiplier` < 1.0, proportionally decreases the number + of filters in each layer. + - If `filter_multiplier` > 1.0, proportionally increases the number + of filters in each layer. + - If `filter_multiplier` = 1, default number of filters from the + paper are used at each layer. + include_top: Whether to include the fully-connected + layer at the top of the network. + weights: `None` (random initialization) or + `imagenet` (ImageNet weights) + input_tensor: Optional Keras tensor (i.e. output of + `layers.Input()`) + to use as image input for the model. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model + will be the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a + 2D tensor. + - `max` means that global max pooling will + be applied. + classes: Optional number of classes to classify images + into, only to be specified if `include_top` is True, and + if no `weights` argument is specified. + default_size: Specifies the default image size of the model + + Returns: + A Keras model instance. + + Raises: + ValueError: In case of invalid argument for `weights`, + invalid input shape or invalid `penultimate_filters` value. + RuntimeError: If attempting to run this model with a + backend that does not support separable convolutions. + """ + if K.backend() != 'tensorflow': + raise RuntimeError('Only Tensorflow backend is currently supported, ' + 'as other backends do not support ' + 'separable convolution.') + + if not (weights in {'imagenet', None} or os.path.exists(weights)): + raise ValueError('The `weights` argument should be either ' + '`None` (random initialization), `imagenet` ' + '(pre-training on ImageNet), ' + 'or the path to the weights file to be loaded.') + + if weights == 'imagenet' and include_top and classes != 1000: + raise ValueError('If using `weights` as ImageNet with `include_top` ' + 'as true, `classes` should be 1000') + + if default_size is None: + default_size = 331 + + # Determine proper input shape and default size. + input_shape = _obtain_input_shape( + input_shape, + default_size=default_size, + min_size=32, + data_format=K.image_data_format(), + require_flatten=include_top or weights, + weights=weights) + + if K.image_data_format() != 'channels_last': + logging.warning('The NASNet family of models is only available ' + 'for the input data format "channels_last" ' + '(width, height, channels). ' + 'However your settings specify the default ' + 'data format "channels_first" (channels, width, height).' + ' You should set `image_data_format="channels_last"` ' + 'in your Keras config located at ~/.keras/keras.json. ' + 'The model being returned right now will expect inputs ' + 'to follow the "channels_last" data format.') + K.set_image_data_format('channels_last') + old_data_format = 'channels_first' + else: + old_data_format = None + + if input_tensor is None: + img_input = Input(shape=input_shape) + else: + if not K.is_keras_tensor(input_tensor): + img_input = Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + if penultimate_filters % 24 != 0: + raise ValueError( + 'For NASNet-A models, the value of `penultimate_filters` ' + 'needs to be divisible by 24. Current value: %d' % penultimate_filters) + + channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 + filters = penultimate_filters // 24 + + if not skip_reduction: + x = Conv2D( + stem_block_filters, (3, 3), + strides=(2, 2), + padding='valid', + use_bias=False, + name='stem_conv1', + kernel_initializer='he_normal')( + img_input) + else: + x = Conv2D( + stem_block_filters, (3, 3), + strides=(1, 1), + padding='same', + use_bias=False, + name='stem_conv1', + kernel_initializer='he_normal')( + img_input) + + x = BatchNormalization( + axis=channel_dim, momentum=0.9997, epsilon=1e-3, name='stem_bn1')( + x) + + p = None + if not skip_reduction: # imagenet / mobile mode + x, p = _reduction_a_cell( + x, p, filters // (filter_multiplier**2), block_id='stem_1') + x, p = _reduction_a_cell( + x, p, filters // filter_multiplier, block_id='stem_2') + + for i in range(num_blocks): + x, p = _normal_a_cell(x, p, filters, block_id='%d' % (i)) + + x, p0 = _reduction_a_cell( + x, p, filters * filter_multiplier, block_id='reduce_%d' % (num_blocks)) + + p = p0 if not skip_reduction else p + + for i in range(num_blocks): + x, p = _normal_a_cell( + x, p, filters * filter_multiplier, block_id='%d' % (num_blocks + i + 1)) + + x, p0 = _reduction_a_cell( + x, + p, + filters * filter_multiplier**2, + block_id='reduce_%d' % (2 * num_blocks)) + + p = p0 if not skip_reduction else p + + for i in range(num_blocks): + x, p = _normal_a_cell( + x, + p, + filters * filter_multiplier**2, + block_id='%d' % (2 * num_blocks + i + 1)) + + x = Activation('relu')(x) + + if include_top: + x = GlobalAveragePooling2D()(x) + x = Dense(classes, activation='softmax', name='predictions')(x) + else: + if pooling == 'avg': + x = GlobalAveragePooling2D()(x) + elif pooling == 'max': + x = GlobalMaxPooling2D()(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = get_source_inputs(input_tensor) + else: + inputs = img_input + + model = Model(inputs, x, name='NASNet') + + # load weights + if weights == 'imagenet': + if default_size == 224: # mobile version + if include_top: + weight_path = NASNET_MOBILE_WEIGHT_PATH + model_name = 'nasnet_mobile.h5' + else: + weight_path = NASNET_MOBILE_WEIGHT_PATH_NO_TOP + model_name = 'nasnet_mobile_no_top.h5' + + weights_file = get_file(model_name, weight_path, cache_subdir='models') + model.load_weights(weights_file) + + elif default_size == 331: # large version + if include_top: + weight_path = NASNET_LARGE_WEIGHT_PATH + model_name = 'nasnet_large.h5' + else: + weight_path = NASNET_LARGE_WEIGHT_PATH_NO_TOP + model_name = 'nasnet_large_no_top.h5' + + weights_file = get_file(model_name, weight_path, cache_subdir='models') + model.load_weights(weights_file) + else: + raise ValueError('ImageNet weights can only be loaded with NASNetLarge' + ' or NASNetMobile') + elif weights is not None: + model.load_weights(weights) + + if old_data_format: + K.set_image_data_format(old_data_format) + + return model + + +@tf_export('keras.applications.NASNetLarge', + 'keras.applications.nasnet.NASNetLarge') +def NASNetLarge(input_shape=None, + include_top=True, + weights='imagenet', + input_tensor=None, + pooling=None, + classes=1000): + """Instantiates a NASNet model in ImageNet mode. + + Note that only TensorFlow is supported for now, + therefore it only works with the data format + `image_data_format='channels_last'` in your Keras config + at `~/.keras/keras.json`. + + Arguments: + input_shape: Optional shape tuple, only to be specified + if `include_top` is False (otherwise the input shape + has to be `(331, 331, 3)` for NASNetLarge. + It should have exactly 3 inputs channels, + and width and height should be no smaller than 32. + E.g. `(224, 224, 3)` would be one valid value. + include_top: Whether to include the fully-connected + layer at the top of the network. + weights: `None` (random initialization) or + `imagenet` (ImageNet weights) + input_tensor: Optional Keras tensor (i.e. output of + `layers.Input()`) + to use as image input for the model. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model + will be the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a + 2D tensor. + - `max` means that global max pooling will + be applied. + classes: Optional number of classes to classify images + into, only to be specified if `include_top` is True, and + if no `weights` argument is specified. + + Returns: + A Keras model instance. + + Raises: + ValueError: in case of invalid argument for `weights`, + or invalid input shape. + RuntimeError: If attempting to run this model with a + backend that does not support separable convolutions. + """ + return NASNet( + input_shape, + penultimate_filters=4032, + num_blocks=6, + stem_block_filters=96, + skip_reduction=False, + filter_multiplier=2, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + pooling=pooling, + classes=classes, + default_size=331) + + +@tf_export('keras.applications.NASNetMobile', + 'keras.applications.nasnet.NASNetMobile') +def NASNetMobile(input_shape=None, + include_top=True, + weights='imagenet', + input_tensor=None, + pooling=None, + classes=1000): + """Instantiates a Mobile NASNet model in ImageNet mode. + + Note that only TensorFlow is supported for now, + therefore it only works with the data format + `image_data_format='channels_last'` in your Keras config + at `~/.keras/keras.json`. + + Arguments: + input_shape: Optional shape tuple, only to be specified + if `include_top` is False (otherwise the input shape + has to be `(224, 224, 3)` for NASNetMobile + It should have exactly 3 inputs channels, + and width and height should be no smaller than 32. + E.g. `(224, 224, 3)` would be one valid value. + include_top: Whether to include the fully-connected + layer at the top of the network. + weights: `None` (random initialization) or + `imagenet` (ImageNet weights) + input_tensor: Optional Keras tensor (i.e. output of + `layers.Input()`) + to use as image input for the model. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model + will be the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a + 2D tensor. + - `max` means that global max pooling will + be applied. + classes: Optional number of classes to classify images + into, only to be specified if `include_top` is True, and + if no `weights` argument is specified. + + Returns: + A Keras model instance. + + Raises: + ValueError: In case of invalid argument for `weights`, + or invalid input shape. + RuntimeError: If attempting to run this model with a + backend that does not support separable convolutions. + """ + return NASNet( + input_shape, + penultimate_filters=1056, + num_blocks=4, + stem_block_filters=32, + skip_reduction=False, + filter_multiplier=2, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + pooling=pooling, + classes=classes, + default_size=224) + + +def _separable_conv_block(ip, + filters, + kernel_size=(3, 3), + strides=(1, 1), + block_id=None): + """Adds 2 blocks of [relu-separable conv-batchnorm]. + + Arguments: + ip: Input tensor + filters: Number of output filters per layer + kernel_size: Kernel size of separable convolutions + strides: Strided convolution for downsampling + block_id: String block_id + + Returns: + A Keras tensor + """ + channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 + + with K.name_scope('separable_conv_block_%s' % block_id): + x = Activation('relu')(ip) + x = SeparableConv2D( + filters, + kernel_size, + strides=strides, + name='separable_conv_1_%s' % block_id, + padding='same', + use_bias=False, + kernel_initializer='he_normal')( + x) + x = BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name='separable_conv_1_bn_%s' % (block_id))( + x) + x = Activation('relu')(x) + x = SeparableConv2D( + filters, + kernel_size, + name='separable_conv_2_%s' % block_id, + padding='same', + use_bias=False, + kernel_initializer='he_normal')( + x) + x = BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name='separable_conv_2_bn_%s' % (block_id))( + x) + return x + + +def _adjust_block(p, ip, filters, block_id=None): + """Adjusts the input `previous path` to match the shape of the `input`. + + Used in situations where the output number of filters needs to be changed. + + Arguments: + p: Input tensor which needs to be modified + ip: Input tensor whose shape needs to be matched + filters: Number of output filters to be matched + block_id: String block_id + + Returns: + Adjusted Keras tensor + """ + channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 + img_dim = 2 if K.image_data_format() == 'channels_first' else -2 + + ip_shape = K.int_shape(ip) + + if p is not None: + p_shape = K.int_shape(p) + + with K.name_scope('adjust_block'): + if p is None: + p = ip + + elif p_shape[img_dim] != ip_shape[img_dim]: + with K.name_scope('adjust_reduction_block_%s' % block_id): + p = Activation('relu', name='adjust_relu_1_%s' % block_id)(p) + + p1 = AveragePooling2D( + (1, 1), + strides=(2, 2), + padding='valid', + name='adjust_avg_pool_1_%s' % block_id)( + p) + p1 = Conv2D( + filters // 2, (1, 1), + padding='same', + use_bias=False, + name='adjust_conv_1_%s' % block_id, + kernel_initializer='he_normal')( + p1) + + p2 = ZeroPadding2D(padding=((0, 1), (0, 1)))(p) + p2 = Cropping2D(cropping=((1, 0), (1, 0)))(p2) + p2 = AveragePooling2D( + (1, 1), + strides=(2, 2), + padding='valid', + name='adjust_avg_pool_2_%s' % block_id)( + p2) + p2 = Conv2D( + filters // 2, (1, 1), + padding='same', + use_bias=False, + name='adjust_conv_2_%s' % block_id, + kernel_initializer='he_normal')( + p2) + + p = concatenate([p1, p2], axis=channel_dim) + p = BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name='adjust_bn_%s' % block_id)( + p) + + elif p_shape[channel_dim] != filters: + with K.name_scope('adjust_projection_block_%s' % block_id): + p = Activation('relu')(p) + p = Conv2D( + filters, (1, 1), + strides=(1, 1), + padding='same', + name='adjust_conv_projection_%s' % block_id, + use_bias=False, + kernel_initializer='he_normal')( + p) + p = BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name='adjust_bn_%s' % block_id)( + p) + return p + + +def _normal_a_cell(ip, p, filters, block_id=None): + """Adds a Normal cell for NASNet-A (Fig. 4 in the paper). + + Arguments: + ip: Input tensor `x` + p: Input tensor `p` + filters: Number of output filters + block_id: String block_id + + Returns: + A Keras tensor + """ + channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 + + with K.name_scope('normal_A_block_%s' % block_id): + p = _adjust_block(p, ip, filters, block_id) + + h = Activation('relu')(ip) + h = Conv2D( + filters, (1, 1), + strides=(1, 1), + padding='same', + name='normal_conv_1_%s' % block_id, + use_bias=False, + kernel_initializer='he_normal')( + h) + h = BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name='normal_bn_1_%s' % block_id)( + h) + + with K.name_scope('block_1'): + x1_1 = _separable_conv_block( + h, filters, kernel_size=(5, 5), block_id='normal_left1_%s' % block_id) + x1_2 = _separable_conv_block( + p, filters, block_id='normal_right1_%s' % block_id) + x1 = add([x1_1, x1_2], name='normal_add_1_%s' % block_id) + + with K.name_scope('block_2'): + x2_1 = _separable_conv_block( + p, filters, (5, 5), block_id='normal_left2_%s' % block_id) + x2_2 = _separable_conv_block( + p, filters, (3, 3), block_id='normal_right2_%s' % block_id) + x2 = add([x2_1, x2_2], name='normal_add_2_%s' % block_id) + + with K.name_scope('block_3'): + x3 = AveragePooling2D( + (3, 3), + strides=(1, 1), + padding='same', + name='normal_left3_%s' % (block_id))( + h) + x3 = add([x3, p], name='normal_add_3_%s' % block_id) + + with K.name_scope('block_4'): + x4_1 = AveragePooling2D( + (3, 3), + strides=(1, 1), + padding='same', + name='normal_left4_%s' % (block_id))( + p) + x4_2 = AveragePooling2D( + (3, 3), + strides=(1, 1), + padding='same', + name='normal_right4_%s' % (block_id))( + p) + x4 = add([x4_1, x4_2], name='normal_add_4_%s' % block_id) + + with K.name_scope('block_5'): + x5 = _separable_conv_block( + h, filters, block_id='normal_left5_%s' % block_id) + x5 = add([x5, h], name='normal_add_5_%s' % block_id) + + x = concatenate( + [p, x1, x2, x3, x4, x5], + axis=channel_dim, + name='normal_concat_%s' % block_id) + return x, ip + + +def _reduction_a_cell(ip, p, filters, block_id=None): + """Adds a Reduction cell for NASNet-A (Fig. 4 in the paper). + + Arguments: + ip: Input tensor `x` + p: Input tensor `p` + filters: Number of output filters + block_id: String block_id + + Returns: + A Keras tensor + """ + channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 + + with K.name_scope('reduction_A_block_%s' % block_id): + p = _adjust_block(p, ip, filters, block_id) + + h = Activation('relu')(ip) + h = Conv2D( + filters, (1, 1), + strides=(1, 1), + padding='same', + name='reduction_conv_1_%s' % block_id, + use_bias=False, + kernel_initializer='he_normal')( + h) + h = BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name='reduction_bn_1_%s' % block_id)( + h) + + with K.name_scope('block_1'): + x1_1 = _separable_conv_block( + h, + filters, (5, 5), + strides=(2, 2), + block_id='reduction_left1_%s' % block_id) + x1_2 = _separable_conv_block( + p, + filters, (7, 7), + strides=(2, 2), + block_id='reduction_1_%s' % block_id) + x1 = add([x1_1, x1_2], name='reduction_add_1_%s' % block_id) + + with K.name_scope('block_2'): + x2_1 = MaxPooling2D( + (3, 3), + strides=(2, 2), + padding='same', + name='reduction_left2_%s' % block_id)( + h) + x2_2 = _separable_conv_block( + p, + filters, (7, 7), + strides=(2, 2), + block_id='reduction_right2_%s' % block_id) + x2 = add([x2_1, x2_2], name='reduction_add_2_%s' % block_id) + + with K.name_scope('block_3'): + x3_1 = AveragePooling2D( + (3, 3), + strides=(2, 2), + padding='same', + name='reduction_left3_%s' % block_id)( + h) + x3_2 = _separable_conv_block( + p, + filters, (5, 5), + strides=(2, 2), + block_id='reduction_right3_%s' % block_id) + x3 = add([x3_1, x3_2], name='reduction_add3_%s' % block_id) + + with K.name_scope('block_4'): + x4 = AveragePooling2D( + (3, 3), + strides=(1, 1), + padding='same', + name='reduction_left4_%s' % block_id)( + x1) + x4 = add([x2, x4]) + + with K.name_scope('block_5'): + x5_1 = _separable_conv_block( + x1, filters, (3, 3), block_id='reduction_left4_%s' % block_id) + x5_2 = MaxPooling2D( + (3, 3), + strides=(2, 2), + padding='same', + name='reduction_right5_%s' % block_id)( + h) + x5 = add([x5_1, x5_2], name='reduction_add4_%s' % block_id) + + x = concatenate( + [x2, x3, x4, x5], + axis=channel_dim, + name='reduction_concat_%s' % block_id) + return x, ip diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py b/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1dec670cb995e47bdcf88bd69594c532781b18 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Nasnet application.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test + + +class NASNetMobileTest(test.TestCase): + + def test_with_top(self): + model = keras.applications.NASNetMobile(weights=None) + self.assertEqual(model.output_shape, (None, 1000)) + + def test_no_top(self): + model = keras.applications.NASNetMobile(weights=None, include_top=False) + self.assertEqual(model.output_shape, (None, None, None, 1056)) + + def test_with_pooling(self): + model = keras.applications.NASNetMobile(weights=None, + include_top=False, + pooling='avg') + self.assertEqual(model.output_shape, (None, 1056)) + + def test_weight_loading(self): + with self.assertRaises(ValueError): + keras.applications.NASNetMobile(weights='unknown', + include_top=False) + with self.assertRaises(ValueError): + keras.applications.NASNetMobile(weights='imagenet', + classes=2000) + + +class NASNetLargeTest(test.TestCase): + + def test_with_top(self): + model = keras.applications.NASNetLarge(weights=None) + self.assertEqual(model.output_shape, (None, 1000)) + + def test_no_top(self): + model = keras.applications.NASNetLarge(weights=None, include_top=False) + self.assertEqual(model.output_shape, (None, None, None, 4032)) + + def test_with_pooling(self): + model = keras.applications.NASNetLarge(weights=None, + include_top=False, + pooling='avg') + self.assertEqual(model.output_shape, (None, 4032)) + + def test_weight_loading(self): + with self.assertRaises(ValueError): + keras.applications.NASNetLarge(weights='unknown', + include_top=False) + with self.assertRaises(ValueError): + keras.applications.NASNetLarge(weights='imagenet', + classes=2000) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py index 8ab46693aa6e46de6c6df320c745ca9ed01fbe0b..a47dd657bb9ea0627d82831b7ee5d0b33788b5b7 100644 --- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py +++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name +# pylint: disable=unused-import """ResNet50 model for Keras. # Reference: @@ -31,8 +32,8 @@ import os from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import layers from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input # pylint: disable=unused-import +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D @@ -45,7 +46,10 @@ from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D from tensorflow.python.keras._impl.keras.layers import Input from tensorflow.python.keras._impl.keras.layers import MaxPooling2D from tensorflow.python.keras._impl.keras.models import Model +from tensorflow.python.keras._impl.keras.utils import layer_utils from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5' @@ -78,7 +82,8 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): x = Activation('relu')(x) x = Conv2D( - filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x) + filters2, kernel_size, padding='same', name=conv_name_base + '2b')( + x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = Activation('relu')(x) @@ -92,7 +97,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): - """conv_block is the block that has a conv layer at shortcut. + """A block that has a conv layer at shortcut. Arguments: input_tensor: input tensor @@ -100,14 +105,14 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, filters: list of integers, the filters of 3 conv layer at main path stage: integer, current stage label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names - strides: Tuple of integers. + strides: Strides for the first conv layer in the block. Returns: Output tensor for the block. - Note that from stage 3, the first conv layer at main path is with - strides=(2,2) - And the shortcut should have strides=(2,2) as well + Note that from stage 3, + the first conv layer at main path is with strides=(2, 2) + And the shortcut should have strides=(2, 2) as well """ filters1, filters2, filters3 = filters if K.image_data_format() == 'channels_last': @@ -118,13 +123,14 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, bn_name_base = 'bn' + str(stage) + block + '_branch' x = Conv2D( - filters1, (1, 1), strides=strides, - name=conv_name_base + '2a')(input_tensor) + filters1, (1, 1), strides=strides, name=conv_name_base + '2a')( + input_tensor) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = Activation('relu')(x) x = Conv2D( - filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x) + filters2, kernel_size, padding='same', name=conv_name_base + '2b')( + x) x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = Activation('relu')(x) @@ -132,8 +138,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) shortcut = Conv2D( - filters3, (1, 1), strides=strides, - name=conv_name_base + '1')(input_tensor) + filters3, (1, 1), strides=strides, name=conv_name_base + '1')( + input_tensor) shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) x = layers.add([x, shortcut]) @@ -141,6 +147,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, return x +@tf_export('keras.applications.ResNet50', + 'keras.applications.resnet50.ResNet50') def ResNet50(include_top=True, weights='imagenet', input_tensor=None, @@ -152,7 +160,7 @@ def ResNet50(include_top=True, Optionally loads weights pre-trained on ImageNet. Note that when using TensorFlow, for best performance you should set - `image_data_format="channels_last"` in your Keras config + `image_data_format='channels_last'` in your Keras config at ~/.keras/keras.json. The model and the weights are compatible with both @@ -164,15 +172,15 @@ def ResNet50(include_top=True, include_top: whether to include the fully-connected layer at the top of the network. weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. input_shape: optional shape tuple, only to be specified if `include_top` is False (otherwise the input shape has to be `(224, 224, 3)` (with `channels_last` data format) or `(3, 224, 224)` (with `channels_first` data format). - It should have exactly 3 input channels, + It should have exactly 3 inputs channels, and width and height should be no smaller than 197. E.g. `(200, 200, 3)` would be one valid value. pooling: Optional pooling mode for feature extraction @@ -219,15 +227,18 @@ def ResNet50(include_top=True, if input_tensor is None: img_input = Input(shape=input_shape) else: - img_input = Input(tensor=input_tensor, shape=input_shape) - + if not K.is_keras_tensor(input_tensor): + img_input = Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor if K.image_data_format() == 'channels_last': bn_axis = 3 else: bn_axis = 1 - x = Conv2D(64, (7, 7), - strides=(2, 2), padding='same', name='conv1')(img_input) + x = Conv2D( + 64, (7, 7), strides=(2, 2), padding='same', name='conv1')( + img_input) x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) x = Activation('relu')(x) x = MaxPooling2D((3, 3), strides=(2, 2))(x) @@ -289,4 +300,5 @@ def ResNet50(include_top=True, model.load_weights(weights_path) elif weights is not None: model.load_weights(weights) + return model diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg16.py b/tensorflow/python/keras/_impl/keras/applications/vgg16.py index 38dbbdc809e708cc19d5529665352fe4807fad90..9da74253abc2124844ab89b7727ddda4f754d8e2 100644 --- a/tensorflow/python/keras/_impl/keras/applications/vgg16.py +++ b/tensorflow/python/keras/_impl/keras/applications/vgg16.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name +# pylint: disable=unused-import """VGG16 model for Keras. # Reference @@ -29,8 +30,8 @@ import os from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input # pylint: disable=unused-import +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Conv2D from tensorflow.python.keras._impl.keras.layers import Dense @@ -42,12 +43,15 @@ from tensorflow.python.keras._impl.keras.layers import MaxPooling2D from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.utils import layer_utils from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5' WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5' +@tf_export('keras.applications.VGG16', 'keras.applications.vgg16.VGG16') def VGG16(include_top=True, weights='imagenet', input_tensor=None, @@ -59,7 +63,7 @@ def VGG16(include_top=True, Optionally loads weights pre-trained on ImageNet. Note that when using TensorFlow, for best performance you should set - `image_data_format="channels_last"` in your Keras config + `image_data_format='channels_last'` in your Keras config at ~/.keras/keras.json. The model and the weights are compatible with both @@ -71,8 +75,8 @@ def VGG16(include_top=True, include_top: whether to include the 3 fully-connected layers at the top of the network. weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. input_shape: optional shape tuple, only to be specified @@ -125,48 +129,62 @@ def VGG16(include_top=True, if input_tensor is None: img_input = Input(shape=input_shape) else: - img_input = Input(tensor=input_tensor, shape=input_shape) - + if not K.is_keras_tensor(input_tensor): + img_input = Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor # Block 1 x = Conv2D( - 64, (3, 3), activation='relu', padding='same', - name='block1_conv1')(img_input) + 64, (3, 3), activation='relu', padding='same', name='block1_conv1')( + img_input) x = Conv2D( - 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) + 64, (3, 3), activation='relu', padding='same', name='block1_conv2')( + x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) # Block 2 x = Conv2D( - 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) + 128, (3, 3), activation='relu', padding='same', name='block2_conv1')( + x) x = Conv2D( - 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) + 128, (3, 3), activation='relu', padding='same', name='block2_conv2')( + x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) # Block 3 x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) + 256, (3, 3), activation='relu', padding='same', name='block3_conv1')( + x) x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) + 256, (3, 3), activation='relu', padding='same', name='block3_conv2')( + x) x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) + 256, (3, 3), activation='relu', padding='same', name='block3_conv3')( + x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) # Block 4 x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) + 512, (3, 3), activation='relu', padding='same', name='block4_conv1')( + x) x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) + 512, (3, 3), activation='relu', padding='same', name='block4_conv2')( + x) x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) + 512, (3, 3), activation='relu', padding='same', name='block4_conv3')( + x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) # Block 5 x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) + 512, (3, 3), activation='relu', padding='same', name='block5_conv1')( + x) x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) + 512, (3, 3), activation='relu', padding='same', name='block5_conv2')( + x) x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) + 512, (3, 3), activation='relu', padding='same', name='block5_conv3')( + x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) if include_top: @@ -215,6 +233,8 @@ def VGG16(include_top=True, dense = model.get_layer(name='fc1') layer_utils.convert_dense_weights_data_format(dense, shape, 'channels_first') + elif weights is not None: model.load_weights(weights) + return model diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg19.py b/tensorflow/python/keras/_impl/keras/applications/vgg19.py index 126c64260b51a7d4e6ca653e850e22c03799dcb0..961c1f991893dbc0df858e9f72b61202c9fee500 100644 --- a/tensorflow/python/keras/_impl/keras/applications/vgg19.py +++ b/tensorflow/python/keras/_impl/keras/applications/vgg19.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name +# pylint: disable=unused-import """VGG19 model for Keras. # Reference @@ -29,8 +30,8 @@ import os from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input # pylint: disable=unused-import +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Conv2D from tensorflow.python.keras._impl.keras.layers import Dense @@ -42,12 +43,15 @@ from tensorflow.python.keras._impl.keras.layers import MaxPooling2D from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.utils import layer_utils from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5' WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5' +@tf_export('keras.applications.VGG19', 'keras.applications.vgg19.VGG19') def VGG19(include_top=True, weights='imagenet', input_tensor=None, @@ -59,7 +63,7 @@ def VGG19(include_top=True, Optionally loads weights pre-trained on ImageNet. Note that when using TensorFlow, for best performance you should set - `image_data_format="channels_last"` in your Keras config + `image_data_format='channels_last'` in your Keras config at ~/.keras/keras.json. The model and the weights are compatible with both @@ -71,15 +75,15 @@ def VGG19(include_top=True, include_top: whether to include the 3 fully-connected layers at the top of the network. weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. input_shape: optional shape tuple, only to be specified if `include_top` is False (otherwise the input shape has to be `(224, 224, 3)` (with `channels_last` data format) or `(3, 224, 224)` (with `channels_first` data format). - It should have exactly 3 input channels, + It should have exactly 3 inputs channels, and width and height should be no smaller than 48. E.g. `(200, 200, 3)` would be one valid value. pooling: Optional pooling mode for feature extraction @@ -125,54 +129,71 @@ def VGG19(include_top=True, if input_tensor is None: img_input = Input(shape=input_shape) else: - img_input = Input(tensor=input_tensor, shape=input_shape) - + if not K.is_keras_tensor(input_tensor): + img_input = Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor # Block 1 x = Conv2D( - 64, (3, 3), activation='relu', padding='same', - name='block1_conv1')(img_input) + 64, (3, 3), activation='relu', padding='same', name='block1_conv1')( + img_input) x = Conv2D( - 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) + 64, (3, 3), activation='relu', padding='same', name='block1_conv2')( + x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) # Block 2 x = Conv2D( - 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) + 128, (3, 3), activation='relu', padding='same', name='block2_conv1')( + x) x = Conv2D( - 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) + 128, (3, 3), activation='relu', padding='same', name='block2_conv2')( + x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) # Block 3 x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) + 256, (3, 3), activation='relu', padding='same', name='block3_conv1')( + x) x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) + 256, (3, 3), activation='relu', padding='same', name='block3_conv2')( + x) x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) + 256, (3, 3), activation='relu', padding='same', name='block3_conv3')( + x) x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv4')(x) + 256, (3, 3), activation='relu', padding='same', name='block3_conv4')( + x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) # Block 4 x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) + 512, (3, 3), activation='relu', padding='same', name='block4_conv1')( + x) x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) + 512, (3, 3), activation='relu', padding='same', name='block4_conv2')( + x) x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) + 512, (3, 3), activation='relu', padding='same', name='block4_conv3')( + x) x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv4')(x) + 512, (3, 3), activation='relu', padding='same', name='block4_conv4')( + x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) # Block 5 x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) + 512, (3, 3), activation='relu', padding='same', name='block5_conv1')( + x) x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) + 512, (3, 3), activation='relu', padding='same', name='block5_conv2')( + x) x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) + 512, (3, 3), activation='relu', padding='same', name='block5_conv3')( + x) x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv4')(x) + 512, (3, 3), activation='relu', padding='same', name='block5_conv4')( + x) x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) if include_top: @@ -211,6 +232,8 @@ def VGG19(include_top=True, cache_subdir='models', file_hash='253f8cb515780f3b799900260a226db6') model.load_weights(weights_path) + if K.backend() == 'theano': + layer_utils.convert_all_kernels_in_model(model) if K.image_data_format() == 'channels_first': if include_top: @@ -219,6 +242,8 @@ def VGG19(include_top=True, dense = model.get_layer(name='fc1') layer_utils.convert_dense_weights_data_format(dense, shape, 'channels_first') + elif weights is not None: model.load_weights(weights) + return model diff --git a/tensorflow/python/keras/_impl/keras/applications/xception.py b/tensorflow/python/keras/_impl/keras/applications/xception.py index 821983140852b9f1ab505376d824db2392f54391..7e7ca5a18a31622ac79d61ab01ce65341a4a46c5 100644 --- a/tensorflow/python/keras/_impl/keras/applications/xception.py +++ b/tensorflow/python/keras/_impl/keras/applications/xception.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name +# pylint: disable=unused-import """Xception V1 model for Keras. On ImageNet, this model gets to a top-1 validation accuracy of 0.790 @@ -42,7 +43,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import layers from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import +from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import BatchNormalization @@ -56,12 +57,15 @@ from tensorflow.python.keras._impl.keras.layers import SeparableConv2D from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.utils.data_utils import get_file from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels.h5' TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels_notop.h5' +@tf_export('keras.applications.Xception', + 'keras.applications.xception.Xception') def Xception(include_top=True, weights='imagenet', input_tensor=None, @@ -74,7 +78,7 @@ def Xception(include_top=True, on ImageNet. This model is available for TensorFlow only, and can only be used with inputs following the TensorFlow data format `(width, height, channels)`. - You should set `image_data_format="channels_last"` in your Keras config + You should set `image_data_format='channels_last'` in your Keras config located at ~/.keras/keras.json. Note that the default input image size for this model is 299x299. @@ -83,14 +87,14 @@ def Xception(include_top=True, include_top: whether to include the fully-connected layer at the top of the network. weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. input_shape: optional shape tuple, only to be specified if `include_top` is False (otherwise the input shape has to be `(299, 299, 3)`. - It should have exactly 3 input channels, + It should have exactly 3 inputs channels, and width and height should be no smaller than 71. E.g. `(150, 150, 3)` would be one valid value. pooling: Optional pooling mode for feature extraction @@ -155,11 +159,14 @@ def Xception(include_top=True, if input_tensor is None: img_input = Input(shape=input_shape) else: - img_input = Input(tensor=input_tensor, shape=input_shape) + if not K.is_keras_tensor(input_tensor): + img_input = Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor x = Conv2D( - 32, (3, 3), strides=(2, 2), use_bias=False, - name='block1_conv1')(img_input) + 32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')( + img_input) x = BatchNormalization(name='block1_conv1_bn')(x) x = Activation('relu', name='block1_conv1_act')(x) x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x) @@ -167,53 +174,65 @@ def Xception(include_top=True, x = Activation('relu', name='block1_conv2_act')(x) residual = Conv2D( - 128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) + 128, (1, 1), strides=(2, 2), padding='same', use_bias=False)( + x) residual = BatchNormalization()(residual) x = SeparableConv2D( - 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x) + 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')( + x) x = BatchNormalization(name='block2_sepconv1_bn')(x) x = Activation('relu', name='block2_sepconv2_act')(x) x = SeparableConv2D( - 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x) + 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')( + x) x = BatchNormalization(name='block2_sepconv2_bn')(x) x = MaxPooling2D( - (3, 3), strides=(2, 2), padding='same', name='block2_pool')(x) + (3, 3), strides=(2, 2), padding='same', name='block2_pool')( + x) x = layers.add([x, residual]) residual = Conv2D( - 256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) + 256, (1, 1), strides=(2, 2), padding='same', use_bias=False)( + x) residual = BatchNormalization()(residual) x = Activation('relu', name='block3_sepconv1_act')(x) x = SeparableConv2D( - 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x) + 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')( + x) x = BatchNormalization(name='block3_sepconv1_bn')(x) x = Activation('relu', name='block3_sepconv2_act')(x) x = SeparableConv2D( - 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x) + 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')( + x) x = BatchNormalization(name='block3_sepconv2_bn')(x) x = MaxPooling2D( - (3, 3), strides=(2, 2), padding='same', name='block3_pool')(x) + (3, 3), strides=(2, 2), padding='same', name='block3_pool')( + x) x = layers.add([x, residual]) residual = Conv2D( - 728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) + 728, (1, 1), strides=(2, 2), padding='same', use_bias=False)( + x) residual = BatchNormalization()(residual) x = Activation('relu', name='block4_sepconv1_act')(x) x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x) + 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')( + x) x = BatchNormalization(name='block4_sepconv1_bn')(x) x = Activation('relu', name='block4_sepconv2_act')(x) x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x) + 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')( + x) x = BatchNormalization(name='block4_sepconv2_bn')(x) x = MaxPooling2D( - (3, 3), strides=(2, 2), padding='same', name='block4_pool')(x) + (3, 3), strides=(2, 2), padding='same', name='block4_pool')( + x) x = layers.add([x, residual]) for i in range(8): @@ -222,46 +241,52 @@ def Xception(include_top=True, x = Activation('relu', name=prefix + '_sepconv1_act')(x) x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, - name=prefix + '_sepconv1')(x) + 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')( + x) x = BatchNormalization(name=prefix + '_sepconv1_bn')(x) x = Activation('relu', name=prefix + '_sepconv2_act')(x) x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, - name=prefix + '_sepconv2')(x) + 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')( + x) x = BatchNormalization(name=prefix + '_sepconv2_bn')(x) x = Activation('relu', name=prefix + '_sepconv3_act')(x) x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, - name=prefix + '_sepconv3')(x) + 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')( + x) x = BatchNormalization(name=prefix + '_sepconv3_bn')(x) x = layers.add([x, residual]) residual = Conv2D( - 1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) + 1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)( + x) residual = BatchNormalization()(residual) x = Activation('relu', name='block13_sepconv1_act')(x) x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x) + 728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')( + x) x = BatchNormalization(name='block13_sepconv1_bn')(x) x = Activation('relu', name='block13_sepconv2_act')(x) x = SeparableConv2D( - 1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x) + 1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')( + x) x = BatchNormalization(name='block13_sepconv2_bn')(x) x = MaxPooling2D( - (3, 3), strides=(2, 2), padding='same', name='block13_pool')(x) + (3, 3), strides=(2, 2), padding='same', name='block13_pool')( + x) x = layers.add([x, residual]) x = SeparableConv2D( - 1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x) + 1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')( + x) x = BatchNormalization(name='block14_sepconv1_bn')(x) x = Activation('relu', name='block14_sepconv1_act')(x) x = SeparableConv2D( - 2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x) + 2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')( + x) x = BatchNormalization(name='block14_sepconv2_bn')(x) x = Activation('relu', name='block14_sepconv2_act')(x) @@ -303,11 +328,10 @@ def Xception(include_top=True, if old_data_format: K.set_image_data_format(old_data_format) - elif weights is not None: - model.load_weights(weights) return model +@tf_export('keras.applications.xception.preprocess_input') def preprocess_input(x): """Preprocesses a numpy array encoding a batch of images. diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index 9476085bd8cbc36f63d3c6c8ecad732b557a4f8a..a238a3f7483685ca0b08c746b55bbc9e868cb2d3 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -29,6 +29,7 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_module +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops @@ -47,6 +48,7 @@ from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import @@ -54,6 +56,7 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables as variables_module from tensorflow.python.training import moving_averages from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export py_all = all @@ -85,7 +88,7 @@ _MANUAL_VAR_INIT = False _FLOATX = 'float32' # Epsilon fuzz factor used throughout the codebase. -_EPSILON = 10e-8 +_EPSILON = 1e-7 # Default image data format, one of "channels_last", "channels_first". _IMAGE_DATA_FORMAT = 'channels_last' @@ -96,6 +99,7 @@ _IMAGE_DATA_FORMAT = 'channels_last' _LOCAL_DEVICES = None +@tf_export('keras.backend.backend') def backend(): """Publicly accessible method for determining the current backend. @@ -107,6 +111,7 @@ def backend(): return 'tensorflow' +@tf_export('keras.backend.epsilon') def epsilon(): """Returns the value of the fuzz factor used in numeric expressions. @@ -116,12 +121,13 @@ def epsilon(): Example: ```python >>> keras.backend.epsilon() - 1e-08 + 1e-07 ``` """ return _EPSILON +@tf_export('keras.backend.set_epsilon') def set_epsilon(value): """Sets the value of the fuzz factor used in numeric expressions. @@ -132,7 +138,7 @@ def set_epsilon(value): ```python >>> from keras import backend as K >>> K.epsilon() - 1e-08 + 1e-07 >>> K.set_epsilon(1e-05) >>> K.epsilon() 1e-05 @@ -142,6 +148,7 @@ def set_epsilon(value): _EPSILON = value +@tf_export('keras.backend.floatx') def floatx(): """Returns the default float type, as a string. @@ -159,6 +166,7 @@ def floatx(): return _FLOATX +@tf_export('keras.backend.set_floatx') def set_floatx(value): """Sets the default float type. @@ -184,6 +192,7 @@ def set_floatx(value): _FLOATX = str(value) +@tf_export('keras.backend.cast_to_floatx') def cast_to_floatx(x): """Cast a Numpy array to the default Keras float type. @@ -211,6 +220,7 @@ def cast_to_floatx(x): return np.asarray(x, dtype=_FLOATX) +@tf_export('keras.backend.image_data_format') def image_data_format(): """Returns the default image data format convention. @@ -226,6 +236,7 @@ def image_data_format(): return _IMAGE_DATA_FORMAT +@tf_export('keras.backend.set_image_data_format') def set_image_data_format(data_format): """Sets the value of the image data format convention. @@ -247,10 +258,11 @@ def set_image_data_format(data_format): """ global _IMAGE_DATA_FORMAT if data_format not in {'channels_last', 'channels_first'}: - raise ValueError('Unknown data_format:', data_format) + raise ValueError('Unknown data_format: ' + str(data_format)) _IMAGE_DATA_FORMAT = str(data_format) +@tf_export('keras.backend.get_uid') def get_uid(prefix=''): """Associates a string prefix with an integer counter in a TensorFlow graph. @@ -278,6 +290,7 @@ def get_uid(prefix=''): return layer_name_uids[prefix] +@tf_export('keras.backend.reset_uids') def reset_uids(): per_graph_layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS keys = list(per_graph_layer_name_uids.keys()) @@ -285,6 +298,7 @@ def reset_uids(): del per_graph_layer_name_uids[key] +@tf_export('keras.backend.clear_session') def clear_session(): """Destroys the current TF graph and creates a new one. @@ -295,11 +309,13 @@ def clear_session(): ops.reset_default_graph() reset_uids() _SESSION = None - phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase') + phase = array_ops.placeholder_with_default( + False, shape=(), name='keras_learning_phase') _GRAPH_LEARNING_PHASES = {} _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase +@tf_export('keras.backend.manual_variable_initialization') def manual_variable_initialization(value): """Sets the manual variable initialization flag. @@ -316,6 +332,7 @@ def manual_variable_initialization(value): _MANUAL_VAR_INIT = value +@tf_export('keras.backend.learning_phase') def learning_phase(): """Returns the learning phase flag. @@ -326,13 +343,21 @@ def learning_phase(): Returns: Learning phase (scalar integer tensor or Python integer). """ + if context.in_eager_mode(): + if 'eager' not in _GRAPH_LEARNING_PHASES: + # Fallback to inference mode as default. + return 0 + return _GRAPH_LEARNING_PHASES['eager'] + graph = ops.get_default_graph() if graph not in _GRAPH_LEARNING_PHASES: - phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase') + phase = array_ops.placeholder_with_default( + False, shape=(), name='keras_learning_phase') _GRAPH_LEARNING_PHASES[graph] = phase return _GRAPH_LEARNING_PHASES[graph] +@tf_export('keras.backend.set_learning_phase') def set_learning_phase(value): """Sets the learning phase to a fixed value. @@ -345,9 +370,13 @@ def set_learning_phase(value): global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned if value not in {0, 1}: raise ValueError('Expected learning phase to be ' '0 or 1.') - _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value + if context.in_eager_mode(): + _GRAPH_LEARNING_PHASES['eager'] = value + else: + _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value +@tf_export('keras.backend.get_session') def get_session(): """Returns the TF session to be used by the backend. @@ -383,6 +412,7 @@ def get_session(): return session +@tf_export('keras.backend.set_session') def set_session(session): """Sets the global TensorFlow session. @@ -456,7 +486,7 @@ def _get_available_gpus(): def _has_nchw_support(): """Check whether the current scope supports NCHW ops. - Tensorflow does not support NCHW on CPU. Therefore we check if we are not + TensorFlow does not support NCHW on CPU. Therefore we check if we are not explicitly put on CPU, and have GPUs available. In this case there will be soft-placing on the GPU device. @@ -485,6 +515,7 @@ def _to_tensor(x, dtype): return ops.convert_to_tensor(x, dtype=dtype) +@tf_export('keras.backend.is_sparse') def is_sparse(tensor): """Returns whether a tensor is a sparse tensor. @@ -508,6 +539,7 @@ def is_sparse(tensor): return isinstance(tensor, sparse_tensor.SparseTensor) +@tf_export('keras.backend.to_dense') def to_dense(tensor): """Converts a sparse tensor into a dense tensor and returns it. @@ -537,6 +569,7 @@ def to_dense(tensor): name_scope = ops.name_scope +@tf_export('keras.backend.variable') def variable(value, dtype=None, name=None, constraint=None): """Instantiates a variable and returns it. @@ -575,7 +608,7 @@ def variable(value, dtype=None, name=None, constraint=None): v._keras_shape = sparse_coo.shape v._uses_learning_phase = False return v - v = variables_module.Variable( + v = resource_variable_ops.ResourceVariable( value, dtype=dtypes_module.as_dtype(dtype), name=name, @@ -609,6 +642,7 @@ def _initialize_variables(session): session.run(variables_module.variables_initializer(uninitialized_vars)) +@tf_export('keras.backend.constant') def constant(value, dtype=None, shape=None, name=None): """Creates a constant tensor. @@ -677,6 +711,7 @@ def is_keras_tensor(x): return hasattr(x, '_keras_history') +@tf_export('keras.backend.placeholder') def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None): """Instantiates a placeholder tensor and returns it. @@ -729,6 +764,7 @@ def is_placeholder(x): return False +@tf_export('keras.backend.shape') def shape(x): """Returns the symbolic shape of a tensor or variable. @@ -761,6 +797,7 @@ def shape(x): return array_ops.shape(x) +@tf_export('keras.backend.int_shape') def int_shape(x): """Returns the shape of tensor or variable as a tuple of int or None entries. @@ -788,6 +825,7 @@ def int_shape(x): return None +@tf_export('keras.backend.ndim') def ndim(x): """Returns the number of axes in a tensor, as an integer. @@ -815,6 +853,7 @@ def ndim(x): return None +@tf_export('keras.backend.dtype') def dtype(x): """Returns the dtype of a Keras tensor or variable, as a string. @@ -845,6 +884,7 @@ def dtype(x): return x.dtype.base_dtype.name +@tf_export('keras.backend.eval') def eval(x): """Evaluates the value of a variable. @@ -866,6 +906,7 @@ def eval(x): return to_dense(x).eval(session=get_session()) +@tf_export('keras.backend.zeros') def zeros(shape, dtype=None, name=None): """Instantiates an all-zeros variable and returns it. @@ -876,6 +917,8 @@ def zeros(shape, dtype=None, name=None): Returns: A variable (including Keras metadata), filled with `0.0`. + Note that if `shape` was symbolic, we cannot return a variable, + and will return a dynamically-shaped tensor instead. Example: ```python @@ -890,12 +933,15 @@ def zeros(shape, dtype=None, name=None): if dtype is None: dtype = floatx() tf_dtype = dtypes_module.as_dtype(dtype) - return variable( - init_ops.constant_initializer(0., dtype=tf_dtype)(shape), dtype, name) + v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name) + if py_all(v.get_shape().as_list()): + return variable(v, dtype=dtype, name=name) + return v +@tf_export('keras.backend.ones') def ones(shape, dtype=None, name=None): - """Instantiates an all-ones tensor variable and returns it. + """Instantiates an all-ones variable and returns it. Arguments: shape: Tuple of integers, shape of returned Keras variable. @@ -904,6 +950,8 @@ def ones(shape, dtype=None, name=None): Returns: A Keras variable, filled with `1.0`. + Note that if `shape` was symbolic, we cannot return a variable, + and will return a dynamically-shaped tensor instead. Example: ```python @@ -918,10 +966,13 @@ def ones(shape, dtype=None, name=None): if dtype is None: dtype = floatx() tf_dtype = dtypes_module.as_dtype(dtype) - return variable( - init_ops.constant_initializer(1., dtype=tf_dtype)(shape), dtype, name) + v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name) + if py_all(v.get_shape().as_list()): + return variable(v, dtype=dtype, name=name) + return v +@tf_export('keras.backend.eye') def eye(size, dtype=None, name=None): """Instantiate an identity matrix and returns it. @@ -950,6 +1001,7 @@ def eye(size, dtype=None, name=None): return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name) +@tf_export('keras.backend.zeros_like') def zeros_like(x, dtype=None, name=None): """Instantiates an all-zeros variable of the same shape as another tensor. @@ -975,6 +1027,7 @@ def zeros_like(x, dtype=None, name=None): return array_ops.zeros_like(x, dtype=dtype, name=name) +@tf_export('keras.backend.ones_like') def ones_like(x, dtype=None, name=None): """Instantiates an all-ones variable of the same shape as another tensor. @@ -1013,6 +1066,7 @@ def identity(x, name=None): return array_ops.identity(x, name=name) +@tf_export('keras.backend.random_uniform_variable') def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None): """Instantiates a variable with values drawn from a uniform distribution. @@ -1049,6 +1103,7 @@ def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None): return variable(value, dtype=dtype, name=name) +@tf_export('keras.backend.random_normal_variable') def random_normal_variable(shape, mean, scale, dtype=None, name=None, seed=None): """Instantiates a variable with values drawn from a normal distribution. @@ -1086,6 +1141,7 @@ def random_normal_variable(shape, mean, scale, dtype=None, name=None, return variable(value, dtype=dtype, name=name) +@tf_export('keras.backend.count_params') def count_params(x): """Returns the static number of elements in a variable or tensor. @@ -1108,6 +1164,7 @@ def count_params(x): return np.prod(x.get_shape().as_list()) +@tf_export('keras.backend.cast') def cast(x, dtype): """Casts a tensor to a different dtype and returns it. @@ -1143,10 +1200,12 @@ def cast(x, dtype): # UPDATES OPS +@tf_export('keras.backend.update') def update(x, new_x): return state_ops.assign(x, new_x) +@tf_export('keras.backend.update_add') def update_add(x, increment): """Update the value of `x` by adding `increment`. @@ -1160,6 +1219,7 @@ def update_add(x, increment): return state_ops.assign_add(x, increment) +@tf_export('keras.backend.update_sub') def update_sub(x, decrement): """Update the value of `x` by subtracting `decrement`. @@ -1173,6 +1233,7 @@ def update_sub(x, decrement): return state_ops.assign_sub(x, decrement) +@tf_export('keras.backend.moving_average_update') def moving_average_update(x, value, momentum): """Compute the moving average of a variable. @@ -1185,12 +1246,13 @@ def moving_average_update(x, value, momentum): An Operation to update the variable. """ return moving_averages.assign_moving_average( - x, value, momentum, zero_debias=False) + x, value, momentum, zero_debias=True) # LINEAR ALGEBRA +@tf_export('keras.backend.dot') def dot(x, y): """Multiplies 2 tensors (and/or variables) and returns a *tensor*. @@ -1262,6 +1324,7 @@ def dot(x, y): return out +@tf_export('keras.backend.batch_dot') def batch_dot(x, y, axes=None): """Batchwise dot product. @@ -1354,6 +1417,7 @@ def batch_dot(x, y, axes=None): return out +@tf_export('keras.backend.transpose') def transpose(x): """Transposes a tensor and returns it. @@ -1389,6 +1453,7 @@ def transpose(x): return array_ops.transpose(x) +@tf_export('keras.backend.gather') def gather(reference, indices): """Retrieves the elements of indices `indices` in the tensor `reference`. @@ -1405,6 +1470,7 @@ def gather(reference, indices): # ELEMENT-WISE OPERATIONS +@tf_export('keras.backend.max') def max(x, axis=None, keepdims=False): """Maximum value in a tensor. @@ -1419,9 +1485,10 @@ def max(x, axis=None, keepdims=False): Returns: A tensor with maximum values of `x`. """ - return math_ops.reduce_max(x, axis=axis, keep_dims=keepdims) + return math_ops.reduce_max(x, axis, keepdims) +@tf_export('keras.backend.min') def min(x, axis=None, keepdims=False): """Minimum value in a tensor. @@ -1436,9 +1503,10 @@ def min(x, axis=None, keepdims=False): Returns: A tensor with miminum values of `x`. """ - return math_ops.reduce_min(x, axis=axis, keep_dims=keepdims) + return math_ops.reduce_min(x, axis, keepdims) +@tf_export('keras.backend.sum') def sum(x, axis=None, keepdims=False): """Sum of the values in a tensor, alongside the specified axis. @@ -1453,9 +1521,10 @@ def sum(x, axis=None, keepdims=False): Returns: A tensor with sum of `x`. """ - return math_ops.reduce_sum(x, axis=axis, keep_dims=keepdims) + return math_ops.reduce_sum(x, axis, keepdims) +@tf_export('keras.backend.prod') def prod(x, axis=None, keepdims=False): """Multiplies the values in a tensor, alongside the specified axis. @@ -1470,7 +1539,7 @@ def prod(x, axis=None, keepdims=False): Returns: A tensor with the product of elements of `x`. """ - return math_ops.reduce_prod(x, axis=axis, keep_dims=keepdims) + return math_ops.reduce_prod(x, axis, keepdims) def cumsum(x, axis=0): @@ -1499,6 +1568,7 @@ def cumprod(x, axis=0): return math_ops.cumprod(x, axis=axis) +@tf_export('keras.backend.var') def var(x, axis=None, keepdims=False): """Variance of a tensor, alongside the specified axis. @@ -1515,12 +1585,13 @@ def var(x, axis=None, keepdims=False): """ if x.dtype.base_dtype == dtypes_module.bool: x = math_ops.cast(x, floatx()) - m = math_ops.reduce_mean(x, axis=axis, keep_dims=True) + m = math_ops.reduce_mean(x, axis, True) devs_squared = math_ops.square(x - m) return math_ops.reduce_mean( - devs_squared, axis=axis, keep_dims=keepdims) + devs_squared, axis, keepdims) +@tf_export('keras.backend.std') def std(x, axis=None, keepdims=False): """Standard deviation of a tensor, alongside the specified axis. @@ -1538,6 +1609,7 @@ def std(x, axis=None, keepdims=False): return math_ops.sqrt(var(x, axis=axis, keepdims=keepdims)) +@tf_export('keras.backend.mean') def mean(x, axis=None, keepdims=False): """Mean of a tensor, alongside the specified axis. @@ -1546,7 +1618,7 @@ def mean(x, axis=None, keepdims=False): axis: A list of integer. Axes to compute the mean. keepdims: A boolean, whether to keep the dimensions or not. If `keepdims` is `False`, the rank of the tensor is reduced - by 1 for each entry in `axis`. If `keep_dims` is `True`, + by 1 for each entry in `axis`. If `keepdims` is `True`, the reduced dimensions are retained with length 1. Returns: @@ -1554,9 +1626,10 @@ def mean(x, axis=None, keepdims=False): """ if x.dtype.base_dtype == dtypes_module.bool: x = math_ops.cast(x, floatx()) - return math_ops.reduce_mean(x, axis=axis, keep_dims=keepdims) + return math_ops.reduce_mean(x, axis, keepdims) +@tf_export('keras.backend.any') def any(x, axis=None, keepdims=False): """Bitwise reduction (logical OR). @@ -1569,9 +1642,10 @@ def any(x, axis=None, keepdims=False): A uint8 tensor (0s and 1s). """ x = math_ops.cast(x, dtypes_module.bool) - return math_ops.reduce_any(x, axis=axis, keep_dims=keepdims) + return math_ops.reduce_any(x, axis, keepdims) +@tf_export('keras.backend.all') def all(x, axis=None, keepdims=False): """Bitwise reduction (logical AND). @@ -1584,9 +1658,10 @@ def all(x, axis=None, keepdims=False): A uint8 tensor (0s and 1s). """ x = math_ops.cast(x, dtypes_module.bool) - return math_ops.reduce_all(x, axis=axis, keep_dims=keepdims) + return math_ops.reduce_all(x, axis, keepdims) +@tf_export('keras.backend.argmax') def argmax(x, axis=-1): """Returns the index of the maximum value along an axis. @@ -1600,6 +1675,7 @@ def argmax(x, axis=-1): return math_ops.argmax(x, axis) +@tf_export('keras.backend.argmin') def argmin(x, axis=-1): """Returns the index of the minimum value along an axis. @@ -1613,6 +1689,7 @@ def argmin(x, axis=-1): return math_ops.argmin(x, axis) +@tf_export('keras.backend.square') def square(x): """Element-wise square. @@ -1625,6 +1702,7 @@ def square(x): return math_ops.square(x) +@tf_export('keras.backend.abs') def abs(x): """Element-wise absolute value. @@ -1637,6 +1715,7 @@ def abs(x): return math_ops.abs(x) +@tf_export('keras.backend.sqrt') def sqrt(x): """Element-wise square root. @@ -1652,6 +1731,7 @@ def sqrt(x): return math_ops.sqrt(x) +@tf_export('keras.backend.exp') def exp(x): """Element-wise exponential. @@ -1664,6 +1744,7 @@ def exp(x): return math_ops.exp(x) +@tf_export('keras.backend.log') def log(x): """Element-wise log. @@ -1694,9 +1775,10 @@ def logsumexp(x, axis=None, keepdims=False): Returns: The reduced tensor. """ - return math_ops.reduce_logsumexp(x, axis=axis, keep_dims=keepdims) + return math_ops.reduce_logsumexp(x, axis, keepdims) +@tf_export('keras.backend.round') def round(x): """Element-wise rounding to the closest integer. @@ -1711,6 +1793,7 @@ def round(x): return math_ops.round(x) +@tf_export('keras.backend.sign') def sign(x): """Element-wise sign. @@ -1723,6 +1806,7 @@ def sign(x): return math_ops.sign(x) +@tf_export('keras.backend.pow') def pow(x, a): """Element-wise exponentiation. @@ -1736,6 +1820,7 @@ def pow(x, a): return math_ops.pow(x, a) +@tf_export('keras.backend.clip') def clip(x, min_value, max_value): """Element-wise value clipping. @@ -1756,6 +1841,7 @@ def clip(x, min_value, max_value): return clip_ops.clip_by_value(x, min_value, max_value) +@tf_export('keras.backend.equal') def equal(x, y): """Element-wise equality between two tensors. @@ -1769,6 +1855,7 @@ def equal(x, y): return math_ops.equal(x, y) +@tf_export('keras.backend.not_equal') def not_equal(x, y): """Element-wise inequality between two tensors. @@ -1782,6 +1869,7 @@ def not_equal(x, y): return math_ops.not_equal(x, y) +@tf_export('keras.backend.greater') def greater(x, y): """Element-wise truth value of (x > y). @@ -1795,6 +1883,7 @@ def greater(x, y): return math_ops.greater(x, y) +@tf_export('keras.backend.greater_equal') def greater_equal(x, y): """Element-wise truth value of (x >= y). @@ -1808,6 +1897,7 @@ def greater_equal(x, y): return math_ops.greater_equal(x, y) +@tf_export('keras.backend.less') def less(x, y): """Element-wise truth value of (x < y). @@ -1821,6 +1911,7 @@ def less(x, y): return math_ops.less(x, y) +@tf_export('keras.backend.less_equal') def less_equal(x, y): """Element-wise truth value of (x <= y). @@ -1834,6 +1925,7 @@ def less_equal(x, y): return math_ops.less_equal(x, y) +@tf_export('keras.backend.maximum') def maximum(x, y): """Element-wise maximum of two tensors. @@ -1847,6 +1939,7 @@ def maximum(x, y): return math_ops.maximum(x, y) +@tf_export('keras.backend.minimum') def minimum(x, y): """Element-wise minimum of two tensors. @@ -1860,6 +1953,7 @@ def minimum(x, y): return math_ops.minimum(x, y) +@tf_export('keras.backend.sin') def sin(x): """Computes sin of x element-wise. @@ -1872,6 +1966,7 @@ def sin(x): return math_ops.sin(x) +@tf_export('keras.backend.cos') def cos(x): """Computes cos of x element-wise. @@ -1884,6 +1979,109 @@ def cos(x): return math_ops.cos(x) +def _regular_normalize_batch_in_training(x, + gamma, + beta, + reduction_axes, + epsilon=1e-3): + """Non-fused version of `normalize_batch_in_training`. + + Arguments: + x: Input tensor or variable. + gamma: Tensor by which to scale the input. + beta: Tensor with which to center the input. + reduction_axes: iterable of integers, + axes over which to normalize. + epsilon: Fuzz factor. + + Returns: + A tuple length of 3, `(normalized_tensor, mean, variance)`. + """ + mean, var = nn.moments(x, reduction_axes, None, None, False) + normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon) + return normed, mean, var + + +def _broadcast_normalize_batch_in_training(x, + gamma, + beta, + reduction_axes, + epsilon=1e-3): + """Non-fused, broadcast version of `normalize_batch_in_training`. + + Arguments: + x: Input tensor or variable. + gamma: Tensor by which to scale the input. + beta: Tensor with which to center the input. + reduction_axes: iterable of integers, + axes over which to normalize. + epsilon: Fuzz factor. + + Returns: + A tuple length of 3, `(normalized_tensor, mean, variance)`. + """ + mean, var = nn.moments(x, reduction_axes, None, None, False) + target_shape = [] + for axis in range(ndim(x)): + if axis in reduction_axes: + target_shape.append(1) + else: + target_shape.append(array_ops.shape(x)[axis]) + target_shape = array_ops.stack(target_shape) + + broadcast_mean = array_ops.reshape(mean, target_shape) + broadcast_var = array_ops.reshape(var, target_shape) + if gamma is None: + broadcast_gamma = None + else: + broadcast_gamma = array_ops.reshape(gamma, target_shape) + if beta is None: + broadcast_beta = None + else: + broadcast_beta = array_ops.reshape(beta, target_shape) + + normed = nn.batch_normalization(x, broadcast_mean, broadcast_var, + broadcast_beta, broadcast_gamma, epsilon) + return normed, mean, var + + +def _fused_normalize_batch_in_training(x, + gamma, + beta, + reduction_axes, + epsilon=1e-3): + """Fused version of `normalize_batch_in_training`. + + Arguments: + x: Input tensor or variable. + gamma: Tensor by which to scale the input. + beta: Tensor with which to center the input. + reduction_axes: iterable of integers, + axes over which to normalize. + epsilon: Fuzz factor. + + Returns: + A tuple length of 3, `(normalized_tensor, mean, variance)`. + """ + if list(reduction_axes) == [0, 1, 2]: + normalization_axis = 3 + tf_data_format = 'NHWC' + else: + normalization_axis = 1 + tf_data_format = 'NCHW' + + if gamma is None: + gamma = constant_op.constant( + 1.0, dtype=x.dtype, shape=[x.get_shape()[normalization_axis]]) + if beta is None: + beta = constant_op.constant( + 0.0, dtype=x.dtype, shape=[x.get_shape()[normalization_axis]]) + + return nn.fused_batch_norm( + x, gamma, beta, epsilon=epsilon, data_format=tf_data_format) + + +@tf_export('keras.backend.normalize_batch_in_training') def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3): """Computes mean and std for batch then apply batch_normalization on batch. @@ -1898,35 +2096,22 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3): Returns: A tuple length of 3, `(normalized_tensor, mean, variance)`. """ - mean, var = nn.moments( - x, reduction_axes, shift=None, name=None, keep_dims=False) - if sorted(reduction_axes) == list(range(ndim(x)))[:-1]: - normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon) + if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]: + if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]: + return _broadcast_normalize_batch_in_training( + x, gamma, beta, reduction_axes, epsilon=epsilon) + return _fused_normalize_batch_in_training( + x, gamma, beta, reduction_axes, epsilon=epsilon) else: - # need broadcasting - target_shape = [] - for axis in range(ndim(x)): - if axis in reduction_axes: - target_shape.append(1) - else: - target_shape.append(array_ops.shape(x)[axis]) - target_shape = array_ops.stack(target_shape) - - broadcast_mean = array_ops.reshape(mean, target_shape) - broadcast_var = array_ops.reshape(var, target_shape) - if gamma is None: - broadcast_gamma = None + if sorted(reduction_axes) == list(range(ndim(x)))[:-1]: + return _regular_normalize_batch_in_training( + x, gamma, beta, reduction_axes, epsilon=epsilon) else: - broadcast_gamma = array_ops.reshape(gamma, target_shape) - if beta is None: - broadcast_beta = None - else: - broadcast_beta = array_ops.reshape(beta, target_shape) - normed = nn.batch_normalization(x, broadcast_mean, broadcast_var, - broadcast_beta, broadcast_gamma, epsilon) - return normed, mean, var + return _broadcast_normalize_batch_in_training( + x, gamma, beta, reduction_axes, epsilon=epsilon) +@tf_export('keras.backend.batch_normalization') def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3): """Applies batch normalization on x given mean, var, beta and gamma. @@ -1950,6 +2135,7 @@ def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3): # SHAPE OPERATIONS +@tf_export('keras.backend.concatenate') def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. @@ -1973,6 +2159,7 @@ def concatenate(tensors, axis=-1): return array_ops.concat([to_dense(x) for x in tensors], axis) +@tf_export('keras.backend.reshape') def reshape(x, shape): """Reshapes a tensor to the specified shape. @@ -1986,6 +2173,7 @@ def reshape(x, shape): return array_ops.reshape(x, shape) +@tf_export('keras.backend.permute_dimensions') def permute_dimensions(x, pattern): """Permutes axes in a tensor. @@ -2000,6 +2188,7 @@ def permute_dimensions(x, pattern): return array_ops.transpose(x, perm=pattern) +@tf_export('keras.backend.resize_images') def resize_images(x, height_factor, width_factor, data_format): """Resizes the images contained in a 4D tensor. @@ -2041,9 +2230,10 @@ def resize_images(x, height_factor, width_factor, data_format): if original_shape[2] is not None else None, None)) return x else: - raise ValueError('Invalid data_format:', data_format) + raise ValueError('Invalid data_format: ' + str(data_format)) +@tf_export('keras.backend.resize_volumes') def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): """Resizes the volume contained in a 5D tensor. @@ -2072,9 +2262,10 @@ def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): output = repeat_elements(output, width_factor, axis=3) return output else: - raise ValueError('Invalid data_format:', data_format) + raise ValueError('Invalid data_format: ' + str(data_format)) +@tf_export('keras.backend.repeat_elements') def repeat_elements(x, rep, axis): """Repeats the elements of a tensor along an axis, like `np.repeat`. @@ -2127,6 +2318,7 @@ def repeat_elements(x, rep, axis): return x_rep +@tf_export('keras.backend.repeat') def repeat(x, n): """Repeats a 2D tensor. @@ -2146,12 +2338,13 @@ def repeat(x, n): return array_ops.tile(x, pattern) +@tf_export('keras.backend.arange') def arange(start, stop=None, step=1, dtype='int32'): """Creates a 1D tensor containing a sequence of integers. The function arguments use the same convention as Theano's arange: if only one argument is provided, - it is in fact the "stop" argument. + it is in fact the "stop" argument and "start" is 0. The default type of the returned tensor is `'int32'` to match TensorFlow's default. @@ -2166,7 +2359,7 @@ def arange(start, stop=None, step=1, dtype='int32'): An integer tensor. """ - # Match the behavior of numpy and Theano by returning an empty seqence. + # Match the behavior of numpy and Theano by returning an empty sequence. if stop is None and start < 0: start = 0 result = math_ops.range(start, limit=stop, delta=step, name='arange') @@ -2191,6 +2384,7 @@ def tile(x, n): return array_ops.tile(x, n) +@tf_export('keras.backend.flatten') def flatten(x): """Flatten a tensor. @@ -2203,6 +2397,7 @@ def flatten(x): return array_ops.reshape(x, [-1]) +@tf_export('keras.backend.batch_flatten') def batch_flatten(x): """Turn a nD tensor into a 2D tensor with same 0th dimension. @@ -2218,6 +2413,7 @@ def batch_flatten(x): return x +@tf_export('keras.backend.expand_dims') def expand_dims(x, axis=-1): """Adds a 1-sized dimension at index "axis". @@ -2231,6 +2427,7 @@ def expand_dims(x, axis=-1): return array_ops.expand_dims(x, axis) +@tf_export('keras.backend.squeeze') def squeeze(x, axis): """Removes a 1-dimension from the tensor at index "axis". @@ -2244,6 +2441,7 @@ def squeeze(x, axis): return array_ops.squeeze(x, [axis]) +@tf_export('keras.backend.temporal_padding') def temporal_padding(x, padding=(1, 1)): """Pads the middle dimension of a 3D tensor. @@ -2260,6 +2458,7 @@ def temporal_padding(x, padding=(1, 1)): return array_ops.pad(x, pattern) +@tf_export('keras.backend.spatial_2d_padding') def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): """Pads the 2nd and 3rd dimensions of a 4D tensor. @@ -2281,7 +2480,7 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) if data_format == 'channels_first': pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])] @@ -2290,6 +2489,7 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): return array_ops.pad(x, pattern) +@tf_export('keras.backend.spatial_3d_padding') def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): """Pads 5D tensor with zeros along the depth, height, width dimensions. @@ -2321,7 +2521,7 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) if data_format == 'channels_first': pattern = [[0, 0], [0, 0], [padding[0][0], padding[0][1]], @@ -2333,6 +2533,7 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): return array_ops.pad(x, pattern) +@tf_export('keras.backend.stack') def stack(x, axis=0): """Stacks a list of rank `R` tensors into a rank `R+1` tensor. @@ -2346,6 +2547,7 @@ def stack(x, axis=0): return array_ops.stack(x, axis=axis) +@tf_export('keras.backend.one_hot') def one_hot(indices, num_classes): """Computes the one-hot representation of an integer tensor. @@ -2364,6 +2566,7 @@ def one_hot(indices, num_classes): return array_ops.one_hot(indices, depth=num_classes, axis=-1) +@tf_export('keras.backend.reverse') def reverse(x, axes): """Reverse a tensor along the specified axes. @@ -2383,6 +2586,7 @@ def reverse(x, axes): # VALUE MANIPULATION +@tf_export('keras.backend.get_value') def get_value(x): """Returns the value of a variable. @@ -2392,9 +2596,12 @@ def get_value(x): Returns: A Numpy array. """ + if context.in_eager_mode(): + return x.numpy() return x.eval(session=get_session()) +@tf_export('keras.backend.batch_get_value') def batch_get_value(tensors): """Returns the value of more than one tensor variable. @@ -2404,12 +2611,15 @@ def batch_get_value(tensors): Returns: A list of Numpy arrays. """ + if context.in_eager_mode(): + return [x.numpy() for x in tensors] if tensors: return get_session().run(tensors) else: return [] +@tf_export('keras.backend.set_value') def set_value(x, value): """Sets the value of a variable, from a Numpy array. @@ -2419,18 +2629,22 @@ def set_value(x, value): (of the same shape). """ value = np.asarray(value, dtype=dtype(x)) - tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0]) - if hasattr(x, '_assign_placeholder'): - assign_placeholder = x._assign_placeholder - assign_op = x._assign_op + if context.in_eager_mode(): + x.assign(value) else: - assign_placeholder = array_ops.placeholder(tf_dtype, shape=value.shape) - assign_op = x.assign(assign_placeholder) - x._assign_placeholder = assign_placeholder - x._assign_op = assign_op - get_session().run(assign_op, feed_dict={assign_placeholder: value}) + tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0]) + if hasattr(x, '_assign_placeholder'): + assign_placeholder = x._assign_placeholder + assign_op = x._assign_op + else: + assign_placeholder = array_ops.placeholder(tf_dtype, shape=value.shape) + assign_op = x.assign(assign_placeholder) + x._assign_placeholder = assign_placeholder + x._assign_op = assign_op + get_session().run(assign_op, feed_dict={assign_placeholder: value}) +@tf_export('keras.backend.batch_set_value') def batch_set_value(tuples): """Sets the values of many tensor variables at once. @@ -2438,25 +2652,31 @@ def batch_set_value(tuples): tuples: a list of tuples `(tensor, value)`. `value` should be a Numpy array. """ - if tuples: - assign_ops = [] - feed_dict = {} + if context.in_eager_mode(): for x, value in tuples: - value = np.asarray(value, dtype=dtype(x)) - tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0]) - if hasattr(x, '_assign_placeholder'): - assign_placeholder = x._assign_placeholder - assign_op = x._assign_op - else: - assign_placeholder = array_ops.placeholder(tf_dtype, shape=value.shape) - assign_op = x.assign(assign_placeholder) - x._assign_placeholder = assign_placeholder - x._assign_op = assign_op - assign_ops.append(assign_op) - feed_dict[assign_placeholder] = value - get_session().run(assign_ops, feed_dict=feed_dict) + x.assign(np.asarray(value, dtype=dtype(x))) + else: + if tuples: + assign_ops = [] + feed_dict = {} + for x, value in tuples: + value = np.asarray(value, dtype=dtype(x)) + tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0]) + if hasattr(x, '_assign_placeholder'): + assign_placeholder = x._assign_placeholder + assign_op = x._assign_op + else: + assign_placeholder = array_ops.placeholder(tf_dtype, + shape=value.shape) + assign_op = x.assign(assign_placeholder) + x._assign_placeholder = assign_placeholder + x._assign_op = assign_op + assign_ops.append(assign_op) + feed_dict[assign_placeholder] = value + get_session().run(assign_ops, feed_dict=feed_dict) +@tf_export('keras.backend.print_tensor') def print_tensor(x, message=''): """Prints `message` and the tensor value when evaluated. @@ -2554,6 +2774,7 @@ class Function(object): return updated[:len(self.outputs)] +@tf_export('keras.backend.function') def function(inputs, outputs, updates=None, **kwargs): """Instantiates a Keras function. @@ -2573,12 +2794,13 @@ def function(inputs, outputs, updates=None, **kwargs): for key in kwargs: if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and key not in tf_inspect.getargspec(Function.__init__)[0]): - msg = ('Invalid argument "%s" passed to K.function with Tensorflow ' + msg = ('Invalid argument "%s" passed to K.function with TensorFlow ' 'backend') % key raise ValueError(msg) return Function(inputs, outputs, updates=updates, **kwargs) +@tf_export('keras.backend.gradients') def gradients(loss, variables): """Returns the gradients of `variables` w.r.t. `loss`. @@ -2593,6 +2815,7 @@ def gradients(loss, variables): loss, variables, colocate_gradients_with_ops=True) +@tf_export('keras.backend.stop_gradient') def stop_gradient(variables): """Returns `variables` but with zero gradient w.r.t. every other variable. @@ -2613,13 +2836,15 @@ def stop_gradient(variables): # CONTROL FLOW +@tf_export('keras.backend.rnn') def rnn(step_function, inputs, initial_states, go_backwards=False, mask=None, constants=None, - unroll=False): + unroll=False, + input_length=None): """Iterates over the time dimension of a tensor. Arguments: @@ -2648,6 +2873,7 @@ def rnn(step_function, constants: a list of constant values passed at each step. unroll: whether to unroll the RNN or to use a symbolic loop (`while_loop` or `scan` depending on backend). + input_length: Unused; exists for API compatibility. Returns: A tuple, `(last_output, outputs, new_states)`. @@ -2665,9 +2891,11 @@ def rnn(step_function, ValueError: if `mask` is provided (not `None`) but states is not provided (`len(states)` == 0). """ + del input_length ndim = len(inputs.get_shape()) if ndim < 3: raise ValueError('Input should be at least 3D.') + inputs_shape = inputs.get_shape() axes = [1, 0] + list(range(2, ndim)) inputs = array_ops.transpose(inputs, (axes)) @@ -2686,7 +2914,7 @@ def rnn(step_function, if unroll: if not inputs.get_shape()[0]: - raise ValueError('Unrolling requires a ' 'fixed number of timesteps.') + raise ValueError('Unrolling requires a fixed number of timesteps.') states = initial_states successive_states = [] successive_outputs = [] @@ -2852,10 +3080,18 @@ def rnn(step_function, axes = [1, 0] + list(range(2, len(outputs.get_shape()))) outputs = array_ops.transpose(outputs, axes) + + # Static shape inference: (samples, time, ...) + outputs_shape = outputs.get_shape().as_list() + outputs_shape[0] = inputs_shape[0] + outputs_shape[1] = inputs_shape[1] + outputs.set_shape(outputs_shape) + last_output._uses_learning_phase = uses_learning_phase return last_output, outputs, new_states +@tf_export('keras.backend.switch') def switch(condition, then_expression, else_expression): """Switches between two operations depending on a scalar value. @@ -2919,6 +3155,7 @@ def switch(condition, then_expression, else_expression): return x +@tf_export('keras.backend.in_train_phase') def in_train_phase(x, alt, training=None): """Selects `x` in train phase, and `alt` otherwise. @@ -2962,6 +3199,7 @@ def in_train_phase(x, alt, training=None): return x +@tf_export('keras.backend.in_test_phase') def in_test_phase(x, alt, training=None): """Selects `x` in test phase, and `alt` otherwise. @@ -2985,6 +3223,7 @@ def in_test_phase(x, alt, training=None): # NN OPERATIONS +@tf_export('keras.backend.relu') def relu(x, alpha=0., max_value=None): """Rectified linear unit. @@ -3011,12 +3250,13 @@ def relu(x, alpha=0., max_value=None): return x +@tf_export('keras.backend.elu') def elu(x, alpha=1.): """Exponential linear unit. Arguments: x: A tensor or variable to compute the activation function for. - alpha: A scalar, slope of positive section. + alpha: A scalar, slope of negative section. Returns: A tensor. @@ -3028,6 +3268,7 @@ def elu(x, alpha=1.): return array_ops.where(x > 0, res, alpha * res) +@tf_export('keras.backend.softmax') def softmax(x): """Softmax of a tensor. @@ -3040,6 +3281,7 @@ def softmax(x): return nn.softmax(x) +@tf_export('keras.backend.softplus') def softplus(x): """Softplus of a tensor. @@ -3052,6 +3294,7 @@ def softplus(x): return nn.softplus(x) +@tf_export('keras.backend.softsign') def softsign(x): """Softsign of a tensor. @@ -3064,6 +3307,7 @@ def softsign(x): return nn.softsign(x) +@tf_export('keras.backend.categorical_crossentropy') def categorical_crossentropy(target, output, from_logits=False): """Categorical crossentropy between an output tensor and a target tensor. @@ -3082,8 +3326,8 @@ def categorical_crossentropy(target, output, from_logits=False): # expects logits, Keras expects probabilities. if not from_logits: # scale preds so that the class probas of each sample sum to 1 - output /= math_ops.reduce_sum( - output, axis=len(output.get_shape()) - 1, keep_dims=True) + output = output / math_ops.reduce_sum( # pylint: disable=g-no-augmented-assignment + output, len(output.get_shape()) - 1, True) # manual computation of crossentropy epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype) output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_) @@ -3094,6 +3338,7 @@ def categorical_crossentropy(target, output, from_logits=False): return nn.softmax_cross_entropy_with_logits(labels=target, logits=output) +@tf_export('keras.backend.sparse_categorical_crossentropy') def sparse_categorical_crossentropy(target, output, from_logits=False): """Categorical crossentropy with integer targets. @@ -3127,6 +3372,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False): return res +@tf_export('keras.backend.binary_crossentropy') def binary_crossentropy(target, output, from_logits=False): """Binary crossentropy between an output tensor and a target tensor. @@ -3150,6 +3396,7 @@ def binary_crossentropy(target, output, from_logits=False): return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output) +@tf_export('keras.backend.sigmoid') def sigmoid(x): """Element-wise sigmoid. @@ -3162,6 +3409,7 @@ def sigmoid(x): return nn.sigmoid(x) +@tf_export('keras.backend.hard_sigmoid') def hard_sigmoid(x): """Segment-wise linear approximation of sigmoid. @@ -3182,6 +3430,7 @@ def hard_sigmoid(x): return x +@tf_export('keras.backend.tanh') def tanh(x): """Element-wise tanh. @@ -3194,6 +3443,7 @@ def tanh(x): return nn.tanh(x) +@tf_export('keras.backend.dropout') def dropout(x, level, noise_shape=None, seed=None): """Sets entries in `x` to zero at random, while scaling the entire tensor. @@ -3216,6 +3466,7 @@ def dropout(x, level, noise_shape=None, seed=None): return nn.dropout(x * 1., retain_prob, noise_shape, seed=seed) +@tf_export('keras.backend.l2_normalize') def l2_normalize(x, axis=None): """Normalizes a tensor wrt the L2 norm alongside the specified axis. @@ -3229,6 +3480,7 @@ def l2_normalize(x, axis=None): return nn.l2_normalize(x, dim=axis) +@tf_export('keras.backend.in_top_k') def in_top_k(predictions, targets, k): """Returns whether the `targets` are in the top `k` `predictions`. @@ -3248,6 +3500,25 @@ def in_top_k(predictions, targets, k): # CONVOLUTIONS +def _preprocess_conv1d_input(x, data_format): + """Transpose and cast the input before the conv1d. + + Arguments: + x: input tensor. + data_format: string, `"channels_last"` or `"channels_first"`. + + Returns: + A tensor. + """ + tf_data_format = 'NHWC' # to pass TF Conv2dNative operations + if data_format == 'channels_first': + if not _has_nchw_support(): + x = array_ops.transpose(x, (0, 2, 1)) # NCW -> NWC + else: + tf_data_format = 'NCHW' + return x, tf_data_format + + def _preprocess_conv2d_input(x, data_format): """Transpose and cast the input before the conv2d. @@ -3287,7 +3558,7 @@ def _preprocess_conv3d_input(x, data_format): def _preprocess_padding(padding): - """Convert keras' padding to tensorflow's padding. + """Convert keras' padding to TensorFlow's padding. Arguments: padding: string, one of 'same' , 'valid' @@ -3303,10 +3574,11 @@ def _preprocess_padding(padding): elif padding == 'valid': padding = 'VALID' else: - raise ValueError('Invalid padding:', padding) + raise ValueError('Invalid padding: ' + str(padding)) return padding +@tf_export('keras.backend.conv1d') def conv1d(x, kernel, strides=1, @@ -3333,7 +3605,7 @@ def conv1d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) kernel_shape = kernel.get_shape().as_list() if padding == 'causal': @@ -3356,6 +3628,7 @@ def conv1d(x, return x +@tf_export('keras.backend.conv2d') def conv2d(x, kernel, strides=(1, 1), @@ -3384,7 +3657,7 @@ def conv2d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv2d_input(x, data_format) padding = _preprocess_padding(padding) @@ -3400,6 +3673,7 @@ def conv2d(x, return x +@tf_export('keras.backend.conv2d_transpose') def conv2d_transpose(x, kernel, output_shape, @@ -3430,7 +3704,7 @@ def conv2d_transpose(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) if isinstance(output_shape, (tuple, list)): output_shape = array_ops.stack(output_shape) @@ -3461,6 +3735,69 @@ def conv2d_transpose(x, return x +def separable_conv1d(x, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding='valid', + data_format=None, + dilation_rate=1): + """1D convolution with separable filters. + + Arguments: + x: input tensor + depthwise_kernel: convolution kernel for the depthwise convolution. + pointwise_kernel: kernel for the 1x1 convolution. + strides: stride integer. + padding: string, `"same"` or `"valid"`. + data_format: string, `"channels_last"` or `"channels_first"`. + dilation_rate: integer dilation rate. + + Returns: + Output tensor. + + Raises: + ValueError: if `data_format` is neither `channels_last` or + `channels_first`. + """ + if data_format is None: + data_format = image_data_format() + if data_format not in {'channels_first', 'channels_last'}: + raise ValueError('Unknown data_format: ' + str(data_format)) + + x, tf_data_format = _preprocess_conv1d_input(x, data_format) + padding = _preprocess_padding(padding) + if not isinstance(strides, tuple): + strides = tuple(strides) + if tf_data_format == 'NHWC': + spatial_start_dim = 1 + strides = (1,) + strides * 2 + (1,) + else: + spatial_start_dim = 2 + strides = (1, 1) + strides * 2 + x = array_ops.expand_dims(x, spatial_start_dim) + depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0) + pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0) + dilation_rate = (1,) + dilation_rate + + x = nn.separable_conv2d( + x, + depthwise_kernel, + pointwise_kernel, + strides=strides, + padding=padding, + rate=dilation_rate, + data_format=tf_data_format) + + x = array_ops.squeeze(x, [spatial_start_dim]) + + if data_format == 'channels_first' and tf_data_format == 'NHWC': + x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW + + return x + + +@tf_export('keras.backend.separable_conv2d') def separable_conv2d(x, depthwise_kernel, pointwise_kernel, @@ -3490,10 +3827,12 @@ def separable_conv2d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv2d_input(x, data_format) padding = _preprocess_padding(padding) + if not isinstance(strides, tuple): + strides = tuple(strides) if tf_data_format == 'NHWC': strides = (1,) + strides + (1,) else: @@ -3539,7 +3878,7 @@ def depthwise_conv2d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv2d_input(x, data_format) padding = _preprocess_padding(padding) @@ -3560,6 +3899,7 @@ def depthwise_conv2d(x, return x +@tf_export('keras.backend.conv3d') def conv3d(x, kernel, strides=(1, 1, 1), @@ -3588,7 +3928,7 @@ def conv3d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv3d_input(x, data_format) padding = _preprocess_padding(padding) @@ -3634,7 +3974,7 @@ def conv3d_transpose(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) if isinstance(output_shape, (tuple, list)): output_shape = array_ops.stack(output_shape) @@ -3665,6 +4005,7 @@ def conv3d_transpose(x, return x +@tf_export('keras.backend.pool2d') def pool2d(x, pool_size, strides=(1, 1), @@ -3692,7 +4033,7 @@ def pool2d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv2d_input(x, data_format) padding = _preprocess_padding(padding) @@ -3710,13 +4051,14 @@ def pool2d(x, x = nn.avg_pool( x, pool_size, strides, padding=padding, data_format=tf_data_format) else: - raise ValueError('Invalid pooling mode:', pool_mode) + raise ValueError('Invalid pooling mode: ' + str(pool_mode)) if data_format == 'channels_first' and tf_data_format == 'NHWC': x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW return x +@tf_export('keras.backend.pool3d') def pool3d(x, pool_size, strides=(1, 1, 1), @@ -3744,7 +4086,7 @@ def pool3d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv3d_input(x, data_format) padding = _preprocess_padding(padding) @@ -3762,7 +4104,7 @@ def pool3d(x, x = nn.avg_pool3d( x, pool_size, strides, padding=padding, data_format=tf_data_format) else: - raise ValueError('Invalid pooling mode:', pool_mode) + raise ValueError('Invalid pooling mode: ' + str(pool_mode)) if data_format == 'channels_first' and tf_data_format == 'NDHWC': x = array_ops.transpose(x, (0, 4, 1, 2, 3)) @@ -3793,7 +4135,7 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) stride = strides[0] kernel_shape = int_shape(kernel) @@ -3849,7 +4191,7 @@ def local_conv2d(inputs, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) stride_row, stride_col = strides output_row, output_col = output_shape @@ -3880,6 +4222,7 @@ def local_conv2d(inputs, return output +@tf_export('keras.backend.bias_add') def bias_add(x, bias, data_format=None): """Adds a bias vector to a tensor. @@ -3901,53 +4244,59 @@ def bias_add(x, bias, data_format=None): if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) bias_shape = int_shape(bias) if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1: raise ValueError( 'Unexpected bias dimensions %d, expect to be 1 or %d dimensions' % (len(bias_shape), ndim(x))) + # pylint: disable=g-no-augmented-assignment if ndim(x) == 5: if data_format == 'channels_first': if len(bias_shape) == 1: - x += reshape(bias, (1, bias_shape[0], 1, 1, 1)) + x = x + reshape(bias, (1, bias_shape[0], 1, 1, 1)) else: - x += reshape(bias, (1, bias_shape[3]) + bias_shape[:3]) + x = x + reshape(bias, (1, bias_shape[3]) + bias_shape[:3]) elif data_format == 'channels_last': if len(bias_shape) == 1: - x += reshape(bias, (1, 1, 1, bias_shape[0])) + x = x + reshape(bias, (1, 1, 1, bias_shape[0])) else: - x += reshape(bias, (1,) + bias_shape) + x = x + reshape(bias, (1,) + bias_shape) elif ndim(x) == 4: if data_format == 'channels_first': if len(bias_shape) == 1: - x += reshape(bias, (1, bias_shape[0], 1, 1)) + if _has_nchw_support(): + x = nn.bias_add(x, bias, data_format='NCHW') + else: + x = x + reshape(bias, (1, bias_shape[0], 1, 1)) else: - x += reshape(bias, (1, bias_shape[2]) + bias_shape[:2]) + x = x + reshape(bias, (1, bias_shape[2]) + bias_shape[:2]) elif data_format == 'channels_last': if len(bias_shape) == 1: x = nn.bias_add(x, bias, data_format='NHWC') else: - x += reshape(bias, (1,) + bias_shape) + x = x + reshape(bias, (1,) + bias_shape) elif ndim(x) == 3: if data_format == 'channels_first': if len(bias_shape) == 1: - x += reshape(bias, (1, bias_shape[0], 1)) + x = x + reshape(bias, (1, bias_shape[0], 1)) else: - x += reshape(bias, (1, bias_shape[1], bias_shape[0])) + x = x + reshape(bias, (1, bias_shape[1], bias_shape[0])) elif data_format == 'channels_last': if len(bias_shape) == 1: - x += reshape(bias, (1, 1, bias_shape[0])) + x = x + reshape(bias, (1, 1, bias_shape[0])) else: - x += reshape(bias, (1,) + bias_shape) + x = x + reshape(bias, (1,) + bias_shape) else: x = nn.bias_add(x, bias) + # pylint: enable=g-no-augmented-assignment return x # RANDOMNESS +@tf_export('keras.backend.random_normal') def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): """Returns a tensor with normal distribution of values. @@ -3970,6 +4319,7 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed) +@tf_export('keras.backend.random_uniform') def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): """Returns a tensor with uniform distribution of values. @@ -3993,6 +4343,7 @@ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed) +@tf_export('keras.backend.random_binomial') def random_binomial(shape, p=0.0, dtype=None, seed=None): """Returns a tensor with random binomial distribution of values. @@ -4014,6 +4365,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None): array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype)) +@tf_export('keras.backend.truncated_normal') def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): """Returns a tensor with truncated random normal distribution of values. @@ -4047,6 +4399,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): # in TensorFlow's CTC implementation +@tf_export('keras.backend.ctc_label_dense_to_sparse') def ctc_label_dense_to_sparse(labels, label_lengths): """Converts CTC labels from dense to sparse. @@ -4091,6 +4444,7 @@ def ctc_label_dense_to_sparse(labels, label_lengths): math_ops.to_int64(indices), vals_sparse, math_ops.to_int64(label_shape)) +@tf_export('keras.backend.ctc_batch_cost') def ctc_batch_cost(y_true, y_pred, input_length, label_length): """Runs CTC loss algorithm on each batch element. @@ -4113,13 +4467,14 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length): sparse_labels = math_ops.to_int32( ctc_label_dense_to_sparse(y_true, label_length)) - y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + 1e-8) + y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon()) return array_ops.expand_dims( ctc.ctc_loss( inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1) +@tf_export('keras.backend.ctc_decode') def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): """Decodes the output of a softmax. @@ -4148,7 +4503,7 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): Tensor `(top_paths, )` that contains the log probability of each decoded sequence. """ - y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + 1e-8) + y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon()) input_length = math_ops.to_int32(input_length) if greedy: @@ -4171,6 +4526,7 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): # HIGH ORDER FUNCTIONS +@tf_export('keras.backend.map_fn') def map_fn(fn, elems, name=None, dtype=None): """Map the function fn over the elements elems and return the outputs. @@ -4186,6 +4542,7 @@ def map_fn(fn, elems, name=None, dtype=None): return functional_ops.map_fn(fn, elems, name=name, dtype=dtype) +@tf_export('keras.backend.foldl') def foldl(fn, elems, initializer=None, name=None): """Reduce elems using fn to combine them from left to right. @@ -4202,6 +4559,7 @@ def foldl(fn, elems, initializer=None, name=None): return functional_ops.foldl(fn, elems, initializer=initializer, name=name) +@tf_export('keras.backend.foldr') def foldr(fn, elems, initializer=None, name=None): """Reduce elems using fn to combine them from right to left. diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py index e34f1b6926a8fd2c472664d330fe3cd9d714f021..f29ca49378bc43385b9e90d3f1cefb7937df64cd 100644 --- a/tensorflow/python/keras/_impl/keras/backend_test.py +++ b/tensorflow/python/keras/_impl/keras/backend_test.py @@ -915,6 +915,15 @@ class BackendNNOpsTest(test.TestCase): last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, initial_states, **kwargs) + # check static shape inference + self.assertEquals(last_output.get_shape().as_list(), + [num_samples, output_dim]) + self.assertEquals(outputs.get_shape().as_list(), + [num_samples, timesteps, output_dim]) + for state in new_states: + self.assertEquals(state.get_shape().as_list(), + [num_samples, output_dim]) + last_output_list[i].append(keras.backend.eval(last_output)) outputs_list[i].append(keras.backend.eval(outputs)) self.assertEqual(len(new_states), 1) @@ -954,7 +963,6 @@ class BackendNNOpsTest(test.TestCase): x = keras.backend.variable(val) reduction_axes = (0, 2, 3) - # case: need broadcasting g_val = np.random.random((3,)) b_val = np.random.random((3,)) gamma = keras.backend.variable(g_val) @@ -965,17 +973,6 @@ class BackendNNOpsTest(test.TestCase): self.assertEqual(mean.get_shape().as_list(), [3,]) self.assertEqual(var.get_shape().as_list(), [3,]) - # case: doesn't need broadcasting - g_val = np.random.random((1, 3, 1, 1)) - b_val = np.random.random((1, 3, 1, 1)) - gamma = keras.backend.variable(g_val) - beta = keras.backend.variable(b_val) - normed, mean, var = keras.backend.normalize_batch_in_training( - x, gamma, beta, reduction_axes, epsilon=1e-3) - self.assertEqual(normed.get_shape().as_list(), [10, 3, 10, 10]) - self.assertEqual(mean.get_shape().as_list(), [3,]) - self.assertEqual(var.get_shape().as_list(), [3,]) - # case: gamma=None gamma = None normed, mean, var = keras.backend.normalize_batch_in_training( diff --git a/tensorflow/python/keras/_impl/keras/callbacks.py b/tensorflow/python/keras/_impl/keras/callbacks.py index 8da3b857182237a47daa0f00a2340959a448160e..f6c466142522927135d66f73f9f5c697671649ec 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks.py +++ b/tensorflow/python/keras/_impl/keras/callbacks.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras callbacks: utilities called at certain points during model training. +# pylint: disable=g-import-not-at-top +"""Callbacks: utilities called at certain points during model training. """ from __future__ import absolute_import from __future__ import division @@ -34,14 +35,13 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as tf_summary +from tensorflow.python.util.tf_export import tf_export -# pylint: disable=g-import-not-at-top try: import requests except ImportError: requests = None -# pylint: enable=g-import-not-at-top class CallbackList(object): @@ -109,9 +109,9 @@ class CallbackList(object): delta_t_median = np.median(self._delta_ts_batch_begin) if (self._delta_t_batch > 0. and delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1): - logging.warning( - 'Method on_batch_begin() is slow compared ' - 'to the batch update (%f). Check your callbacks.' % delta_t_median) + logging.warning('Method on_batch_begin() is slow compared ' + 'to the batch update (%f). Check your callbacks.', + delta_t_median) self._t_enter_batch = time.time() def on_batch_end(self, batch, logs=None): @@ -132,9 +132,9 @@ class CallbackList(object): delta_t_median = np.median(self._delta_ts_batch_end) if (self._delta_t_batch > 0. and (delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)): - logging.warning( - 'Method on_batch_end() is slow compared ' - 'to the batch update (%f). Check your callbacks.' % delta_t_median) + logging.warning('Method on_batch_end() is slow compared ' + 'to the batch update (%f). Check your callbacks.', + delta_t_median) def on_train_begin(self, logs=None): """Called at the beginning of training. @@ -160,10 +160,11 @@ class CallbackList(object): return iter(self.callbacks) +@tf_export('keras.callbacks.Callback') class Callback(object): """Abstract base class used to build new callbacks. - # Properties + Attributes: params: dict. Training parameters (eg. verbosity, batch size, number of epochs...). model: instance of `keras.models.Model`. @@ -216,12 +217,23 @@ class Callback(object): pass +@tf_export('keras.callbacks.BaseLogger') class BaseLogger(Callback): """Callback that accumulates epoch averages of metrics. This callback is automatically applied to every Keras model. + + Arguments: + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over an epoch. + Metrics in this list will be logged as-is in `on_epoch_end`. + All others will be averaged in `on_epoch_end`. """ + def __init__(self, stateful_metrics=None): + super(BaseLogger, self).__init__() + self.stateful_metrics = set(stateful_metrics or []) + def on_epoch_begin(self, epoch, logs=None): self.seen = 0 self.totals = {} @@ -232,21 +244,29 @@ class BaseLogger(Callback): self.seen += batch_size for k, v in logs.items(): - if k in self.totals: - self.totals[k] += v * batch_size + if k in self.stateful_metrics: + self.totals[k] = v else: - self.totals[k] = v * batch_size + if k in self.totals: + self.totals[k] += v * batch_size + else: + self.totals[k] = v * batch_size def on_epoch_end(self, epoch, logs=None): if logs is not None: for k in self.params['metrics']: if k in self.totals: # Make value available to next callbacks. - logs[k] = self.totals[k] / self.seen + if k in self.stateful_metrics: + logs[k] = self.totals[k] + else: + logs[k] = self.totals[k] / self.seen +@tf_export('keras.callbacks.TerminateOnNaN') class TerminateOnNaN(Callback): - """Callback that terminates training when a NaN loss is encountered.""" + """Callback that terminates training when a NaN loss is encountered. + """ def __init__(self): super(TerminateOnNaN, self).__init__() @@ -260,6 +280,7 @@ class TerminateOnNaN(Callback): self.model.stop_training = True +@tf_export('keras.callbacks.ProgbarLogger') class ProgbarLogger(Callback): """Callback that prints metrics to stdout. @@ -267,12 +288,16 @@ class ProgbarLogger(Callback): count_mode: One of "steps" or "samples". Whether the progress bar should count samples seen or steps (batches) seen. + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over an epoch. + Metrics in this list will be logged as-is. + All others will be averaged over time (e.g. loss, etc). Raises: ValueError: In case of invalid `count_mode`. """ - def __init__(self, count_mode='samples'): + def __init__(self, count_mode='samples', stateful_metrics=None): super(ProgbarLogger, self).__init__() if count_mode == 'samples': self.use_steps = False @@ -280,6 +305,7 @@ class ProgbarLogger(Callback): self.use_steps = True else: raise ValueError('Unknown `count_mode`: ' + str(count_mode)) + self.stateful_metrics = set(stateful_metrics or []) def on_train_begin(self, logs=None): self.verbose = self.params['verbose'] @@ -293,7 +319,10 @@ class ProgbarLogger(Callback): else: target = self.params['samples'] self.target = target - self.progbar = Progbar(target=self.target, verbose=self.verbose) + self.progbar = Progbar( + target=self.target, + verbose=self.verbose, + stateful_metrics=self.stateful_metrics) self.seen = 0 def on_batch_begin(self, batch, logs=None): @@ -323,9 +352,10 @@ class ProgbarLogger(Callback): if k in logs: self.log_values.append((k, logs[k])) if self.verbose: - self.progbar.update(self.seen, self.log_values, force=True) + self.progbar.update(self.seen, self.log_values) +@tf_export('keras.callbacks.History') class History(Callback): """Callback that records events into a `History` object. @@ -345,6 +375,7 @@ class History(Callback): self.history.setdefault(k, []).append(v) +@tf_export('keras.callbacks.ModelCheckpoint') class ModelCheckpoint(Callback): """Save the model after every epoch. @@ -396,7 +427,7 @@ class ModelCheckpoint(Callback): if mode not in ['auto', 'min', 'max']: logging.warning('ModelCheckpoint mode %s is unknown, ' - 'fallback to auto mode.' % mode) + 'fallback to auto mode.', (mode), RuntimeWarning) mode = 'auto' if mode == 'min': @@ -423,11 +454,11 @@ class ModelCheckpoint(Callback): current = logs.get(self.monitor) if current is None: logging.warning('Can save best model only with %s available, ' - 'skipping.' % (self.monitor)) + 'skipping.', self.monitor, RuntimeWarning) else: if self.monitor_op(current, self.best): if self.verbose > 0: - print('Epoch %05d: %s improved from %0.5f to %0.5f,' + print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' ' saving model to %s' % (epoch + 1, self.monitor, self.best, current, filepath)) self.best = current @@ -437,17 +468,18 @@ class ModelCheckpoint(Callback): self.model.save(filepath, overwrite=True) else: if self.verbose > 0: - print('Epoch %05d: %s did not improve' % (epoch + 1, - self.monitor)) + print('\nEpoch %05d: %s did not improve' % (epoch + 1, + self.monitor)) else: if self.verbose > 0: - print('Epoch %05d: saving model to %s' % (epoch + 1, filepath)) + print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True) +@tf_export('keras.callbacks.EarlyStopping') class EarlyStopping(Callback): """Stop training when a monitored quantity has stopped improving. @@ -486,7 +518,7 @@ class EarlyStopping(Callback): if mode not in ['auto', 'min', 'max']: logging.warning('EarlyStopping mode %s is unknown, ' - 'fallback to auto mode.' % mode) + 'fallback to auto mode.', mode, RuntimeWarning) mode = 'auto' if mode == 'min': @@ -514,8 +546,8 @@ class EarlyStopping(Callback): current = logs.get(self.monitor) if current is None: logging.warning('Early stopping conditioned on metric `%s` ' - 'which is not available. Available metrics are: %s' % - (self.monitor, ','.join(list(logs.keys())))) + 'which is not available. Available metrics are: %s', + self.monitor, ','.join(list(logs.keys())), RuntimeWarning) return if self.monitor_op(current - self.min_delta, self.best): self.best = current @@ -531,6 +563,7 @@ class EarlyStopping(Callback): print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) +@tf_export('keras.callbacks.RemoteMonitor') class RemoteMonitor(Callback): """Callback used to stream events to a server. @@ -544,8 +577,6 @@ class RemoteMonitor(Callback): path: String; path relative to `root` to which the events will be sent. field: String; JSON field under which the data will be stored. headers: Dictionary; optional custom HTTP headers. - Defaults to: - `{'Accept': 'application/json', 'Content-Type': 'application/json'}` """ def __init__(self, @@ -554,11 +585,7 @@ class RemoteMonitor(Callback): field='data', headers=None): super(RemoteMonitor, self).__init__() - if headers is None: - headers = { - 'Accept': 'application/json', - 'Content-Type': 'application/json' - } + self.root = root self.path = path self.field = field @@ -581,6 +608,7 @@ class RemoteMonitor(Callback): 'root server at ' + str(self.root)) +@tf_export('keras.callbacks.LearningRateScheduler') class LearningRateScheduler(Callback): """Learning rate scheduler. @@ -588,11 +616,13 @@ class LearningRateScheduler(Callback): schedule: a function that takes an epoch index as input (integer, indexed from 0) and returns a new learning rate as output (float). + verbose: int. 0: quiet, 1: update messages. """ - def __init__(self, schedule): + def __init__(self, schedule, verbose=0): super(LearningRateScheduler, self).__init__() self.schedule = schedule + self.verbose = verbose def on_epoch_begin(self, epoch, logs=None): if not hasattr(self.model.optimizer, 'lr'): @@ -602,8 +632,12 @@ class LearningRateScheduler(Callback): raise ValueError('The output of the "schedule" function ' 'should be float.') K.set_value(self.model.optimizer.lr, lr) + if self.verbose > 0: + print('\nEpoch %05d: LearningRateScheduler reducing learning ' + 'rate to %s.' % (epoch + 1, lr)) +@tf_export('keras.callbacks.TensorBoard') class TensorBoard(Callback): # pylint: disable=line-too-long """Tensorboard basic visualizations. @@ -773,6 +807,7 @@ class TensorBoard(Callback): self.writer.close() +@tf_export('keras.callbacks.ReduceLROnPlateau') class ReduceLROnPlateau(Callback): """Reduce learning rate when a metric has stopped improving. @@ -842,7 +877,7 @@ class ReduceLROnPlateau(Callback): """ if self.mode not in ['auto', 'min', 'max']: logging.warning('Learning Rate Plateau Reducing mode %s is unknown, ' - 'fallback to auto mode.' % (self.mode)) + 'fallback to auto mode.', self.mode, RuntimeWarning) self.mode = 'auto' if (self.mode == 'min' or (self.mode == 'auto' and 'acc' not in self.monitor)): @@ -853,7 +888,6 @@ class ReduceLROnPlateau(Callback): self.best = -np.Inf self.cooldown_counter = 0 self.wait = 0 - self.lr_epsilon = self.min_lr * 1e-4 def on_train_begin(self, logs=None): self._reset() @@ -864,8 +898,9 @@ class ReduceLROnPlateau(Callback): current = logs.get(self.monitor) if current is None: logging.warning('Reduce LR on plateau conditioned on metric `%s` ' - 'which is not available. Available metrics are: %s' % - (self.monitor, ','.join(list(logs.keys())))) + 'which is not available. Available metrics are: %s', + self.monitor, ','.join(list(logs.keys())), RuntimeWarning) + else: if self.in_cooldown(): self.cooldown_counter -= 1 @@ -877,13 +912,13 @@ class ReduceLROnPlateau(Callback): elif not self.in_cooldown(): if self.wait >= self.patience: old_lr = float(K.get_value(self.model.optimizer.lr)) - if old_lr > self.min_lr + self.lr_epsilon: + if old_lr > self.min_lr: new_lr = old_lr * self.factor new_lr = max(new_lr, self.min_lr) K.set_value(self.model.optimizer.lr, new_lr) if self.verbose > 0: - print('\nEpoch %05d: reducing learning rate to %s.' % (epoch, - new_lr)) + print('\nEpoch %05d: ReduceLROnPlateau reducing learning ' + 'rate to %s.' % (epoch + 1, new_lr)) self.cooldown_counter = self.cooldown self.wait = 0 self.wait += 1 @@ -892,6 +927,7 @@ class ReduceLROnPlateau(Callback): return self.cooldown_counter > 0 +@tf_export('keras.callbacks.CSVLogger') class CSVLogger(Callback): """Callback that streams epoch results to a csv file. @@ -899,10 +935,11 @@ class CSVLogger(Callback): including 1D iterables such as np.ndarray. Example: - ```python - csv_logger = CSVLogger('training.log') - model.fit(X_train, Y_train, callbacks=[csv_logger]) - ``` + + ```python + csv_logger = CSVLogger('training.log') + model.fit(X_train, Y_train, callbacks=[csv_logger]) + ``` Arguments: filename: filename of the csv file, e.g. 'run/log.csv'. @@ -942,12 +979,14 @@ class CSVLogger(Callback): else: return k + if self.keys is None: + self.keys = sorted(logs.keys()) + if self.model.stop_training: # We set NA so that csv parsers do not fail for this last epoch. logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys]) if not self.writer: - self.keys = sorted(logs.keys()) class CustomDialect(csv.excel): delimiter = self.sep @@ -969,6 +1008,7 @@ class CSVLogger(Callback): self.writer = None +@tf_export('keras.callbacks.LambdaCallback') class LambdaCallback(Callback): r"""Callback for creating simple, custom callbacks on-the-fly. @@ -993,32 +1033,32 @@ class LambdaCallback(Callback): Example: - ```python - # Print the batch number at the beginning of every batch. - batch_print_callback = LambdaCallback( - on_batch_begin=lambda batch,logs: print(batch)) - - # Stream the epoch loss to a file in JSON format. The file content - # is not well-formed JSON but rather has a JSON object per line. - import json - json_log = open('loss_log.json', mode='wt', buffering=1) - json_logging_callback = LambdaCallback( - on_epoch_end=lambda epoch, logs: json_log.write( - json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'), - on_train_end=lambda logs: json_log.close() - ) - - # Terminate some processes after having finished model training. - processes = ... - cleanup_callback = LambdaCallback( - on_train_end=lambda logs: [ - p.terminate() for p in processes if p.is_alive()]) - - model.fit(..., - callbacks=[batch_print_callback, - json_logging_callback, - cleanup_callback]) - ``` + ```python + # Print the batch number at the beginning of every batch. + batch_print_callback = LambdaCallback( + on_batch_begin=lambda batch,logs: print(batch)) + + # Stream the epoch loss to a file in JSON format. The file content + # is not well-formed JSON but rather has a JSON object per line. + import json + json_log = open('loss_log.json', mode='wt', buffering=1) + json_logging_callback = LambdaCallback( + on_epoch_end=lambda epoch, logs: json_log.write( + json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'), + on_train_end=lambda logs: json_log.close() + ) + + # Terminate some processes after having finished model training. + processes = ... + cleanup_callback = LambdaCallback( + on_train_end=lambda logs: [ + p.terminate() for p in processes if p.is_alive()]) + + model.fit(..., + callbacks=[batch_print_callback, + json_logging_callback, + cleanup_callback]) + ``` """ def __init__(self, diff --git a/tensorflow/python/keras/_impl/keras/constraints.py b/tensorflow/python/keras/_impl/keras/constraints.py index e58e3b0377b4b0fcad923095177c54d9c3ee1c0b..271fbbb63d3dfd50507837e190860d48315a14f2 100644 --- a/tensorflow/python/keras/_impl/keras/constraints.py +++ b/tensorflow/python/keras/_impl/keras/constraints.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Constraints: functions that impose constraints on weights values. +# pylint: disable=invalid-name +"""Constraints: functions that impose constraints on weight values. """ from __future__ import absolute_import from __future__ import division @@ -23,8 +24,10 @@ import six from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.constraints.Constraint') class Constraint(object): def __call__(self, w): @@ -34,6 +37,7 @@ class Constraint(object): return {} +@tf_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm') class MaxNorm(Constraint): """MaxNorm weight constraint. @@ -54,10 +58,6 @@ class MaxNorm(Constraint): to constrain the weights of each filter tensor of size `(rows, cols, input_depth)`. - References: - - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting - Srivastava, Hinton, et al. - 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) """ def __init__(self, max_value=2, axis=0): @@ -67,22 +67,22 @@ class MaxNorm(Constraint): def __call__(self, w): norms = K.sqrt(K.sum(K.square(w), axis=self.axis, keepdims=True)) desired = K.clip(norms, 0, self.max_value) - w *= (desired / (K.epsilon() + norms)) - return w + return w * (desired / (K.epsilon() + norms)) def get_config(self): return {'max_value': self.max_value, 'axis': self.axis} +@tf_export('keras.constraints.NonNeg', 'keras.constraints.non_neg') class NonNeg(Constraint): """Constrains the weights to be non-negative. """ def __call__(self, w): - w *= K.cast(w >= 0., K.floatx()) - return w + return w * K.cast(K.greater_equal(w, 0.), K.floatx()) +@tf_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm') class UnitNorm(Constraint): """Constrains the weights incident to each hidden unit to have unit norm. @@ -111,6 +111,7 @@ class UnitNorm(Constraint): return {'axis': self.axis} +@tf_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm') class MinMaxNorm(Constraint): """MinMaxNorm weight constraint. @@ -132,7 +133,7 @@ class MinMaxNorm(Constraint): has shape `(input_dim, output_dim)`, set `axis` to `0` to constrain each weight vector of length `(input_dim,)`. - In a `Conv2D` layer with `dim_ordering="channels_last"`, + In a `Conv2D` layer with `data_format="channels_last"`, the weight tensor has shape `(rows, cols, input_depth, output_depth)`, set `axis` to `[0, 1, 2]` @@ -148,10 +149,10 @@ class MinMaxNorm(Constraint): def __call__(self, w): norms = K.sqrt(K.sum(K.square(w), axis=self.axis, keepdims=True)) - desired = (self.rate * K.clip(norms, self.min_value, self.max_value) + - (1 - self.rate) * norms) - w *= (desired / (K.epsilon() + norms)) - return w + desired = ( + self.rate * K.clip(norms, self.min_value, self.max_value) + + (1 - self.rate) * norms) + return w * (desired / (K.epsilon() + norms)) def get_config(self): return { @@ -164,19 +165,23 @@ class MinMaxNorm(Constraint): # Aliases. -# pylint: disable=invalid-name max_norm = MaxNorm non_neg = NonNeg unit_norm = UnitNorm min_max_norm = MinMaxNorm -# pylint: enable=invalid-name +# Legacy aliases. +maxnorm = max_norm +nonneg = non_neg +unitnorm = unit_norm +@tf_export('keras.constraints.serialize') def serialize(constraint): return serialize_keras_object(constraint) +@tf_export('keras.constraints.deserialize') def deserialize(config, custom_objects=None): return deserialize_keras_object( config, @@ -185,6 +190,7 @@ def deserialize(config, custom_objects=None): printable_module_name='constraint') +@tf_export('keras.constraints.get') def get(identifier): if identifier is None: return None @@ -196,4 +202,5 @@ def get(identifier): elif callable(identifier): return identifier else: - raise ValueError('Could not interpret constraint identifier:', identifier) + raise ValueError('Could not interpret constraint identifier: ' + + str(identifier)) diff --git a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py index 0570e9bc0c7344641edf44cd5ef03a4f09005061..13fa9aed2b8da124af4e9f68c779e08d3094cb5d 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py +++ b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py @@ -25,25 +25,25 @@ from tensorflow.python.util.tf_export import tf_export @tf_export('keras.datasets.boston_housing.load_data') -def load_data(path='boston_housing.npz', seed=113, test_split=0.2): +def load_data(path='boston_housing.npz', test_split=0.2, seed=113): """Loads the Boston Housing dataset. Arguments: path: path where to cache the dataset locally (relative to ~/.keras/datasets). + test_split: fraction of the data to reserve as test set. seed: Random seed for shuffling the data before computing the test split. - test_split: fraction of the data to reserve as test set. Returns: Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ assert 0 <= test_split < 1 - fh = 'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5' path = get_file( path, origin='https://s3.amazonaws.com/keras-datasets/boston_housing.npz', - file_hash=fh) + file_hash= + 'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5') f = np.load(path) x = f['x'] y = f['y'] diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar.py b/tensorflow/python/keras/_impl/keras/datasets/cifar.py index 564709c0eed6778b9809eb8c23556cac3c4702d9..02344897f774723d0ad690ae641952cb63022bdf 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar.py +++ b/tensorflow/python/keras/_impl/keras/datasets/cifar.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities used by the CIFAR10 and CIFAR100 datasets. +"""Utilities common to CIFAR10 and CIFAR100 datasets. """ - from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -35,17 +34,16 @@ def load_batch(fpath, label_key='labels'): Returns: A tuple `(data, labels)`. """ - f = open(fpath, 'rb') - if sys.version_info < (3,): - d = cPickle.load(f) - else: - d = cPickle.load(f, encoding='bytes') - # decode utf8 - d_decoded = {} - for k, v in d.items(): - d_decoded[k.decode('utf8')] = v - d = d_decoded - f.close() + with open(fpath, 'rb') as f: + if sys.version_info < (3,): + d = cPickle.load(f) + else: + d = cPickle.load(f, encoding='bytes') + # decode utf8 + d_decoded = {} + for k, v in d.items(): + d_decoded[k.decode('utf8')] = v + d = d_decoded data = d['data'] labels = d[label_key] diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py index 1971f434b9af820af287a3848ef538f5163a2a9a..6b772433822474c06efcce1701226a4a67abe361 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py +++ b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""CIFAR10 small image classification dataset. +"""CIFAR10 small images classification dataset. """ from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py index f4039e935076a55baaf471ad544986082a4e4ad8..28d74116a50979abab207dbec88e384210dfc070 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py +++ b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""CIFAR100 small image classification dataset. +"""CIFAR100 small images classification dataset. """ from __future__ import absolute_import from __future__ import division @@ -42,7 +42,7 @@ def load_data(label_mode='fine'): ValueError: in case of invalid `label_mode`. """ if label_mode not in ['fine', 'coarse']: - raise ValueError('label_mode must be one of "fine" "coarse".') + raise ValueError('`label_mode` must be one of `"fine"`, `"coarse"`.') dirname = 'cifar-100-python' origin = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' diff --git a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py index 17be684e4f8bdb800c6b0883649da25f18fa0402..b9ae41a0d4d0e8d9df70e3fc1952e81c5f57e8d9 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py +++ b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py @@ -20,7 +20,9 @@ from __future__ import print_function import gzip import os + import numpy as np + from tensorflow.python.keras._impl.keras.utils.data_utils import get_file @@ -38,9 +40,8 @@ def load_data(): ] paths = [] - for given_file in files: - paths.append( - get_file(given_file, origin=base + given_file, cache_subdir=dirname)) + for fname in files: + paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname)) with gzip.open(paths[0], 'rb') as lbpath: y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) diff --git a/tensorflow/python/keras/_impl/keras/datasets/imdb.py b/tensorflow/python/keras/_impl/keras/datasets/imdb.py index 7946c46960ef15fdcaff6b5ad9f0bc2623a84b17..7467bb24646227705972262381aa5cf1de809f1c 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/imdb.py +++ b/tensorflow/python/keras/_impl/keras/datasets/imdb.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""IMDB movie review sentiment classification dataset. +"""IMDB sentiment classification dataset. """ from __future__ import absolute_import from __future__ import division @@ -21,9 +21,10 @@ from __future__ import print_function import json import numpy as np -from six.moves import zip # pylint: disable=redefined-builtin +from tensorflow.python.keras._impl.keras.preprocessing.sequence import _remove_long_seq from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -35,7 +36,8 @@ def load_data(path='imdb.npz', seed=113, start_char=1, oov_char=2, - index_from=3): + index_from=3, + **kwargs): """Loads the IMDB dataset. Arguments: @@ -52,6 +54,7 @@ def load_data(path='imdb.npz', oov_char: words that were cut out because of the `num_words` or `skip_top` limit will be replaced with this character. index_from: index actual words with this index and higher. + **kwargs: Used for backwards compatibility. Returns: Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. @@ -66,14 +69,21 @@ def load_data(path='imdb.npz', Words that were not seen in the training set but are in the test set have simply been skipped. """ + # Legacy support + if 'nb_words' in kwargs: + logging.warning('The `nb_words` argument in `load_data` ' + 'has been renamed `num_words`.') + num_words = kwargs.pop('nb_words') + if kwargs: + raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) + path = get_file( path, origin='https://s3.amazonaws.com/text-datasets/imdb.npz', file_hash='599dadb1135973df5b59232a0e9a887c') - f = np.load(path) - x_train, labels_train = f['x_train'], f['y_train'] - x_test, labels_test = f['x_test'], f['y_test'] - f.close() + with np.load(path) as f: + x_train, labels_train = f['x_train'], f['y_train'] + x_test, labels_test = f['x_test'], f['y_test'] np.random.seed(seed) indices = np.arange(len(x_train)) @@ -95,14 +105,7 @@ def load_data(path='imdb.npz', xs = [[w + index_from for w in x] for x in xs] if maxlen: - new_xs = [] - new_labels = [] - for x, y in zip(xs, labels): - if len(x) < maxlen: - new_xs.append(x) - new_labels.append(y) - xs = new_xs - labels = new_labels + xs, labels = _remove_long_seq(maxlen, xs, labels) if not xs: raise ValueError('After filtering for sequences shorter than maxlen=' + str(maxlen) + ', no sequence was kept. ' @@ -114,23 +117,15 @@ def load_data(path='imdb.npz', # reserve 'index_from' (=3 by default) characters: # 0 (padding), 1 (start), 2 (OOV) if oov_char is not None: - xs = [[oov_char if (w >= num_words or w < skip_top) else w for w in x] - for x in xs] + xs = [ + [w if (skip_top <= w < num_words) else oov_char for w in x] for x in xs + ] else: - new_xs = [] - for x in xs: - nx = [] - for w in x: - if skip_top <= w < num_words: - nx.append(w) - new_xs.append(nx) - xs = new_xs - - x_train = np.array(xs[:len(x_train)]) - y_train = np.array(labels[:len(x_train)]) + xs = [[w for w in x if skip_top <= w < num_words] for x in xs] - x_test = np.array(xs[len(x_train):]) - y_test = np.array(labels[len(x_train):]) + idx = len(x_train) + x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx]) + x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:]) return (x_train, y_train), (x_test, y_test) @@ -147,8 +142,7 @@ def get_word_index(path='imdb_word_index.json'): """ path = get_file( path, - origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.json') - f = open(path) - data = json.load(f) - f.close() - return data + origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.json', + file_hash='bfafd718b763782e994055a2d397834f') + with open(path) as f: + return json.load(f) diff --git a/tensorflow/python/keras/_impl/keras/datasets/mnist.py b/tensorflow/python/keras/_impl/keras/datasets/mnist.py index e9f53480150034d3e83f85cfad67f63e61422f3e..e30691373e9aafad61b101476e21d6860527ce98 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/mnist.py +++ b/tensorflow/python/keras/_impl/keras/datasets/mnist.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""MNIST handwritten digits classification dataset. +"""MNIST handwritten digits dataset. """ from __future__ import absolute_import from __future__ import division @@ -40,9 +40,7 @@ def load_data(path='mnist.npz'): origin='https://s3.amazonaws.com/img-datasets/mnist.npz', file_hash='8a61469f7ea1b51cbae51d4f78837e45') f = np.load(path) - x_train = f['x_train'] - y_train = f['y_train'] - x_test = f['x_test'] - y_test = f['y_test'] + x_train, y_train = f['x_train'], f['y_train'] + x_test, y_test = f['x_test'], f['y_test'] f.close() return (x_train, y_train), (x_test, y_test) diff --git a/tensorflow/python/keras/_impl/keras/datasets/reuters.py b/tensorflow/python/keras/_impl/keras/datasets/reuters.py index 6da5aa4b5eb8b8eb5dcd8c75c3f1f86340436601..b711696b5eecf9ba07a66cef25c1811c182b3b60 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/reuters.py +++ b/tensorflow/python/keras/_impl/keras/datasets/reuters.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Reuters newswire topic classification dataset. +"""Reuters topic classification dataset. """ - from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -22,9 +21,10 @@ from __future__ import print_function import json import numpy as np -from six.moves import zip # pylint: disable=redefined-builtin +from tensorflow.python.keras._impl.keras.preprocessing.sequence import _remove_long_seq from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -37,7 +37,8 @@ def load_data(path='reuters.npz', seed=113, start_char=1, oov_char=2, - index_from=3): + index_from=3, + **kwargs): """Loads the Reuters newswire classification dataset. Arguments: @@ -55,6 +56,7 @@ def load_data(path='reuters.npz', oov_char: words that were cut out because of the `num_words` or `skip_top` limit will be replaced with this character. index_from: index actual words with this index and higher. + **kwargs: Used for backwards compatibility. Returns: Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. @@ -65,14 +67,20 @@ def load_data(path='reuters.npz', Words that were not seen in the training set but are in the test set have simply been skipped. """ + # Legacy support + if 'nb_words' in kwargs: + logging.warning('The `nb_words` argument in `load_data` ' + 'has been renamed `num_words`.') + num_words = kwargs.pop('nb_words') + if kwargs: + raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) + path = get_file( path, origin='https://s3.amazonaws.com/text-datasets/reuters.npz', file_hash='87aedbeb0cb229e378797a632c1997b6') - npzfile = np.load(path) - xs = npzfile['x'] - labels = npzfile['y'] - npzfile.close() + with np.load(path) as f: + xs, labels = f['x'], f['y'] np.random.seed(seed) indices = np.arange(len(xs)) @@ -80,22 +88,13 @@ def load_data(path='reuters.npz', xs = xs[indices] labels = labels[indices] - np.random.shuffle(labels) - if start_char is not None: xs = [[start_char] + [w + index_from for w in x] for x in xs] elif index_from: xs = [[w + index_from for w in x] for x in xs] if maxlen: - new_xs = [] - new_labels = [] - for x, y in zip(xs, labels): - if len(x) < maxlen: - new_xs.append(x) - new_labels.append(y) - xs = new_xs - labels = new_labels + xs, labels = _remove_long_seq(maxlen, xs, labels) if not num_words: num_words = max([max(x) for x in xs]) @@ -104,23 +103,13 @@ def load_data(path='reuters.npz', # reserve 'index_from' (=3 by default) characters: # 0 (padding), 1 (start), 2 (OOV) if oov_char is not None: - xs = [[oov_char if (w >= num_words or w < skip_top) else w for w in x] - for x in xs] + xs = [[w if skip_top <= w < num_words else oov_char for w in x] for x in xs] else: - new_xs = [] - for x in xs: - nx = [] - for w in x: - if skip_top <= w < num_words: - nx.append(w) - new_xs.append(nx) - xs = new_xs - - x_train = np.array(xs[:int(len(xs) * (1 - test_split))]) - y_train = np.array(labels[:int(len(xs) * (1 - test_split))]) - - x_test = np.array(xs[int(len(xs) * (1 - test_split)):]) - y_test = np.array(labels[int(len(xs) * (1 - test_split)):]) + xs = [[w for w in x if skip_top <= w < num_words] for x in xs] + + idx = int(len(xs) * (1 - test_split)) + x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx]) + x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:]) return (x_train, y_train), (x_test, y_test) diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py index d6e0be8e432eb535a053a4c09fda35a32f6c70f3..7de5af41c5e04e046e7d6798706f630374d5640f 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology.py +++ b/tensorflow/python/keras/_impl/keras/engine/topology.py @@ -27,6 +27,7 @@ import numpy as np from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context +from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers @@ -38,6 +39,8 @@ from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.layers import network as tf_network from tensorflow.python.layers import utils as tf_layers_util from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export # pylint: disable=g-import-not-at-top @@ -59,6 +62,7 @@ TFBaseLayer = tf_base_layers.Layer # pylint: enable=invalid-name +@tf_export('keras.layers.Layer') class Layer(tf_base_layers.Layer): """Abstract base layer class. @@ -258,6 +262,10 @@ class Layer(tf_base_layers.Layer): if context.in_eager_mode(): return output + # Un-built subclassed network: build it + if isinstance(self, Network) and not self.inputs: + self._set_inputs(inputs, training=kwargs.get('training')) + # Update learning phase info. output_tensors = _to_list(output) uses_lp = any( @@ -489,6 +497,7 @@ class Layer(tf_base_layers.Layer): self._activity_regularizer = activity_regularizer +@tf_export('keras.layers.InputLayer') class InputLayer(tf_network.InputLayer, Layer): """Layer to be used as an entry point into a graph. @@ -551,6 +560,7 @@ class InputLayer(tf_network.InputLayer, Layer): return config +@tf_export('keras.layers.Input', 'keras.Input') def Input( # pylint: disable=invalid-name shape=None, batch_size=None, @@ -676,10 +686,28 @@ class Network(tf_network.GraphNetwork, Layer): from_config """ - def __init__(self, inputs, outputs, name=None): + def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called + # Signature detection + if (len(args) == 2 or + len(args) == 1 and 'outputs' in kwargs or + 'inputs' in kwargs and 'outputs' in kwargs): + # Graph network + self._init_graph_network(*args, **kwargs) + else: + # Subclassed network + self._init_subclassed_network(**kwargs) + + def _init_graph_network(self, inputs, outputs, name=None): + # TODO(fchollet): merge back tf.layers.Network and tf.keras.Network + # into a single class tf.keras.Network super(Network, self).__init__(inputs, outputs, name=name) + self._is_compiled = False + self._expects_training_arg = False + self.supports_masking = False + self.optimizer = None + # Fill in the output mask cache. masks = [] for x in self.inputs: @@ -707,13 +735,68 @@ class Network(tf_network.GraphNetwork, Layer): self.input_names.append(layer.name) if layer.is_placeholder: self._feed_input_names.append(layer.name) - self._feed_inputs.append(layer.input) self._feed_input_shapes.append(K.int_shape(self.inputs[i])) + # layer.input gives an error in eager mode + if context.in_graph_mode(): + self._feed_inputs.append(layer.input) for layer in self._output_layers: self.output_names.append(layer.name) - self.internal_input_shapes = [K.int_shape(x) for x in self.inputs] - self.internal_output_shapes = [K.int_shape(x) for x in self.outputs] + def _init_subclassed_network(self, name=None): + self._init_set_name(name) + self._layers = [] + self._is_graph_network = False + self._is_compiled = False + if 'training' in tf_inspect.getargspec(self.call).args: + self._expects_training_arg = True + else: + self._expects_training_arg = False + + self.outputs = None + self.inputs = None + self.trainable = True + self.supports_masking = False + self.built = False + self.optimizer = None + + # Not used, exists for compatibility purposes due to implementation of + # the base layer tf.layers.Layer - TODO(fchollet): clean up when refactoring + self._scope = None + self._reuse = None + self._dtype = None + self._graph = None + self._activity_regularizer = None + + # Used in symbolic mode only + self._updates = [] + self._losses = [] + + # Used in symbolic mode only, only in conjonction with graph-networks + self._outbound_nodes = [] + self._inbound_nodes = [] + + def __setattr__(self, name, value): + if isinstance(value, (tf_base_layers.Layer, Network)): + try: + is_graph_network = self._is_graph_network + except AttributeError: + raise RuntimeError('It looks like you are subclassing `Model` and you ' + 'forgot to call `super(YourClass, self).__init__()`.' + ' Always start with this line.') + if not is_graph_network: + if value not in self._layers: + self._layers.append(value) + super(Network, self).__setattr__(name, value) + + def add_variable(self, name, shape, dtype=None, initializer=None, + regularizer=None, trainable=True, constraint=None): + raise NotImplementedError('`add_variable` is not supported on Networks') + + def add_loss(self, *args, **kwargs): + if context.in_eager_mode(): + raise NotImplementedError('`add_loss` is not supported in eager-mode ' + 'on Networks') + super(Network, self).add_loss(*args, **kwargs) @property def uses_learning_phase(self): @@ -776,13 +859,16 @@ class Network(tf_network.GraphNetwork, Layer): K.batch_set_value(tuples) def compute_mask(self, inputs, mask): + if not self._is_graph_network: + return None + inputs = _to_list(inputs) if mask is None: masks = [None for _ in range(len(inputs))] else: masks = _to_list(mask) - cache_key = ','.join([str(id(x)) for x in inputs]) - cache_key += '_' + ','.join([str(id(x)) for x in masks]) + cache_key = (tf_layers_util.object_list_uid(inputs) + + '_' + tf_layers_util.object_list_uid(masks)) if cache_key in self._output_mask_cache: return self._output_mask_cache[cache_key] else: @@ -790,6 +876,9 @@ class Network(tf_network.GraphNetwork, Layer): return output_masks def get_config(self): + if not self._is_graph_network: + raise NotImplementedError + config = { 'name': self.name, } @@ -1039,6 +1128,9 @@ class Network(tf_network.GraphNetwork, Layer): model = load_model('my_model.h5') ``` """ + if not self._is_graph_network: + raise NotImplementedError + from tensorflow.python.keras._impl.keras.models import save_model # pylint: disable=g-import-not-at-top save_model(self, filepath, overwrite, include_optimizer) @@ -1070,10 +1162,8 @@ class Network(tf_network.GraphNetwork, Layer): proceed = ask_to_proceed_with_overwrite(filepath) if not proceed: return - f = h5py.File(filepath, 'w') - save_weights_to_hdf5_group(f, self.layers) - f.flush() - f.close() + with h5py.File(filepath, 'w') as f: + save_weights_to_hdf5_group(f, self.layers) def load_weights(self, filepath, by_name=False): """Loads all layer weights from a HDF5 save file. @@ -1100,16 +1190,13 @@ class Network(tf_network.GraphNetwork, Layer): """ if h5py is None: raise ImportError('`load_weights` requires h5py.') - f = h5py.File(filepath, mode='r') - if 'layer_names' not in f.attrs and 'model_weights' in f: - f = f['model_weights'] - if by_name: - load_weights_from_hdf5_group_by_name(f, self.layers) - else: - load_weights_from_hdf5_group(f, self.layers) - - if hasattr(f, 'close'): - f.close() + with h5py.File(filepath, 'r') as f: + if 'layer_names' not in f.attrs and 'model_weights' in f: + f = f['model_weights'] + if by_name: + load_weights_from_hdf5_group_by_name(f, self.layers) + else: + load_weights_from_hdf5_group(f, self.layers) def _updated_config(self): """Util hared between different serialization methods. @@ -1141,6 +1228,8 @@ class Network(tf_network.GraphNetwork, Layer): Returns: A JSON string. """ + if not self._is_graph_network: + raise NotImplementedError def get_json_type(obj): # If obj is any numpy type @@ -1176,6 +1265,9 @@ class Network(tf_network.GraphNetwork, Layer): Raises: ImportError: if yaml module is not found. """ + if not self._is_graph_network: + raise NotImplementedError + if yaml is None: raise ImportError('Requires yaml module installed.') return yaml.dump(self._updated_config(), **kwargs) @@ -1303,18 +1395,17 @@ def preprocess_weights_for_loading(layer, Returns: A list of weights values (Numpy arrays). """ - if original_keras_version == '1': - if layer.__class__.__name__ == 'Bidirectional': - num_weights_per_layer = len(weights) // 2 - - forward_weights = preprocess_weights_for_loading( - layer.forward_layer, weights[:num_weights_per_layer], - original_keras_version, original_backend) - backward_weights = preprocess_weights_for_loading( - layer.backward_layer, weights[num_weights_per_layer:], - original_keras_version, original_backend) - weights = forward_weights + backward_weights + if layer.__class__.__name__ == 'Bidirectional': + num_weights_per_layer = len(weights) // 2 + forward_weights = preprocess_weights_for_loading( + layer.forward_layer, weights[:num_weights_per_layer], + original_keras_version, original_backend) + backward_weights = preprocess_weights_for_loading( + layer.backward_layer, weights[num_weights_per_layer:], + original_keras_version, original_backend) + weights = forward_weights + backward_weights + if original_keras_version == '1': if layer.__class__.__name__ == 'TimeDistributed': weights = preprocess_weights_for_loading( layer.layer, weights, original_keras_version, original_backend) @@ -1418,7 +1509,7 @@ def preprocess_weights_for_loading(layer, conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] if layer.__class__.__name__ in conv_layers: - if original_backend and K.backend() != original_backend: + if original_backend == 'theano': weights[0] = conv_utils.convert_kernel(weights[0]) if layer.__class__.__name__ == 'ConvLSTM2D': weights[1] = conv_utils.convert_kernel(weights[1]) @@ -1427,10 +1518,9 @@ def preprocess_weights_for_loading(layer, if layer.__class__.__name__ == 'ConvLSTM2D': weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) - # convert the weights of CuDNNLSTM so that they could be loaded into LSTM + # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM if layer.__class__.__name__ == 'LSTM' and len(weights) == 3: - # determine if we're loading a CuDNNLSTM layer from the number of bias - # weights: + # Determine if loading a CuDNNLSTM layer from the number of bias weights: # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) # if there's no bias weight in the file, skip this conversion units = weights[1].shape[0] @@ -1572,3 +1662,31 @@ def load_weights_from_hdf5_group_by_name(f, layers): for i in range(len(weight_values)): weight_value_tuples.append((symbolic_weights[i], weight_values[i])) K.batch_set_value(weight_value_tuples) + + +def shape_type_conversion(fn): + """Decorator that handles tuple/TensorShape conversion. + + Used in `compute_output_shape` and `build`. + + Arguments: + fn: function to wrap. + + Returns: + Wrapped function. + """ + + def wrapper(instance, input_shape): + if input_shape is not None: + if isinstance(input_shape, list): + input_shape = [ + tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape] + else: + input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) + output_shape = fn(instance, input_shape) + if output_shape is not None: + if isinstance(output_shape, list): + return [tensor_shape.TensorShape(x) for x in output_shape] + return tensor_shape.TensorShape(output_shape) + + return wrapper diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py index 479ee877fd2471a67b5b5b81e8fbf338ce755a7b..28ddc094ee585ca4011d0cdaf190cfe826a2f0ce 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py @@ -26,6 +26,8 @@ import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.keras._impl import keras from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import test try: @@ -42,22 +44,28 @@ except ImportError: class TopologyConstructionTest(test.TestCase): def test_get_updates_for(self): - a = keras.layers.Input(shape=(2,)) + a = keras.layers.Input(shape=(1,)) dense_layer = keras.layers.Dense(1) - dense_layer.add_update(0, inputs=a) - dense_layer.add_update(1, inputs=None) + dense_layer.build((None, 1)) + update_1 = state_ops.assign_add(dense_layer.kernel, a) + update_2 = state_ops.assign_add(dense_layer.kernel, [[1.]]) + dense_layer.add_update(update_1, inputs=a) + dense_layer.add_update(update_2, inputs=None) - self.assertListEqual(dense_layer.get_updates_for(a), [0]) - self.assertListEqual(dense_layer.get_updates_for(None), [1]) + self.assertListEqual(dense_layer.get_updates_for(a), [update_1]) + self.assertListEqual(dense_layer.get_updates_for(None), [update_2]) def test_get_losses_for(self): - a = keras.layers.Input(shape=(2,)) + a = keras.layers.Input(shape=(1,)) dense_layer = keras.layers.Dense(1) - dense_layer.add_loss(0, inputs=a) - dense_layer.add_loss(1, inputs=None) + dense_layer.build((None, 1)) + loss_1 = math_ops.reduce_sum(a) + loss_2 = math_ops.reduce_sum(dense_layer.kernel) + dense_layer.add_loss(loss_1, inputs=a) + dense_layer.add_loss(loss_2, inputs=None) - self.assertListEqual(dense_layer.get_losses_for(a), [0]) - self.assertListEqual(dense_layer.get_losses_for(None), [1]) + self.assertListEqual(dense_layer.get_losses_for(a), [loss_1]) + self.assertListEqual(dense_layer.get_losses_for(None), [loss_2]) def test_trainable_weights(self): a = keras.layers.Input(shape=(2,)) @@ -340,6 +348,7 @@ class TopologyConstructionTest(test.TestCase): e = keras.layers.Input(shape=(32,), name='input_e') f = keras.layers.Input(shape=(32,), name='input_f') g, h = model([e, f]) + self.assertEqual(g.name, 'model_1/dense_2/BiasAdd:0') self.assertListEqual(g.get_shape().as_list(), c.get_shape().as_list()) self.assertListEqual(h.get_shape().as_list(), d.get_shape().as_list()) @@ -546,6 +555,27 @@ class TopologyConstructionTest(test.TestCase): model = keras.models.Model(a, b) self.assertEqual(model.output_mask.get_shape().as_list(), [None, 10]) + def test_activity_regularization_with_model_composition(self): + + def reg(x): + return keras.backend.sum(x) + + net_a_input = keras.Input((2,)) + net_a = net_a_input + net_a = keras.layers.Dense(2, kernel_initializer='ones', + use_bias=False, + activity_regularizer=reg)(net_a) + model_a = keras.Model([net_a_input], [net_a]) + + net_b_input = keras.Input((2,)) + net_b = model_a(net_b_input) + model_b = keras.Model([net_b_input], [net_b]) + + model_b.compile(optimizer='sgd', loss=None) + x = np.ones((1, 2)) + loss = model_b.evaluate(x) + self.assertEqual(loss, 4.) + def test_weight_preprocessing(self): input_dim = 3 output_dim = 3 diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index debea2503ee2e440000847c0ce92185e3d230138..d8ea2fe3db500d3b52d80e46b0cff22a3d1c5915 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras training and evaluation routines. +"""Training-related part of the Keras engine. """ - from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -23,17 +22,32 @@ import copy import numpy as np +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import callbacks as cbks from tensorflow.python.keras._impl.keras import losses from tensorflow.python.keras._impl.keras import metrics as metrics_module from tensorflow.python.keras._impl.keras import optimizers +from tensorflow.python.keras._impl.keras.engine import training_eager +from tensorflow.python.keras._impl.keras.engine.topology import Layer from tensorflow.python.keras._impl.keras.engine.topology import Network from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence +from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays +from tensorflow.python.layers.base import _DeferredTensor from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import optimizer as tf_optimizer_module +from tensorflow.python.util.tf_export import tf_export + +try: + from scipy.sparse import issparse # pylint: disable=g-import-not-at-top +except ImportError: + issparse = None def _standardize_input_data(data, @@ -70,89 +84,75 @@ def _standardize_input_data(data, return [] if data is None: return [None for _ in range(len(names))] + if isinstance(data, dict): - for key, value in data.items(): - if value.__class__.__name__ == 'DataFrame': - data[key] = value.values - arrays = [] - for name in names: - if name not in data: - raise ValueError('No data provided for "' + name + - '". Need data for each key in: ' + str(names)) - arrays.append(data[name]) + try: + data = [ + data[x].values + if data[x].__class__.__name__ == 'DataFrame' else data[x] + for x in names + ] + except KeyError as e: + raise ValueError('No data provided for "' + e.args[0] + '". Need data ' + 'for each key in: ' + str(names)) elif isinstance(data, list): - for key, value in enumerate(data): - if value.__class__.__name__ == 'DataFrame': - data[key] = value.values - if len(data) != len(names): - if data and hasattr(data[0], 'shape'): - raise ValueError( - 'Error when checking model ' + exception_prefix + - ': the list of Numpy arrays ' - 'that you are passing to your model ' - 'is not the size the model expected. ' - 'Expected to see ' + str(len(names)) + ' array(s), but instead got ' - 'the following list of ' + str(len(data)) + ' arrays: ' + - str(data)[:200] + '...') - else: - if len(names) == 1: - data = [np.asarray(data)] - else: - raise ValueError('Error when checking model ' + exception_prefix + - ': you are passing a list as ' - 'input to your model, ' - 'but the model expects ' - 'a list of ' + str(len(names)) + - ' Numpy arrays instead. ' - 'The list you passed was: ' + str(data)[:200]) - arrays = data - elif data.__class__.__name__ == 'DataFrame': - # test if data is a DataFrame, without pandas installed - arrays = data.values + if isinstance(data[0], list): + data = [np.asarray(d) for d in data] + elif len(names) == 1 and isinstance(data[0], (float, int)): + data = [np.asarray(data)] + else: + data = [ + x.values if x.__class__.__name__ == 'DataFrame' else x for x in data + ] else: - if not hasattr(data, 'shape'): + data = data.values if data.__class__.__name__ == 'DataFrame' else data + data = [data] + data = [ + np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data + ] + + if len(data) != len(names): + if data and hasattr(data[0], 'shape'): + raise ValueError('Error when checking model ' + exception_prefix + + ': the list of Numpy arrays that you are passing to ' + 'your model is not the size the model expected. ' + 'Expected to see ' + str(len(names)) + ' array(s), ' + 'but instead got the following list of ' + + str(len(data)) + ' arrays: ' + str(data)[:200] + '...') + elif len(names) > 1: + raise ValueError( + 'Error when checking model ' + exception_prefix + + ': you are passing a list as input to your model, ' + 'but the model expects a list of ' + str(len(names)) + + ' Numpy arrays instead. The list you passed was: ' + str(data)[:200]) + elif len(data) == 1 and not hasattr(data[0], 'shape'): raise TypeError('Error when checking model ' + exception_prefix + - ': data should be a Numpy array, ' - 'or list/dict of Numpy arrays. ' - 'Found: ' + str(data)[:200] + '...') - if len(names) > 1: - # Case: model expects multiple inputs but only received - # a single Numpy array. - raise ValueError('The model expects ' + str(len(names)) + ' ' + - exception_prefix + - ' arrays, but only received one array. ' - 'Found: array with shape ' + str(data.shape)) - arrays = [data] - - # Make arrays at least 2D. - for i in range(len(names)): - array = arrays[i] - if len(array.shape) == 1: - array = np.expand_dims(array, 1) - arrays[i] = array + ': data should be a Numpy array, or list/dict of ' + 'Numpy arrays. Found: ' + str(data)[:200] + '...') + elif len(names) == 1: + data = [np.asarray(data)] # Check shapes compatibility. if shapes: for i in range(len(names)): - if shapes[i] is None: - continue - array = arrays[i] - if len(array.shape) != len(shapes[i]): - raise ValueError( - 'Error when checking ' + exception_prefix + ': expected ' + names[i] - + ' to have ' + str(len(shapes[i])) + - ' dimensions, but got array with shape ' + str(array.shape)) - for j, (dim, ref_dim) in enumerate(zip(array.shape, shapes[i])): - if not j and not check_batch_axis: - # skip the first axis - continue - if ref_dim: - if ref_dim != dim: - raise ValueError('Error when checking ' + exception_prefix + - ': expected ' + names[i] + ' to have shape ' + - str(shapes[i]) + ' but got array with shape ' + - str(array.shape)) - return arrays + if shapes[i] is not None: + data_shape = data[i].shape + shape = shapes[i] + if data[i].ndim != len(shape): + raise ValueError('Error when checking ' + exception_prefix + + ': expected ' + names[i] + ' to have ' + + str(len(shape)) + ' dimensions, but got array ' + 'with shape ' + str(data_shape)) + if not check_batch_axis: + data_shape = data_shape[1:] + shape = shape[1:] + for dim, ref_dim in zip(data_shape, shape): + if ref_dim != dim and ref_dim: + raise ValueError( + 'Error when checking ' + exception_prefix + ': expected ' + + names[i] + ' to have shape ' + str(shape) + + ' but got array with shape ' + str(data_shape)) + return data def _standardize_sample_or_class_weights(x_weight, output_names, weight_type): @@ -193,10 +193,10 @@ def _standardize_sample_or_class_weights(x_weight, output_names, weight_type): x_weights.append(x_weight.get(name)) return x_weights else: - raise TypeError('The model has multiple outputs, so `' + weight_type + '` ' - 'should be either a list or a dict. ' - 'Provided `' + weight_type + '` type not understood: ' + - str(x_weight)) + raise TypeError( + 'The model has multiple outputs, so `' + weight_type + '` ' + 'should be either a list or a dict. ' + 'Provided `' + weight_type + '` type not understood: ' + str(x_weight)) def _standardize_class_weights(class_weight, output_names): @@ -234,12 +234,12 @@ def _check_array_lengths(inputs, targets, weights=None): set_w = set_of_lengths(weights) if len(set_x) > 1: raise ValueError('All input arrays (x) should have ' - 'the same number of samples. Got array shapes: ' + str( - [x.shape for x in inputs])) + 'the same number of samples. Got array shapes: ' + + str([x.shape for x in inputs])) if len(set_y) > 1: raise ValueError('All target arrays (y) should have ' - 'the same number of samples. Got array shapes: ' + str( - [y.shape for y in targets])) + 'the same number of samples. Got array shapes: ' + + str([y.shape for y in targets])) if set_x and set_y and list(set_x)[0] != list(set_y)[0]: raise ValueError('Input arrays should have ' 'the same number of samples as target arrays. ' @@ -247,8 +247,8 @@ def _check_array_lengths(inputs, targets, weights=None): 'and ' + str(list(set_y)[0]) + ' target samples.') if len(set_w) > 1: raise ValueError('All sample_weight arrays should have ' - 'the same number of samples. Got array shapes: ' + str( - [w.shape for w in weights])) + 'the same number of samples. Got array shapes: ' + + str([w.shape for w in weights])) if set_y and set_w and list(set_y)[0] != list(set_w)[0]: raise ValueError('Sample_weight arrays should have ' 'the same number of samples as target arrays. Got ' + @@ -275,7 +275,7 @@ def _check_loss_and_target_compatibility(targets, loss_fns, output_shapes): losses.categorical_crossentropy } for y, loss, shape in zip(targets, loss_fns, output_shapes): - if loss is None: + if y is None or loss is None: continue if loss is losses.categorical_crossentropy: if y.shape[-1] == 1: @@ -365,62 +365,6 @@ def _batch_shuffle(index_array, batch_size): return np.append(index_array, last_batch) -def _make_batches(size, batch_size): - """Returns a list of batch indices (tuples of indices). - - Arguments: - size: Integer, total size of the data to slice into batches. - batch_size: Integer, batch size. - - Returns: - A list of tuples of array indices. - """ - num_batches = (size + batch_size - 1) // batch_size # round up - return [(i * batch_size, min(size, (i + 1) * batch_size)) - for i in range(num_batches)] - - -def _slice_arrays(arrays, start=None, stop=None): - """Slice an array or list of arrays. - - This takes an array-like, or a list of - array-likes, and outputs: - - arrays[start:stop] if `arrays` is an array-like - - [x[start:stop] for x in arrays] if `arrays` is a list - - Can also work on list/array of indices: `_slice_arrays(x, indices)` - - Arguments: - arrays: Single array or list of arrays. - start: can be an integer index (start index) - or a list/array of indices - stop: integer (stop index); should be None if - `start` was a list. - - Returns: - A slice of the array(s). - """ - if arrays is None: - return [None] - elif isinstance(arrays, list): - if hasattr(start, '__len__'): - # hdf5 datasets only support list objects as indices - if hasattr(start, 'shape'): - start = start.tolist() - return [None if x is None else x[start] for x in arrays] - else: - return [None if x is None else x[start:stop] for x in arrays] - else: - if hasattr(start, '__len__'): - if hasattr(start, 'shape'): - start = start.tolist() - return arrays[start] - elif hasattr(start, '__getitem__'): - return arrays[start:stop] - else: - return [None] - - def _weighted_masked_objective(fn): """Adds support for masking and sample-weighting to an objective function. @@ -528,23 +472,23 @@ def _standardize_weights(y, if sample_weight is not None: if len(sample_weight.shape) > len(y.shape): - raise ValueError('Found a sample_weight with shape' + - str(sample_weight.shape) + '.' - 'Expected sample_weight with rank ' - 'less than or equal to ' + str(len(y.shape))) + raise ValueError( + 'Found a sample_weight with shape' + str(sample_weight.shape) + '.' + 'Expected sample_weight with rank ' + 'less than or equal to ' + str(len(y.shape))) if y.shape[:sample_weight.ndim] != sample_weight.shape: - raise ValueError('Found a sample_weight array with shape ' + - str(sample_weight.shape) + ' for an input with shape ' + - str(y.shape) + '. ' - 'sample_weight cannot be broadcast.') + raise ValueError( + 'Found a sample_weight array with shape ' + str(sample_weight.shape) + + ' for an input with shape ' + str(y.shape) + '. ' + 'sample_weight cannot be broadcast.') return sample_weight elif isinstance(class_weight, dict): if len(y.shape) > 2: raise ValueError('`class_weight` not supported for ' '3+ dimensional targets.') if y.shape[1] > 1: - y_classes = y.argmax(axis=1) + y_classes = np.argmax(y, axis=1) elif y.shape[1] == 1: y_classes = np.reshape(y, y.shape[0]) else: @@ -569,13 +513,72 @@ def _standardize_weights(y, return np.ones((y.shape[0], y.shape[1]), dtype=K.floatx()) +@tf_export('keras.models.Model', 'keras.Model') class Model(Network): - """The `Model` class adds training & evaluation routines to a `Network`. + """`Model` groups layers into an object with training and inference features. + + There are two ways to instantiate a `Model`: + + 1 - With the "functional API", where you start from `Input`, + you chain layer calls to specify the model's forward pass, + and finally you create your model from inputs and outputs: + + ```python + import tensorflow as tf + + inputs = tf.keras.Input(shape=(3,)) + x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) + outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) + model = tf.keras.Model(inputs=inputs, outputs=outputs) + ``` + + 2 - By subclassing the `Model` class: in that case, you should define your + layers in `__init__` and you should implement the model's forward pass + in `call`. + + ```python + import tensorflow as tf + + class MyModel(tf.keras.Model): + + def __init__(self): + self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) + self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) + + def call(self, inputs): + x = self.dense1(inputs) + return self.dense2(x) + + model = MyModel() + ``` + + If you subclass `Model`, you can optionally have + a `training` argument (boolean) in `call`, which you can use to specify + a different behavior in training and inference: + + ```python + import tensorflow as tf + + class MyModel(tf.keras.Model): + + def __init__(self): + self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) + self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) + self.dropout = tf.keras.layers.Dropout(0.5) + + def call(self, inputs, training=False): + x = self.dense1(inputs) + if training: + x = self.dropout(x, training=training) + return self.dense2(x) + + model = MyModel() + ``` """ def compile(self, optimizer, - loss, + loss=None, metrics=None, loss_weights=None, sample_weight_mode=None, @@ -631,20 +634,39 @@ class Model(Network): `optimizer`, `loss`, `metrics` or `sample_weight_mode`. """ loss = loss or {} + if context.in_eager_mode() and not isinstance( + optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): + raise ValueError('Only TF native optimizers are supported in Eager mode.') + self.optimizer = optimizers.get(optimizer) - self.sample_weight_mode = sample_weight_mode self.loss = loss + self.metrics = metrics or [] self.loss_weights = loss_weights + if context.in_eager_mode() and sample_weight_mode is not None: + raise ValueError('sample_weight_mode is not supported in Eager mode.') self.sample_weight_mode = sample_weight_mode + if context.in_eager_mode() and weighted_metrics is not None: + raise ValueError('weighted_metrics is not supported in Eager mode.') + self.weighted_metrics = weighted_metrics + if context.in_eager_mode() and target_tensors is not None: + raise ValueError('target_tensors is not supported in Eager mode.') + self.target_tensors = target_tensors + + if not self.built: + # Model is not compilable because it does not know its number of inputs + # and outputs, nor their shapes and names. We will compile after the first + # time the model gets called on training data. + return + self._is_compiled = True # Prepare loss functions. if isinstance(loss, dict): for name in loss: if name not in self.output_names: - raise ValueError('Unknown entry in loss ' - 'dictionary: "' + name + '". ' - 'Only expected the following keys: ' + - str(self.output_names)) + raise ValueError( + 'Unknown entry in loss ' + 'dictionary: "' + name + '". ' + 'Only expected the following keys: ' + str(self.output_names)) loss_functions = [] for name in self.output_names: if name not in loss: @@ -657,7 +679,7 @@ class Model(Network): elif isinstance(loss, list): if len(loss) != len(self.outputs): raise ValueError('When passing a list as loss, ' - 'it should have one entry per model output. ' + 'it should have one entry per model outputs. ' 'The model has ' + str(len(self.outputs)) + ' outputs, but you passed loss=' + str(loss)) loss_functions = [losses.get(l) for l in loss] @@ -665,6 +687,7 @@ class Model(Network): loss_function = losses.get(loss) loss_functions = [loss_function for _ in range(len(self.outputs))] self.loss_functions = loss_functions + weighted_losses = [_weighted_masked_objective(fn) for fn in loss_functions] skip_target_indices = [] skip_target_weighing_indices = [] @@ -678,11 +701,12 @@ class Model(Network): skip_target_weighing_indices.append(i) # Prepare output masks. - masks = self.compute_mask(self.inputs, mask=None) - if masks is None: - masks = [None for _ in self.outputs] - if not isinstance(masks, list): - masks = [masks] + if context.in_graph_mode(): + masks = self.compute_mask(self.inputs, mask=None) + if masks is None: + masks = [None for _ in self.outputs] + if not isinstance(masks, list): + masks = [masks] # Prepare loss weights. if loss_weights is None: @@ -690,57 +714,81 @@ class Model(Network): elif isinstance(loss_weights, dict): for name in loss_weights: if name not in self.output_names: - raise ValueError('Unknown entry in loss_weights ' - 'dictionary: "' + name + '". ' - 'Only expected the following keys: ' + - str(self.output_names)) + raise ValueError( + 'Unknown entry in loss_weights ' + 'dictionary: "' + name + '". ' + 'Only expected the following keys: ' + str(self.output_names)) loss_weights_list = [] for name in self.output_names: loss_weights_list.append(loss_weights.get(name, 1.)) elif isinstance(loss_weights, list): if len(loss_weights) != len(self.outputs): - raise ValueError('When passing a list as loss_weights, ' - 'it should have one entry per model output. ' - 'The model has ' + str(len(self.outputs)) + - ' outputs, but you passed loss_weights=' + - str(loss_weights)) + raise ValueError( + 'When passing a list as loss_weights, ' + 'it should have one entry per model output. ' + 'The model has ' + str(len(self.outputs)) + + ' outputs, but you passed loss_weights=' + str(loss_weights)) loss_weights_list = loss_weights else: raise TypeError('Could not interpret loss_weights argument: ' + str(loss_weights) + ' - expected a list of dicts.') + self.loss_weights_list = loss_weights_list + + # initialization for Eager mode execution + if context.in_eager_mode(): + if target_tensors is not None: + raise ValueError('target_tensors are not currently supported in Eager' + 'mode.') + self.total_loss = None + self.metrics_tensors = [] + self.metrics_names = ['loss'] + for i in range(len(self.outputs)): + if len(self.outputs) > 1: + self.metrics_names.append(self.output_names[i] + '_loss') + self.nested_metrics = _collect_metrics(metrics, self.output_names) + self._feed_sample_weight_modes = [] + for i in range(len(self.outputs)): + self._feed_sample_weight_modes.append(None) + self.sample_weights = [] + self.targets = [] + for i in range(len(self.outputs)): + self._feed_output_names.append(self.output_names[i]) + self._collected_trainable_weights = self.trainable_weights + return # Prepare targets of model. self.targets = [] self._feed_targets = [] - if target_tensors is not None: + if target_tensors not in (None, []): if isinstance(target_tensors, list): if len(target_tensors) != len(self.outputs): - raise ValueError('When passing a list as `target_tensors`, ' - 'it should have one entry per model output. ' - 'The model has ' + str(len(self.outputs)) + - ' outputs, but you passed target_tensors=' + - str(target_tensors)) + raise ValueError( + 'When passing a list as `target_tensors`, ' + 'it should have one entry per model output. ' + 'The model has ' + str(len(self.outputs)) + + ' outputs, but you passed target_tensors=' + str(target_tensors)) elif isinstance(target_tensors, dict): for name in target_tensors: if name not in self.output_names: - raise ValueError('Unknown entry in `target_tensors` ' - 'dictionary: "' + name + '". ' - 'Only expected the following keys: ' + - str(self.output_names)) - target_tensors_ = [] + raise ValueError( + 'Unknown entry in `target_tensors` ' + 'dictionary: "' + name + '". ' + 'Only expected the following keys: ' + str(self.output_names)) + tmp_target_tensors = [] for name in self.output_names: - target_tensors_.append(target_tensors.get(name, None)) - target_tensors = target_tensors_ + tmp_target_tensors.append(target_tensors.get(name, None)) + target_tensors = tmp_target_tensors else: raise TypeError('Expected `target_tensors` to be ' 'a list or dict, but got:', target_tensors) + for i in range(len(self.outputs)): if i in skip_target_indices: self.targets.append(None) else: - shape = self.internal_output_shapes[i] + shape = K.int_shape(self.outputs[i]) name = self.output_names[i] - if target_tensors is not None: + if target_tensors not in (None, []): target = target_tensors[i] else: target = None @@ -766,24 +814,24 @@ class Model(Network): if isinstance(sample_weight_mode, dict): for name in sample_weight_mode: if name not in self.output_names: - raise ValueError('Unknown entry in ' - 'sample_weight_mode dictionary: "' + name + '". ' - 'Only expected the following keys: ' + - str(self.output_names)) + raise ValueError( + 'Unknown entry in ' + 'sample_weight_mode dictionary: "' + name + '". ' + 'Only expected the following keys: ' + str(self.output_names)) for i, name in enumerate(self.output_names): if i in skip_target_weighing_indices: weight = None sample_weight_modes.append(None) else: if name not in sample_weight_mode: - raise ValueError('Output "' + name + - '" missing from sample_weight_modes ' - 'dictionary') + raise ValueError( + 'Output "' + name + '" missing from sample_weight_modes ' + 'dictionary') if sample_weight_mode.get(name) == 'temporal': weight = K.placeholder(ndim=2, name=name + '_sample_weights') sample_weight_modes.append('temporal') else: - weight = K.placeholder(ndim=1, name=name + '_sample_weights') + weight = K.placeholder(ndim=1, name=name + 'sample_weights') sample_weight_modes.append(None) sample_weights.append(weight) elif isinstance(sample_weight_mode, list): @@ -828,7 +876,6 @@ class Model(Network): self._feed_sample_weight_modes.append(self.sample_weight_modes[i]) # Prepare metrics. - self.metrics = metrics self.weighted_metrics = weighted_metrics self.metrics_names = ['loss'] self.metrics_tensors = [] @@ -871,14 +918,8 @@ class Model(Network): nested_metrics = _collect_metrics(metrics, self.output_names) nested_weighted_metrics = _collect_metrics(weighted_metrics, self.output_names) - - def append_metric(layer_index, metric_name, metric_tensor): - """Helper function used in loop below.""" - if len(self.output_names) > 1: - metric_name = self.output_names[layer_index] + '_' + metric_name - self.metrics_names.append(metric_name) - self.metrics_tensors.append(metric_tensor) - + self.metrics_updates = [] + self.stateful_metric_names = [] with K.name_scope('metrics'): for i in range(len(self.outputs)): if i in skip_target_indices: @@ -894,32 +935,68 @@ class Model(Network): metric_name_prefix = 'weighted_' if weights is not None else '' for metric in metrics: - if metric == 'accuracy' or metric == 'acc': - # custom handling of accuracy + if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): + # custom handling of accuracy/crossentropy # (because of class mode duality) - output_shape = self.internal_output_shapes[i] + output_shape = self.outputs[i].get_shape().as_list() if (output_shape[-1] == 1 or self.loss_functions[i] == losses.binary_crossentropy): - # case: binary accuracy - acc_fn = metrics_module.binary_accuracy + # case: binary accuracy/crossentropy + if metric in ('accuracy', 'acc'): + metric_fn = metrics_module.binary_accuracy + elif metric in ('crossentropy', 'ce'): + metric_fn = metrics_module.binary_crossentropy elif self.loss_functions[ i] == losses.sparse_categorical_crossentropy: - # case: categorical accuracy with sparse targets - acc_fn = metrics_module.sparse_categorical_accuracy + # case: categorical accuracy/crossentropy with sparse targets + if metric in ('accuracy', 'acc'): + metric_fn = metrics_module.sparse_categorical_accuracy + elif metric in ('crossentropy', 'ce'): + metric_fn = metrics_module.sparse_categorical_crossentropy else: - acc_fn = metrics_module.categorical_accuracy - - weighted_metric_fn = _weighted_masked_objective(acc_fn) - metric_name = metric_name_prefix + 'acc' + # case: categorical accuracy/crossentropy + if metric in ('accuracy', 'acc'): + metric_fn = metrics_module.categorical_accuracy + elif metric in ('crossentropy', 'ce'): + metric_fn = metrics_module.categorical_crossentropy + if metric in ('accuracy', 'acc'): + suffix = 'acc' + elif metric in ('crossentropy', 'ce'): + suffix = 'ce' + weighted_metric_fn = _weighted_masked_objective(metric_fn) + metric_name = metric_name_prefix + suffix else: metric_fn = metrics_module.get(metric) weighted_metric_fn = _weighted_masked_objective(metric_fn) - metric_name = metric_name_prefix + metric_fn.__name__ + # Get metric name as string + if hasattr(metric_fn, 'name'): + metric_name = metric_fn.name + else: + metric_name = metric_fn.__name__ + metric_name = metric_name_prefix + metric_name with K.name_scope(metric_name): metric_result = weighted_metric_fn( y_true, y_pred, weights=weights, mask=masks[i]) - append_metric(i, metric_name, metric_result) + + # Append to self.metrics_names, self.metric_tensors, + # self.stateful_metric_names + if len(self.output_names) > 1: + metric_name = '%s_%s' % (self.output_names[i], metric_name) + # Dedupe name + j = 1 + base_metric_name = metric_name + while metric_name in self.metrics_names: + metric_name = '%s_%d' % (base_metric_name, j) + j += 1 + self.metrics_names.append(metric_name) + self.metrics_tensors.append(metric_result) + + # Keep track of state updates created by + # stateful metrics (i.e. metrics layers). + if isinstance(metric_fn, Layer): + self.stateful_metric_names.append(metric_name) + self.metrics_updates += metric_fn.updates handle_metrics(output_metrics) handle_metrics(output_weighted_metrics, weights=weights) @@ -930,7 +1007,7 @@ class Model(Network): self._feed_sample_weights = [] for i in range(len(self.sample_weights)): if i not in skip_target_weighing_indices: - self._feed_sample_weights.append(sample_weights[i]) + self._feed_sample_weights.append(self.sample_weights[i]) # Functions for train, test and predict will # be compiled lazily when required. @@ -949,7 +1026,7 @@ class Model(Network): """Check trainable weights count consistency. This will raise a warning if `trainable_weights` and - `_collected_trainable_weights` are consistent (i.e. have the same + `_collected_trainable_weights` are inconsistent (i.e. have different number of parameters). Inconsistency will typically arise when one modifies `model.trainable` without calling `model.compile` again. @@ -959,9 +1036,10 @@ class Model(Network): if len(self.trainable_weights) != len(self._collected_trainable_weights): logging.warning( - 'Discrepancy between trainable weights and collected trainable' - ' weights, did you set `model.trainable` without calling' - ' `model.compile` after ?') + UserWarning( + 'Discrepancy between trainable weights and collected trainable' + ' weights, did you set `model.trainable` without calling' + ' `model.compile` after ?')) def _make_train_function(self): if not hasattr(self, 'train_function'): @@ -976,9 +1054,15 @@ class Model(Network): with K.name_scope('training'): with K.name_scope(self.optimizer.__class__.__name__): - training_updates = self.optimizer.get_updates( + # Training updates + updates = self.optimizer.get_updates( params=self._collected_trainable_weights, loss=self.total_loss) - updates = self.updates + training_updates + # Unconditional updates + updates += self.get_updates_for(None) + # Conditional updates relevant to this model + updates += self.get_updates_for(self._feed_inputs) + # Stateful metrics updates + updates += self.metrics_updates # Gets loss and metrics. Updates weights at each call. self.train_function = K.function( inputs, [self.total_loss] + self.metrics_tensors, @@ -999,7 +1083,7 @@ class Model(Network): # Does update the network states. self.test_function = K.function( inputs, [self.total_loss] + self.metrics_tensors, - updates=self.state_updates, + updates=self.state_updates + self.metrics_updates, name='test_function', **self._function_kwargs) @@ -1050,18 +1134,21 @@ class Model(Network): processed based on the size of the first dimension of the first input numpy array. When steps is not `None` and `batch_size` is `None`, returns `None`. + + Raises: + ValueError: In case of invalid arguments. """ if steps is not None: num_samples = None if batch_size is not None: - raise ValueError('If ' + steps_name + - ' is set, the `batch_size` must be None.') + raise ValueError( + 'If ' + steps_name + ' is set, the `batch_size` must be None.') elif ins and hasattr(ins[0], 'shape'): num_samples = ins[0].shape[0] else: - raise ValueError('Either the input data should have ' - 'a defined shape, or ' + steps_name + - ' should be specified.') + raise ValueError( + 'Either the input data should have ' + 'a defined shape, or ' + steps_name + ' should be specified.') return num_samples def _fit_loop(self, @@ -1104,43 +1191,49 @@ class Model(Network): steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. Ignored with the default value of `None`. - validation_steps: Number of steps to run validation for (only if doing - validation from data tensors). Ignored with default value of `None`. + validation_steps: Number of steps to run validation for + (only if doing validation from data tensors). + Ignored with the default value of `None`. Returns: `History` object. Raises: - ValueError: In case of invalid argument values. + ValueError: in case of invalid arguments. """ do_validation = False if val_f and val_ins: do_validation = True - if (verbose and ins and - hasattr(ins[0], 'shape') and hasattr(val_ins[0], 'shape')): + if verbose and ins and hasattr(ins[0], 'shape') and hasattr( + val_ins[0], 'shape'): print('Train on %d samples, validate on %d samples' % (ins[0].shape[0], val_ins[0].shape[0])) if validation_steps: - if steps_per_epoch is None: - raise ValueError('Can only use `validation_steps` when doing step-wise ' - 'training, i.e. `steps_per_epoch` must be set.') do_validation = True + if steps_per_epoch is None: + raise ValueError('Can only use `validation_steps` ' + 'when doing step-wise ' + 'training, i.e. `steps_per_epoch` ' + 'must be set.') num_train_samples = self._check_num_samples( ins, batch_size, steps_per_epoch, 'steps_per_epoch') - if num_train_samples is not None: index_array = np.arange(num_train_samples) self.history = cbks.History() - callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history] + all_callbacks = [cbks.BaseLogger( + stateful_metrics=self.stateful_metric_names)] if verbose: if steps_per_epoch is not None: count_mode = 'steps' else: count_mode = 'samples' - callbacks += [cbks.ProgbarLogger(count_mode)] - callbacks = cbks.CallbackList(callbacks) + all_callbacks.append( + cbks.ProgbarLogger( + count_mode, stateful_metrics=self.stateful_metric_names)) + all_callbacks += (callbacks or []) + [self.history] + callbacks = cbks.CallbackList(all_callbacks) out_labels = out_labels or [] # it's possible to callback a different model than self @@ -1151,6 +1244,7 @@ class Model(Network): callback_model = self callbacks.set_model(callback_model) + callbacks.set_params({ 'batch_size': batch_size, 'epochs': epochs, @@ -1165,7 +1259,19 @@ class Model(Network): for cbk in callbacks: cbk.validation_data = val_ins + # To prevent a slowdown, we find beforehand the arrays that need conversion. + feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights + indices_for_conversion_to_dense = [] + for i in range(len(feed)): + if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]): + indices_for_conversion_to_dense.append(i) + for epoch in range(initial_epoch, epochs): + # Reset stateful metrics + for m in self.metrics: + if isinstance(m, Layer): + m.reset_states() + # Update callbacks callbacks.on_epoch_begin(epoch) epoch_logs = {} if steps_per_epoch is not None: @@ -1203,15 +1309,16 @@ class Model(Network): elif shuffle: np.random.shuffle(index_array) - batches = _make_batches(num_train_samples, batch_size) + batches = make_batches(num_train_samples, batch_size) + for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] try: if isinstance(ins[-1], float): # Do not slice the training phase flag. - ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: - ins_batch = _slice_arrays(ins, batch_ids) + ins_batch = slice_arrays(ins, batch_ids) except TypeError: raise TypeError('TypeError while preparing batch. ' 'If using HDF5 input data, ' @@ -1220,6 +1327,9 @@ class Model(Network): batch_logs['batch'] = batch_index batch_logs['size'] = len(batch_ids) callbacks.on_batch_begin(batch_index, batch_logs) + for i in indices_for_conversion_to_dense: + ins_batch[i] = ins_batch[i].toarray() + outs = f(ins_batch) if not isinstance(outs, list): outs = [outs] @@ -1262,12 +1372,26 @@ class Model(Network): or list of arrays of predictions (if the model has multiple outputs). """ + if hasattr(self, 'metrics'): + for m in self.metrics: + if isinstance(m, Layer): + m.reset_states() + num_samples = self._check_num_samples(ins, batch_size, steps, 'steps') if verbose == 1: if steps is not None: - progbar = Progbar(target=steps) + progbar = Progbar(target=steps, + stateful_metrics=self.stateful_metric_names) else: - progbar = Progbar(target=num_samples) + progbar = Progbar(target=num_samples, + stateful_metrics=self.stateful_metric_names) + + indices_for_conversion_to_dense = [] + for i in range(len(self._feed_inputs)): + if (issparse is not None and issparse(ins[i]) and + not K.is_sparse(self._feed_inputs[i])): + indices_for_conversion_to_dense.append(i) + if steps is not None: # Step-based predictions. # Since we do not know how many samples @@ -1296,15 +1420,18 @@ class Model(Network): else: # Sample-based predictions. outs = [] - batches = _make_batches(num_samples, batch_size) + batches = make_batches(num_samples, batch_size) index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] if ins and isinstance(ins[-1], float): # Do not slice the training phase flag. - ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: - ins_batch = _slice_arrays(ins, batch_ids) + ins_batch = slice_arrays(ins, batch_ids) + for i in indices_for_conversion_to_dense: + ins_batch[i] = ins_batch[i].toarray() + batch_outs = f(ins_batch) if not isinstance(batch_outs, list): batch_outs = [batch_outs] @@ -1339,14 +1466,32 @@ class Model(Network): and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. """ + if hasattr(self, 'metrics'): + for m in self.metrics: + if isinstance(m, Layer): + m.reset_states() + stateful_metric_indices = [ + i for i, name in enumerate(self.metrics_names) + if str(name) in self.stateful_metric_names + ] + else: + stateful_metric_indices = [] + num_samples = self._check_num_samples(ins, batch_size, steps, 'steps') outs = [] - if verbose == 1: if steps is not None: progbar = Progbar(target=steps) else: progbar = Progbar(target=num_samples) + + # To prevent a slowdown, we find beforehand the arrays that need conversion. + feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights + indices_for_conversion_to_dense = [] + for i in range(len(feed)): + if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]): + indices_for_conversion_to_dense.append(i) + if steps is not None: for step in range(steps): batch_outs = f(ins) @@ -1355,7 +1500,10 @@ class Model(Network): for _ in enumerate(batch_outs): outs.append(0.) for i, batch_out in enumerate(batch_outs): - outs[i] += batch_out + if i in stateful_metric_indices: + outs[i] = batch_out + else: + outs[i] += batch_out else: if step == 0: outs.append(0.) @@ -1363,91 +1511,254 @@ class Model(Network): if verbose == 1: progbar.update(step + 1) for i in range(len(outs)): - outs[i] /= steps + if i not in stateful_metric_indices: + outs[i] /= steps else: - if verbose == 1: - progbar = Progbar(target=num_samples) - batches = _make_batches(num_samples, batch_size) + batches = make_batches(num_samples, batch_size) index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] if isinstance(ins[-1], float): # Do not slice the training phase flag. - ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: - ins_batch = _slice_arrays(ins, batch_ids) + ins_batch = slice_arrays(ins, batch_ids) + for i in indices_for_conversion_to_dense: + ins_batch[i] = ins_batch[i].toarray() batch_outs = f(ins_batch) + if isinstance(batch_outs, list): if batch_index == 0: for batch_out in enumerate(batch_outs): outs.append(0.) for i, batch_out in enumerate(batch_outs): - outs[i] += batch_out * len(batch_ids) + if i in stateful_metric_indices: + outs[i] = batch_out + else: + outs[i] += batch_out * len(batch_ids) else: if batch_index == 0: outs.append(0.) outs[0] += batch_outs * len(batch_ids) - if verbose == 1: progbar.update(batch_end) for i in range(len(outs)): - outs[i] /= num_samples + if i not in stateful_metric_indices: + outs[i] /= num_samples if len(outs) == 1: return outs[0] return outs def _standardize_user_data(self, x, - y, + y=None, sample_weight=None, class_weight=None, - check_batch_axis=True, batch_size=None): - if not hasattr(self, 'optimizer'): - raise RuntimeError('You must compile a model before ' - 'training/testing. ' - 'Use `model.compile(optimizer, loss)`.') - - output_shapes = [] - for output_shape, loss_fn in zip(self._feed_output_shapes, - self._feed_loss_fns): - if loss_fn is losses.sparse_categorical_crossentropy: - output_shapes.append(output_shape[:-1] + (1,)) - elif (not hasattr(loss_fn, '__name__') or - getattr(losses, loss_fn.__name__, None) is None): - # If `loss_fn` is not a function (e.g. callable class) - # or if it not in the `losses` module, then - # it is a user-defined loss and we make no assumptions - # about it. - output_shapes.append(None) + """Runs validation checks on input and target data passed by the user. + + Also standardizes the data to lists of arrays, in order. + + Also builds and compiles the model on the fly if it is a subclassed model + that has never been called before (and thus has no inputs/outputs). + + This is a purely internal method, subject to refactoring at any time. + + Args: + x: An array or list of arrays, to be used as input data. If the model + has known, named inputs, this could also be a dict mapping input names + to the corresponding array. + y: An array or list of arrays, to be used as target data. If the model + has known, named outputs, this could also be a dict mapping output names + to the corresponding array. + sample_weight: An optional sample-weight array passed by the user to + weight the importance of each sample in `x`. + class_weight: An optional class-weight array by the user to + weight the importance of samples in `x` based on the class they belong + to, as conveyed by `y`. + batch_size: Integer batch size. If provided, it is used to run additional + validation checks on stateful models. + + Returns: + A tuple of 3 lists: input arrays, target arrays, sample-weight arrays. + If the model's input and targets are symbolic, these lists are empty + (since the model takes no user-provided data, instead the data comes + from the symbolic inputs/targets). + + Raises: + ValueError: In case of invalid user-provided data. + RuntimeError: If the model was never compiled. + """ + # First, we build/compile the model on the fly if necessary. + all_inputs = [] + if not self.built: + # We need to use `x` to set the model inputs. + # We type-check that `x` and `y` are either single arrays + # or lists of arrays. + if isinstance(x, (list, tuple)): + if not all(isinstance(v, np.ndarray) or + tensor_util.is_tensor(v) for v in x): + raise ValueError('Please provide as model inputs either a single ' + 'array or a list of arrays. You passed: x=' + str(x)) + all_inputs += list(x) + elif isinstance(x, dict): + raise ValueError('Please do not pass a dictionary as model inputs.') else: - output_shapes.append(output_shape) + if not isinstance(x, np.ndarray) and not tensor_util.is_tensor(x): + raise ValueError('Please provide as model inputs either a single ' + 'array or a list of arrays. You passed: x=' + str(x)) + all_inputs.append(x) + + # Build the model using the retrieved inputs (value or symbolic). + # If values, then in symbolic-mode placeholders will be created + # to match the value shapes. + if not self.inputs: + self._set_inputs(x) + + if y is not None: + if not self.optimizer: + raise RuntimeError('You must compile a model before ' + 'training/testing. ' + 'Use `model.compile(optimizer, loss)`.') + if not self._is_compiled: + # On-the-fly compilation of the model. + # We need to use `y` to set the model targets. + if isinstance(y, (list, tuple)): + if not all(isinstance(v, np.ndarray) or + tensor_util.is_tensor(v) for v in y): + raise ValueError('Please provide as model targets either a single ' + 'array or a list of arrays. ' + 'You passed: y=' + str(y)) + elif isinstance(y, dict): + raise ValueError('Please do not pass a dictionary as model targets.') + else: + if not isinstance(y, np.ndarray) and not tensor_util.is_tensor(y): + raise ValueError('Please provide as model targets either a single ' + 'array or a list of arrays. ' + 'You passed: y=' + str(y)) + + # Typecheck that all inputs are *either* value *or* symbolic. + # TODO(fchollet): this check could be removed in Eager mode? + if y is not None: + if isinstance(y, (list, tuple)): + all_inputs += list(y) + else: + all_inputs.append(y) + if any(tensor_util.is_tensor(v) for v in all_inputs): + if not all(tensor_util.is_tensor(v) for v in all_inputs): + raise ValueError('Do not pass inputs that mix Numpy arrays and ' + 'TensorFlow tensors. ' + 'You passed: x=' + str(x) + '; y=' + str(y)) + + if context.in_graph_mode(): + # Handle target tensors if any passed. + if not isinstance(y, (list, tuple)): + y = [y] + target_tensors = [v for v in y if tensor_util.is_tensor(v)] + else: + target_tensors = None + self.compile(optimizer=self.optimizer, + loss=self.loss, + metrics=self.metrics, + loss_weights=self.loss_weights, + target_tensors=target_tensors) + + # If `x` and `y` were all symbolic, then no model should not be fed any + # inputs and targets. + # Note: in this case, `any` and `all` are equivalent since we disallow + # mixed symbolic/value inputs. + if any(tensor_util.is_tensor(v) for v in all_inputs): + return [], [], [] + + # What follows is input validation and standardization to list format, + # in the case where all inputs are value arrays. + + if context.in_eager_mode(): + # In eager mode, do not do shape validation. + feed_input_names = self.input_names + feed_input_shapes = None + elif not self._is_graph_network: + # Case: symbolic-mode subclassed network. Do not do shape validation. + feed_input_names = self._feed_input_names + feed_input_shapes = None + else: + # Case: symbolic-mode graph network. + # In this case, we run extensive shape validation checks. + feed_input_names = self._feed_input_names + feed_input_shapes = self._feed_input_shapes + + # Standardize the inputs. x = _standardize_input_data( x, - self._feed_input_names, - self._feed_input_shapes, - check_batch_axis=False, + feed_input_names, + feed_input_shapes, + check_batch_axis=False, # Don't enforce the batch size. exception_prefix='input') - y = _standardize_input_data( - y, - self._feed_output_names, - output_shapes, - check_batch_axis=False, - exception_prefix='target') - sample_weights = _standardize_sample_weights(sample_weight, - self._feed_output_names) - class_weights = _standardize_class_weights(class_weight, - self._feed_output_names) - sample_weights = [ - _standardize_weights(ref, sw, cw, mode) - for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, - self._feed_sample_weight_modes) - ] - _check_array_lengths(x, y, sample_weights) - _check_loss_and_target_compatibility(y, self._feed_loss_fns, - self._feed_output_shapes) + + if y is not None: + if context.in_eager_mode(): + feed_output_names = self.output_names + feed_output_shapes = None + # Sample weighting not supported in this case. + # TODO(fchollet): consider supporting it. + feed_sample_weight_modes = [None for _ in self.outputs] + elif not self._is_graph_network: + feed_output_names = self._feed_output_names + feed_output_shapes = None + # Sample weighting not supported in this case. + # TODO(fchollet): consider supporting it. + feed_sample_weight_modes = [None for _ in self.outputs] + else: + feed_output_names = self._feed_output_names + feed_sample_weight_modes = self._feed_sample_weight_modes + feed_output_shapes = [] + for output_shape, loss_fn in zip(self._feed_output_shapes, + self._feed_loss_fns): + if loss_fn is losses.sparse_categorical_crossentropy: + feed_output_shapes.append(output_shape[:-1] + (1,)) + elif (not hasattr(loss_fn, '__name__') or + getattr(losses, loss_fn.__name__, None) is None): + # If `loss_fn` is not a function (e.g. callable class) + # or if it not in the `losses` module, then + # it is a user-defined loss and we make no assumptions + # about it. + feed_output_shapes.append(None) + else: + feed_output_shapes.append(output_shape) + + # Standardize the outputs. + y = _standardize_input_data( + y, + feed_output_names, + feed_output_shapes, + check_batch_axis=False, # Don't enforce the batch size. + exception_prefix='target') + + # Generate sample-wise weight values given the `sample_weight` and + # `class_weight` arguments. + sample_weights = _standardize_sample_weights(sample_weight, + feed_output_names) + class_weights = _standardize_class_weights(class_weight, + feed_output_names) + sample_weights = [ + _standardize_weights(ref, sw, cw, mode) + for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, + feed_sample_weight_modes) + ] + # Check that all arrays have the same length. + _check_array_lengths(x, y, sample_weights) + if self._is_graph_network and not context.in_eager_mode(): + # Additional checks to avoid users mistakenly using improper loss fns. + _check_loss_and_target_compatibility(y, self._feed_loss_fns, + feed_output_shapes) + else: + y = [] + sample_weights = [] + if self.stateful and batch_size: + # Check that for stateful networks, number of samples is a multiple + # of the static batch size. if x[0].shape[0] % batch_size != 0: raise ValueError('In a stateful network, ' 'you should only pass inputs with ' @@ -1456,19 +1767,151 @@ class Model(Network): str(x[0].shape[0]) + ' samples') return x, y, sample_weights - def _get_deduped_metrics_names(self): - out_labels = self.metrics_names + def _set_inputs(self, inputs, training=None): + """Set model's input and output specs based on the input data received. + + This is to be used for Model subclasses, which do not know at instantiation + time what their inputs look like. + + Args: + inputs: Single array, or list of arrays. The arrays could be placeholders, + Numpy arrays, or data tensors. + - if placeholders: the model is built on top of these placeholders, + and we expect Numpy data to be fed for them when calling `fit`/etc. + - if Numpy data: we create placeholders matching the shape of the Numpy + arrays. We expect Numpy data to be fed for these placeholders + when calling `fit`/etc. + - if data tensors: the model is built on top of these tensors. + We do not expect any Numpy data to be provided when calling `fit`/etc. + training: Boolean or None. Only relevant in symbolic mode. Specifies + whether to build the model's graph in inference mode (False), training + mode (True), or using the Keras learning phase (None). + """ + if context.in_eager_mode(): + self._eager_set_inputs(inputs) + else: + self._symbolic_set_inputs(inputs, training=training) + + def _eager_set_inputs(self, inputs): + """Set model's input and output specs based on the input data received. + + This is to be used for Model subclasses, which do not know at instantiation + time what their inputs look like. + + We assume the number and ndim of outputs + does not change over different calls. - # Rename duplicated metrics name - # (can happen with an output layer shared among multiple dataflows). - deduped_out_labels = [] - for i, label in enumerate(out_labels): - new_label = label - if out_labels.count(label) > 1: - dup_idx = out_labels[:i].count(label) - new_label += '_' + str(dup_idx + 1) - deduped_out_labels.append(new_label) - return deduped_out_labels + Args: + inputs: Argument `x` (input data) passed by the user upon first model use. + + Raises: + ValueError: If the model's inputs are already set. + """ + assert context.in_eager_mode() + if self.inputs: + raise ValueError('Model inputs are already set.') + # On-the-fly setting of model inputs/outputs as DeferredTensors, + # to keep track of number of inputs and outputs and their ndim. + if isinstance(inputs, (list, tuple)): + dummy_output_values = self.call( + [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs]) + dummy_input_values = list(inputs) + else: + dummy_output_values = self.call( + ops.convert_to_tensor(inputs, dtype=K.floatx())) + dummy_input_values = [inputs] + if isinstance(dummy_output_values, (list, tuple)): + dummy_output_values = list(dummy_output_values) + else: + dummy_output_values = [dummy_output_values] + self.outputs = [ + _DeferredTensor(shape=(None for _ in v.shape), + dtype=v.dtype) for v in dummy_output_values] + self.inputs = [ + _DeferredTensor(shape=(None for _ in v.shape), + dtype=v.dtype) for v in dummy_input_values] + self.input_names = [ + 'input_%d' % (i + 1) for i in range(len(dummy_input_values))] + self.output_names = [ + 'output_%d' % (i + 1) for i in range(len(dummy_output_values))] + self.built = True + + def _symbolic_set_inputs(self, inputs, training=None): + """Set model's inputs based on the input data received from the user. + + This is to be used for Model subclasses, which do not know at instantiation + time what their inputs look like. + + Args: + inputs: Argument `x` (input data) passed by the user upon first model use. + training: Boolean or None. Only relevant in symbolic mode. Specifies + whether to build the model's graph in inference mode (False), training + mode (True), or using the Keras learning phase (None). + + Raises: + ValueError: If the model's inputs are already set. + """ + assert context.in_graph_mode() + if self.inputs: + raise ValueError('Model inputs are already set.') + + # On-the-fly setting of symbolic model inputs (either by using the tensor + # provided, or by creating a placeholder if Numpy data was provided). + self.inputs = [] + self.input_names = [] + self._feed_inputs = [] + self._feed_input_names = [] + self._feed_input_shapes = [] + if isinstance(inputs, (list, tuple)): + inputs = list(inputs) + else: + inputs = [inputs] + + for i, v in enumerate(inputs): + name = 'input_%d' % (i + 1) + self.input_names.append(name) + if isinstance(v, list): + v = np.asarray(v) + if v.ndim == 1: + v = np.expand_dims(v, 1) + if isinstance(v, (np.ndarray)): + # We fix the placeholder shape except the batch size. + # This is suboptimal, but it is the best we can do with the info + # we have. The user should call `model._set_inputs(placeholders)` + # to specify custom placeholders if the need arises. + shape = (None,) + v.shape[1:] + placeholder = K.placeholder(shape=shape, name=name) + self.inputs.append(placeholder) + self._feed_inputs.append(placeholder) + self._feed_input_names.append(name) + self._feed_input_shapes.append(shape) + else: + # Assumed tensor - TODO(fchollet) additional type check? + self.inputs.append(v) + if K.is_placeholder(v): + self._feed_inputs.append(v) + self._feed_input_names.append(name) + self._feed_input_shapes.append(K.int_shape(v)) + + # Obtain symbolic outputs by calling the model. + if len(self.inputs) == 1: + if self._expects_training_arg: + outputs = self.call(self.inputs[0], training=training) + else: + outputs = self.call(self.inputs[0]) + else: + if self._expects_training_arg: + outputs = self.call(self.inputs, training=training) + else: + outputs = self.call(self.inputs) + if isinstance(outputs, (list, tuple)): + outputs = list(outputs) + else: + outputs = [outputs] + self.outputs = outputs + self.output_names = [ + 'output_%d' % (i + 1) for i in range(len(self.outputs))] + self.built = True def fit(self, x=None, @@ -1484,7 +1927,8 @@ class Model(Network): sample_weight=None, initial_epoch=0, steps_per_epoch=None, - validation_steps=None): + validation_steps=None, + **kwargs): """Trains the model for a fixed number of epochs (iterations on a dataset). Arguments: @@ -1501,10 +1945,9 @@ class Model(Network): dictionary mapping output names to Numpy arrays. `y` can be `None` (default) if feeding from TensorFlow data tensors. - Can be `None` (default) if feeding from framework-native tensors. batch_size: Integer or `None`. Number of samples per gradient update. - If unspecified, it will default to 32. + If unspecified, `batch_size` will default to 32. epochs: Integer. Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided. @@ -1513,7 +1956,7 @@ class Model(Network): The model is not trained for a number of iterations given by `epochs`, but merely until the epoch of index `epochs` is reached. - verbose: 0, 1, or 2. Verbosity mode. + verbose: Integer. 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. callbacks: List of `keras.callbacks.Callback` instances. List of callbacks to apply during training. @@ -1530,7 +1973,7 @@ class Model(Network): `(x_val, y_val, val_sample_weights)` on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. - This will override `validation_split`. + `validation_data` will override `validation_split`. shuffle: Boolean (whether to shuffle the training data before each epoch) or str (for 'batch'). 'batch' is a special option for dealing with the @@ -1553,17 +1996,20 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. - initial_epoch: Epoch at which to start training + initial_epoch: Integer. + Epoch at which to start training (useful for resuming a previous training run). - steps_per_epoch: Total number of steps (batches of samples) + steps_per_epoch: Integer or `None`. + Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. When training with input tensors such as TensorFlow data tensors, the default `None` is equal to - the number of unique samples in your dataset divided by + the number of samples in your dataset divided by the batch size, or 1 if that cannot be determined. validation_steps: Only relevant if `steps_per_epoch` is specified. Total number of steps (batches of samples) to validate before stopping. + **kwargs: Used for backwards compatibility. Returns: A `History` object. Its `History.history` attribute is @@ -1572,25 +2018,36 @@ class Model(Network): and validation metrics values (if applicable). Raises: + RuntimeError: If the model was never compiled. ValueError: In case of mismatch between the provided input data and what the model expects. """ + # TODO(fchollet): this method may be creating reference cycles, which would + # lead to accumulating garbage in memory when called in a loop. Investigate. + # Backwards compatibility if batch_size is None and steps_per_epoch is None: batch_size = 32 + # Legacy support + if 'nb_epoch' in kwargs: + logging.warning( + 'The `nb_epoch` argument in `fit` ' + 'has been renamed `epochs`.') + epochs = kwargs.pop('nb_epoch') + if kwargs: + raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) if x is None and y is None and steps_per_epoch is None: raise ValueError('If fitting from data tensors, ' 'you should specify the `steps_per_epoch` ' 'argument.') + # Validate user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, class_weight=class_weight, - check_batch_axis=False, batch_size=batch_size) - # Prepare validation data. do_validation = False val_ins = [] @@ -1612,7 +2069,6 @@ class Model(Network): val_x, val_y, sample_weight=val_sample_weight, - check_batch_axis=False, batch_size=batch_size) if self.uses_learning_phase and not isinstance(K.learning_phase(), int): val_ins = val_x + val_y + val_sample_weights + [0.] @@ -1625,10 +2081,10 @@ class Model(Network): split_at = int(x[0].shape[0] * (1. - validation_split)) else: split_at = int(len(x[0]) * (1. - validation_split)) - x, val_x = (_slice_arrays(x, 0, split_at), _slice_arrays(x, split_at)) - y, val_y = (_slice_arrays(y, 0, split_at), _slice_arrays(y, split_at)) - sample_weights, val_sample_weights = (_slice_arrays( - sample_weights, 0, split_at), _slice_arrays(sample_weights, split_at)) + x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at)) + y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at)) + sample_weights, val_sample_weights = (slice_arrays( + sample_weights, 0, split_at), slice_arrays(sample_weights, split_at)) if self.uses_learning_phase and not isinstance(K.learning_phase(), int): val_ins = val_x + val_y + val_sample_weights + [0.] else: @@ -1644,38 +2100,65 @@ class Model(Network): ins = x + y + sample_weights + [1.] else: ins = x + y + sample_weights - self._make_train_function() - f = self.train_function # Prepare display labels. - out_labels = self._get_deduped_metrics_names() + out_labels = self.metrics_names - if do_validation: - self._make_test_function() - val_f = self.test_function - callback_metrics = copy.copy(out_labels) + [ - 'val_' + n for n in out_labels - ] + if context.in_eager_mode(): + if do_validation: + callback_metrics = copy.copy(out_labels) + [ + 'val_' + n for n in out_labels + ] + else: + callback_metrics = copy.copy(out_labels) + + return training_eager.fit_loop( + self, + ins, + out_labels=out_labels, + batch_size=batch_size, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + val_ins=val_ins, + shuffle=shuffle, + callback_metrics=callback_metrics, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps) else: - val_f = None - callback_metrics = copy.copy(out_labels) - - # Delegate logic to `_fit_loop`. - return self._fit_loop( - f, - ins, - out_labels=out_labels, - batch_size=batch_size, - epochs=epochs, - verbose=verbose, - callbacks=callbacks, - val_f=val_f, - val_ins=val_ins, - shuffle=shuffle, - callback_metrics=callback_metrics, - initial_epoch=initial_epoch, - steps_per_epoch=steps_per_epoch, - validation_steps=validation_steps) + self._make_train_function() + f = self.train_function + + if do_validation: + if context.in_graph_mode(): + self._make_test_function() + val_f = self.test_function + else: + val_f = None + callback_metrics = copy.copy(out_labels) + [ + 'val_' + n for n in out_labels + ] + else: + val_f = None + callback_metrics = copy.copy(out_labels) + + # Delegate logic to `_fit_loop`. + return self._fit_loop( + f, + ins, + out_labels=out_labels, + batch_size=batch_size, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + val_f=val_f, + val_ins=val_ins, + shuffle=shuffle, + callback_metrics=callback_metrics, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps) def evaluate(self, x=None, @@ -1694,14 +2177,14 @@ class Model(Network): If input layers in the model are named, you can also pass a dictionary mapping input names to Numpy arrays. `x` can be `None` (default) if feeding from - framework-native tensors (e.g. TensorFlow data tensors). + TensorFlow data tensors. y: Numpy array of target (label) data (if the model has a single output), or list of Numpy arrays (if the model has multiple outputs). If output layers in the model are named, you can also pass a dictionary mapping output names to Numpy arrays. `y` can be `None` (default) if feeding from - framework-native tensors (e.g. TensorFlow data tensors). + TensorFlow data tensors. batch_size: Integer or `None`. Number of samples per evaluation step. If unspecified, `batch_size` will default to 32. @@ -1721,8 +2204,7 @@ class Model(Network): steps: Integer or `None`. Total number of steps (batches of samples) before declaring the evaluation round finished. - The default `None` is equal to the number of unique samples in - your dataset divided by the batch size. + Ignored with the default value of `None`. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -1731,7 +2213,7 @@ class Model(Network): the display labels for the scalar outputs. Raises: - ValueError: In case of invalid arguments. + ValueError: in case of invalid arguments. """ # Backwards compatibility. if batch_size is None and steps is None: @@ -1740,22 +2222,27 @@ class Model(Network): raise ValueError('If evaluating from data tensors, ' 'you should specify the `steps` ' 'argument.') + # Validate user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, - check_batch_axis=False, batch_size=batch_size) # Prepare inputs, delegate logic to `_test_loop`. if self.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = x + y + sample_weights + [0.] else: ins = x + y + sample_weights - self._make_test_function() - f = self.test_function - return self._test_loop( - f, ins, batch_size=batch_size, verbose=verbose, steps=steps) + + if context.in_eager_mode(): + return training_eager.test_loop( + self, ins, batch_size=batch_size, verbose=verbose, steps=steps) + else: + self._make_test_function() + f = self.test_function + return self._test_loop( + f, ins, batch_size=batch_size, verbose=verbose, steps=steps) def predict(self, x, batch_size=None, verbose=0, steps=None): """Generates output predictions for the input samples. @@ -1787,30 +2274,23 @@ class Model(Network): raise ValueError('If predicting from data tensors, ' 'you should specify the `steps` ' 'argument.') - # Validate user data. - x = _standardize_input_data( - x, - self._feed_input_names, - self._feed_input_shapes, - check_batch_axis=False) - if self.stateful: - if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0: - raise ValueError('In a stateful network, ' - 'you should only pass inputs with ' - 'a number of samples that can be ' - 'divided by the batch size. Found: ' + - str(x[0].shape[0]) + ' samples. ' - 'Batch size: ' + str(batch_size) + '.') + x, _, _ = self._standardize_user_data(x) # Prepare inputs, delegate logic to `_predict_loop`. if self.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = x + [0.] else: ins = x - self._make_predict_function() - f = self.predict_function - return self._predict_loop( - f, ins, batch_size=batch_size, verbose=verbose, steps=steps) + + if context.in_eager_mode(): + return training_eager.predict_loop( + self, ins, batch_size=batch_size, verbose=verbose, steps=steps) + else: + self._make_predict_function() + f = self.predict_function + + return self._predict_loop( + f, ins, batch_size=batch_size, verbose=verbose, steps=steps) def train_on_batch(self, x, y, sample_weight=None, class_weight=None): """Runs a single gradient update on a single batch of data. @@ -1846,19 +2326,24 @@ class Model(Network): or list of scalars (if the model has multiple outputs and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. + """ x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, - class_weight=class_weight, - check_batch_axis=True) + class_weight=class_weight) if self.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = x + y + sample_weights + [1.] else: ins = x + y + sample_weights - self._make_train_function() - outputs = self.train_function(ins) + + if context.in_eager_mode(): + outputs = training_eager.train_on_batch(self, ins) + else: + self._make_train_function() + outputs = self.train_function(ins) + if len(outputs) == 1: return outputs[0] return outputs @@ -1890,15 +2375,23 @@ class Model(Network): or list of scalars (if the model has multiple outputs and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. + + Raises: + ValueError: in case of invalid arguments. """ x, y, sample_weights = self._standardize_user_data( - x, y, sample_weight=sample_weight, check_batch_axis=True) + x, y, sample_weight=sample_weight) if self.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = x + y + sample_weights + [0.] else: ins = x + y + sample_weights - self._make_test_function() - outputs = self.test_function(ins) + + if context.in_eager_mode(): + outputs = training_eager.test_on_batch(self, ins) + else: + self._make_test_function() + outputs = self.test_function(ins) + if len(outputs) == 1: return outputs[0] return outputs @@ -1911,18 +2404,33 @@ class Model(Network): Returns: Numpy array(s) of predictions. + """ - x = _standardize_input_data(x, self._feed_input_names, - self._feed_input_shapes) + x, _, _ = self._standardize_user_data(x) + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = x + [0.] else: ins = x - self._make_predict_function() - outputs = self.predict_function(ins) - if len(outputs) == 1: - return outputs[0] - return outputs + + if context.in_eager_mode(): + ins_batch_converted = [] + for ib in ins: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + + eager_model_inputs = [] + for i in range(len(self.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + + outs = self(eager_model_inputs) # pylint: disable=not-callable + return outs + + if context.in_graph_mode(): + self._make_predict_function() + outputs = self.predict_function(ins) + if len(outputs) == 1: + return outputs[0] + return outputs def fit_generator(self, generator, @@ -1937,8 +2445,7 @@ class Model(Network): workers=1, use_multiprocessing=False, shuffle=True, - initial_epoch=0, - **kwargs): + initial_epoch=0): """Fits the model on data yielded batch-by-batch by a Python generator. The generator is run in parallel to the model, for efficiency. @@ -1950,22 +2457,31 @@ class Model(Network): using `use_multiprocessing=True`. Arguments: - generator: A generator or an instance of Sequence (keras.utils.Sequence) - object in order to avoid duplicate data when using multiprocessing. + generator: A generator or an instance of `Sequence` + (`keras.utils.Sequence`) + object in order to avoid duplicate data + when using multiprocessing. The output of the generator must be either - - a tuple (inputs, targets) - - a tuple (inputs, targets, sample_weights). - All arrays should contain the same number of samples. + - a tuple `(inputs, targets)` + - a tuple `(inputs, targets, sample_weights)`. + This tuple (a single output of the generator) makes a single batch. + Therefore, all arrays in this tuple must have the same length (equal + to the size of this batch). Different batches may have different + sizes. + For example, the last batch of the epoch is commonly smaller than + the + others, if the size of the dataset is not divisible by the batch + size. The generator is expected to loop over its data indefinitely. An epoch finishes when `steps_per_epoch` batches have been seen by the model. steps_per_epoch: Total number of steps (batches of samples) to yield from `generator` before declaring one epoch finished and starting the next epoch. It should typically - be equal to the number of unique samples of your dataset + be equal to the number of samples of your dataset divided by the batch size. Optional for `Sequence`: if unspecified, will use - `len(generator)` as a number of steps. + the `len(generator)` as a number of steps. epochs: Integer, total number of iterations on the data. verbose: Verbosity mode, 0, 1, or 2. callbacks: List of callbacks to be called during training. @@ -1977,27 +2493,28 @@ class Model(Network): is a generator. Total number of steps (batches of samples) to yield from `generator` before stopping. Optional for `Sequence`: if unspecified, will use - `len(generator)` as a number of steps. + the `len(validation_data)` as a number of steps. class_weight: Dictionary mapping class indices to a weight for the class. - max_queue_size: Maximum size for the generator queue. + max_queue_size: Integer. Maximum size for the generator queue. + If unspecified, `max_queue_size` will default to 10. workers: Integer. Maximum number of processes to spin up when using process based threading. If unspecified, `workers` will default to 1. If 0, will execute the generator on the main thread. - use_multiprocessing: If True, use process based threading. + use_multiprocessing: Boolean. If True, use process based threading. + If unspecified, `workers` will default to False. Note that because this implementation relies on multiprocessing, you should not pass non picklable arguments to the generator as they can't be passed easily to children processes. - shuffle: Whether to shuffle the data at the beginning of each - epoch. Only used with instances of `Sequence` - (`keras.utils.Sequence`). + shuffle: Whether to shuffle the order of the batches at + the beginning of each epoch. Only used with instances + of `Sequence` (keras.utils.Sequence). initial_epoch: Epoch at which to start training (useful for resuming a previous training run) - **kwargs: support for legacy arguments. Returns: A `History` object. @@ -2018,23 +2535,13 @@ class Model(Network): model.fit_generator(generate_arrays_from_file('/my_file.txt'), steps_per_epoch=10000, epochs=10) ``` - Raises: ValueError: In case the generator yields data in an invalid format. """ - # Legacy support - if 'max_q_size' in kwargs: - max_queue_size = kwargs.pop('max_q_size') - logging.warning('The argument `max_q_size` has been renamed ' - '`max_queue_size`. Update your method calls accordingly.') - if 'pickle_safe' in kwargs: - use_multiprocessing = kwargs.pop('pickle_safe') - logging.warning('The argument `pickle_safe` has been renamed ' - '`use_multiprocessing`. ' - 'Update your method calls accordingly.') - if kwargs: - raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) + if not self._is_graph_network: + raise NotImplementedError( + '`fit_generator` is not yet enabled for Model subclasses') wait_time = 0.01 # in seconds epoch = initial_epoch @@ -2046,10 +2553,11 @@ class Model(Network): is_sequence = isinstance(generator, Sequence) if not is_sequence and use_multiprocessing and workers > 1: - logging.warning('Using a generator with `use_multiprocessing=True`' + logging.warning( + UserWarning('Using a generator with `use_multiprocessing=True`' ' and multiple workers may duplicate your data.' ' Please consider using the`keras.utils.Sequence' - ' class.') + ' class.')) if steps_per_epoch is None: if is_sequence: steps_per_epoch = len(generator) @@ -2073,8 +2581,8 @@ class Model(Network): ' the `keras.utils.Sequence` class.') # Prepare display labels. - out_labels = self._get_deduped_metrics_names() - callback_metrics = out_labels + ['val_' + n for n in out_labels] + out_labels = self.metrics_names + callback_metrics = out_labels + ['val_%s' % n for n in out_labels] # prepare callbacks self.history = cbks.History() @@ -2098,26 +2606,47 @@ class Model(Network): }) callbacks.on_train_begin() - if do_validation and not val_gen: - if len(validation_data) == 2: - val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence - val_sample_weight = None - elif len(validation_data) == 3: - val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence - else: - raise ValueError('`validation_data` should be a tuple ' - '`(val_x, val_y, val_sample_weight)` ' - 'or `(val_x, val_y)`. Found: ' + str(validation_data)) - val_x, val_y, val_sample_weights = self._standardize_user_data( - val_x, val_y, val_sample_weight) - val_data = val_x + val_y + val_sample_weights - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_data += [0.] - for cbk in callbacks: - cbk.validation_data = val_data enqueuer = None + val_enqueuer = None try: + if do_validation: + if val_gen: + if workers > 0: + if isinstance(validation_data, Sequence): + val_enqueuer = OrderedEnqueuer( + validation_data, use_multiprocessing=use_multiprocessing) + if validation_steps is None: + validation_steps = len(validation_data) + else: + val_enqueuer = GeneratorEnqueuer( + validation_data, + use_multiprocessing=use_multiprocessing, + wait_time=wait_time) + val_enqueuer.start(workers=workers, max_queue_size=max_queue_size) + validation_generator = val_enqueuer.get() + else: + validation_generator = validation_data + else: + if len(validation_data) == 2: + val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence + val_sample_weight = None + elif len(validation_data) == 3: + val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence + else: + raise ValueError( + '`validation_data` should be a tuple ' + '`(val_x, val_y, val_sample_weight)` ' + 'or `(val_x, val_y)`. Found: ' + str(validation_data)) + val_x, val_y, val_sample_weights = self._standardize_user_data( + val_x, val_y, val_sample_weight) + val_data = val_x + val_y + val_sample_weights + if self.uses_learning_phase and not isinstance( + K.learning_phase(), int): + val_data += [0.] + for cbk in callbacks: + cbk.validation_data = val_data + if workers > 0: if is_sequence: enqueuer = OrderedEnqueuer( @@ -2135,6 +2664,8 @@ class Model(Network): output_generator = generator callback_model.stop_training = False + # Construct epoch logs. + epoch_logs = {} while epoch < epochs: callbacks.on_epoch_begin(epoch) steps_done = 0 @@ -2178,8 +2709,6 @@ class Model(Network): callbacks.on_batch_end(batch_index, batch_logs) - # Construct epoch logs. - epoch_logs = {} batch_index += 1 steps_done += 1 @@ -2187,11 +2716,7 @@ class Model(Network): if steps_done >= steps_per_epoch and do_validation: if val_gen: val_outs = self.evaluate_generator( - validation_data, - validation_steps, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing) + validation_generator, validation_steps, workers=0) else: # No need for try/except because # data has already been validated. @@ -2216,8 +2741,12 @@ class Model(Network): break finally: - if enqueuer is not None: - enqueuer.stop() + try: + if enqueuer is not None: + enqueuer.stop() + finally: + if val_enqueuer is not None: + val_enqueuer.stop() callbacks.on_train_end() return self.history @@ -2227,8 +2756,7 @@ class Model(Network): steps=None, max_queue_size=10, workers=1, - use_multiprocessing=False, - **kwargs): + use_multiprocessing=False): """Evaluates the model on a data generator. The generator should return the same kind of data @@ -2256,7 +2784,6 @@ class Model(Network): non picklable arguments to the generator as they can't be passed easily to children processes. - **kwargs: support for legacy arguments. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -2264,22 +2791,16 @@ class Model(Network): and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. + Raises: + ValueError: in case of invalid arguments. + Raises: ValueError: In case the generator yields data in an invalid format. """ - # Legacy support - if 'max_q_size' in kwargs: - max_queue_size = kwargs.pop('max_q_size') - logging.warning('The argument `max_q_size` has been renamed ' - '`max_queue_size`. Update your method calls accordingly.') - if 'pickle_safe' in kwargs: - use_multiprocessing = kwargs.pop('pickle_safe') - logging.warning('The argument `pickle_safe` has been renamed ' - '`use_multiprocessing`. ' - 'Update your method calls accordingly.') - if kwargs: - raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) + if not self._is_graph_network: + raise NotImplementedError( + '`evaluate_generator` is not yet enabled for Model subclasses') self._make_test_function() @@ -2289,10 +2810,11 @@ class Model(Network): batch_sizes = [] is_sequence = isinstance(generator, Sequence) if not is_sequence and use_multiprocessing and workers > 1: - logging.warning('Using a generator with `use_multiprocessing=True`' + logging.warning( + UserWarning('Using a generator with `use_multiprocessing=True`' ' and multiple workers may duplicate your data.' ' Please consider using the`keras.utils.Sequence' - ' class.') + ' class.')) if steps is None: if is_sequence: steps = len(generator) @@ -2368,8 +2890,7 @@ class Model(Network): max_queue_size=10, workers=1, use_multiprocessing=False, - verbose=0, - **kwargs): + verbose=0): """Generates predictions for the input samples from a data generator. The generator should return the same kind of data as accepted by @@ -2377,9 +2898,9 @@ class Model(Network): Arguments: generator: Generator yielding batches of input samples - or an instance of Sequence (keras.utils.Sequence) - object in order to avoid duplicate data - when using multiprocessing. + or an instance of Sequence (keras.utils.Sequence) + object in order to avoid duplicate data + when using multiprocessing. steps: Total number of steps (batches of samples) to yield from `generator` before stopping. Optional for `Sequence`: if unspecified, will use @@ -2397,7 +2918,6 @@ class Model(Network): as they can't be passed easily to children processes. verbose: verbosity mode, 0 or 1. - **kwargs: support for legacy arguments. Returns: Numpy array(s) of predictions. @@ -2406,16 +2926,9 @@ class Model(Network): ValueError: In case the generator yields data in an invalid format. """ - # Legacy support - if 'max_q_size' in kwargs: - max_queue_size = kwargs.pop('max_q_size') - logging.warning('The argument `max_q_size` has been renamed ' - '`max_queue_size`. Update your method calls accordingly.') - if 'pickle_safe' in kwargs: - use_multiprocessing = kwargs.pop('pickle_safe') - logging.warning('The argument `pickle_safe` has been renamed ' - '`use_multiprocessing`. ' - 'Update your method calls accordingly.') + if not self._is_graph_network: + raise NotImplementedError( + '`predict_generator` is not yet enabled for Model subclasses') self._make_predict_function() @@ -2424,10 +2937,11 @@ class Model(Network): all_outs = [] is_sequence = isinstance(generator, Sequence) if not is_sequence and use_multiprocessing and workers > 1: - logging.warn('Using a generator with `use_multiprocessing=True`' - ' and multiple workers may duplicate your data.' - ' Please consider using the`keras.utils.Sequence' - ' class.') + logging.warning( + UserWarning('Using a generator with `use_multiprocessing=True`' + ' and multiple workers may duplicate your data.' + ' Please consider using the`keras.utils.Sequence' + ' class.')) if steps is None: if is_sequence: steps = len(generator) @@ -2498,6 +3012,6 @@ class Model(Network): else: return np.concatenate(all_outs[0]) if steps_done == 1: - return [out for out in all_outs] + return [out[0] for out in all_outs] else: return [np.concatenate(out) for out in all_outs] diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py new file mode 100644 index 0000000000000000000000000000000000000000..3507f36e14de28e1049895da5cbfd036dbb414f7 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py @@ -0,0 +1,628 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras training and evaluation routines. +""" +# pylint: disable=protected-access +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from tensorflow.python.eager.backprop import GradientTape +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import callbacks as cbks +from tensorflow.python.keras._impl.keras import losses +from tensorflow.python.keras._impl.keras import metrics as metrics_module +from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches +from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays + + +def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None): + if metric == 'accuracy' or metric == 'acc': + # custom handling of accuracy + # (because of class mode duality) + output_shape = internal_output_shapes + if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy: + # case: binary accuracy + acc_fn = metrics_module.binary_accuracy + elif loss_func == losses.sparse_categorical_crossentropy: + # case: categorical accuracy with sparse targets + acc_fn = metrics_module.sparse_categorical_accuracy + else: + acc_fn = metrics_module.categorical_accuracy + + metric_name = 'acc' + return metric_name, acc_fn + else: + metric_fn = metrics_module.get(metric) + metric_name = metric_fn.__name__ + return metric_name, metric_fn + + +def _eager_loss_fn(outputs, targets, loss_fn, output_name): + with K.name_scope(output_name + '_loss'): + loss = loss_fn(targets, outputs) + return loss + + +def _eager_metrics_fn(model, outputs, targets): + """Calculates the metrics for each output of the given model. + + Arguments: + model: The model on which metrics are being calculated. + outputs: The outputs of the given model. + targets: The predictions or targets of the given model. + + Returns: + Returns the metric names and metric results for each output of the model. + """ + metric_names = [] + metric_results = [] + if not isinstance(outputs, list): + outputs = [outputs] + + if not isinstance(targets, list): + targets = [targets] + + for i in range(len(model.outputs)): + output_metrics = model.nested_metrics[i] + for nested_output_metric in output_metrics: + metric_name, metric_fn = _get_metrics_info( + nested_output_metric, K.int_shape(model.outputs[i]), + model.loss_functions[i]) + + if len(model.output_names) > 1: + metric_name = model.output_names[i] + '_' + metric_name + if metric_name not in model.metrics_names: + model.metrics_names.append(metric_name) + + with K.name_scope(metric_name): + metric_result = metric_fn(outputs[i], targets[i]) + metric_names.append(metric_name) + metric_results.append(K.mean(metric_result)) + + return metric_names, metric_results + + +def _model_loss(model, inputs, targets, training=False): + """Calculates the loss for a given model. + + Arguments: + model: The model on which metrics are being calculated. + inputs: The inputs of the given model. This is typically the mini batch of + data that is fed to the model. + targets: The predictions or targets of the given model. + training: Whether the model should be run in inference or training mode. + + Returns: + Returns the model output, total loss and loss value calculated using the + specified loss function. The total loss includes regularization losses and + applies masking and sample weighting to the loss value. + """ + total_loss = 0 + if len(inputs) == 1: + if model._expects_training_arg: + outs = model.call(inputs[0], training=training) + else: + outs = model.call(inputs[0]) + else: + if model._expects_training_arg: + outs = model.call(inputs, training=training) + else: + outs = model.call(inputs) + if not isinstance(outs, list): + outs = [outs] + + if not isinstance(targets, list): + targets = [targets] + + loss_metrics = [] + with K.name_scope('loss'): + for i, loss_fn in enumerate(model.loss_functions): + # compute the loss + output_loss = _eager_loss_fn(outs[i], targets[i], loss_fn, + model.output_names[i]) + loss_metrics.append(K.mean(output_loss)) + + mask = outs[i]._keras_mask + # adapted from weighted_loss_fn + if mask is not None: + # mask should have the same shape as output_loss + output_loss *= mask + # the loss per batch should be proportional + # to the number of unmasked samples. + output_loss /= K.mean(mask) + + # adapted from weighted_loss_fn + # apply sample weighting + if model.sample_weights: + # reduce score_array to same ndim as weight array + ndim = K.ndim(output_loss) + weight_ndim = K.ndim(model.sample_weights) + output_loss = K.mean(output_loss, axis=list(range(weight_ndim, ndim))) + output_loss *= model.sample_weights + output_loss /= K.mean(K.cast(K.not_equal(model.sample_weights, 0), + K.floatx())) + output_loss = K.mean(output_loss) + + loss_weight = model.loss_weights_list[i] + if total_loss is None: + total_loss = loss_weight * output_loss + else: + total_loss += loss_weight * output_loss + + total_loss = K.mean(total_loss) + # Add regularization losses + custom_losses = [] + for layer in model.layers: + if layer.losses: + custom_losses += layer.losses + + if custom_losses: + total_loss += sum(custom_losses) + + return outs, total_loss, loss_metrics + + +def _process_single_batch(eager_model_inputs, eager_model_outputs, model, + training=False): + """Calculate the loss and gradient for one input batch. + + The model weights are updated if training is set to True. + + Arguments: + eager_model_inputs: Input batch data. + eager_model_outputs: Output batch data. + model: Model whose loss has to be calculated. + training: The boolean represents if the weights of the model are updated. + 'fit' methods will set this to True while 'evaluate' methods will + set this to False. + + Returns: + output of the model, total loss and the loss associated with each output. + + Raises: + ValueError: If the model loss is 0 or if the trainable weights list is + empty when the trainable parameter is set to True. + """ + K.set_learning_phase(training) + with GradientTape() as tape: + outs, loss, loss_metrics = _model_loss(model, eager_model_inputs, + eager_model_outputs, + training=training) + if loss is None: + raise ValueError('The model cannot be run ' + 'because it has no loss to optimize.') + if training: + if not model._collected_trainable_weights: + raise ValueError('The list of trainable weights is empty. Make sure that ' + 'you are not setting model.trainable to False before ' + 'compiling the model.') + grads = tape.gradient(loss, model._collected_trainable_weights) + model.optimizer.apply_gradients(zip(grads, + model._collected_trainable_weights)) + return outs, loss, loss_metrics + + +def train_on_batch(model, ins): + """Calculates the loss and gradient updates for one input batch. + + Arguments: + model: Given model on which loss and gradients are calculated. + ins: Input and output batch numpy arrays. + + Returns: + total loss and the loss associated with each output. + """ + ins_batch_converted = [] + for ib in ins: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + eager_model_inputs = [] + eager_model_outputs = [] + for i in range(len(model.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + for i in range(len(model.inputs), len(ins_batch_converted)): + eager_model_outputs.append(ins_batch_converted[i]) + outs, loss, _ = _process_single_batch( + eager_model_inputs, eager_model_outputs, model, training=True) + if not isinstance(outs, list): + outs = [outs] + _, metrics_results = _eager_metrics_fn( + model, outs, eager_model_outputs) + if not isinstance(loss, list): + loss = [loss] + return loss + metrics_results + + +def test_on_batch(model, ins): + """Calculates the loss for one input batch. + + Arguments: + model: Given model on which loss is calculated. + ins: Input and output batch numpy arrays. + + Returns: + total loss, loss and metrics associated with each output. + """ + ins_batch_converted = [] + for ib in ins: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + eager_model_inputs = [] + eager_model_outputs = [] + for i in range(len(model.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + for i in range(len(model.inputs), len(ins_batch_converted)): + eager_model_outputs.append(ins_batch_converted[i]) + outs, loss, loss_metrics = _process_single_batch( + eager_model_inputs, eager_model_outputs, model, training=False) + if not isinstance(outs, list): + outs = [outs] + metric_names, metrics_results = _eager_metrics_fn( + model, outs, eager_model_outputs) + model.metrics_names.append(metric_names) + if not isinstance(loss, list): + loss = [loss] + return loss + loss_metrics + metrics_results + + +def fit_loop( + model, + ins, + out_labels=None, + batch_size=None, + epochs=100, + verbose=1, + callbacks=None, + val_ins=None, + shuffle=True, + callback_metrics=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None): + """Abstract fit function for `f(ins)`. + + Assume that f returns a list, labeled by out_labels. + + Arguments: + model: Instance of the model that is being executed in Eager mode. + ins: List of tensors to be fed to `f` + out_labels: List of strings, display names of + the outputs of `f` + batch_size: Integer batch size or None if unknown. + epochs: Number of times to iterate over the data + verbose: Verbosity mode, 0, 1 or 2 + callbacks: List of callbacks to be called during training + val_ins: List of tensors to be fed to `val_f` + shuffle: Whether to shuffle the data at the beginning of each epoch + callback_metrics: List of strings, the display names of the metrics + passed to the callbacks. They should be the + concatenation of list the display names of the outputs of + `f` and the list of display names of the outputs of `f_val`. + initial_epoch: Epoch at which to start training + (useful for resuming a previous training run) + steps_per_epoch: Total number of steps (batches of samples) + before declaring one epoch finished and starting the + next epoch. Ignored with the default value of `None`. + validation_steps: Number of steps to run validation for (only if doing + validation from data tensors). Ignored with default value of `None`. + + Returns: + `History` object. + + Raises: + ValueError: In case of invalid argument values. + """ + # Required for Eager mode + K.set_learning_phase(True) + + do_validation = False + if val_ins: + do_validation = True + if (verbose and ins and hasattr(ins[0], 'shape') and + hasattr(val_ins[0], 'shape')): + print('Train on %d samples, validate on %d samples' % + (ins[0].shape[0], val_ins[0].shape[0])) + if validation_steps: + if steps_per_epoch is None: + raise ValueError('Can only use `validation_steps` when doing step-wise ' + 'training, i.e. `steps_per_epoch` must be set.') + do_validation = True + + num_train_samples = model._check_num_samples( + ins, batch_size, steps_per_epoch, 'steps_per_epoch') + + if num_train_samples is not None: + index_array = np.arange(num_train_samples) + + model.history = cbks.History() + callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] + if verbose: + if steps_per_epoch is not None: + count_mode = 'steps' + else: + count_mode = 'samples' + callbacks += [cbks.ProgbarLogger(count_mode)] + callbacks = cbks.CallbackList(callbacks) + out_labels = out_labels or [] + + # it's possible to callback a different model than self + # (used by Sequential models) + if hasattr(model, 'callback_model') and model.callback_model: + callback_model = model.callback_model + else: + callback_model = model + + callbacks.set_model(callback_model) + + callbacks.set_params({ + 'batch_size': batch_size, + 'epochs': epochs, + 'steps': steps_per_epoch, + 'samples': num_train_samples, + 'verbose': verbose, + 'do_validation': do_validation, + 'metrics': callback_metrics or [], + }) + callbacks.on_train_begin() + callback_model.stop_training = False + for cbk in callbacks: + cbk.validation_data = val_ins + + for epoch in range(initial_epoch, epochs): + callbacks.on_epoch_begin(epoch) + epoch_logs = {} + if shuffle == 'batch': + index_array = model._batch_shuffle(index_array, batch_size) + elif shuffle: + np.random.shuffle(index_array) + + batches = make_batches(num_train_samples, batch_size) + + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + try: + if isinstance(ins[-1], float): + # Do not slice the training phase flag. + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + else: + ins_batch = slice_arrays(ins, batch_ids) + except TypeError: + raise TypeError('TypeError while preparing batch. ' + 'If using HDF5 input data, ' + 'pass shuffle="batch".') + batch_logs = {} + batch_logs['batch'] = batch_index + batch_logs['size'] = len(batch_ids) + + callbacks.on_batch_begin(batch_index, batch_logs) + + ins_batch_converted = [] + for ib in ins_batch: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + eager_model_inputs = [] + eager_model_outputs = [] + for i in range(len(model.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + + for i in range(len(model.inputs), len(ins_batch_converted)): + eager_model_outputs.append(ins_batch_converted[i]) + + outs, loss, loss_metrics = _process_single_batch(eager_model_inputs, + eager_model_outputs, + model, + training=True) + + if not isinstance(outs, list): + outs = [outs] + + for l, o in zip(out_labels, outs): + batch_logs[l] = o + # Required for Eager mode + metrics_names, metrics_results = _eager_metrics_fn(model, outs, + eager_model_outputs) + batch_logs['loss'] = tensor_util.constant_value(K.mean(loss)) + + # TODO(anjalisridhar): Move this to compile to avoid duplicate code. + # In graph mode we set the metric names in compile. However in + # Eager mode we calculate the metrics for each batch in fit_loop. + # We could calculate the metric names and functions in compile. + # This would avoid setting the callback parameters separately. + # We need to do this for the first iteration alone + for m in metrics_names: + if m not in callback_metrics: + callback_metrics.append(m) + + callbacks.set_params({ + 'batch_size': batch_size, + 'epochs': epochs, + 'steps': steps_per_epoch, + 'samples': num_train_samples, + 'verbose': verbose, + 'do_validation': do_validation, + 'metrics': callback_metrics or [], + }) + + for k, v in zip(model.metrics_names, + [K.mean(loss)] + loss_metrics + metrics_results): + batch_logs[k] = tensor_util.constant_value(v) + + callbacks.on_batch_end(batch_index, batch_logs) + if callback_model.stop_training: + break + + if batch_index == len(batches) - 1: # Last batch. + if do_validation: + val_outs = test_loop( + model, val_ins, batch_size=batch_size, verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + callbacks.on_epoch_end(epoch, epoch_logs) + if callback_model.stop_training: + break + callbacks.on_train_end() + return model.history + + +def test_loop(model, ins, batch_size=None, verbose=0, steps=None): + """Abstract method to loop over some data in batches. + + Arguments: + model: Model instance that is being evaluated in Eager mode. + ins: list of tensors to be fed to `f`. + batch_size: integer batch size or `None`. + verbose: verbosity mode. + steps: Total number of steps (batches of samples) + before declaring predictions finished. + Ignored with the default value of `None`. + + Returns: + Scalar loss (if the model has a single output and no metrics) + or list of scalars (if the model has multiple outputs + and/or metrics). The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + """ + K.set_learning_phase(False) + num_samples = model._check_num_samples(ins, batch_size, steps, 'steps') + outs = [] + if verbose == 1: + progbar = Progbar(target=num_samples) + batches = make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + if isinstance(ins[-1], float): + # Do not slice the training phase flag. + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + else: + ins_batch = slice_arrays(ins, batch_ids) + + ins_batch_converted = [] + for ib in ins_batch: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + + eager_model_inputs = [] + eager_model_outputs = [] + for i in range(len(model.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + + for i in range(len(model.inputs), len(ins_batch_converted)): + eager_model_outputs.append(ins_batch_converted[i]) + + loss_outs, loss, loss_metrics = _model_loss(model, eager_model_inputs, + eager_model_outputs, + training=False) + _, metrics_results = _eager_metrics_fn(model, loss_outs, + eager_model_outputs) + batch_outs = [] + for _, v in zip(model.metrics_names, + [K.mean(loss)] + loss_metrics + metrics_results): + batch_outs.append(tensor_util.constant_value(v)) + + if isinstance(batch_outs, list): + if batch_index == 0: + for batch_out in enumerate(batch_outs): + outs.append(0.) + for i, batch_out in enumerate(batch_outs): + outs[i] += batch_out * len(batch_ids) + else: + if batch_index == 0: + outs.append(0.) + outs[0] += batch_outs * len(batch_ids) + + if verbose == 1: + progbar.update(batch_end) + for i in range(len(outs)): + outs[i] /= num_samples + if len(outs) == 1: + return outs[0] + return outs + + +def predict_loop(model, ins, batch_size=32, verbose=0, steps=None): + """Abstract method to loop over some data in batches. + + Arguments: + model: + ins: list of tensors to be fed to `f`. + batch_size: integer batch size. + verbose: verbosity mode. + steps: Total number of steps (batches of samples) + before declaring `_predict_loop` finished. + Ignored with the default value of `None`. + + Returns: + Array of predictions (if the model has a single output) + or list of arrays of predictions + (if the model has multiple outputs). + """ + K.set_learning_phase(False) + num_samples = model._check_num_samples(ins, batch_size, steps, 'steps') + if verbose == 1: + if steps is not None: + progbar = Progbar(target=steps) + else: + progbar = Progbar(target=num_samples) + + outs = [] + batches = make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + if ins and isinstance(ins[-1], float): + # Do not slice the training phase flag. + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + else: + ins_batch = slice_arrays(ins, batch_ids) + + ins_batch_converted = [] + for ib in ins_batch: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + + eager_model_inputs = [] + for i in range(len(model.inputs)): + eager_model_inputs.append(ins_batch_converted[i]) + + if len(eager_model_inputs) == 1: + if model._expects_training_arg: + batch_outs = model.call(eager_model_inputs[0], training=False) + else: + batch_outs = model.call(eager_model_inputs[0]) + else: + if model._expects_training_arg: + batch_outs = model.call(eager_model_inputs, training=False) + else: + batch_outs = model.call(eager_model_inputs) + + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if batch_index == 0: + # Pre-allocate the results arrays. + for batch_out in batch_outs: + dims = batch_out.shape[1:].dims + dims_list = [d.value for d in dims] + shape = (num_samples,) + tuple(dims_list) + outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype)) + for i, batch_out in enumerate(batch_outs): + outs[i][batch_start:batch_end] = batch_out + if verbose == 1: + progbar.update(batch_end) + if len(outs) == 1: + return outs[0] + return outs diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py new file mode 100644 index 0000000000000000000000000000000000000000..45601f964a090fd927a22eb525d3c1c154fd71db --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py @@ -0,0 +1,756 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for training routines.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.keras._impl import keras +from tensorflow.python.keras._impl.keras import testing_utils +from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays +from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer + + +class TrainingTest(test.TestCase): + + def test_fit_on_arrays(self): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + loss_weights = [1., 0.5] + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + + # Test fit at different verbosity + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5, + verbose=0) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5, + verbose=1) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=2, + batch_size=5, + verbose=2) + + # Test with validation data + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + validation_data=([input_a_np, input_b_np], [output_d_np, + output_e_np]), + epochs=1, + batch_size=5, + verbose=0) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + validation_data=([input_a_np, input_b_np], [output_d_np, + output_e_np]), + epochs=2, + batch_size=5, + verbose=1) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + validation_data=([input_a_np, input_b_np], [output_d_np, + output_e_np]), + epochs=2, + batch_size=5, + verbose=2) + model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) + + # Test with validation split + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=2, + batch_size=5, + verbose=0, + validation_split=0.2) + + # Test with dictionary inputs + model.fit( + { + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}, + epochs=1, + batch_size=5, + verbose=0) + model.fit( + { + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}, + epochs=1, + batch_size=5, + verbose=1) + model.fit( + { + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}, + validation_data=({'input_a': input_a_np, + 'input_b': input_b_np + }, + { + 'dense': output_d_np, + 'dropout': output_e_np + }), + epochs=1, + batch_size=5, + verbose=0) + model.train_on_batch({ + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}) + # Test with lists for loss, metrics + loss = ['mae', 'mse'] + metrics = ['acc', 'mae'] + model.compile(optimizer, loss, metrics=metrics) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5, + verbose=0) + + # Test with dictionaries for loss, metrics, loss weights + loss = {'dense': 'mse', 'dropout': 'mae'} + loss_weights = {'dense': 1., 'dropout': 0.5} + metrics = {'dense': 'mse', 'dropout': 'mae'} + model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5, + verbose=0) + + # Invalid use cases + with self.assertRaises(AttributeError): + model.fit( + [input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + validation_data=([input_a_np, input_b_np], 0, 0), + verbose=0) + with self.assertRaises(ValueError): + model.train_on_batch({'input_a': input_a_np}, + [output_d_np, output_e_np]) + with self.assertRaises(ValueError): + model.train_on_batch([input_a_np], [output_d_np, output_e_np]) + with self.assertRaises(AttributeError): + model.train_on_batch(1, [output_d_np, output_e_np]) + with self.assertRaises(ValueError): + model.train_on_batch(input_a_np, [output_d_np, output_e_np]) + with self.assertRaises(ValueError): + bad_input = np.random.random((11, 3)) + model.train_on_batch([bad_input, input_b_np], + [output_d_np, output_e_np]) + with self.assertRaises(ValueError): + bad_target = np.random.random((11, 4)) + model.train_on_batch([input_a_np, input_b_np], + [bad_target, output_e_np]) + + # Build single-input model + x = keras.layers.Input(shape=(3,), name='input_a') + y = keras.layers.Dense(4)(x) + model = keras.models.Model(x, y) + model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse') + # This will work + model.fit([input_a_np], output_d_np, epochs=1) + with self.assertRaises(ValueError): + model.fit([input_a_np, input_a_np], output_d_np, epochs=1) + + def test_evaluate_predict_on_arrays(self): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + loss_weights = [1., 0.5] + metrics = ['mae'] + model.compile( + optimizer, + loss, + metrics=metrics, + loss_weights=loss_weights, + sample_weight_mode=None) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + + # Test evaluate at different verbosity + out = model.evaluate( + [input_a_np, input_b_np], [output_d_np, output_e_np], + batch_size=5, + verbose=0) + self.assertEqual(len(out), 5) + out = model.evaluate( + [input_a_np, input_b_np], [output_d_np, output_e_np], + batch_size=5, + verbose=1) + self.assertEqual(len(out), 5) + out = model.evaluate( + [input_a_np, input_b_np], [output_d_np, output_e_np], + batch_size=5, + verbose=2) + self.assertEqual(len(out), 5) + out = model.test_on_batch([input_a_np, input_b_np], + [output_d_np, output_e_np]) + self.assertEqual(len(out), 5) + + # Test evaluate with dictionary inputs + model.evaluate( + { + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}, + batch_size=5, + verbose=0) + model.evaluate( + { + 'input_a': input_a_np, + 'input_b': input_b_np + }, {'dense': output_d_np, + 'dropout': output_e_np}, + batch_size=5, + verbose=1) + + # Test predict + out = model.predict([input_a_np, input_b_np], batch_size=5) + self.assertEqual(len(out), 2) + out = model.predict({'input_a': input_a_np, 'input_b': input_b_np}) + self.assertEqual(len(out), 2) + out = model.predict_on_batch({ + 'input_a': input_a_np, + 'input_b': input_b_np + }) + self.assertEqual(len(out), 2) + + def test_invalid_loss_or_metrics(self): + num_classes = 5 + train_samples = 1000 + test_samples = 1000 + input_dim = 5 + + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, input_shape=(input_dim,))) + model.add(keras.layers.Activation('relu')) + model.add(keras.layers.Dense(num_classes)) + model.add(keras.layers.Activation('softmax')) + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001)) + np.random.seed(1337) + + (x_train, y_train), (_, _) = testing_utils.get_test_data( + train_samples=train_samples, + test_samples=test_samples, + input_shape=(input_dim,), + num_classes=num_classes) + + with self.assertRaises(ValueError): + model.fit(x_train, np.concatenate([y_train, y_train], axis=-1)) + + with self.assertRaises(TypeError): + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=set(0)) + + with self.assertRaises(ValueError): + model.compile(loss=None, + optimizer='rms') + + +class LossWeightingTest(test.TestCase): + + def test_class_weights(self): + num_classes = 5 + batch_size = 5 + epochs = 5 + weighted_class = 3 + train_samples = 3000 + test_samples = 3000 + input_dim = 5 + + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, input_shape=(input_dim,))) + model.add(keras.layers.Activation('relu')) + model.add(keras.layers.Dense(num_classes)) + model.add(keras.layers.Activation('softmax')) + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001)) + + np.random.seed(1337) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=train_samples, + test_samples=test_samples, + input_shape=(input_dim,), + num_classes=num_classes) + int_y_test = y_test.copy() + int_y_train = y_train.copy() + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + test_ids = np.where(int_y_test == np.array(weighted_class))[0] + + class_weight = dict([(i, 1.) for i in range(num_classes)]) + class_weight[weighted_class] = 2. + + sample_weight = np.ones((y_train.shape[0])) + sample_weight[int_y_train == weighted_class] = 2. + + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs // 3, + verbose=0, + class_weight=class_weight, + validation_data=(x_train, y_train, sample_weight)) + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs // 2, + verbose=0, + class_weight=class_weight) + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs // 2, + verbose=0, + class_weight=class_weight, + validation_split=0.1) + + model.train_on_batch( + x_train[:batch_size], y_train[:batch_size], class_weight=class_weight) + ref_score = model.evaluate(x_test, y_test, verbose=0) + score = model.evaluate( + x_test[test_ids, :], y_test[test_ids, :], verbose=0) + self.assertLess(score, ref_score) + + def test_sample_weights(self): + num_classes = 5 + batch_size = 5 + epochs = 5 + weighted_class = 3 + train_samples = 3000 + test_samples = 3000 + input_dim = 5 + + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, input_shape=(input_dim,))) + model.add(keras.layers.Activation('relu')) + model.add(keras.layers.Dense(num_classes)) + model.add(keras.layers.Activation('softmax')) + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001)) + + np.random.seed(43) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=train_samples, + test_samples=test_samples, + input_shape=(input_dim,), + num_classes=num_classes) + int_y_test = y_test.copy() + int_y_train = y_train.copy() + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + test_ids = np.where(int_y_test == np.array(weighted_class))[0] + + class_weight = dict([(i, 1.) for i in range(num_classes)]) + class_weight[weighted_class] = 2. + + sample_weight = np.ones((y_train.shape[0])) + sample_weight[int_y_train == weighted_class] = 2. + + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs // 3, + verbose=0, + sample_weight=sample_weight) + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs // 3, + verbose=0, + sample_weight=sample_weight, + validation_split=0.1) + model.train_on_batch( + x_train[:batch_size], + y_train[:batch_size], + sample_weight=sample_weight[:batch_size]) + model.test_on_batch( + x_train[:batch_size], + y_train[:batch_size], + sample_weight=sample_weight[:batch_size]) + + def test_temporal_sample_weights(self): + num_classes = 5 + weighted_class = 3 + train_samples = 1000 + test_samples = 1000 + input_dim = 5 + timesteps = 3 + + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(num_classes), + input_shape=(timesteps, input_dim))) + model.add(keras.layers.Activation('softmax')) + + np.random.seed(1337) + (_, y_train), _ = testing_utils.get_test_data( + train_samples=train_samples, + test_samples=test_samples, + input_shape=(input_dim,), + num_classes=num_classes) + int_y_train = y_train.copy() + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + + class_weight = dict([(i, 1.) for i in range(num_classes)]) + class_weight[weighted_class] = 2. + + sample_weight = np.ones((y_train.shape[0])) + sample_weight[int_y_train == weighted_class] = 2. + with self.assertRaises(ValueError): + model.compile( + loss='binary_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001), + sample_weight_mode='temporal') + + def test_class_weight_invalid_use_case(self): + num_classes = 5 + train_samples = 1000 + test_samples = 1000 + input_dim = 5 + timesteps = 3 + + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(num_classes), + input_shape=(timesteps, input_dim))) + model.add(keras.layers.Activation('softmax')) + model.compile( + loss='binary_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001)) + + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=train_samples, + test_samples=test_samples, + input_shape=(input_dim,), + num_classes=num_classes) + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + class_weight = dict([(i, 1.) for i in range(num_classes)]) + + del class_weight[1] + with self.assertRaises(ValueError): + model.fit(x_train, y_train, + epochs=0, verbose=0, class_weight=class_weight) + + with self.assertRaises(ValueError): + model.compile( + loss='binary_crossentropy', + optimizer=RMSPropOptimizer(learning_rate=0.001), + sample_weight_mode=[]) + + # Build multi-output model + x = keras.Input((3,)) + y1 = keras.layers.Dense(4, name='1')(x) + y2 = keras.layers.Dense(4, name='2')(x) + model = keras.models.Model(x, [y1, y2]) + model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse') + x_np = np.random.random((10, 3)) + y_np = np.random.random((10, 4)) + w_np = np.random.random((10,)) + # This will work + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': w_np}) + # These will not + with self.assertRaises(ValueError): + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=[w_np]) + with self.assertRaises(TypeError): + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=w_np) + with self.assertRaises(ValueError): + bad_w_np = np.random.random((11,)) + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) + with self.assertRaises(ValueError): + bad_w_np = np.random.random((10, 2)) + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) + with self.assertRaises(ValueError): + bad_w_np = np.random.random((10, 2, 2)) + model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) + + +class TestDynamicTrainability(test.TestCase): + + def test_trainable_warning(self): + x = np.random.random((5, 3)) + y = np.random.random((5, 2)) + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_dim=3)) + model.trainable = False + model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') + model.trainable = True + with self.assertRaises(ValueError): + model.train_on_batch(x, y) + + def test_trainable_argument(self): + x = np.random.random((5, 3)) + y = np.random.random((5, 2)) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_dim=3, trainable=False)) + model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') + out = model.predict(x) + with self.assertRaises(ValueError): + model.train_on_batch(x, y) + out_2 = model.predict(x) + self.assertAllClose(out, out_2) + + # test with nesting + inputs = keras.layers.Input(shape=(3,)) + output = model(inputs) + model = keras.models.Model(inputs, output) + model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') + out = model.predict(x) + with self.assertRaises(ValueError): + model.train_on_batch(x, y) + out_2 = model.predict(x) + self.assertAllClose(out, out_2) + + def test_layer_trainability_switch(self): + # with constructor argument, in Sequential + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, trainable=False, input_dim=1)) + self.assertListEqual(model.trainable_weights, []) + + # by setting the `trainable` argument, in Sequential + model = keras.models.Sequential() + layer = keras.layers.Dense(2, input_dim=1) + model.add(layer) + self.assertListEqual(model.trainable_weights, layer.trainable_weights) + layer.trainable = False + self.assertListEqual(model.trainable_weights, []) + + # with constructor argument, in Model + x = keras.layers.Input(shape=(1,)) + y = keras.layers.Dense(2, trainable=False)(x) + model = keras.models.Model(x, y) + self.assertListEqual(model.trainable_weights, []) + + # by setting the `trainable` argument, in Model + x = keras.layers.Input(shape=(1,)) + layer = keras.layers.Dense(2) + y = layer(x) + model = keras.models.Model(x, y) + self.assertListEqual(model.trainable_weights, layer.trainable_weights) + layer.trainable = False + self.assertListEqual(model.trainable_weights, []) + + def test_model_trainability_switch(self): + # a non-trainable model has no trainable weights + x = keras.layers.Input(shape=(1,)) + y = keras.layers.Dense(2)(x) + model = keras.models.Model(x, y) + model.trainable = False + self.assertListEqual(model.trainable_weights, []) + + # same for Sequential + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_dim=1)) + model.trainable = False + self.assertListEqual(model.trainable_weights, []) + + def test_nested_model_trainability(self): + + # a Sequential inside a Model + inner_model = keras.models.Sequential() + inner_model.add(keras.layers.Dense(2, input_dim=1)) + + x = keras.layers.Input(shape=(1,)) + y = inner_model(x) + outer_model = keras.models.Model(x, y) + self.assertListEqual(outer_model.trainable_weights, + inner_model.trainable_weights) + inner_model.trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + inner_model.trainable = True + inner_model.layers[-1].trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + + # a Sequential inside a Sequential + inner_model = keras.models.Sequential() + inner_model.add(keras.layers.Dense(2, input_dim=1)) + outer_model = keras.models.Sequential() + outer_model.add(inner_model) + self.assertListEqual(outer_model.trainable_weights, + inner_model.trainable_weights) + inner_model.trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + inner_model.trainable = True + inner_model.layers[-1].trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + + # a Model inside a Model + x = keras.layers.Input(shape=(1,)) + y = keras.layers.Dense(2)(x) + inner_model = keras.models.Model(x, y) + x = keras.layers.Input(shape=(1,)) + y = inner_model(x) + outer_model = keras.models.Model(x, y) + self.assertListEqual(outer_model.trainable_weights, + inner_model.trainable_weights) + inner_model.trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + inner_model.trainable = True + inner_model.layers[-1].trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + + # a Model inside a Sequential + x = keras.layers.Input(shape=(1,)) + y = keras.layers.Dense(2)(x) + inner_model = keras.models.Model(x, y) + outer_model = keras.models.Sequential() + outer_model.add(inner_model) + self.assertListEqual(outer_model.trainable_weights, + inner_model.trainable_weights) + inner_model.trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + inner_model.trainable = True + inner_model.layers[-1].trainable = False + self.assertListEqual(outer_model.trainable_weights, []) + + +class TestTrainingUtils(test.TestCase): + + def test_check_array_lengths(self): + keras.engine.training._check_array_lengths(None, None, None) + a_np = np.random.random((4, 3, 3)) + keras.engine.training._check_array_lengths(a_np, a_np, a_np) + keras.engine.training._check_array_lengths( + [a_np, a_np], [a_np, a_np], [a_np, a_np]) + keras.engine.training._check_array_lengths([None], [None], [None]) + + b_np = np.random.random((3, 4)) + with self.assertRaises(ValueError): + keras.engine.training._check_array_lengths(a_np, None, None) + with self.assertRaises(ValueError): + keras.engine.training._check_array_lengths(a_np, a_np, None) + with self.assertRaises(ValueError): + keras.engine.training._check_array_lengths([a_np], [None], None) + with self.assertRaises(ValueError): + keras.engine.training._check_array_lengths([a_np], [b_np], None) + with self.assertRaises(ValueError): + keras.engine.training._check_array_lengths([a_np], None, [b_np]) + + def test_slice_arrays(self): + input_a = np.random.random((10, 3)) + slice_arrays(None) + slice_arrays(input_a, 0) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) + input_a = [None, [1, 1], None, [1, 1]] + slice_arrays(input_a, 0) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) + input_a = [None] + slice_arrays(input_a, 0) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) + input_a = None + slice_arrays(input_a, 0) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) + + def test_fit_with_BatchNorm(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, input_dim=4)) + model.add(keras.layers.BatchNormalization()) + model.add(keras.layers.Activation('tanh')) + model.add(keras.layers.Dropout(0.2)) + + input_a_np = np.random.random((10, 4)) + output_b_np = np.random.random((10, 10)) + + model.compile(loss='binary_crossentropy', optimizer=RMSPropOptimizer(0.001)) + model.fit(input_a_np, output_b_np, epochs=1, batch_size=5, verbose=0) + + def test_fit_with_regularization(self): + model = keras.models.Sequential() + with self.assertRaises(ValueError): + model.add( + keras.layers.Dense(4, input_dim=3, + kernel_regularizer=keras.regularizers.l2(0.01), + activity_regularizer=keras.regularizers.l1(0.01))) + + +if __name__ == '__main__': + # Bazel sets these environment variables to very long paths. + # Tempfile uses them to create long paths, and in turn multiprocessing + # library tries to create sockets named after paths. Delete whatever bazel + # writes to these to avoid tests failing due to socket addresses being too + # long. + for var in ('TMPDIR', 'TMP', 'TEMP'): + if var in os.environ: + del os.environ[var] + + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index 7650bfb6e80aa581f7c14f3c693106bcd6e73740..9651eb9f14f1275dc79c8d3b1fb54690772086a1 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -26,8 +26,14 @@ import numpy as np from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.keras._impl.keras.engine.training import _weighted_masked_objective +from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays from tensorflow.python.platform import test +try: + import scipy.sparse as scipy_sparse # pylint: disable=g-import-not-at-top +except ImportError: + scipy_sparse = None + class TrainingTest(test.TestCase): @@ -73,6 +79,14 @@ class TrainingTest(test.TestCase): verbose=2) model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) + # Test model with input data as a list of lists + model.fit( + [np.ndarray.tolist(input_a_np), np.ndarray.tolist(input_b_np)], + [output_d_np, output_e_np], + epochs=2, + batch_size=5, + verbose=2) + # Test with validation data model.fit( [input_a_np, input_b_np], [output_d_np, output_e_np], @@ -169,7 +183,7 @@ class TrainingTest(test.TestCase): with self.assertRaises(ValueError): model.train_on_batch({'input_a': input_a_np}, [output_d_np, output_e_np]) - with self.assertRaises(TypeError): + with self.assertRaises(AttributeError): model.fit( [input_a_np, input_b_np], [output_d_np, output_e_np], epochs=1, @@ -177,7 +191,7 @@ class TrainingTest(test.TestCase): verbose=0) with self.assertRaises(ValueError): model.train_on_batch([input_a_np], [output_d_np, output_e_np]) - with self.assertRaises(TypeError): + with self.assertRaises(AttributeError): model.train_on_batch(1, [output_d_np, output_e_np]) with self.assertRaises(ValueError): model.train_on_batch(input_a_np, [output_d_np, output_e_np]) @@ -200,6 +214,16 @@ class TrainingTest(test.TestCase): with self.assertRaises(ValueError): model.fit([input_a_np, input_a_np], output_d_np, epochs=1) + # Test model on a list of floats + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 4)) + + model.fit([np.ndarray.tolist(input_a_np)], + [np.ndarray.tolist(input_b_np)], + epochs=2, + batch_size=5, + verbose=2) + def test_evaluate_predict_on_arrays(self): with self.test_session(): a = keras.layers.Input(shape=(3,), name='input_a') @@ -312,6 +336,63 @@ class TrainingTest(test.TestCase): model.compile(loss=None, optimizer='rmsprop') + def test_training_on_sparse_data_with_dense_placeholders(self): + if scipy_sparse is None: + return + + test_inputs = [ + scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)] + test_outputs = [ + scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)] + in1 = keras.layers.Input(shape=(3,)) + in2 = keras.layers.Input(shape=(3,)) + out1 = keras.layers.Dropout(0.5, name='dropout')(in1) + out2 = keras.layers.Dense(4, name='dense_1')(in2) + model = keras.Model([in1, in2], [out1, out2]) + model.predict(test_inputs, batch_size=2) + model.compile('rmsprop', 'mse') + model.fit(test_inputs, test_outputs, + epochs=1, batch_size=2, validation_split=0.5) + model.evaluate(test_inputs, test_outputs, batch_size=2) + + def test_that_trainable_disables_updates(self): + val_a = np.random.random((10, 4)) + val_out = np.random.random((10, 4)) + + with self.test_session(): + a = keras.layers.Input(shape=(4,)) + layer = keras.layers.BatchNormalization(input_shape=(4,)) + b = layer(a) + model = keras.Model(a, b) + + model.trainable = False + assert not model.updates + + model.compile('sgd', 'mse') + assert not model.updates + + x1 = model.predict(val_a) + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + self.assertAllClose(x1, x2, atol=1e-7) + + model.trainable = True + model.compile('sgd', 'mse') + assert model.updates + + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + assert np.abs(np.sum(x1 - x2)) > 1e-5 + + layer.trainable = False + model.compile('sgd', 'mse') + assert not model.updates + + x1 = model.predict(val_a) + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + self.assertAllClose(x1, x2, atol=1e-7) + class LossWeightingTest(test.TestCase): @@ -869,25 +950,6 @@ class TestGeneratorMethods(test.TestCase): use_multiprocessing=False, workers=0) - # Test legacy API - model.fit_generator(custom_generator(), - steps_per_epoch=5, - epochs=1, - verbose=1, - max_q_size=10, - workers=4, - pickle_safe=True) - model.predict_generator(custom_generator(), - steps=5, - max_q_size=10, - workers=2, - pickle_safe=True) - model.evaluate_generator(custom_generator(), - steps=5, - max_q_size=10, - workers=2, - pickle_safe=True) - def test_generator_methods_with_sample_weights(self): arr_data = np.random.random((50, 2)) arr_labels = np.random.random((50,)) @@ -960,7 +1022,7 @@ class TestGeneratorMethods(test.TestCase): use_multiprocessing=False, validation_data=custom_generator(), validation_steps=10) - with self.assertRaises(TypeError): + with self.assertRaises(AttributeError): model.predict_generator(custom_generator(), steps=5, max_queue_size=10, @@ -996,22 +1058,22 @@ class TestTrainingUtils(test.TestCase): def test_slice_arrays(self): input_a = np.random.random((10, 3)) - keras.engine.training._slice_arrays(None) - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) + slice_arrays(input_a, 0) + slice_arrays(None) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) input_a = [None, [1, 1], None, [1, 1]] - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) + slice_arrays(input_a, 0) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) input_a = [None] - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) + slice_arrays(input_a, 0) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) input_a = None - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) + slice_arrays(input_a, 0) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) class TestTrainingWithDataTensors(test.TestCase): diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py index 624e92a04b8860d9a3974f2edb4a443482958259..db0140c2df4d20f9e18e6c1401c6c6aa197bcf1f 100644 --- a/tensorflow/python/keras/_impl/keras/estimator.py +++ b/tensorflow/python/keras/_impl/keras/estimator.py @@ -37,6 +37,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util +from tensorflow.python.util.tf_export import tf_export _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -244,6 +245,7 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects, saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt')) +@tf_export('keras.estimator.model_to_estimator') def model_to_estimator(keras_model=None, keras_model_path=None, custom_objects=None, diff --git a/tensorflow/python/keras/_impl/keras/initializers.py b/tensorflow/python/keras/_impl/keras/initializers.py index 8752faa534a3d6094ce530e490571ff939f86dbb..300bed5e1437074d010760c427c14f68e58ac363 100644 --- a/tensorflow/python/keras/_impl/keras/initializers.py +++ b/tensorflow/python/keras/_impl/keras/initializers.py @@ -32,8 +32,10 @@ from tensorflow.python.ops.init_ops import RandomUniform from tensorflow.python.ops.init_ops import TruncatedNormal from tensorflow.python.ops.init_ops import VarianceScaling from tensorflow.python.ops.init_ops import Zeros +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.initializers.lecun_normal') def lecun_normal(seed=None): """LeCun normal initializer. @@ -56,6 +58,7 @@ def lecun_normal(seed=None): scale=1., mode='fan_in', distribution='normal', seed=seed) +@tf_export('keras.initializers.lecun_uniform') def lecun_uniform(seed=None): """LeCun uniform initializer. @@ -77,6 +80,7 @@ def lecun_uniform(seed=None): scale=1., mode='fan_in', distribution='uniform', seed=seed) +@tf_export('keras.initializers.glorot_normal') def glorot_normal(seed=None): """Glorot normal initializer, also called Xavier normal initializer. @@ -99,6 +103,7 @@ def glorot_normal(seed=None): scale=1., mode='fan_avg', distribution='normal', seed=seed) +@tf_export('keras.initializers.glorot_uniform') def glorot_uniform(seed=None): """Glorot uniform initializer, also called Xavier uniform initializer. @@ -121,6 +126,7 @@ def glorot_uniform(seed=None): scale=1., mode='fan_avg', distribution='uniform', seed=seed) +@tf_export('keras.initializers.he_normal') def he_normal(seed=None): """He normal initializer. @@ -141,6 +147,7 @@ def he_normal(seed=None): scale=2., mode='fan_in', distribution='normal', seed=seed) +@tf_export('keras.initializers.he_uniform') def he_uniform(seed=None): """He uniform variance scaling initializer. @@ -178,10 +185,12 @@ orthogonal = Orthogonal # Utility functions +@tf_export('keras.initializers.serialize') def serialize(initializer): return serialize_keras_object(initializer) +@tf_export('keras.initializers.deserialize') def deserialize(config, custom_objects=None): return deserialize_keras_object( config, @@ -190,6 +199,7 @@ def deserialize(config, custom_objects=None): printable_module_name='initializer') +@tf_export('keras.initializers.get') def get(identifier): if isinstance(identifier, dict): return deserialize(identifier) @@ -199,4 +209,5 @@ def get(identifier): elif callable(identifier): return identifier else: - raise ValueError('Could not interpret initializer identifier:', identifier) + raise ValueError('Could not interpret initializer identifier: ' + + str(identifier)) diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py index e4b9afd38aa21924693f32b5d0fdf64a97019bce..7cac17c51a9adcf8fc62154b6633de60bab18387 100644 --- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py @@ -14,20 +14,22 @@ # ============================================================================== """Layers that act as activation functions. """ - from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras._impl.keras import activations from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.LeakyReLU') class LeakyReLU(Layer): """Leaky version of a Rectified Linear Unit. @@ -61,10 +63,12 @@ class LeakyReLU(Layer): base_config = super(LeakyReLU, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @shape_type_conversion def compute_output_shape(self, input_shape): return input_shape +@tf_export('keras.layers.PReLU') class PReLU(Layer): """Parametric Rectified Linear Unit. @@ -114,9 +118,9 @@ class PReLU(Layer): else: self.shared_axes = list(shared_axes) + @shape_type_conversion def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() - param_shape = input_shape[1:] + param_shape = list(input_shape[1:]) self.param_broadcast = [False] * len(param_shape) if self.shared_axes is not None: for i in self.shared_axes: @@ -140,15 +144,13 @@ class PReLU(Layer): def call(self, inputs, mask=None): pos = K.relu(inputs) if K.backend() == 'theano': - neg = (K.pattern_broadcast(self.alpha, self.param_broadcast) * - (inputs - K.abs(inputs)) * 0.5) + neg = ( + K.pattern_broadcast(self.alpha, self.param_broadcast) * + (inputs - K.abs(inputs)) * 0.5) else: neg = -self.alpha * K.relu(-inputs) return pos + neg - def compute_output_shape(self, input_shape): - return input_shape - def get_config(self): config = { 'alpha_initializer': initializers.serialize(self.alpha_initializer), @@ -159,7 +161,12 @@ class PReLU(Layer): base_config = super(PReLU, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @shape_type_conversion + def compute_output_shape(self, input_shape): + return input_shape + +@tf_export('keras.layers.ELU') class ELU(Layer): """Exponential Linear Unit. @@ -188,15 +195,17 @@ class ELU(Layer): def call(self, inputs): return K.elu(inputs, self.alpha) - def compute_output_shape(self, input_shape): - return input_shape - def get_config(self): config = {'alpha': float(self.alpha)} base_config = super(ELU, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @shape_type_conversion + def compute_output_shape(self, input_shape): + return input_shape + +@tf_export('keras.layers.ThresholdedReLU') class ThresholdedReLU(Layer): """Thresholded Rectified Linear Unit. @@ -223,12 +232,47 @@ class ThresholdedReLU(Layer): self.theta = K.cast_to_floatx(theta) def call(self, inputs, mask=None): - return inputs * K.cast(inputs > self.theta, K.floatx()) + return inputs * K.cast(K.greater(inputs, self.theta), K.floatx()) + + def get_config(self): + config = {'theta': float(self.theta)} + base_config = super(ThresholdedReLU, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + @shape_type_conversion def compute_output_shape(self, input_shape): return input_shape + +@tf_export('keras.layers.Softmax') +class Softmax(Layer): + """Softmax activation function. + + Input shape: + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape: + Same shape as the input. + + Arguments: + axis: Integer, axis along which the softmax normalization is applied. + """ + + def __init__(self, axis=-1, **kwargs): + super(Softmax, self).__init__(**kwargs) + self.supports_masking = True + self.axis = axis + + def call(self, inputs): + return activations.softmax(inputs, axis=self.axis) + def get_config(self): - config = {'theta': float(self.theta)} - base_config = super(ThresholdedReLU, self).get_config() + config = {'axis': self.axis} + base_config = super(Softmax, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + @shape_type_conversion + def compute_output_shape(self, input_shape): + return input_shape diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py index 91efab30edf99901b25dc0085b7d49e70d1b6d6d..343b7949accf3f0c9ddc5245910aa5faad8335c6 100644 --- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py @@ -56,6 +56,12 @@ class AdvancedActivationsTest(test.TestCase): kwargs={'theta': 0.5}, input_shape=(2, 3, 4)) + def test_softmax(self): + with self.test_session(): + testing_utils.layer_test(keras.layers.Softmax, + kwargs={'axis': 1}, + input_shape=(2, 3, 4)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py index f0f5e1fb463b828afe0f5bc19369408c98b57a08..162ae6c28f1afae1dd8aaf70213b808d9ad9598f 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py @@ -38,8 +38,10 @@ from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D # pylint: enable=unused-import from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.layers import convolutional as tf_convolutional_layers +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.Conv1D', 'keras.layers.Convolution1D') class Conv1D(tf_convolutional_layers.Conv1D, Layer): """1D convolution layer (e.g. temporal convolution). @@ -58,7 +60,7 @@ class Conv1D(tf_convolutional_layers.Conv1D, Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of a single integer, specifying the length of the 1D convolution window. strides: An integer or tuple/list of a single integer, @@ -153,6 +155,7 @@ class Conv1D(tf_convolutional_layers.Conv1D, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Conv2D', 'keras.layers.Convolution2D') class Conv2D(tf_convolutional_layers.Conv2D, Layer): """2D convolution layer (e.g. spatial convolution over images). @@ -170,7 +173,7 @@ class Conv2D(tf_convolutional_layers.Conv2D, Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the width and height of the 2D convolution window. Can be a single integer to specify the same value for @@ -286,6 +289,7 @@ class Conv2D(tf_convolutional_layers.Conv2D, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Conv3D', 'keras.layers.Convolution3D') class Conv3D(tf_convolutional_layers.Conv3D, Layer): """3D convolution layer (e.g. spatial convolution over volumes). @@ -304,7 +308,7 @@ class Conv3D(tf_convolutional_layers.Conv3D, Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 3 integers, specifying the depth, height and width of the 3D convolution window. Can be a single integer to specify the same value for @@ -426,6 +430,8 @@ class Conv3D(tf_convolutional_layers.Conv3D, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Conv2DTranspose', + 'keras.layers.Convolution2DTranspose') class Conv2DTranspose(tf_convolutional_layers.Conv2DTranspose, Layer): """Transposed convolution layer (sometimes called Deconvolution). @@ -563,6 +569,8 @@ class Conv2DTranspose(tf_convolutional_layers.Conv2DTranspose, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Conv3DTranspose', + 'keras.layers.Convolution3DTranspose') class Conv3DTranspose(tf_convolutional_layers.Conv3DTranspose, Layer): """Transposed convolution layer (sometimes called Deconvolution). @@ -711,6 +719,148 @@ class Conv3DTranspose(tf_convolutional_layers.Conv3DTranspose, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.SeparableConv1D', + 'keras.layers.SeparableConvolution1D') +class SeparableConv1D(tf_convolutional_layers.SeparableConv1D, Layer): + """Depthwise separable 1D convolution. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. + If `use_bias` is True and a bias initializer is provided, + it adds a bias vector to the output. + It then optionally applies an activation function to produce the final output. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A single integer specifying the spatial + dimensions of the filters. + strides: A single integer specifying the strides + of the convolution. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: A single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + depth_multiplier: The number of depthwise convolution output channels for + each input channel. The total number of depthwise convolution output + channels will be equal to `num_filters_in * depth_multiplier`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + depthwise_initializer: An initializer for the depthwise convolution kernel. + pointwise_initializer: An initializer for the pointwise convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used for + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, + filters, + kernel_size, + strides=1, + padding='valid', + data_format=None, + dilation_rate=1, + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer='glorot_uniform', + pointwise_initializer='glorot_uniform', + bias_initializer='zeros', + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + **kwargs): + if data_format is None: + data_format = K.image_data_format() + super(SeparableConv1D, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activations.get(activation), + use_bias=use_bias, + depthwise_initializer=initializers.get(depthwise_initializer), + pointwise_initializer=initializers.get(pointwise_initializer), + bias_initializer=initializers.get(bias_initializer), + depthwise_regularizer=regularizers.get(depthwise_regularizer), + pointwise_regularizer=regularizers.get(pointwise_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + depthwise_constraint=constraints.get(depthwise_constraint), + pointwise_constraint=constraints.get(pointwise_constraint), + bias_constraint=constraints.get(bias_constraint), + **kwargs) + + def get_config(self): + config = { + 'filters': self.filters, + 'kernel_size': self.kernel_size, + 'strides': self.strides, + 'padding': self.padding, + 'data_format': self.data_format, + 'dilation_rate': self.dilation_rate, + 'activation': activations.serialize(self.activation), + 'use_bias': self.use_bias, + 'depthwise_initializer': + initializers.serialize(self.depthwise_initializer), + 'pointwise_initializer': + initializers.serialize(self.pointwise_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'depthwise_regularizer': + regularizers.serialize(self.depthwise_regularizer), + 'pointwise_regularizer': + regularizers.serialize(self.pointwise_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), + 'activity_regularizer': + regularizers.serialize(self.activity_regularizer), + 'depthwise_constraint': + constraints.serialize(self.depthwise_constraint), + 'pointwise_constraint': + constraints.serialize(self.pointwise_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint) + } + base_config = super(SeparableConv1D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@tf_export('keras.layers.SeparableConv2D', + 'keras.layers.SeparableConvolution2D') class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): """Depthwise separable 2D convolution. @@ -727,7 +877,7 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the width and height of the 2D convolution window. Can be a single integer to specify the same value for @@ -874,6 +1024,7 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.UpSampling1D') class UpSampling1D(Layer): """Upsampling layer for 1D inputs. @@ -909,6 +1060,7 @@ class UpSampling1D(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.UpSampling2D') class UpSampling2D(Layer): """Upsampling layer for 2D inputs. @@ -976,6 +1128,7 @@ class UpSampling2D(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.UpSampling3D') class UpSampling3D(Layer): """Upsampling layer for 3D inputs. @@ -1048,6 +1201,7 @@ class UpSampling3D(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.ZeroPadding1D') class ZeroPadding1D(Layer): """Zero-padding layer for 1D input (e.g. temporal sequence). @@ -1088,6 +1242,7 @@ class ZeroPadding1D(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.ZeroPadding2D') class ZeroPadding2D(Layer): """Zero-padding layer for 2D input (e.g. picture). @@ -1189,6 +1344,7 @@ class ZeroPadding2D(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.ZeroPadding3D') class ZeroPadding3D(Layer): """Zero-padding layer for 3D data (spatial or spatio-temporal). @@ -1306,6 +1462,7 @@ class ZeroPadding3D(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Cropping1D') class Cropping1D(Layer): """Cropping layer for 1D input (e.g. temporal sequence). @@ -1350,6 +1507,7 @@ class Cropping1D(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Cropping2D') class Cropping2D(Layer): """Cropping layer for 2D input (e.g. picture). @@ -1481,6 +1639,7 @@ class Cropping2D(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Cropping3D') class Cropping3D(Layer): """Cropping layer for 3D data (e.g. @@ -1663,6 +1822,7 @@ class Cropping3D(Layer): Convolution1D = Conv1D Convolution2D = Conv2D Convolution3D = Conv3D +SeparableConvolution1D = SeparableConv1D SeparableConvolution2D = SeparableConv2D Convolution2DTranspose = Conv2DTranspose Convolution3DTranspose = Conv3DTranspose diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py index 4f0e9fc691133ae7f9a7834e17379cb8e25a8a2c..d2792b9636214d21e9658018f853fb6c0808abb4 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py @@ -20,15 +20,16 @@ from __future__ import print_function import numpy as np -from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import activations from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec +from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent from tensorflow.python.keras._impl.keras.utils import conv_utils +from tensorflow.python.util.tf_export import tf_export class ConvRecurrent2D(Recurrent): @@ -38,7 +39,7 @@ class ConvRecurrent2D(Recurrent): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of n integers, specifying the dimensions of the convolution window. strides: An integer or tuple/list of n integers, @@ -127,10 +128,10 @@ class ConvRecurrent2D(Recurrent): self.input_spec = [InputSpec(ndim=5)] self.state_spec = None + @shape_type_conversion def compute_output_shape(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] - input_shape = tensor_shape.TensorShape(input_shape).as_list() if self.data_format == 'channels_first': rows = input_shape[3] cols = input_shape[4] @@ -151,30 +152,28 @@ class ConvRecurrent2D(Recurrent): dilation=self.dilation_rate[1]) if self.return_sequences: if self.data_format == 'channels_first': - output_shape = [input_shape[0], input_shape[1], - self.filters, rows, cols] + output_shape = (input_shape[0], input_shape[1], self.filters, rows, + cols) elif self.data_format == 'channels_last': - output_shape = [input_shape[0], input_shape[1], - rows, cols, self.filters] + output_shape = (input_shape[0], input_shape[1], rows, cols, + self.filters) else: if self.data_format == 'channels_first': - output_shape = [input_shape[0], self.filters, rows, cols] + output_shape = (input_shape[0], self.filters, rows, cols) elif self.data_format == 'channels_last': - output_shape = [input_shape[0], rows, cols, self.filters] + output_shape = (input_shape[0], rows, cols, self.filters) if self.return_state: if self.data_format == 'channels_first': - output_shapes = [output_shape] + [(input_shape[0], - self.filters, - rows, - cols) for _ in range(2)] + output_shape = [output_shape] + [ + (input_shape[0], self.filters, rows, cols) for _ in range(2) + ] elif self.data_format == 'channels_last': - output_shapes = [output_shape] + [(input_shape[0], - rows, - cols, - self.filters) for _ in range(2)] - return [tensor_shape.TensorShape(shape) for shape in output_shapes] - return tensor_shape.TensorShape(output_shape) + output_shape = [output_shape] + [ + (input_shape[0], rows, cols, self.filters) for _ in range(2) + ] + + return output_shape def get_config(self): config = { @@ -192,6 +191,7 @@ class ConvRecurrent2D(Recurrent): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.ConvLSTM2D') class ConvLSTM2D(ConvRecurrent2D): """Convolutional LSTM. @@ -200,7 +200,7 @@ class ConvLSTM2D(ConvRecurrent2D): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of n integers, specifying the dimensions of the convolution window. strides: An integer or tuple/list of n integers, @@ -294,11 +294,6 @@ class ConvLSTM2D(ConvRecurrent2D): Raises: ValueError: in case of invalid constructor arguments. - References: - - [Convolutional LSTM Network: A Machine Learning Approach for - Precipitation Nowcasting](http://arxiv.org/abs/1506.04214v1) - The current implementation does not include the feedback loop on the - cells output """ def __init__(self, @@ -338,7 +333,6 @@ class ConvLSTM2D(ConvRecurrent2D): return_sequences=return_sequences, go_backwards=go_backwards, stateful=stateful, - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) self.activation = activations.get(activation) self.recurrent_activation = activations.get(recurrent_activation) @@ -352,6 +346,7 @@ class ConvLSTM2D(ConvRecurrent2D): self.kernel_regularizer = regularizers.get(kernel_regularizer) self.recurrent_regularizer = regularizers.get(recurrent_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.recurrent_constraint = constraints.get(recurrent_constraint) @@ -361,13 +356,12 @@ class ConvLSTM2D(ConvRecurrent2D): self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.state_spec = [InputSpec(ndim=4), InputSpec(ndim=4)] + @shape_type_conversion def build(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] - input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) batch_size = input_shape[0] if self.stateful else None self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:]) - if self.stateful: self.reset_states() else: diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py index be7da6f2b409aa57e3f1328441f0e37ede924c11..4a6228121b4f8839daa98e35748b2c5867ccca96 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy + import numpy as np from tensorflow.python.keras._impl import keras @@ -27,45 +29,39 @@ from tensorflow.python.platform import test class Convolution1DTest(test.TestCase): - def test_dilated_conv1d(self): - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Conv1D, - input_data=np.reshape(np.arange(4, dtype='float32'), (1, 4, 1)), - kwargs={ - 'filters': 1, - 'kernel_size': 2, - 'dilation_rate': 1, - 'padding': 'valid', - 'kernel_initializer': 'ones', - 'use_bias': False, - }, - expected_output=[[[1], [3], [5]]]) - - def test_conv_1d(self): - batch_size = 2 - steps = 8 - input_dim = 2 - kernel_size = 3 - filters = 3 + def _run_test(self, kwargs, arg, values): + num_samples = 2 + stack_size = 3 + length = 7 - for padding in ['valid', 'same']: - for strides in [1, 2]: - if padding == 'same' and strides != 1: - continue + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.Conv1D, + kwargs=test_kwargs, + input_shape=(num_samples, length, stack_size)) + + def test_conv1d(self): + kwargs = { + 'filters': 2, + 'kernel_size': 3, + } + + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [2]) + self._run_test(kwargs, 'dilation_rate', [2]) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Conv1D, - kwargs={ - 'filters': filters, - 'kernel_size': kernel_size, - 'padding': padding, - 'strides': strides - }, - input_shape=(batch_size, steps, input_dim)) - - def test_conv_1d_regularizers(self): + kwargs = { + 'filters': 2, + 'kernel_size': 3, + 'padding': 'same', + } + self._run_test(kwargs, 'dilation_rate', [2]) + self._run_test(kwargs, 'dilation_rate', [3]) + + def test_conv1d_regularizers(self): kwargs = { 'filters': 3, 'kernel_size': 3, @@ -82,7 +78,7 @@ class Convolution1DTest(test.TestCase): layer(keras.backend.variable(np.ones((1, 5, 2)))) self.assertEqual(len(layer.losses), 3) - def test_conv_1d_constraints(self): + def test_conv1d_constraints(self): k_constraint = lambda x: x b_constraint = lambda x: x @@ -103,35 +99,43 @@ class Convolution1DTest(test.TestCase): class Conv2DTest(test.TestCase): - def test_convolution_2d(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 2 stack_size = 3 - kernel_size = (3, 2) num_row = 7 num_col = 6 - for padding in ['valid', 'same']: - for strides in [(1, 1), (2, 2)]: - if padding == 'same' and strides != (1, 1): - continue + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.SeparableConv2D, + kwargs=test_kwargs, + input_shape=(num_samples, num_row, num_col, stack_size)) + + def test_conv2d(self): + kwargs = { + 'filters': 2, + 'kernel_size': (3, 3), + } - with self.test_session(use_gpu=True): - # Only runs on GPU with CUDA, channels_first is not supported on CPU. - # TODO(b/62340061): Support channels_first on CPU. - if test.is_gpu_available(cuda_only=True): - testing_utils.layer_test( - keras.layers.Conv2D, - kwargs={ - 'filters': filters, - 'kernel_size': kernel_size, - 'padding': padding, - 'strides': strides, - 'data_format': 'channels_first' - }, - input_shape=(num_samples, stack_size, num_row, num_col)) - - def test_convolution_2d_regularizers(self): + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [(2, 2)]) + if test.is_gpu_available(cuda_only=True): + # Only runs on GPU with CUDA, channels_first is not supported on CPU. + # TODO(b/62340061): Support channels_first on CPU. + self._run_test(kwargs, 'data_format', ['channels_first']) + self._run_test(kwargs, 'dilation_rate', [(2, 2)]) + + kwargs = { + 'filters': 2, + 'kernel_size': 3, + 'padding': 'same', + } + self._run_test(kwargs, 'dilation_rate', [2]) + + def test_conv2d_regularizers(self): kwargs = { 'filters': 3, 'kernel_size': 3, @@ -148,7 +152,7 @@ class Conv2DTest(test.TestCase): layer(keras.backend.variable(np.ones((1, 5, 5, 2)))) self.assertEqual(len(layer.losses), 3) - def test_convolution_2d_constraints(self): + def test_conv2d_constraints(self): k_constraint = lambda x: x b_constraint = lambda x: x @@ -166,51 +170,34 @@ class Conv2DTest(test.TestCase): self.assertEqual(layer.kernel.constraint, k_constraint) self.assertEqual(layer.bias.constraint, b_constraint) - def test_dilated_conv_2d(self): - num_samples = 2 - filters = 2 - stack_size = 3 - kernel_size = (3, 2) - num_row = 7 - num_col = 6 - - # Test dilation - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Conv2D, - kwargs={ - 'filters': filters, - 'kernel_size': kernel_size, - 'dilation_rate': (2, 2) - }, - input_shape=(num_samples, num_row, num_col, stack_size)) - class Conv2DTransposeTest(test.TestCase): - def test_conv2d_transpose(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 2 stack_size = 3 - num_row = 5 + num_row = 7 num_col = 6 - for padding in ['valid', 'same']: - for strides in [(1, 1), (2, 2)]: - if padding == 'same' and strides != (1, 1): - continue + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.Conv2DTranspose, + kwargs=test_kwargs, + input_shape=(num_samples, num_row, num_col, stack_size)) + + def test_conv2dtranspose(self): + kwargs = { + 'filters': 2, + 'kernel_size': (3, 3), + } - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Conv2DTranspose, - kwargs={ - 'filters': filters, - 'kernel_size': 3, - 'padding': padding, - 'strides': strides, - 'data_format': 'channels_last' - }, - input_shape=(num_samples, num_row, num_col, stack_size)) + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [(2, 2)]) + if test.is_gpu_available(cuda_only=True): + self._run_test(kwargs, 'data_format', ['channels_first']) def test_conv2dtranspose_regularizers(self): kwargs = { @@ -250,30 +237,32 @@ class Conv2DTransposeTest(test.TestCase): class Conv3DTransposeTest(test.TestCase): - def test_conv3d_transpose(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 2 stack_size = 3 - num_row = 5 + num_row = 7 num_col = 6 - depth = 4 + depth = 5 + + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.Conv3DTranspose, + kwargs=test_kwargs, + input_shape=(num_samples, depth, num_row, num_col, stack_size)) - for padding in ['valid', 'same']: - for strides in [(1, 1, 1), (2, 2, 2)]: - if padding == 'same' and strides != (1, 1, 1): - continue + def test_conv3dtranspose(self): + kwargs = { + 'filters': 2, + 'kernel_size': (3, 3, 3), + } - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Conv3DTranspose, - kwargs={ - 'filters': filters, - 'kernel_size': 3, - 'padding': padding, - 'strides': strides, - 'data_format': 'channels_last' - }, - input_shape=(num_samples, depth, num_row, num_col, stack_size)) + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [(2, 2, 2)]) + if test.is_gpu_available(cuda_only=True): + self._run_test(kwargs, 'data_format', ['channels_first']) def test_conv3dtranspose_regularizers(self): kwargs = { @@ -311,32 +300,116 @@ class Conv3DTransposeTest(test.TestCase): self.assertEqual(layer.bias.constraint, b_constraint) +class SeparableConv1DTest(test.TestCase): + + def _run_test(self, kwargs, arg, values): + num_samples = 2 + stack_size = 3 + length = 7 + + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.SeparableConv1D, + kwargs=test_kwargs, + input_shape=(num_samples, length, stack_size)) + + def test_separable_conv1d(self): + kwargs = { + 'filters': 2, + 'kernel_size': 3, + } + + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [2]) + self._run_test(kwargs, 'dilation_rate', [2]) + self._run_test(kwargs, 'depth_multiplier', [2]) + + kwargs = { + 'filters': 2, + 'kernel_size': 3, + 'padding': 'same', + } + self._run_test(kwargs, 'dilation_rate', [2]) + + def test_separable_conv1d_regularizers(self): + kwargs = { + 'filters': 3, + 'kernel_size': 3, + 'padding': 'valid', + 'depthwise_regularizer': 'l2', + 'pointwise_regularizer': 'l2', + 'bias_regularizer': 'l2', + 'activity_regularizer': 'l2', + 'strides': 1 + } + with self.test_session(use_gpu=True): + layer = keras.layers.SeparableConv1D(**kwargs) + layer.build((None, 5, 2)) + self.assertEqual(len(layer.losses), 3) + layer(keras.backend.variable(np.ones((1, 5, 2)))) + self.assertEqual(len(layer.losses), 4) + + def test_separable_conv1d_constraints(self): + d_constraint = lambda x: x + p_constraint = lambda x: x + b_constraint = lambda x: x + + kwargs = { + 'filters': 3, + 'kernel_size': 3, + 'padding': 'valid', + 'pointwise_constraint': p_constraint, + 'depthwise_constraint': d_constraint, + 'bias_constraint': b_constraint, + 'strides': 1 + } + with self.test_session(use_gpu=True): + layer = keras.layers.SeparableConv1D(**kwargs) + layer.build((None, 5, 2)) + self.assertEqual(layer.depthwise_kernel.constraint, d_constraint) + self.assertEqual(layer.pointwise_kernel.constraint, p_constraint) + self.assertEqual(layer.bias.constraint, b_constraint) + + class SeparableConv2DTest(test.TestCase): - def test_separable_conv_2d(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 6 stack_size = 3 num_row = 7 num_col = 6 - for padding in ['valid', 'same']: - for strides in [(1, 1), (2, 2)]: - for multiplier in [1, 2]: - if padding == 'same' and strides != (1, 1): - continue + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.SeparableConv2D, + kwargs=test_kwargs, + input_shape=(num_samples, num_row, num_col, stack_size)) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.SeparableConv2D, - kwargs={ - 'filters': filters, - 'kernel_size': (3, 3), - 'padding': padding, - 'strides': strides, - 'depth_multiplier': multiplier - }, - input_shape=(num_samples, num_row, num_col, stack_size)) + def test_separable_conv2d(self): + kwargs = { + 'filters': 2, + 'kernel_size': 3, + } + + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [2]) + if test.is_gpu_available(cuda_only=True): + self._run_test(kwargs, 'data_format', ['channels_first']) + self._run_test(kwargs, 'dilation_rate', [2]) + self._run_test(kwargs, 'depth_multiplier', [2]) + + kwargs = { + 'filters': 2, + 'kernel_size': 3, + 'padding': 'same', + } + self._run_test(kwargs, 'dilation_rate', [2]) def test_separable_conv2d_regularizers(self): kwargs = { @@ -380,33 +453,35 @@ class SeparableConv2DTest(test.TestCase): class Conv3DTest(test.TestCase): - def test_convolution_3d(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 2 stack_size = 3 + num_row = 7 + num_col = 6 + depth = 5 - input_len_dim1 = 9 - input_len_dim2 = 8 - input_len_dim3 = 8 + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.Conv3D, + kwargs=test_kwargs, + input_shape=(num_samples, depth, num_row, num_col, stack_size)) + + def test_conv3d(self): + kwargs = { + 'filters': 2, + 'kernel_size': (3, 3, 3), + } - for padding in ['valid', 'same']: - for strides in [(1, 1, 1), (2, 2, 2)]: - if padding == 'same' and strides != (1, 1, 1): - continue + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [(2, 2, 2)]) + self._run_test(kwargs, 'dilation_rate', [(2, 2, 2)]) + if test.is_gpu_available(cuda_only=True): + self._run_test(kwargs, 'data_format', ['channels_first']) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Convolution3D, - kwargs={ - 'filters': filters, - 'kernel_size': 3, - 'padding': padding, - 'strides': strides - }, - input_shape=(num_samples, input_len_dim1, input_len_dim2, - input_len_dim3, stack_size)) - - def test_convolution_3d_regularizers(self): + def test_conv3d_regularizers(self): kwargs = { 'filters': 3, 'kernel_size': 3, @@ -424,7 +499,7 @@ class Conv3DTest(test.TestCase): layer(keras.backend.variable(np.ones((1, 5, 5, 5, 2)))) self.assertEqual(len(layer.losses), 3) - def test_convolution_3d_constraints(self): + def test_conv3d_constraints(self): k_constraint = lambda x: x b_constraint = lambda x: x diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py index 6ee3fb48b2f1426b87c5c1947e90d0797e9b9ff7..50a197c80c3d97f47a071a24297301dddf78a27e 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core.py +++ b/tensorflow/python/keras/_impl/keras/layers/core.py @@ -23,6 +23,7 @@ import types as python_types import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import activations from tensorflow.python.keras._impl.keras import backend as K @@ -36,8 +37,10 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import func_dump from tensorflow.python.keras._impl.keras.utils.generic_utils import func_load from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.layers import core as tf_core_layers +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.Masking') class Masking(Layer): """Masks a sequence by using a mask value to skip timesteps. @@ -88,6 +91,7 @@ class Masking(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Dropout') class Dropout(tf_core_layers.Dropout, Layer): """Applies Dropout to the input. @@ -119,7 +123,8 @@ class Dropout(tf_core_layers.Dropout, Layer): if training is None: training = K.learning_phase() output = super(Dropout, self).call(inputs, training=training) - if training is K.learning_phase(): + # EagerTensor object has no attribute _uses_learning_phase + if not context.in_eager_mode() and training is K.learning_phase(): output._uses_learning_phase = True # pylint: disable=protected-access return output @@ -133,6 +138,7 @@ class Dropout(tf_core_layers.Dropout, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.SpatialDropout1D') class SpatialDropout1D(Dropout): """Spatial 1D version of Dropout. @@ -169,6 +175,7 @@ class SpatialDropout1D(Dropout): return noise_shape +@tf_export('keras.layers.SpatialDropout2D') class SpatialDropout2D(Dropout): """Spatial 2D version of Dropout. @@ -222,6 +229,7 @@ class SpatialDropout2D(Dropout): return (input_shape[0], 1, 1, input_shape[3]) +@tf_export('keras.layers.SpatialDropout3D') class SpatialDropout3D(Dropout): """Spatial 3D version of Dropout. @@ -274,6 +282,7 @@ class SpatialDropout3D(Dropout): return (input_shape[0], 1, 1, 1, input_shape[4]) +@tf_export('keras.layers.Activation') class Activation(Layer): """Applies an activation function to an output. @@ -307,6 +316,7 @@ class Activation(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Reshape') class Reshape(Layer): """Reshapes an output to a certain shape. @@ -412,6 +422,7 @@ class Reshape(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Permute') class Permute(Layer): """Permutes the dimensions of the input according to a given pattern. @@ -464,6 +475,7 @@ class Permute(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Flatten') class Flatten(tf_core_layers.Flatten, Layer): """Flattens the input. Does not affect the batch size. @@ -483,6 +495,7 @@ class Flatten(tf_core_layers.Flatten, Layer): pass +@tf_export('keras.layers.RepeatVector') class RepeatVector(Layer): """Repeats the input n times. @@ -526,6 +539,7 @@ class RepeatVector(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Lambda') class Lambda(Layer): """Wraps arbitrary expression as a `Layer` object. @@ -707,6 +721,7 @@ class Lambda(Layer): return cls(**config) +@tf_export('keras.layers.Dense') class Dense(tf_core_layers.Dense, Layer): """Just your regular densely-connected NN layer. @@ -811,6 +826,7 @@ class Dense(tf_core_layers.Dense, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.ActivityRegularization') class ActivityRegularization(Layer): """Layer that applies an update to the cost function based input activity. diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py index 51c520be38f5fac32fec9e4a13c380a2e477c709..ca92899a455cd28a756e9efff63655d7c43c9f45 100644 --- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py +++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py @@ -18,14 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.Embedding') class Embedding(Layer): """Turns positive integers (indexes) into dense vectors of fixed size. @@ -58,13 +60,13 @@ class Embedding(Layer): output_dim: int >= 0. Dimension of the dense embedding. embeddings_initializer: Initializer for the `embeddings` matrix. embeddings_regularizer: Regularizer function applied to - the `embeddings` matrix. + the `embeddings` matrix. embeddings_constraint: Constraint function applied to - the `embeddings` matrix. + the `embeddings` matrix. mask_zero: Whether or not the input value 0 is a special "padding" value that should be masked out. - This is useful when using recurrent layers, - which may take variable length inputs. + This is useful when using recurrent layers + which may take variable length input. If this is `True` then all subsequent layers in the model need to support masking or an exception will be raised. If mask_zero is set to True, as a consequence, index 0 cannot be @@ -81,9 +83,6 @@ class Embedding(Layer): Output shape: 3D tensor with shape: `(batch_size, sequence_length, output_dim)`. - References: - - [A Theoretically Grounded Application of Dropout in Recurrent Neural - Networks](http://arxiv.org/abs/1512.05287) """ def __init__(self, @@ -101,19 +100,19 @@ class Embedding(Layer): kwargs['input_shape'] = (input_length,) else: kwargs['input_shape'] = (None,) - super(Embedding, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + super(Embedding, self).__init__(**kwargs) self.input_dim = input_dim self.output_dim = output_dim self.embeddings_initializer = initializers.get(embeddings_initializer) self.embeddings_regularizer = regularizers.get(embeddings_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) self.embeddings_constraint = constraints.get(embeddings_constraint) self.mask_zero = mask_zero self.input_length = input_length + @shape_type_conversion def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() self.embeddings = self.add_weight( shape=(self.input_dim, self.output_dim), initializer=self.embeddings_initializer, @@ -129,10 +128,10 @@ class Embedding(Layer): else: return K.not_equal(inputs, 0) + @shape_type_conversion def compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() if self.input_length is None: - return tensor_shape.TensorShape(input_shape + [self.output_dim]) + return input_shape + (self.output_dim,) else: # input_length can be tuple if input is 3D or higher if isinstance(self.input_length, (list, tuple)): @@ -149,8 +148,7 @@ class Embedding(Layer): (str(self.input_length), str(input_shape))) elif s1 is None: in_lens[i] = s2 - return tensor_shape.TensorShape( - (input_shape[0],) + tuple(in_lens) + (self.output_dim,)) + return (input_shape[0],) + tuple(in_lens) + (self.output_dim,) def call(self, inputs): if K.dtype(inputs) != 'int32': diff --git a/tensorflow/python/keras/_impl/keras/layers/local.py b/tensorflow/python/keras/_impl/keras/layers/local.py index 0a31b87fb564b2833c0dea1ebb3a977b07f13a24..df0efe6b8b7eaa0259eb6f4e246269551b3e0c15 100644 --- a/tensorflow/python/keras/_impl/keras/layers/local.py +++ b/tensorflow/python/keras/_impl/keras/layers/local.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import activations from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import constraints @@ -26,9 +25,12 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion from tensorflow.python.keras._impl.keras.utils import conv_utils +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.LocallyConnected1D') class LocallyConnected1D(Layer): """Locally-connected layer for 1D inputs. @@ -51,7 +53,7 @@ class LocallyConnected1D(Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of a single integer, specifying the length of the 1D convolution window. strides: An integer or tuple/list of a single integer, @@ -98,8 +100,7 @@ class LocallyConnected1D(Layer): kernel_constraint=None, bias_constraint=None, **kwargs): - super(LocallyConnected1D, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + super(LocallyConnected1D, self).__init__(**kwargs) self.filters = filters self.kernel_size = conv_utils.normalize_tuple(kernel_size, 1, 'kernel_size') self.strides = conv_utils.normalize_tuple(strides, 1, 'strides') @@ -114,12 +115,13 @@ class LocallyConnected1D(Layer): self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.input_spec = InputSpec(ndim=3) + @shape_type_conversion def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() input_dim = input_shape[2] if input_dim is None: raise ValueError('Axis 2 of input should be fully-defined. ' @@ -146,15 +148,14 @@ class LocallyConnected1D(Layer): self.input_spec = InputSpec(ndim=3, axes={2: input_dim}) self.built = True + @shape_type_conversion def compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() length = conv_utils.conv_output_length(input_shape[1], self.kernel_size[0], self.padding, self.strides[0]) - return tensor_shape.TensorShape([input_shape[0], length, self.filters]) + return (input_shape[0], length, self.filters) def call(self, inputs): output = K.local_conv1d(inputs, self.kernel, self.kernel_size, self.strides) - if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: @@ -163,25 +164,38 @@ class LocallyConnected1D(Layer): def get_config(self): config = { - 'filters': self.filters, - 'kernel_size': self.kernel_size, - 'strides': self.strides, - 'padding': self.padding, - 'activation': activations.serialize(self.activation), - 'use_bias': self.use_bias, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'filters': + self.filters, + 'kernel_size': + self.kernel_size, + 'strides': + self.strides, + 'padding': + self.padding, + 'activation': + activations.serialize(self.activation), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint) + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint) } base_config = super(LocallyConnected1D, self).get_config() return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.LocallyConnected2D') class LocallyConnected2D(Layer): """Locally-connected layer for 2D inputs. @@ -208,7 +222,7 @@ class LocallyConnected2D(Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the width and height of the 2D convolution window. Can be a single integer to specify the same value for @@ -273,8 +287,7 @@ class LocallyConnected2D(Layer): kernel_constraint=None, bias_constraint=None, **kwargs): - super(LocallyConnected2D, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + super(LocallyConnected2D, self).__init__(**kwargs) self.filters = filters self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') @@ -289,12 +302,13 @@ class LocallyConnected2D(Layer): self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.input_spec = InputSpec(ndim=4) + @shape_type_conversion def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() if self.data_format == 'channels_last': input_row, input_col = input_shape[1:-1] input_filter = input_shape[3] @@ -306,7 +320,6 @@ class LocallyConnected2D(Layer): ' a LocallyConnected2D layer ' 'should be fully-defined, but layer received ' 'the inputs shape ' + str(input_shape)) - output_row = conv_utils.conv_output_length(input_row, self.kernel_size[0], self.padding, self.strides[0]) output_col = conv_utils.conv_output_length(input_col, self.kernel_size[1], @@ -337,33 +350,30 @@ class LocallyConnected2D(Layer): self.input_spec = InputSpec(ndim=4, axes={-1: input_filter}) self.built = True + @shape_type_conversion def compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() if self.data_format == 'channels_first': rows = input_shape[2] cols = input_shape[3] elif self.data_format == 'channels_last': rows = input_shape[1] cols = input_shape[2] + rows = conv_utils.conv_output_length(rows, self.kernel_size[0], self.padding, self.strides[0]) cols = conv_utils.conv_output_length(cols, self.kernel_size[1], self.padding, self.strides[1]) if self.data_format == 'channels_first': - return tensor_shape.TensorShape( - [input_shape[0], self.filters, rows, cols]) + return (input_shape[0], self.filters, rows, cols) elif self.data_format == 'channels_last': - return tensor_shape.TensorShape( - [input_shape[0], rows, cols, self.filters]) + return (input_shape[0], rows, cols, self.filters) def call(self, inputs): - output = K.local_conv2d(inputs, - self.kernel, - self.kernel_size, - self.strides, + output = K.local_conv2d(inputs, self.kernel, self.kernel_size, self.strides, (self.output_row, self.output_col), self.data_format) + if self.use_bias: output = K.bias_add(output, self.bias, data_format=self.data_format) @@ -372,21 +382,34 @@ class LocallyConnected2D(Layer): def get_config(self): config = { - 'filters': self.filters, - 'kernel_size': self.kernel_size, - 'strides': self.strides, - 'padding': self.padding, - 'data_format': self.data_format, - 'activation': activations.serialize(self.activation), - 'use_bias': self.use_bias, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'filters': + self.filters, + 'kernel_size': + self.kernel_size, + 'strides': + self.strides, + 'padding': + self.padding, + 'data_format': + self.data_format, + 'activation': + activations.serialize(self.activation), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint) + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint) } base_config = super(LocallyConnected2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py index 8d359bf17cdb80c98aeeed6d69e301962609ce59..b7af1e8cf03dccfa49ded0fe5b3724bba06e62ba 100644 --- a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py @@ -39,6 +39,22 @@ class LSTMLayerTest(test.TestCase): 'return_sequences': True}, input_shape=(num_samples, timesteps, embedding_dim)) + def test_static_shape_inference_LSTM(self): + # Github issue: 15165 + num_samples = 2 + timesteps = 3 + embedding_dim = 4 + units = 2 + + model = keras.models.Sequential() + inputs = keras.layers.Dense(embedding_dim, + input_shape=(timesteps, embedding_dim)) + model.add(inputs) + layer = keras.layers.LSTM(units, return_sequences=True) + model.add(layer) + outputs = model.layers[-1].output + self.assertEquals(outputs.get_shape().as_list(), [None, timesteps, units]) + def test_dynamic_behavior_LSTM(self): num_samples = 2 timesteps = 3 diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py index 76eb03cf274a648da127b9d3e0c911096d361812..cdf2878e83e32147d30d6b29742b7e9013a1facb 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge.py @@ -14,15 +14,16 @@ # ============================================================================== # pylint: disable=not-callable # pylint: disable=redefined-builtin -"""Layers can merge several input tensors into a single output tensor. +"""Layers that can merge several inputs into one. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine.topology import Layer +from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.util.tf_export import tf_export class _Merge(Layer): @@ -73,12 +74,13 @@ class _Merge(Layer): output_shape.append(i) else: if i != j: - raise ValueError('Operands could not be broadcast ' - 'together with shapes ' + str(shape1) + ' ' + - str(shape2)) + raise ValueError( + 'Operands could not be broadcast ' + 'together with shapes ' + str(shape1) + ' ' + str(shape2)) output_shape.append(i) return tuple(output_shape) + @shape_type_conversion def build(self, input_shape): # Used purely for shape validation. if not isinstance(input_shape, list): @@ -87,14 +89,13 @@ class _Merge(Layer): raise ValueError('A merge layer should be called ' 'on a list of at least 2 inputs. ' 'Got ' + str(len(input_shape)) + ' inputs.') - input_shape = [tensor_shape.TensorShape(s).as_list() for s in input_shape] batch_sizes = [s[0] for s in input_shape if s is not None] batch_sizes = set(batch_sizes) batch_sizes -= set([None]) if len(batch_sizes) > 1: - raise ValueError('Can not merge tensors with different ' - 'batch sizes. Got tensors with shapes : ' + - str(input_shape)) + raise ValueError( + 'Can not merge tensors with different ' + 'batch sizes. Got tensors with shapes : ' + str(input_shape)) if input_shape[0] is None: output_shape = None else: @@ -111,9 +112,10 @@ class _Merge(Layer): self._reshape_required = False else: self._reshape_required = True - self.built = True def call(self, inputs): + if not isinstance(inputs, list): + raise ValueError('A merge layer should be called ' 'on a list of inputs.') if self._reshape_required: reshaped_inputs = [] input_ndims = list(map(K.ndim, inputs)) @@ -172,6 +174,7 @@ class _Merge(Layer): else: return self._merge_function(inputs) + @shape_type_conversion def compute_output_shape(self, input_shape): if input_shape[0] is None: output_shape = None @@ -208,12 +211,29 @@ class _Merge(Layer): return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False) +@tf_export('keras.layers.Add') class Add(_Merge): """Layer that adds a list of inputs. It takes as input a list of tensors, all of the same shape, and returns a single tensor (also of the same shape). + + Examples: + + ```python + import keras + + input1 = keras.layers.Input(shape=(16,)) + x1 = keras.layers.Dense(8, activation='relu')(input1) + input2 = keras.layers.Input(shape=(32,)) + x2 = keras.layers.Dense(8, activation='relu')(input2) + added = keras.layers.Add()([x1, x2]) # equivalent to added = + keras.layers.add([x1, x2]) + + out = keras.layers.Dense(4)(added) + model = keras.models.Model(inputs=[input1, input2], outputs=out) + ``` """ def _merge_function(self, inputs): @@ -247,13 +267,21 @@ class Subtract(_Merge): ``` """ + @shape_type_conversion + def build(self, input_shape): + super(Subtract, self).build(input_shape) + if len(input_shape) != 2: + raise ValueError('A `Subtract` layer should be called ' + 'on exactly 2 inputs') + def _merge_function(self, inputs): if len(inputs) != 2: - raise ValueError('`Subtract` layer should be called ' - 'on exactly 2 inputs. Received: %s' % inputs) + raise ValueError('A `Subtract` layer should be called ' + 'on exactly 2 inputs') return inputs[0] - inputs[1] +@tf_export('keras.layers.Multiply') class Multiply(_Merge): """Layer that multiplies (element-wise) a list of inputs. @@ -269,6 +297,7 @@ class Multiply(_Merge): return output +@tf_export('keras.layers.Average') class Average(_Merge): """Layer that averages a list of inputs. @@ -284,6 +313,7 @@ class Average(_Merge): return output / len(inputs) +@tf_export('keras.layers.Maximum') class Maximum(_Merge): """Layer that computes the maximum (element-wise) a list of inputs. @@ -314,6 +344,7 @@ class Minimum(_Merge): return output +@tf_export('keras.layers.Concatenate') class Concatenate(_Merge): """Layer that concatenates a list of inputs. @@ -330,47 +361,43 @@ class Concatenate(_Merge): super(Concatenate, self).__init__(**kwargs) self.axis = axis self.supports_masking = True + self._reshape_required = False + @shape_type_conversion def build(self, input_shape): # Used purely for shape validation. - if not (isinstance(input_shape, list) and len(input_shape) > 1): - raise ValueError('`Concatenate` layer should be called ' - 'on a list containing at least two inputs') + if not isinstance(input_shape, list) or len(input_shape) < 2: + raise ValueError('A `Concatenate` layer should be called ' + 'on a list of at least 2 inputs') if all([shape is None for shape in input_shape]): return - reduced_inputs_shapes = [ - tensor_shape.TensorShape(shape).as_list() for shape in input_shape - ] + reduced_inputs_shapes = [list(shape) for shape in input_shape] shape_set = set() for i in range(len(reduced_inputs_shapes)): del reduced_inputs_shapes[i][self.axis] shape_set.add(tuple(reduced_inputs_shapes[i])) if len(shape_set) > 1: - raise ValueError('`Concatenate` layer requires ' + raise ValueError('A `Concatenate` layer requires ' 'inputs with matching shapes ' 'except for the concat axis. ' 'Got inputs shapes: %s' % (input_shape)) - self.built = True - def call(self, inputs): - if not isinstance(inputs, list): - raise ValueError('A `Concatenate` layer should be called ' - 'on a list of inputs.') + def _merge_function(self, inputs): return K.concatenate(inputs, axis=self.axis) + @shape_type_conversion def compute_output_shape(self, input_shape): if not isinstance(input_shape, list): raise ValueError('A `Concatenate` layer should be called ' 'on a list of inputs.') input_shapes = input_shape - output_shape = tensor_shape.TensorShape(input_shapes[0]).as_list() + output_shape = list(input_shapes[0]) for shape in input_shapes[1:]: - shape = tensor_shape.TensorShape(shape).as_list() if output_shape[self.axis] is None or shape[self.axis] is None: output_shape[self.axis] = None break output_shape[self.axis] += shape[self.axis] - return tensor_shape.TensorShape(output_shape) + return tuple(output_shape) def compute_mask(self, inputs, mask=None): if mask is None: @@ -390,7 +417,7 @@ class Concatenate(_Merge): masks = [] for input_i, mask_i in zip(inputs, mask): if mask_i is None: - # Input is unmasked. Append all 1s to masks + # Input is unmasked. Append all 1s to masks, masks.append(K.ones_like(input_i, dtype='bool')) elif K.ndim(mask_i) < K.ndim(input_i): # Mask is smaller than the input, expand it @@ -408,6 +435,7 @@ class Concatenate(_Merge): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.Dot') class Dot(_Merge): """Layer that computes a dot product between samples in two tensors. @@ -441,14 +469,16 @@ class Dot(_Merge): self.axes = axes self.normalize = normalize self.supports_masking = True + self._reshape_required = False + @shape_type_conversion def build(self, input_shape): # Used purely for shape validation. if not isinstance(input_shape, list) or len(input_shape) != 2: raise ValueError('A `Dot` layer should be called ' 'on a list of 2 inputs.') - shape1 = tensor_shape.TensorShape(input_shape[0]).as_list() - shape2 = tensor_shape.TensorShape(input_shape[1]).as_list() + shape1 = input_shape[0] + shape2 = input_shape[1] if shape1 is None or shape2 is None: return if isinstance(self.axes, int): @@ -462,9 +492,10 @@ class Dot(_Merge): raise ValueError('Dimension incompatibility ' '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) + 'Layer shapes: %s, %s' % (shape1, shape2)) - self.built = True - def call(self, inputs): + def _merge_function(self, inputs): + if len(inputs) != 2: + raise ValueError('A `Dot` layer should be called ' 'on exactly 2 inputs') x1 = inputs[0] x2 = inputs[1] if isinstance(self.axes, int): @@ -485,12 +516,13 @@ class Dot(_Merge): output = K.batch_dot(x1, x2, axes) return output + @shape_type_conversion def compute_output_shape(self, input_shape): if not isinstance(input_shape, list) or len(input_shape) != 2: raise ValueError('A `Dot` layer should be called ' 'on a list of 2 inputs.') - shape1 = tensor_shape.TensorShape(input_shape[0]).as_list() - shape2 = tensor_shape.TensorShape(input_shape[1]).as_list() + shape1 = list(input_shape[0]) + shape2 = list(input_shape[1]) if isinstance(self.axes, int): if self.axes < 0: axes = [self.axes % len(shape1), self.axes % len(shape2)] @@ -504,7 +536,7 @@ class Dot(_Merge): output_shape = shape1 + shape2 if len(output_shape) == 1: output_shape += [1] - return tensor_shape.TensorShape(output_shape) + return tuple(output_shape) def compute_mask(self, inputs, mask=None): return None @@ -518,6 +550,7 @@ class Dot(_Merge): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.add') def add(inputs, **kwargs): """Functional interface to the `Add` layer. @@ -527,6 +560,21 @@ def add(inputs, **kwargs): Returns: A tensor, the sum of the inputs. + + Examples: + + ```python + import keras + + input1 = keras.layers.Input(shape=(16,)) + x1 = keras.layers.Dense(8, activation='relu')(input1) + input2 = keras.layers.Input(shape=(32,)) + x2 = keras.layers.Dense(8, activation='relu')(input2) + added = keras.layers.add([x1, x2]) + + out = keras.layers.Dense(4)(added) + model = keras.models.Model(inputs=[input1, input2], outputs=out) + ``` """ return Add(**kwargs)(inputs) @@ -559,6 +607,7 @@ def subtract(inputs, **kwargs): return Subtract(**kwargs)(inputs) +@tf_export('keras.layers.multiply') def multiply(inputs, **kwargs): """Functional interface to the `Multiply` layer. @@ -572,6 +621,7 @@ def multiply(inputs, **kwargs): return Multiply(**kwargs)(inputs) +@tf_export('keras.layers.average') def average(inputs, **kwargs): """Functional interface to the `Average` layer. @@ -585,6 +635,7 @@ def average(inputs, **kwargs): return Average(**kwargs)(inputs) +@tf_export('keras.layers.maximum') def maximum(inputs, **kwargs): """Functional interface to the `Maximum` layer. @@ -611,6 +662,7 @@ def minimum(inputs, **kwargs): return Minimum(**kwargs)(inputs) +@tf_export('keras.layers.concatenate') def concatenate(inputs, axis=-1, **kwargs): """Functional interface to the `Concatenate` layer. @@ -625,6 +677,7 @@ def concatenate(inputs, axis=-1, **kwargs): return Concatenate(axis=axis, **kwargs)(inputs) +@tf_export('keras.layers.dot') def dot(inputs, axes, normalize=False, **kwargs): """Functional interface to the `Dot` layer. diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py index 459f13145f090f1942543ec2f5da4e9b8cd71509..9010f4961585af58b7eae43dcd224e0c39606239 100644 --- a/tensorflow/python/keras/_impl/keras/layers/noise.py +++ b/tensorflow/python/keras/_impl/keras/layers/noise.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Layers for regularization models via the addition of noise. +"""Layers that operate regularization via the addition of noise. """ from __future__ import absolute_import from __future__ import division @@ -22,8 +22,11 @@ import numpy as np from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.GaussianNoise') class GaussianNoise(Layer): """Apply additive zero-centered Gaussian noise. @@ -59,15 +62,17 @@ class GaussianNoise(Layer): return K.in_train_phase(noised, inputs, training=training) - def compute_output_shape(self, input_shape): - return input_shape - def get_config(self): config = {'stddev': self.stddev} base_config = super(GaussianNoise, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @shape_type_conversion + def compute_output_shape(self, input_shape): + return input_shape + +@tf_export('keras.layers.GaussianDropout') class GaussianDropout(Layer): """Apply multiplicative 1-centered Gaussian noise. @@ -86,10 +91,6 @@ class GaussianDropout(Layer): Output shape: Same shape as input. - References: - - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting - Srivastava, Hinton, et al. - 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) """ def __init__(self, rate, **kwargs): @@ -108,15 +109,17 @@ class GaussianDropout(Layer): return K.in_train_phase(noised, inputs, training=training) return inputs - def compute_output_shape(self, input_shape): - return input_shape - def get_config(self): config = {'rate': self.rate} base_config = super(GaussianDropout, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @shape_type_conversion + def compute_output_shape(self, input_shape): + return input_shape + +@tf_export('keras.layers.AlphaDropout') class AlphaDropout(Layer): """Applies Alpha Dropout to the input. @@ -140,8 +143,6 @@ class AlphaDropout(Layer): Output shape: Same shape as input. - References: - - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) """ def __init__(self, rate, noise_shape=None, seed=None, **kwargs): @@ -157,26 +158,34 @@ class AlphaDropout(Layer): def call(self, inputs, training=None): if 0. < self.rate < 1.: noise_shape = self._get_noise_shape(inputs) - alpha = 1.6732632423543772848170429916717 - scale = 1.0507009873554804934193349852946 - def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed): + def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed): # pylint: disable=missing-docstring + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 alpha_p = -alpha * scale - kept_idx = K.greater_equal(K.random_uniform(noise_shape, seed=seed), - rate) + + kept_idx = K.greater_equal( + K.random_uniform(noise_shape, seed=seed), rate) kept_idx = K.cast(kept_idx, K.floatx()) - a = ((1 - rate) * (1 + rate * alpha_p ** 2)) ** -0.5 + + # Get affine transformation params + a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5 b = -a * alpha_p * rate + + # Apply mask x = inputs * kept_idx + alpha_p * (1 - kept_idx) + + # Do affine transformation return a * x + b return K.in_train_phase(dropped_inputs, inputs, training=training) return inputs - def compute_output_shape(self, input_shape): - return input_shape - def get_config(self): config = {'rate': self.rate} base_config = super(AlphaDropout, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + @shape_type_conversion + def compute_output_shape(self, input_shape): + return input_shape diff --git a/tensorflow/python/keras/_impl/keras/layers/normalization.py b/tensorflow/python/keras/_impl/keras/layers/normalization.py index 965ef70e6e6cb488aa4832462da4a2cb43e964a6..0dedd5e8daa2974038c90ae2e8c68ca6516ba725 100644 --- a/tensorflow/python/keras/_impl/keras/layers/normalization.py +++ b/tensorflow/python/keras/_impl/keras/layers/normalization.py @@ -18,14 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import Layer from tensorflow.python.layers import normalization as tf_normalization_layers +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.BatchNormalization') class BatchNormalization(tf_normalization_layers.BatchNormalization, Layer): """Batch normalization layer (Ioffe and Szegedy, 2014). @@ -108,7 +111,7 @@ class BatchNormalization(tf_normalization_layers.BatchNormalization, Layer): if training is None: training = K.learning_phase() output = super(BatchNormalization, self).call(inputs, training=training) - if training is K.learning_phase(): + if context.in_graph_mode() and training is K.learning_phase(): output._uses_learning_phase = True # pylint: disable=protected-access return output diff --git a/tensorflow/python/keras/_impl/keras/layers/normalization_test.py b/tensorflow/python/keras/_impl/keras/layers/normalization_test.py index 39a90e597089b30d110f26f074eba5d6895e52df..2b3628c3f1023612297465bdf3286246261992a2 100644 --- a/tensorflow/python/keras/_impl/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/normalization_test.py @@ -132,13 +132,19 @@ class NormalizationLayersTest(test.TestCase): model.compile('sgd', 'mse') model.train_on_batch(x, x) - assert len(model.updates) == 2 + self.assertEqual(len(bn.updates), 4) + self.assertEqual(len(model.updates), 2) + self.assertEqual(len(model.get_updates_for(x1)), 0) + self.assertEqual(len(model.get_updates_for(x2)), 2) # Test model-level reuse x3 = keras.layers.Input(shape=(10,)) y3 = model(x3) - new_model = keras.models.Model(x3, y3) - assert len(model.updates) == 2 + new_model = keras.models.Model(x3, y3, name='new_model') + + self.assertEqual(len(new_model.updates), 2) + self.assertEqual(len(model.updates), 4) + self.assertEqual(len(new_model.get_updates_for(x3)), 2) new_model.compile('sgd', 'mse') new_model.train_on_batch(x, x) diff --git a/tensorflow/python/keras/_impl/keras/layers/pooling.py b/tensorflow/python/keras/_impl/keras/layers/pooling.py index b133e2dfaf1bcacd055f6a597bd557f696469ffc..15d53379769d8142f5b2755a07479f60751346d2 100644 --- a/tensorflow/python/keras/_impl/keras/layers/pooling.py +++ b/tensorflow/python/keras/_impl/keras/layers/pooling.py @@ -24,8 +24,10 @@ from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.layers import pooling as tf_pooling_layers +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.MaxPool1D', 'keras.layers.MaxPooling1D') class MaxPooling1D(tf_pooling_layers.MaxPooling1D, Layer): """Max pooling operation for temporal data. @@ -58,6 +60,7 @@ class MaxPooling1D(tf_pooling_layers.MaxPooling1D, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.AveragePooling1D', 'keras.layers.AvgPool1D') class AveragePooling1D(tf_pooling_layers.AveragePooling1D, Layer): """Average pooling for temporal data. @@ -91,6 +94,7 @@ class AveragePooling1D(tf_pooling_layers.AveragePooling1D, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.MaxPool2D', 'keras.layers.MaxPooling2D') class MaxPooling2D(tf_pooling_layers.MaxPooling2D, Layer): """Max pooling operation for spatial data. @@ -156,6 +160,7 @@ class MaxPooling2D(tf_pooling_layers.MaxPooling2D, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.AveragePooling2D', 'keras.layers.AvgPool2D') class AveragePooling2D(tf_pooling_layers.AveragePooling2D, Layer): """Average pooling operation for spatial data. @@ -221,6 +226,7 @@ class AveragePooling2D(tf_pooling_layers.AveragePooling2D, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.MaxPool3D', 'keras.layers.MaxPooling3D') class MaxPooling3D(tf_pooling_layers.MaxPooling3D, Layer): """Max pooling operation for 3D data (spatial or spatio-temporal). @@ -282,6 +288,7 @@ class MaxPooling3D(tf_pooling_layers.MaxPooling3D, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.AveragePooling3D', 'keras.layers.AvgPool3D') class AveragePooling3D(tf_pooling_layers.AveragePooling3D, Layer): """Average pooling operation for 3D data (spatial or spatio-temporal). @@ -359,6 +366,8 @@ class _GlobalPooling1D(Layer): raise NotImplementedError +@tf_export('keras.layers.GlobalAveragePooling1D', + 'keras.layers.GlobalAvgPool1D') class GlobalAveragePooling1D(_GlobalPooling1D): """Global average pooling operation for temporal data. @@ -374,6 +383,7 @@ class GlobalAveragePooling1D(_GlobalPooling1D): return K.mean(inputs, axis=1) +@tf_export('keras.layers.GlobalMaxPool1D', 'keras.layers.GlobalMaxPooling1D') class GlobalMaxPooling1D(_GlobalPooling1D): """Global max pooling operation for temporal data. @@ -414,6 +424,8 @@ class _GlobalPooling2D(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.GlobalAveragePooling2D', + 'keras.layers.GlobalAvgPool2D') class GlobalAveragePooling2D(_GlobalPooling2D): """Global average pooling operation for spatial data. @@ -449,6 +461,7 @@ class GlobalAveragePooling2D(_GlobalPooling2D): return K.mean(inputs, axis=[2, 3]) +@tf_export('keras.layers.GlobalMaxPool2D', 'keras.layers.GlobalMaxPooling2D') class GlobalMaxPooling2D(_GlobalPooling2D): """Global max pooling operation for spatial data. @@ -509,6 +522,8 @@ class _GlobalPooling3D(Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.GlobalAveragePooling3D', + 'keras.layers.GlobalAvgPool3D') class GlobalAveragePooling3D(_GlobalPooling3D): """Global Average pooling operation for 3D data. @@ -544,6 +559,7 @@ class GlobalAveragePooling3D(_GlobalPooling3D): return K.mean(inputs, axis=[2, 3, 4]) +@tf_export('keras.layers.GlobalMaxPool3D', 'keras.layers.GlobalMaxPooling3D') class GlobalMaxPooling3D(_GlobalPooling3D): """Global Max pooling operation for 3D data. diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index 9ea21c9c363455d693cc4d766b5f94ade56838d9..2e9003f52d7617e96950f76637759e577a8b5e4f 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,13 @@ # limitations under the License. # ============================================================================== # pylint: disable=protected-access -"""Recurrent layers. +"""Recurrent layers and their base classes. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numbers import numpy as np from tensorflow.python.framework import tensor_shape @@ -29,10 +30,13 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.StackedRNNCells') class StackedRNNCells(Layer): """Wrapper allowing a stack of RNN cells to behave as a single cell. @@ -84,7 +88,7 @@ class StackedRNNCells(Layer): state_size.append(cell.state_size) return tuple(state_size) - def call(self, inputs, states, **kwargs): + def call(self, inputs, states, constants=None, **kwargs): # Recover per-cell states. nested_states = [] for cell in self.cells[::-1]: @@ -99,7 +103,12 @@ class StackedRNNCells(Layer): # Call the cells in order and store the returned states. new_nested_states = [] for cell, states in zip(self.cells, nested_states): - inputs, states = cell.call(inputs, states, **kwargs) + if has_arg(cell.call, 'constants'): + inputs, states = cell.call(inputs, states, constants=constants, + **kwargs) + else: + inputs, states = cell.call(inputs, states, **kwargs) + new_nested_states.append(states) # Format the new states as a flat list @@ -109,15 +118,22 @@ class StackedRNNCells(Layer): states += cell_states return inputs, states + @shape_type_conversion def build(self, input_shape): + if isinstance(input_shape, list): + constants_shape = input_shape[1:] + input_shape = input_shape[0] for cell in self.cells: if isinstance(cell, Layer): - cell.build(input_shape) + if has_arg(cell.call, 'constants'): + cell.build([input_shape] + constants_shape) + else: + cell.build(input_shape) if hasattr(cell.state_size, '__len__'): output_dim = cell.state_size[0] else: output_dim = cell.state_size - input_shape = (input_shape[0], input_shape[1], output_dim) + input_shape = (input_shape[0], output_dim) self.built = True def get_config(self): @@ -198,19 +214,19 @@ class StackedRNNCells(Layer): losses = [] for cell in self.cells: if isinstance(cell, Layer): - cell_losses = cell.losses - losses += cell_losses - return losses + losses += cell.losses + return losses + self._losses - def get_losses_for(self, inputs=None): - losses = [] + @property + def updates(self): + updates = [] for cell in self.cells: if isinstance(cell, Layer): - cell_losses = cell.get_losses_for(inputs) - losses += cell_losses - return losses + updates += cell.updates + return updates + self._updates +@tf_export('keras.layers.RNN') class RNN(Layer): """Base class for recurrent layers. @@ -262,8 +278,7 @@ class RNN(Layer): (e.g. via the `input_shape` argument) Input shape: - 3D tensor with shape `(batch_size, timesteps, input_dim)`, - (Optional) 2D tensors with shape `(batch_size, output_dim)`. + 3D tensor with shape `(batch_size, timesteps, input_dim)`. Output shape: - if `return_state`: a list of tensors. The first tensor is @@ -370,7 +385,6 @@ class RNN(Layer): go_backwards=False, stateful=False, unroll=False, - activity_regularizer=None, **kwargs): if isinstance(cell, (list, tuple)): cell = StackedRNNCells(cell) @@ -382,8 +396,7 @@ class RNN(Layer): 'an attribute `state_size` ' '(tuple of integers, ' 'one integer per RNN state).') - super(RNN, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + super(RNN, self).__init__(**kwargs) self.cell = cell self.return_sequences = return_sequences self.return_state = return_state @@ -401,7 +414,7 @@ class RNN(Layer): @property def states(self): if self._states is None: - if isinstance(self.cell.state_size, int): + if isinstance(self.cell.state_size, numbers.Integral): num_states = 1 else: num_states = len(self.cell.state_size) @@ -412,15 +425,16 @@ class RNN(Layer): def states(self, states): self._states = states + @shape_type_conversion def compute_output_shape(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] - input_shape = tensor_shape.TensorShape(input_shape).as_list() if hasattr(self.cell.state_size, '__len__'): - output_dim = self.cell.state_size[0] + state_size = self.cell.state_size else: - output_dim = self.cell.state_size + state_size = [self.cell.state_size] + output_dim = state_size[0] if self.return_sequences: output_shape = (input_shape[0], input_shape[1], output_dim) @@ -428,11 +442,10 @@ class RNN(Layer): output_shape = (input_shape[0], output_dim) if self.return_state: - state_shape = [(input_shape[0], output_dim) for _ in self.states] - output_shape = [output_shape] + state_shape + state_shape = [(input_shape[0], dim) for dim in state_size] + return [output_shape] + state_shape else: - output_shape = output_shape - return tensor_shape.TensorShape(output_shape) + return output_shape def compute_mask(self, inputs, mask): if isinstance(mask, list): @@ -444,6 +457,7 @@ class RNN(Layer): else: return output_mask + @shape_type_conversion def build(self, input_shape): # Note input_shape will be list of shapes of initial states and # constants if these are passed in __call__. @@ -454,7 +468,6 @@ class RNN(Layer): if isinstance(input_shape, list): input_shape = input_shape[0] - input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) batch_size = input_shape[0] if self.stateful else None input_dim = input_shape[-1] @@ -478,9 +491,9 @@ class RNN(Layer): # initial_state was passed in call, check compatibility if [spec.shape[-1] for spec in self.state_spec] != state_size: raise ValueError( - 'An initial_state was passed that is not compatible with ' + 'An `initial_state` was passed that is not compatible with ' '`cell.state_size`. Received `state_spec`={}; ' - 'However `cell.state_size` is ' + 'however `cell.state_size` is ' '{}'.format(self.state_spec, self.cell.state_size)) else: self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size] @@ -526,12 +539,14 @@ class RNN(Layer): self._num_constants = len(constants) additional_specs += self.constants_spec # at this point additional_inputs cannot be empty - is_keras_tensor = hasattr(additional_inputs[0], '_keras_history') + is_keras_tensor = K.is_keras_tensor(additional_inputs[0]) for tensor in additional_inputs: - if hasattr(tensor, '_keras_history') != is_keras_tensor: + if K.is_keras_tensor(tensor) != is_keras_tensor: raise ValueError('The initial state or constants of an RNN' ' layer cannot be specified with a mix of' - ' Keras tensors and non-Keras tensors') + ' Keras tensors and non-Keras tensors' + '(a "Keras tensor" is a tensor that was' + 'returned by a Keras layer, or by `Input`)') if is_keras_tensor: # Compute the full input spec, including state and constants @@ -610,11 +625,12 @@ class RNN(Layer): constants=constants, go_backwards=self.go_backwards, mask=mask, - unroll=self.unroll) + unroll=self.unroll, + input_length=timesteps) if self.stateful: updates = [] for i in range(len(states)): - updates.append((self.states[i], states[i])) + updates.append(K.update(self.states[i], states[i])) self.add_update(updates, inputs) if self.return_sequences: @@ -625,6 +641,8 @@ class RNN(Layer): # Properly set learning phase if getattr(last_output, '_uses_learning_phase', False): output._uses_learning_phase = True + for state in states: + state._uses_learning_phase = True if self.return_state: if not isinstance(states, (list, tuple)): @@ -636,7 +654,7 @@ class RNN(Layer): return output def _standardize_args(self, inputs, initial_state, constants): - """Standardize `__call__` arguments to a single list of tensor inputs. + """Standardize `__call__` to a single list of tensor inputs. When running a model loaded from file, the input tensors `initial_state` and `constants` can be passed to `RNN.__call__` as part @@ -688,7 +706,7 @@ class RNN(Layer): 'a `batch_input_shape` ' 'argument to your first layer.\n' '- If using the functional API, specify ' - 'the time dimension by passing a ' + 'the batch size by passing a ' '`batch_shape` argument to your Input layer.') # initialize state if None if self.states[0] is None: @@ -772,53 +790,46 @@ class RNN(Layer): @property def losses(self): + losses = [] if isinstance(self.cell, Layer): - return self.cell.losses - return [] + losses += self.cell.losses + return losses + self._losses - def get_losses_for(self, inputs=None): + @property + def updates(self): + updates = [] if isinstance(self.cell, Layer): - cell_losses = self.cell.get_losses_for(inputs) - return cell_losses + super(RNN, self).get_losses_for(inputs) - return super(RNN, self).get_losses_for(inputs) + updates += self.cell.updates + return updates + self._updates +@tf_export('keras.layers.SimpleRNNCell') class SimpleRNNCell(Layer): """Cell class for SimpleRNN. Arguments: units: Positive integer, dimensionality of the output space. - activation: Activation function to use - (see [activations](../activations.md)). + activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. - (see [initializers](../initializers.md)). recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. - (see [initializers](../initializers.md)). - bias_initializer: Initializer for the bias vector - (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector. kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix - (see [regularizer](../regularizers.md)). + the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix - (see [regularizer](../regularizers.md)). - bias_regularizer: Regularizer function applied to the bias vector - (see [regularizer](../regularizers.md)). + the `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. kernel_constraint: Constraint function applied to - the `kernel` weights matrix - (see [constraints](../constraints.md)). + the `kernel` weights matrix. recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix - (see [constraints](../constraints.md)). - bias_constraint: Constraint function applied to the bias vector - (see [constraints](../constraints.md)). + the `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. @@ -866,6 +877,7 @@ class SimpleRNNCell(Layer): self._dropout_mask = None self._recurrent_dropout_mask = None + @shape_type_conversion def build(self, input_shape): self.kernel = self.add_weight( shape=(input_shape[-1], self.units), @@ -890,33 +902,21 @@ class SimpleRNNCell(Layer): self.bias = None self.built = True - def _generate_dropout_mask(self, inputs, training=None): - if 0 < self.dropout < 1: - ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1)) - - def dropped_inputs(): - return K.dropout(ones, self.dropout) - - self._dropout_mask = K.in_train_phase( - dropped_inputs, ones, training=training) - else: - self._dropout_mask = None - - def _generate_recurrent_dropout_mask(self, inputs, training=None): - if 0 < self.recurrent_dropout < 1: - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, self.units)) - - def dropped_inputs(): - return K.dropout(ones, self.dropout) - - self._recurrent_dropout_mask = K.in_train_phase( - dropped_inputs, ones, training=training) - else: - self._recurrent_dropout_mask = None - def call(self, inputs, states, training=None): prev_output = states[0] + if 0 < self.dropout < 1 and self._dropout_mask is None: + self._dropout_mask = _generate_dropout_mask( + _generate_dropout_ones(inputs, + K.shape(inputs)[-1]), + self.dropout, + training=training) + if (0 < self.recurrent_dropout < 1 and + self._recurrent_dropout_mask is None): + self._recurrent_dropout_mask = _generate_dropout_mask( + _generate_dropout_ones(inputs, self.units), + self.recurrent_dropout, + training=training) + dp_mask = self._dropout_mask rec_dp_mask = self._recurrent_dropout_mask @@ -939,46 +939,70 @@ class SimpleRNNCell(Layer): output._uses_learning_phase = True return output, [output] + def get_config(self): + config = { + 'units': + self.units, + 'activation': + activations.serialize(self.activation), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), + 'recurrent_initializer': + initializers.serialize(self.recurrent_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), + 'recurrent_regularizer': + regularizers.serialize(self.recurrent_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), + 'recurrent_constraint': + constraints.serialize(self.recurrent_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint), + 'dropout': + self.dropout, + 'recurrent_dropout': + self.recurrent_dropout + } + base_config = super(SimpleRNNCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + +@tf_export('keras.layers.SimpleRNN') class SimpleRNN(RNN): """Fully-connected RNN where the output is to be fed back to input. Arguments: units: Positive integer, dimensionality of the output space. - activation: Activation function to use - (see [activations](../activations.md)). + activation: Activation function to use. Default: hyperbolic tangent (`tanh`). - If you pass `None`, no activation is applied + If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. - (see [initializers](../initializers.md)). recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. - (see [initializers](../initializers.md)). - bias_initializer: Initializer for the bias vector - (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector. kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix - (see [regularizer](../regularizers.md)). + the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix - (see [regularizer](../regularizers.md)). - bias_regularizer: Regularizer function applied to the bias vector - (see [regularizer](../regularizers.md)). + the `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. activity_regularizer: Regularizer function applied to - the output of the layer (its "activation"). - (see [regularizer](../regularizers.md)). + the output of the layer (its "activation").. kernel_constraint: Constraint function applied to - the `kernel` weights matrix - (see [constraints](../constraints.md)). + the `kernel` weights matrix. recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix - (see [constraints](../constraints.md)). - bias_constraint: Constraint function applied to the bias vector - (see [constraints](../constraints.md)). + the `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. @@ -1052,12 +1076,12 @@ class SimpleRNN(RNN): go_backwards=go_backwards, stateful=stateful, unroll=unroll, - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + self.activity_regularizer = regularizers.get(activity_regularizer) def call(self, inputs, mask=None, training=None, initial_state=None): - self.cell._generate_dropout_mask(inputs, training=training) - self.cell._generate_recurrent_dropout_mask(inputs, training=training) + self.cell._dropout_mask = None + self.cell._recurrent_dropout_mask = None return super(SimpleRNN, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @@ -1119,25 +1143,36 @@ class SimpleRNN(RNN): def get_config(self): config = { - 'units': self.units, - 'activation': activations.serialize(self.activation), - 'use_bias': self.use_bias, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), + 'units': + self.units, + 'activation': + activations.serialize(self.activation), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint), - 'dropout': self.dropout, - 'recurrent_dropout': self.recurrent_dropout + 'bias_constraint': + constraints.serialize(self.bias_constraint), + 'dropout': + self.dropout, + 'recurrent_dropout': + self.recurrent_dropout } base_config = super(SimpleRNN, self).get_config() del base_config['cell'] @@ -1150,48 +1185,38 @@ class SimpleRNN(RNN): return cls(**config) +@tf_export('keras.layers.GRUCell') class GRUCell(Layer): """Cell class for the GRU layer. Arguments: units: Positive integer, dimensionality of the output space. - activation: Activation function to use - (see [activations](../activations.md)). + activation: Activation function to use. Default: hyperbolic tangent (`tanh`). - If you pass `None`, no activation is applied + If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use - for the recurrent step - (see [activations](../activations.md)). + for the recurrent step. Default: hard sigmoid (`hard_sigmoid`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. - (see [initializers](../initializers.md)). recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. - (see [initializers](../initializers.md)). - bias_initializer: Initializer for the bias vector - (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector. kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix - (see [regularizer](../regularizers.md)). + the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix - (see [regularizer](../regularizers.md)). - bias_regularizer: Regularizer function applied to the bias vector - (see [regularizer](../regularizers.md)). + the `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. kernel_constraint: Constraint function applied to - the `kernel` weights matrix - (see [constraints](../constraints.md)). + the `kernel` weights matrix. recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix - (see [constraints](../constraints.md)). - bias_constraint: Constraint function applied to the bias vector - (see [constraints](../constraints.md)). + the `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. @@ -1249,6 +1274,7 @@ class GRUCell(Layer): self._dropout_mask = None self._recurrent_dropout_mask = None + @shape_type_conversion def build(self, input_shape): input_dim = input_shape[-1] self.kernel = self.add_weight( @@ -1292,38 +1318,24 @@ class GRUCell(Layer): self.bias_h = None self.built = True - def _generate_dropout_mask(self, inputs, training=None): - if 0 < self.dropout < 1: - ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1)) - - def dropped_inputs(): - return K.dropout(ones, self.dropout) - - self._dropout_mask = [ - K.in_train_phase(dropped_inputs, ones, training=training) - for _ in range(3) - ] - else: - self._dropout_mask = None - - def _generate_recurrent_dropout_mask(self, inputs, training=None): - if 0 < self.recurrent_dropout < 1: - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, self.units)) - - def dropped_inputs(): - return K.dropout(ones, self.dropout) - - self._recurrent_dropout_mask = [ - K.in_train_phase(dropped_inputs, ones, training=training) - for _ in range(3) - ] - else: - self._recurrent_dropout_mask = None - def call(self, inputs, states, training=None): h_tm1 = states[0] # previous memory + if 0 < self.dropout < 1 and self._dropout_mask is None: + self._dropout_mask = _generate_dropout_mask( + _generate_dropout_ones(inputs, + K.shape(inputs)[-1]), + self.dropout, + training=training, + count=3) + if (0 < self.recurrent_dropout < 1 and + self._recurrent_dropout_mask is None): + self._recurrent_dropout_mask = _generate_dropout_mask( + _generate_dropout_ones(inputs, self.units), + self.recurrent_dropout, + training=training, + count=3) + # dropout matrices for input units dp_mask = self._dropout_mask # dropout matrices for recurrent units @@ -1387,55 +1399,81 @@ class GRUCell(Layer): h._uses_learning_phase = True return h, [h] + def get_config(self): + config = { + 'units': + self.units, + 'activation': + activations.serialize(self.activation), + 'recurrent_activation': + activations.serialize(self.recurrent_activation), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), + 'recurrent_initializer': + initializers.serialize(self.recurrent_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), + 'recurrent_regularizer': + regularizers.serialize(self.recurrent_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), + 'recurrent_constraint': + constraints.serialize(self.recurrent_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint), + 'dropout': + self.dropout, + 'recurrent_dropout': + self.recurrent_dropout, + 'implementation': + self.implementation + } + base_config = super(GRUCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + +@tf_export('keras.layers.GRU') class GRU(RNN): - # pylint: disable=line-too-long """Gated Recurrent Unit - Cho et al. 2014. Arguments: units: Positive integer, dimensionality of the output space. - activation: Activation function to use - (see [activations](../activations.md)). + activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use - for the recurrent step - (see [activations](../activations.md)). + for the recurrent step. Default: hard sigmoid (`hard_sigmoid`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. - (see [initializers](../initializers.md)). recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. - (see [initializers](../initializers.md)). - bias_initializer: Initializer for the bias vector - (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector. kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix - (see [regularizer](../regularizers.md)). + the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix - (see [regularizer](../regularizers.md)). - bias_regularizer: Regularizer function applied to the bias vector - (see [regularizer](../regularizers.md)). + the `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. activity_regularizer: Regularizer function applied to - the output of the layer (its "activation"). - (see [regularizer](../regularizers.md)). + the output of the layer (its "activation").. kernel_constraint: Constraint function applied to - the `kernel` weights matrix - (see [constraints](../constraints.md)). + the `kernel` weights matrix. recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix - (see [constraints](../constraints.md)). - bias_constraint: Constraint function applied to the bias vector - (see [constraints](../constraints.md)). + the `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. @@ -1465,12 +1503,7 @@ class GRU(RNN): although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. - References: - - [On the Properties of Neural Machine Translation: Encoder-Decoder Approaches](https://arxiv.org/abs/1409.1259) - - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](http://arxiv.org/abs/1412.3555v1) - - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287) """ - # pylint: enable=line-too-long def __init__(self, units, @@ -1528,8 +1561,8 @@ class GRU(RNN): self.activity_regularizer = regularizers.get(activity_regularizer) def call(self, inputs, mask=None, training=None, initial_state=None): - self.cell._generate_dropout_mask(inputs, training=training) - self.cell._generate_recurrent_dropout_mask(inputs, training=training) + self.cell._dropout_mask = None + self.cell._recurrent_dropout_mask = None return super(GRU, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @@ -1599,28 +1632,40 @@ class GRU(RNN): def get_config(self): config = { - 'units': self.units, - 'activation': activations.serialize(self.activation), + 'units': + self.units, + 'activation': + activations.serialize(self.activation), 'recurrent_activation': activations.serialize(self.recurrent_activation), - 'use_bias': self.use_bias, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint), - 'dropout': self.dropout, - 'recurrent_dropout': self.recurrent_dropout, - 'implementation': self.implementation + 'bias_constraint': + constraints.serialize(self.bias_constraint), + 'dropout': + self.dropout, + 'recurrent_dropout': + self.recurrent_dropout, + 'implementation': + self.implementation } base_config = super(GRU, self).get_config() del base_config['cell'] @@ -1633,53 +1678,43 @@ class GRU(RNN): return cls(**config) +@tf_export('keras.layers.LSTMCell') class LSTMCell(Layer): """Cell class for the LSTM layer. Arguments: units: Positive integer, dimensionality of the output space. - activation: Activation function to use - (see [activations](../activations.md)). + activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use - for the recurrent step - (see [activations](../activations.md)). + for the recurrent step. Default: hard sigmoid (`hard_sigmoid`). If you pass `None`, no activation is applied - (ie. "linear" activation: `a(x) = x`). + (ie. "linear" activation: `a(x) = x`).x use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. - (see [initializers](../initializers.md)). recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. - (see [initializers](../initializers.md)). - bias_initializer: Initializer for the bias vector - (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector. unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate at initialization. Setting it to true will also force `bias_initializer="zeros"`. This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix - (see [regularizer](../regularizers.md)). + the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix - (see [regularizer](../regularizers.md)). - bias_regularizer: Regularizer function applied to the bias vector - (see [regularizer](../regularizers.md)). + the `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. kernel_constraint: Constraint function applied to - the `kernel` weights matrix - (see [constraints](../constraints.md)). + the `kernel` weights matrix. recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix - (see [constraints](../constraints.md)). - bias_constraint: Constraint function applied to the bias vector - (see [constraints](../constraints.md)). + the `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. @@ -1739,6 +1774,7 @@ class LSTMCell(Layer): self._dropout_mask = None self._recurrent_dropout_mask = None + @shape_type_conversion def build(self, input_shape): input_dim = input_shape[-1] self.kernel = self.add_weight( @@ -1798,36 +1834,22 @@ class LSTMCell(Layer): self.bias_o = None self.built = True - def _generate_dropout_mask(self, inputs, training=None): - if 0 < self.dropout < 1: - ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1)) - - def dropped_inputs(): - return K.dropout(ones, self.dropout) - - self._dropout_mask = [ - K.in_train_phase(dropped_inputs, ones, training=training) - for _ in range(4) - ] - else: - self._dropout_mask = None - - def _generate_recurrent_dropout_mask(self, inputs, training=None): - if 0 < self.recurrent_dropout < 1: - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, self.units)) - - def dropped_inputs(): - return K.dropout(ones, self.dropout) - - self._recurrent_dropout_mask = [ - K.in_train_phase(dropped_inputs, ones, training=training) - for _ in range(4) - ] - else: - self._recurrent_dropout_mask = None - def call(self, inputs, states, training=None): + if 0 < self.dropout < 1 and self._dropout_mask is None: + self._dropout_mask = _generate_dropout_mask( + _generate_dropout_ones(inputs, + K.shape(inputs)[-1]), + self.dropout, + training=training, + count=4) + if (0 < self.recurrent_dropout < 1 and + self._recurrent_dropout_mask is None): + self._recurrent_dropout_mask = _generate_dropout_mask( + _generate_dropout_ones(inputs, self.units), + self.recurrent_dropout, + training=training, + count=4) + # dropout matrices for input units dp_mask = self._dropout_mask # dropout matrices for recurrent units @@ -1901,59 +1923,86 @@ class LSTMCell(Layer): h._uses_learning_phase = True return h, [h, c] + def get_config(self): + config = { + 'units': + self.units, + 'activation': + activations.serialize(self.activation), + 'recurrent_activation': + activations.serialize(self.recurrent_activation), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), + 'recurrent_initializer': + initializers.serialize(self.recurrent_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'unit_forget_bias': + self.unit_forget_bias, + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), + 'recurrent_regularizer': + regularizers.serialize(self.recurrent_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), + 'recurrent_constraint': + constraints.serialize(self.recurrent_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint), + 'dropout': + self.dropout, + 'recurrent_dropout': + self.recurrent_dropout, + 'implementation': + self.implementation + } + base_config = super(LSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + +@tf_export('keras.layers.LSTM') class LSTM(RNN): - # pylint: disable=line-too-long """Long-Short Term Memory layer - Hochreiter 1997. Arguments: units: Positive integer, dimensionality of the output space. - activation: Activation function to use - (see [activations](../activations.md)). + activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use - for the recurrent step - (see [activations](../activations.md)). - Default: hyperbolic tangent (`tanh`). + for the recurrent step. Default: hard sigmoid (`hard_sigmoid`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, - used for the linear transformation of the inputs. - (see [initializers](../initializers.md)). + used for the linear transformation of the inputs.. recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, - used for the linear transformation of the recurrent state. - (see [initializers](../initializers.md)). - bias_initializer: Initializer for the bias vector - (see [initializers](../initializers.md)). + used for the linear transformation of the recurrent state.. + bias_initializer: Initializer for the bias vector. unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate at initialization. Setting it to true will also force `bias_initializer="zeros"`. This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix - (see [regularizer](../regularizers.md)). + the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix - (see [regularizer](../regularizers.md)). - bias_regularizer: Regularizer function applied to the bias vector - (see [regularizer](../regularizers.md)). + the `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. activity_regularizer: Regularizer function applied to - the output of the layer (its "activation"). - (see [regularizer](../regularizers.md)). + the output of the layer (its "activation").. kernel_constraint: Constraint function applied to - the `kernel` weights matrix - (see [constraints](../constraints.md)). + the `kernel` weights matrix. recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix - (see [constraints](../constraints.md)). - bias_constraint: Constraint function applied to the bias vector - (see [constraints](../constraints.md)). + the `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. @@ -1983,13 +2032,7 @@ class LSTM(RNN): although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. - References: - - [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf) - - [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015) - - [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf) - - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287) """ - # pylint: enable=line-too-long def __init__(self, units, @@ -2049,8 +2092,8 @@ class LSTM(RNN): self.activity_regularizer = regularizers.get(activity_regularizer) def call(self, inputs, mask=None, training=None, initial_state=None): - self.cell._generate_dropout_mask(inputs, training=training) - self.cell._generate_recurrent_dropout_mask(inputs, training=training) + self.cell._dropout_mask = None + self.cell._recurrent_dropout_mask = None return super(LSTM, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @@ -2124,29 +2167,42 @@ class LSTM(RNN): def get_config(self): config = { - 'units': self.units, - 'activation': activations.serialize(self.activation), + 'units': + self.units, + 'activation': + activations.serialize(self.activation), 'recurrent_activation': activations.serialize(self.recurrent_activation), - 'use_bias': self.use_bias, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'unit_forget_bias': self.unit_forget_bias, - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'unit_forget_bias': + self.unit_forget_bias, + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint), - 'dropout': self.dropout, - 'recurrent_dropout': self.recurrent_dropout, - 'implementation': self.implementation + 'bias_constraint': + constraints.serialize(self.bias_constraint), + 'dropout': + self.dropout, + 'recurrent_dropout': + self.recurrent_dropout, + 'implementation': + self.implementation } base_config = super(LSTM, self).get_config() del base_config['cell'] @@ -2159,6 +2215,23 @@ class LSTM(RNN): return cls(**config) +def _generate_dropout_ones(inputs, dims): + return K.ones((K.shape(inputs)[0], dims)) + + +def _generate_dropout_mask(ones, rate, training=None, count=1): + + def dropped_inputs(): + return K.dropout(ones, rate) + + if count > 1: + return [ + K.in_train_phase(dropped_inputs, ones, training=training) + for _ in range(count) + ] + return K.in_train_phase(dropped_inputs, ones, training=training) + + class Recurrent(Layer): """Deprecated abstract base class for recurrent layers. @@ -2285,6 +2358,7 @@ class Recurrent(Layer): self.dropout = 0 self.recurrent_dropout = 0 + @shape_type_conversion def compute_output_shape(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] @@ -2422,7 +2496,7 @@ class Recurrent(Layer): if self.stateful: updates = [] for i in range(len(states)): - updates.append((self.states[i], states[i])) + updates.append(K.update(self.states[i], states[i])) self.add_update(updates, inputs) # Properly set learning phase diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py index 7dc4c1db9b4b71775bd3c52a863752b34d9dc3ea..de022153f6f07240a0dff70e5faeed5b6d4a8c5f 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py @@ -253,7 +253,7 @@ class RNNTest(test.TestCase): self.assertAllClose(y_np, y_np_2, atol=1e-4) with self.test_session(): - # test flat list inputs + # test flat list inputs. with keras.utils.CustomObjectScope(custom_objects): layer = keras.layers.RNN.from_config(config.copy()) y = layer([x, c]) @@ -262,6 +262,35 @@ class RNNTest(test.TestCase): y_np_3 = model.predict([x_np, c_np]) self.assertAllClose(y_np, y_np_3, atol=1e-4) + with self.test_session(): + # Test stacking. + cells = [keras.layers.recurrent.GRUCell(8), + RNNCellWithConstants(12), + RNNCellWithConstants(32)] + layer = keras.layers.recurrent.RNN(cells) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + [np.zeros((6, 5, 5)), np.zeros((6, 3))], + np.zeros((6, 32)) + ) + + with self.test_session(): + # Test stacked RNN serialization + x_np = np.random.random((6, 5, 5)) + c_np = np.random.random((6, 3)) + y_np = model.predict([x_np, c_np]) + weights = model.get_weights() + config = layer.get_config() + with keras.utils.CustomObjectScope(custom_objects): + layer = keras.layers.recurrent.RNN.from_config(config.copy()) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.set_weights(weights) + y_np_2 = model.predict([x_np, c_np]) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + def test_rnn_cell_with_constants_layer_passing_initial_state(self): class RNNCellWithConstants(keras.layers.Layer): @@ -353,13 +382,10 @@ class RNNTest(test.TestCase): self.assertAllClose(y_np, y_np_3, atol=1e-4) def test_stacked_rnn_attributes(self): - cells = [keras.layers.LSTMCell(3), - keras.layers.LSTMCell(3, kernel_regularizer='l2')] + cells = [keras.layers.LSTMCell(1), + keras.layers.LSTMCell(1)] layer = keras.layers.RNN(cells) - layer.build((None, None, 5)) - - # Test regularization losses - self.assertEqual(len(layer.losses), 1) + layer.build((None, None, 1)) # Test weights self.assertEqual(len(layer.trainable_weights), 6) @@ -367,11 +393,32 @@ class RNNTest(test.TestCase): self.assertEqual(len(layer.trainable_weights), 3) self.assertEqual(len(layer.non_trainable_weights), 3) - # Test `get_losses_for` - x = keras.Input((None, 5)) - y = keras.backend.sum(x) - cells[0].add_loss(y, inputs=x) - self.assertEqual(layer.get_losses_for(x), [y]) + # Test `get_losses_for` and `losses` + x = keras.Input((None, 1)) + loss_1 = keras.backend.sum(x) + loss_2 = keras.backend.sum(cells[0].kernel) + cells[0].add_loss(loss_1, inputs=x) + cells[0].add_loss(loss_2) + self.assertEqual(len(layer.losses), 2) + self.assertEqual(layer.get_losses_for(None), [loss_2]) + self.assertEqual(layer.get_losses_for(x), [loss_1]) + + # Test `get_updates_for` and `updates` + cells = [keras.layers.LSTMCell(1), + keras.layers.LSTMCell(1)] + layer = keras.layers.RNN(cells) + layer.build((None, None, 1)) + + x = keras.Input((None, 1)) + update_1 = keras.backend.update_add( + cells[0].kernel, x[0, 0, 0] * cells[0].kernel) + update_2 = keras.backend.update_add( + cells[0].kernel, keras.backend.ones_like(cells[0].kernel)) + cells[0].add_update(update_1, inputs=x) + cells[0].add_update(update_2) + self.assertEqual(len(layer.updates), 2) + self.assertEqual(layer.get_updates_for(None), [update_2]) + self.assertEqual(layer.get_updates_for(x), [update_1]) def test_rnn_dynamic_trainability(self): layer_class = keras.layers.SimpleRNN @@ -392,6 +439,105 @@ class RNNTest(test.TestCase): self.assertEqual(len(layer.trainable_weights), 3) self.assertEqual(len(layer.non_trainable_weights), 0) + def test_state_reuse_with_dropout(self): + layer_class = keras.layers.SimpleRNN + embedding_dim = 4 + units = 3 + timesteps = 2 + num_samples = 2 + + with self.test_session(): + input1 = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim)) + layer = layer_class(units, + return_state=True, + return_sequences=True, + dropout=0.2) + state = layer(input1)[1:] + + input2 = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim)) + output = layer_class(units)(input2, initial_state=state) + model = keras.Model([input1, input2], output) + + inputs = [np.random.random((num_samples, timesteps, embedding_dim)), + np.random.random((num_samples, timesteps, embedding_dim))] + model.predict(inputs) + + def test_builtin_rnn_cell_serialization(self): + for cell_class in [keras.layers.SimpleRNNCell, + keras.layers.GRUCell, + keras.layers.LSTMCell]: + with self.test_session(): + # Test basic case. + x = keras.Input((None, 5)) + cell = cell_class(32) + layer = keras.layers.RNN(cell) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + y_np = model.predict(x_np) + weights = model.get_weights() + config = layer.get_config() + layer = keras.layers.RNN.from_config(config) + y = layer(x) + model = keras.models.Model(x, y) + model.set_weights(weights) + y_np_2 = model.predict(x_np) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + + # Test stacking. + cells = [cell_class(8), + cell_class(12), + cell_class(32)] + layer = keras.layers.RNN(cells) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + + # Test stacked RNN serialization. + x_np = np.random.random((6, 5, 5)) + y_np = model.predict(x_np) + weights = model.get_weights() + config = layer.get_config() + layer = keras.layers.RNN.from_config(config) + y = layer(x) + model = keras.models.Model(x, y) + model.set_weights(weights) + y_np_2 = model.predict(x_np) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + + def test_stacked_rnn_dropout(self): + cells = [keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1), + keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)] + layer = keras.layers.RNN(cells) + + with self.test_session(): + x = keras.Input((None, 5)) + y = layer(x) + model = keras.models.Model(x, y) + model.compile('sgd', 'mse') + x_np = np.random.random((6, 5, 5)) + y_np = np.random.random((6, 3)) + model.train_on_batch(x_np, y_np) + + def test_stacked_rnn_compute_output_shape(self): + cells = [keras.layers.LSTMCell(3), + keras.layers.LSTMCell(6)] + embedding_dim = 4 + timesteps = 2 + layer = keras.layers.RNN(cells, return_state=True, return_sequences=True) + output_shape = layer.compute_output_shape((None, timesteps, embedding_dim)) + expected_output_shape = [(None, timesteps, 6), + (None, 6), + (None, 6), + (None, 3), + (None, 3)] + self.assertEqual( + [tuple(o.as_list()) for o in output_shape], + expected_output_shape) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py index 452801b65639be217ac26d3caa69f070c776634e..61f1a758e4701e6925af88b7fed9c48cf42ca735 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py +++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py @@ -25,10 +25,13 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.layers import utils as tf_layers_util +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.layers.Wrapper') class Wrapper(Layer): """Abstract wrapper base class. @@ -58,6 +61,14 @@ class Wrapper(Layer): else: return None + @property + def trainable(self): + return self.layer.trainable + + @trainable.setter + def trainable(self, value): + self.layer.trainable = value + @property def trainable_weights(self): return self.layer.trainable_weights @@ -68,34 +79,11 @@ class Wrapper(Layer): @property def updates(self): - if hasattr(self.layer, 'updates'): - return self.layer.updates - return [] - - def get_updates_for(self, inputs=None): - # If the wrapper modifies the inputs, use the modified inputs to - # get the updates from the inner layer. - inner_inputs = inputs - if inputs is not None: - uid = tf_layers_util.object_list_uid(inputs) - if uid in self._input_map: - inner_inputs = self._input_map[uid] - - updates = self.layer.get_updates_for(inner_inputs) - updates += super(Wrapper, self).get_updates_for(inputs) - return updates + return self.layer.updates + self._updates @property def losses(self): - if hasattr(self.layer, 'losses'): - return self.layer.losses - return [] - - def get_losses_for(self, inputs=None): - if inputs is None: - losses = self.layer.get_losses_for(None) - return losses + super(Wrapper, self).get_losses_for(None) - return super(Wrapper, self).get_losses_for(inputs) + return self.layer.losses + self._losses def get_weights(self): return self.layer.get_weights() @@ -121,6 +109,7 @@ class Wrapper(Layer): return cls(layer, **config) +@tf_export('keras.layers.TimeDistributed') class TimeDistributed(Wrapper): """This wrapper allows to apply a layer to every temporal slice of an input. @@ -245,6 +234,7 @@ class TimeDistributed(Wrapper): return y +@tf_export('keras.layers.Bidirectional') class Bidirectional(Wrapper): """Bidirectional wrapper for RNNs. @@ -273,7 +263,6 @@ class Bidirectional(Wrapper): """ def __init__(self, layer, merge_mode='concat', weights=None, **kwargs): - super(Bidirectional, self).__init__(layer, **kwargs) if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]: raise ValueError('Invalid merge mode. ' 'Merge mode should be one of ' @@ -291,7 +280,21 @@ class Bidirectional(Wrapper): self.backward_layer.initial_weights = weights[nw // 2:] self.stateful = layer.stateful self.return_sequences = layer.return_sequences + self.return_state = layer.return_state self.supports_masking = True + self._trainable = True + super(Bidirectional, self).__init__(layer, **kwargs) + self.input_spec = layer.input_spec + + @property + def trainable(self): + return self._trainable + + @trainable.setter + def trainable(self, value): + self._trainable = value + self.forward_layer.trainable = value + self.backward_layer.trainable = value def get_weights(self): return self.forward_layer.get_weights() + self.backward_layer.get_weights() @@ -301,27 +304,104 @@ class Bidirectional(Wrapper): self.forward_layer.set_weights(weights[:nw // 2]) self.backward_layer.set_weights(weights[nw // 2:]) + @shape_type_conversion def compute_output_shape(self, input_shape): - input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) - if self.merge_mode in ['sum', 'ave', 'mul']: - return self.forward_layer.compute_output_shape(input_shape) - elif self.merge_mode == 'concat': - shape = self.forward_layer.compute_output_shape(input_shape).as_list() - shape[-1] *= 2 - return tensor_shape.TensorShape(shape) + output_shape = tuple(self.forward_layer.compute_output_shape( + input_shape).as_list()) + if self.return_state: + state_shape = output_shape[1:] + output_shape = output_shape[0] + + if self.merge_mode == 'concat': + output_shape = list(output_shape) + output_shape[-1] *= 2 + output_shape = tuple(output_shape) elif self.merge_mode is None: - shape = self.forward_layer.compute_output_shape(input_shape) - return [shape, copy.copy(shape)] + output_shape = [output_shape, copy.copy(output_shape)] - def call(self, inputs, training=None, mask=None): + if self.return_state: + if self.merge_mode is None: + return output_shape + state_shape + copy.copy(state_shape) + return [output_shape] + state_shape + copy.copy(state_shape) + return output_shape + + def __call__(self, inputs, initial_state=None, **kwargs): + if isinstance(inputs, list): + if len(inputs) > 1: + initial_state = inputs[1:] + inputs = inputs[0] + + if initial_state is None: + return super(Bidirectional, self).__call__(inputs, **kwargs) + + # Standardize `initial_state` into list + if isinstance(initial_state, tuple): + initial_state = list(initial_state) + elif not isinstance(initial_state, list): + initial_state = [initial_state] + + # Check if `initial_state` can be splitted into half + num_states = len(initial_state) + if num_states % 2 > 0: + raise ValueError( + 'When passing `initial_state` to a Bidirectional RNN, the state ' + 'should be a list containing the states of the underlying RNNs. ' + 'Found: ' + str(initial_state)) + + # Applies the same workaround as in `RNN.__call__`, without handling + # constants + kwargs['initial_state'] = initial_state + additional_inputs = initial_state + additional_specs = [InputSpec(shape=K.int_shape(state)) + for state in initial_state] + self.forward_layer.state_spec = additional_specs[:num_states // 2] + self.backward_layer.state_spec = additional_specs[num_states // 2:] + + is_keras_tensor = K.is_keras_tensor(additional_inputs[0]) + for tensor in additional_inputs: + if K.is_keras_tensor(tensor) != is_keras_tensor: + raise ValueError('The initial state of a Bidirectional' + ' layer cannot be specified with a mix of' + ' Keras tensors and non-Keras tensors' + ' (a "Keras tensor" is a tensor that was' + ' returned by a Keras layer, or by `Input`)') + + if is_keras_tensor: + # Compute the full input spec, including state + full_input = [inputs] + additional_inputs + full_input_spec = self.input_spec + additional_specs + + # Perform the call with temporarily replaced input_spec + original_input_spec = self.input_spec + self.input_spec = full_input_spec + output = super(Bidirectional, self).__call__(full_input, **kwargs) + self.input_spec = original_input_spec + return output + else: + return super(Bidirectional, self).__call__(inputs, **kwargs) + + def call(self, inputs, training=None, mask=None, initial_state=None): kwargs = {} if has_arg(self.layer.call, 'training'): kwargs['training'] = training if has_arg(self.layer.call, 'mask'): kwargs['mask'] = mask - y = self.forward_layer.call(inputs, **kwargs) - y_rev = self.backward_layer.call(inputs, **kwargs) + if initial_state is not None and has_arg(self.layer.call, 'initial_state'): + forward_state = initial_state[:len(initial_state) // 2] + backward_state = initial_state[len(initial_state) // 2:] + y = self.forward_layer.call(inputs, initial_state=forward_state, **kwargs) + y_rev = self.backward_layer.call( + inputs, initial_state=backward_state, **kwargs) + else: + y = self.forward_layer.call(inputs, **kwargs) + y_rev = self.backward_layer.call(inputs, **kwargs) + + if self.return_state: + states = y[1:] + y_rev[1:] + y = y[0] + y_rev = y_rev[0] + if self.return_sequences: y_rev = K.reverse(y_rev, 1) if self.merge_mode == 'concat': @@ -343,6 +423,11 @@ class Bidirectional(Wrapper): out._uses_learning_phase = True else: output._uses_learning_phase = True + + if self.return_state: + if self.merge_mode is None: + return output + states + return [output] + states return output def reset_states(self): diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py index 0866c4b0aeddc91ba6eeca6395875b4f2574dbc0..c81d6b883cb0aa2b30331e35b387457072dbf3c3 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py @@ -133,6 +133,20 @@ class TimeDistributedTest(test.TestCase): # Verify input_map has one mapping from inputs to reshaped inputs. self.assertEqual(len(td._input_map.keys()), 1) + def test_TimeDistributed_trainable(self): + # test layers that need learning_phase to be set + x = keras.layers.Input(shape=(3, 2)) + layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization()) + _ = layer(x) + assert len(layer.updates) == 2 + assert len(layer.trainable_weights) == 2 + layer.trainable = False + assert not layer.updates + assert not layer.trainable_weights + layer.trainable = True + assert len(layer.updates) == 2 + assert len(layer.trainable_weights) == 2 + class BidirectionalTest(test.TestCase): @@ -238,6 +252,146 @@ class BidirectionalTest(test.TestCase): model.compile(loss='mse', optimizer='sgd') model.fit(x, y, epochs=1, batch_size=1) + def test_Bidirectional_merged_value(self): + rnn = keras.layers.LSTM + samples = 2 + dim = 5 + timesteps = 3 + units = 3 + x = [np.random.rand(samples, timesteps, dim)] + + with self.test_session(): + for merge_mode in ['sum', 'mul', 'ave', 'concat', None]: + if merge_mode == 'sum': + merge_func = lambda y, y_rev: y + y_rev + elif merge_mode == 'mul': + merge_func = lambda y, y_rev: y * y_rev + elif merge_mode == 'ave': + merge_func = lambda y, y_rev: (y + y_rev) / 2 + elif merge_mode == 'concat': + merge_func = lambda y, y_rev: np.concatenate((y, y_rev), axis=-1) + else: + merge_func = lambda y, y_rev: [y, y_rev] + + # basic case + inputs = keras.Input((timesteps, dim)) + layer = keras.layers.Bidirectional( + rnn(units, return_sequences=True), merge_mode=merge_mode) + f_merged = keras.backend.function([inputs], _to_list(layer(inputs))) + f_forward = keras.backend.function([inputs], + [layer.forward_layer.call(inputs)]) + f_backward = keras.backend.function( + [inputs], + [keras.backend.reverse(layer.backward_layer.call(inputs), 1)]) + + y_merged = f_merged(x) + y_expected = _to_list(merge_func(f_forward(x)[0], f_backward(x)[0])) + assert len(y_merged) == len(y_expected) + for x1, x2 in zip(y_merged, y_expected): + self.assertAllClose(x1, x2, atol=1e-5) + + # test return_state + inputs = keras.Input((timesteps, dim)) + layer = keras.layers.Bidirectional( + rnn(units, return_state=True), merge_mode=merge_mode) + f_merged = keras.backend.function([inputs], layer(inputs)) + f_forward = keras.backend.function([inputs], + layer.forward_layer.call(inputs)) + f_backward = keras.backend.function([inputs], + layer.backward_layer.call(inputs)) + n_states = len(layer.layer.states) + + y_merged = f_merged(x) + y_forward = f_forward(x) + y_backward = f_backward(x) + y_expected = _to_list(merge_func(y_forward[0], y_backward[0])) + assert len(y_merged) == len(y_expected) + n_states * 2 + for x1, x2 in zip(y_merged, y_expected): + self.assertAllClose(x1, x2, atol=1e-5) + + y_merged = y_merged[-n_states * 2:] + y_forward = y_forward[-n_states:] + y_backward = y_backward[-n_states:] + for state_birnn, state_inner in zip(y_merged, y_forward + y_backward): + self.assertAllClose(state_birnn, state_inner, atol=1e-5) + + def test_Bidirectional_dropout(self): + rnn = keras.layers.LSTM + samples = 2 + dim = 5 + timesteps = 3 + units = 3 + merge_mode = 'sum' + x = [np.random.rand(samples, timesteps, dim)] + + with self.test_session(): + inputs = keras.Input((timesteps, dim)) + wrapped = keras.layers.Bidirectional( + rnn(units, dropout=0.2, recurrent_dropout=0.2), merge_mode=merge_mode) + outputs = _to_list(wrapped(inputs, training=True)) + assert all(not getattr(x, '_uses_learning_phase') for x in outputs) + + inputs = keras.Input((timesteps, dim)) + wrapped = keras.layers.Bidirectional( + rnn(units, dropout=0.2, return_state=True), merge_mode=merge_mode) + outputs = _to_list(wrapped(inputs)) + assert all(x._uses_learning_phase for x in outputs) + + model = keras.Model(inputs, outputs) + assert model.uses_learning_phase + y1 = _to_list(model.predict(x)) + y2 = _to_list(model.predict(x)) + for x1, x2 in zip(y1, y2): + self.assertAllClose(x1, x2, atol=1e-5) + + def test_Bidirectional_state_reuse(self): + rnn = keras.layers.LSTM + samples = 2 + dim = 5 + timesteps = 3 + units = 3 + + with self.test_session(): + input1 = keras.layers.Input((timesteps, dim)) + layer = keras.layers.Bidirectional( + rnn(units, return_state=True, return_sequences=True)) + state = layer(input1)[1:] + + # test passing invalid initial_state: passing a tensor + input2 = keras.layers.Input((timesteps, dim)) + with self.assertRaises(ValueError): + output = keras.layers.Bidirectional( + rnn(units))(input2, initial_state=state[0]) + + # test valid usage: passing a list + output = keras.layers.Bidirectional(rnn(units))(input2, + initial_state=state) + model = keras.models.Model([input1, input2], output) + assert len(model.layers) == 4 + assert isinstance(model.layers[-1].input, list) + inputs = [np.random.rand(samples, timesteps, dim), + np.random.rand(samples, timesteps, dim)] + model.predict(inputs) + + def test_Bidirectional_trainable(self): + # test layers that need learning_phase to be set + with self.test_session(): + x = keras.layers.Input(shape=(3, 2)) + layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3)) + _ = layer(x) + assert len(layer.trainable_weights) == 6 + layer.trainable = False + assert not layer.trainable_weights + layer.trainable = True + assert len(layer.trainable_weights) == 6 + + +def _to_list(ls): + if isinstance(ls, list): + return ls + else: + return [ls] + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/losses.py b/tensorflow/python/keras/_impl/keras/losses.py index 1d6319abb13619932fe76966a69004dcfcd0e022..1576ed7b999f65992f46b357c8ebeda8935c68d0 100644 --- a/tensorflow/python/keras/_impl/keras/losses.py +++ b/tensorflow/python/keras/_impl/keras/losses.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Built-in Keras loss functions. +# pylint: disable=unused-import +"""Built-in loss functions. """ from __future__ import absolute_import from __future__ import division @@ -23,43 +24,69 @@ import six from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.metrics.mean_squared_error', + 'keras.losses.mean_squared_error') def mean_squared_error(y_true, y_pred): return K.mean(K.square(y_pred - y_true), axis=-1) +@tf_export('keras.metrics.mean_absolute_error', + 'keras.losses.mean_absolute_error') def mean_absolute_error(y_true, y_pred): return K.mean(K.abs(y_pred - y_true), axis=-1) +@tf_export('keras.metrics.mean_absolute_percentage_error', + 'keras.losses.mean_absolute_percentage_error') def mean_absolute_percentage_error(y_true, y_pred): - # Equivalent to MAE, but sometimes easier to interpret. diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true), K.epsilon(), None)) return 100. * K.mean(diff, axis=-1) +@tf_export('keras.metrics.mean_squared_logarithmic_error', + 'keras.losses.mean_squared_logarithmic_error') def mean_squared_logarithmic_error(y_true, y_pred): first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.) second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.) return K.mean(K.square(first_log - second_log), axis=-1) +@tf_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge') def squared_hinge(y_true, y_pred): return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1) +@tf_export('keras.metrics.hinge', 'keras.losses.hinge') def hinge(y_true, y_pred): return K.mean(K.maximum(1. - y_true * y_pred, 0.), axis=-1) +@tf_export('keras.losses.categorical_hinge') def categorical_hinge(y_true, y_pred): pos = K.sum(y_true * y_pred, axis=-1) neg = K.max((1. - y_true) * y_pred, axis=-1) - return K.maximum(neg - pos + 1., 0.) + return K.maximum(0., neg - pos + 1.) +@tf_export('keras.losses.logcosh') def logcosh(y_true, y_pred): + """Logarithm of the hyperbolic cosine of the prediction error. + + `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and + to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly + like the mean squared error, but will not be so strongly affected by the + occasional wildly incorrect prediction. + + Arguments: + y_true: tensor of true targets. + y_pred: tensor of predicted targets. + + Returns: + Tensor with one scalar loss entry per sample. + """ def _logcosh(x): return x + K.softplus(-2. * x) - K.log(2.) @@ -67,28 +94,38 @@ def logcosh(y_true, y_pred): return K.mean(_logcosh(y_pred - y_true), axis=-1) +@tf_export('keras.metrics.categorical_crossentropy', + 'keras.losses.categorical_crossentropy') def categorical_crossentropy(y_true, y_pred): return K.categorical_crossentropy(y_true, y_pred) +@tf_export('keras.metrics.sparse_categorical_crossentropy', + 'keras.losses.sparse_categorical_crossentropy') def sparse_categorical_crossentropy(y_true, y_pred): return K.sparse_categorical_crossentropy(y_true, y_pred) +@tf_export('keras.metrics.binary_crossentropy', + 'keras.losses.binary_crossentropy') def binary_crossentropy(y_true, y_pred): return K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1) +@tf_export('keras.metrics.kullback_leibler_divergence', + 'keras.losses.kullback_leibler_divergence') def kullback_leibler_divergence(y_true, y_pred): y_true = K.clip(y_true, K.epsilon(), 1) y_pred = K.clip(y_pred, K.epsilon(), 1) return K.sum(y_true * K.log(y_true / y_pred), axis=-1) +@tf_export('keras.metrics.poisson', 'keras.losses.poisson') def poisson(y_true, y_pred): return K.mean(y_pred - y_true * K.log(y_pred + K.epsilon()), axis=-1) +@tf_export('keras.metrics.cosine_proximity', 'keras.losses.cosine_proximity') def cosine_proximity(y_true, y_pred): y_true = K.l2_normalize(y_true, axis=-1) y_pred = K.l2_normalize(y_pred, axis=-1) @@ -105,10 +142,12 @@ kld = KLD = kullback_leibler_divergence cosine = cosine_proximity +@tf_export('keras.losses.serialize') def serialize(loss): return serialize_keras_object(loss) +@tf_export('keras.losses.deserialize') def deserialize(name, custom_objects=None): return deserialize_keras_object( name, @@ -117,6 +156,7 @@ def deserialize(name, custom_objects=None): printable_module_name='loss function') +@tf_export('keras.losses.get') def get(identifier): if identifier is None: return None diff --git a/tensorflow/python/keras/_impl/keras/metrics.py b/tensorflow/python/keras/_impl/keras/metrics.py index 202048f26d2ad201b4762d3b2b32638f9d041e88..82778a3dc4fbdc13bb6682d01e28ff68882b6dd9 100644 --- a/tensorflow/python/keras/_impl/keras/metrics.py +++ b/tensorflow/python/keras/_impl/keras/metrics.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Built-in Keras metrics functions. +# pylint: disable=unused-import +"""Built-in metrics. """ from __future__ import absolute_import from __future__ import division @@ -21,7 +22,6 @@ from __future__ import print_function import six from tensorflow.python.keras._impl.keras import backend as K -# pylint: disable=unused-import from tensorflow.python.keras._impl.keras.losses import binary_crossentropy from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy from tensorflow.python.keras._impl.keras.losses import cosine_proximity @@ -35,14 +35,17 @@ from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_ from tensorflow.python.keras._impl.keras.losses import poisson from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy from tensorflow.python.keras._impl.keras.losses import squared_hinge -# pylint: disable=unused-import from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.metrics.binary_accuracy') def binary_accuracy(y_true, y_pred): return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1) +@tf_export('keras.metrics.categorical_accuracy') def categorical_accuracy(y_true, y_pred): return K.cast( K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)), K.floatx()) @@ -55,13 +58,15 @@ def sparse_categorical_accuracy(y_true, y_pred): K.floatx())), K.floatx()) +@tf_export('keras.metrics.top_k_categorical_accuracy') def top_k_categorical_accuracy(y_true, y_pred, k=5): return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1) +@tf_export('keras.metrics.sparse_top_k_categorical_accuracy') def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): - return K.mean(K.in_top_k(y_pred, - K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1) + return K.mean( + K.in_top_k(y_pred, K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1) # Aliases @@ -73,24 +78,29 @@ msle = MSLE = mean_squared_logarithmic_error cosine = cosine_proximity +@tf_export('keras.metrics.serialize') def serialize(metric): - return metric.__name__ + return serialize_keras_object(metric) -def deserialize(name, custom_objects=None): +@tf_export('keras.metrics.deserialize') +def deserialize(config, custom_objects=None): return deserialize_keras_object( - name, + config, module_objects=globals(), custom_objects=custom_objects, printable_module_name='metric function') +@tf_export('keras.metrics.get') def get(identifier): - if isinstance(identifier, six.string_types): - identifier = str(identifier) - return deserialize(identifier) + if isinstance(identifier, dict): + config = {'class_name': str(identifier), 'config': {}} + return deserialize(config) + elif isinstance(identifier, six.string_types): + return deserialize(str(identifier)) elif callable(identifier): return identifier else: raise ValueError('Could not interpret ' - 'metric function identifier:', identifier) + 'metric function identifier: %s' % identifier) diff --git a/tensorflow/python/keras/_impl/keras/metrics_test.py b/tensorflow/python/keras/_impl/keras/metrics_test.py index f4792f3543cc5ca8e5e7ad03d9906bbfadd1fb04..44289ea02abf5ae5f8befbe515552aea3d4b231e 100644 --- a/tensorflow/python/keras/_impl/keras/metrics_test.py +++ b/tensorflow/python/keras/_impl/keras/metrics_test.py @@ -72,6 +72,77 @@ class KerasMetricsTest(test.TestCase): keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) + def test_stateful_metrics(self): + np.random.seed(1334) + + class BinaryTruePositives(keras.layers.Layer): + """Stateful Metric to count the total true positives over all batches. + + Assumes predictions and targets of shape `(samples, 1)`. + + Arguments: + threshold: Float, lower limit on prediction value that counts as a + positive class prediction. + name: String, name for the metric. + """ + + def __init__(self, name='true_positives', **kwargs): + super(BinaryTruePositives, self).__init__(name=name, **kwargs) + self.true_positives = keras.backend.variable(value=0, dtype='int32') + + def reset_states(self): + keras.backend.set_value(self.true_positives, 0) + + def __call__(self, y_true, y_pred): + """Computes the number of true positives in a batch. + + Args: + y_true: Tensor, batch_wise labels + y_pred: Tensor, batch_wise predictions + + Returns: + The total number of true positives seen this epoch at the + completion of the batch. + """ + y_true = keras.backend.cast(y_true, 'int32') + y_pred = keras.backend.cast(keras.backend.round(y_pred), 'int32') + correct_preds = keras.backend.cast( + keras.backend.equal(y_pred, y_true), 'int32') + true_pos = keras.backend.cast( + keras.backend.sum(correct_preds * y_true), 'int32') + current_true_pos = self.true_positives * 1 + self.add_update(keras.backend.update_add(self.true_positives, + true_pos), + inputs=[y_true, y_pred]) + return current_true_pos + true_pos + + metric_fn = BinaryTruePositives() + config = keras.metrics.serialize(metric_fn) + metric_fn = keras.metrics.deserialize( + config, custom_objects={'BinaryTruePositives': BinaryTruePositives}) + + # Test on simple model + inputs = keras.Input(shape=(2,)) + outputs = keras.layers.Dense(1, activation='sigmoid')(inputs) + model = keras.Model(inputs, outputs) + model.compile(optimizer='sgd', + loss='binary_crossentropy', + metrics=['acc', metric_fn]) + + # Test fit, evaluate + samples = 1000 + x = np.random.random((samples, 2)) + y = np.random.randint(2, size=(samples, 1)) + model.fit(x, y, epochs=1, batch_size=10) + outs = model.evaluate(x, y, batch_size=10) + preds = model.predict(x) + + def ref_true_pos(y_true, y_pred): + return np.sum(np.logical_and(y_pred > 0.5, y_true == 1)) + + # Test correctness (e.g. updates should have been run) + self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3d71a620fcb34d21c41f920eed99b1fe22668899 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py @@ -0,0 +1,589 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Model subclassing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +import numpy as np + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.keras._impl import keras +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer + +try: + import h5py # pylint:disable=g-import-not-at-top +except ImportError: + h5py = None + + +class SimpleTestModel(keras.Model): + + def __init__(self, use_bn=False, use_dp=False, num_classes=10): + super(SimpleTestModel, self).__init__(name='test_model') + self.use_bn = use_bn + self.use_dp = use_dp + self.num_classes = num_classes + + self.dense1 = keras.layers.Dense(32, activation='relu') + self.dense2 = keras.layers.Dense(num_classes, activation='softmax') + if self.use_dp: + self.dp = keras.layers.Dropout(0.5) + if self.use_bn: + self.bn = keras.layers.BatchNormalization(axis=-1) + + def call(self, inputs): + x = self.dense1(inputs) + if self.use_dp: + x = self.dp(x) + if self.use_bn: + x = self.bn(x) + return self.dense2(x) + + +class MultiIOTestModel(keras.Model): + + def __init__(self, use_bn=False, use_dp=False, num_classes=(2, 3)): + super(MultiIOTestModel, self).__init__(name='test_model') + self.use_bn = use_bn + self.use_dp = use_dp + self.num_classes = num_classes + + self.dense1 = keras.layers.Dense(32, activation='relu') + self.dense2 = keras.layers.Dense(num_classes[0], activation='softmax') + self.dense3 = keras.layers.Dense(num_classes[1], activation='softmax') + if use_dp: + self.dp = keras.layers.Dropout(0.5) + if use_bn: + self.bn = keras.layers.BatchNormalization() + + def call(self, inputs): + x1, x2 = inputs + x1 = self.dense1(x1) + x2 = self.dense1(x2) + if self.use_dp: + x1 = self.dp(x1) + if self.use_bn: + x2 = self.bn(x2) + return [self.dense2(x1), self.dense3(x2)] + + +class NestedTestModel1(keras.Model): + """A model subclass nested inside a model subclass. + """ + + def __init__(self, num_classes=2): + super(NestedTestModel1, self).__init__(name='nested_model_1') + self.num_classes = num_classes + self.dense1 = keras.layers.Dense(32, activation='relu') + self.dense2 = keras.layers.Dense(num_classes, activation='relu') + self.bn = keras.layers.BatchNormalization() + self.test_net = SimpleTestModel(num_classes=4, + use_bn=True, + use_dp=True) + + def call(self, inputs): + x = self.dense1(inputs) + x = self.bn(x) + x = self.test_net(x) # pylint: disable=not-callable + return self.dense2(x) + + +def get_functional_graph_model(input_dim, num_classes): + # A simple functional-API model (a.k.a. graph network) + inputs = keras.Input(shape=(input_dim,)) + x = keras.layers.Dense(32, activation='relu')(inputs) + x = keras.layers.BatchNormalization()(x) + outputs = keras.layers.Dense(num_classes)(x) + return keras.Model(inputs, outputs) + + +class NestedTestModel2(keras.Model): + """A model subclass with a functional-API graph network inside. + """ + + def __init__(self, num_classes=2): + super(NestedTestModel2, self).__init__(name='nested_model_2') + self.num_classes = num_classes + self.dense1 = keras.layers.Dense(32, activation='relu') + self.dense2 = keras.layers.Dense(num_classes, activation='relu') + self.bn = self.bn = keras.layers.BatchNormalization() + self.test_net = get_functional_graph_model(32, 4) + + def call(self, inputs): + x = self.dense1(inputs) + x = self.bn(x) + x = self.test_net(x) + return self.dense2(x) + + +def get_nested_model_3(input_dim, num_classes): + # A functional-API model with a subclassed model inside. + # NOTE: this requires the inner subclass to implement `compute_output_shape`. + + inputs = keras.Input(shape=(input_dim,)) + x = keras.layers.Dense(32, activation='relu')(inputs) + x = keras.layers.BatchNormalization()(x) + + class Inner(keras.Model): + + def __init__(self): + super(Inner, self).__init__() + self.dense1 = keras.layers.Dense(32, activation='relu') + self.dense2 = keras.layers.Dense(5, activation='relu') + self.bn = keras.layers.BatchNormalization() + + def call(self, inputs): + x = self.dense1(inputs) + x = self.dense2(x) + return self.bn(x) + + def compute_output_shape(self, input_shape): + return tensor_shape.TensorShape((input_shape[0], 5)) + + test_model = Inner() + x = test_model(x) # pylint: disable=not-callable + outputs = keras.layers.Dense(num_classes)(x) + return keras.Model(inputs, outputs, name='nested_model_3') + + +class ModelSubclassingTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_single_io_workflow_with_np_arrays(self): + num_classes = 2 + num_samples = 100 + input_dim = 50 + + with self.test_session(): + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) + + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) + + @test_util.run_in_graph_and_eager_modes() + def test_multi_io_workflow_with_np_arrays(self): + num_classes = (2, 3) + num_samples = 1000 + input_dim = 50 + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + _ = model.evaluate([x1, x2], [y1, y2], verbose=0) + + def test_single_io_workflow_with_tensors(self): + + num_classes = 2 + num_samples = 10 + input_dim = 50 + + with self.test_session(): + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + + x = array_ops.ones((num_samples, input_dim)) + y = array_ops.zeros((num_samples, num_classes)) + + model.fit(x, y, epochs=2, steps_per_epoch=10, verbose=0) + _ = model.evaluate(steps=10, verbose=0) + + def test_multi_io_workflow_with_tensors(self): + + num_classes = (2, 3) + num_samples = 10 + input_dim = 50 + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + + x1 = array_ops.ones((num_samples, input_dim)) + x2 = array_ops.ones((num_samples, input_dim)) + y1 = array_ops.zeros((num_samples, num_classes[0])) + y2 = array_ops.zeros((num_samples, num_classes[1])) + + model.fit([x1, x2], [y1, y2], epochs=2, steps_per_epoch=10, verbose=0) + _ = model.evaluate(steps=10, verbose=0) + + def test_multi_io_workflow_with_numpy_arrays_and_custom_placeholders(self): + + num_classes = (2, 3) + num_samples = 1000 + input_dim = 50 + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + x2_placeholder = array_ops.placeholder( + dtype='float32', shape=(None, input_dim)) + model._set_inputs([x1, x2_placeholder]) + + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + _ = model.evaluate([x1, x2], [y1, y2], verbose=0) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def test_attributes(self): + # layers, weights, trainable_weights, non_trainable_weights, inputs, outputs + + num_classes = (2, 3) + num_samples = 100 + input_dim = 50 + + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + self.assertEqual(model.name, 'test_model') + self.assertEqual(model.built, False) + self.assertEqual(len(model.weights), 0) + + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.train_on_batch([x1, x2], [y1, y2]) + + self.assertEqual(model.built, True) + self.assertEqual(len(model.layers), 4) + self.assertEqual(len(model.weights), 10) + self.assertEqual(len(model.trainable_weights), 8) + self.assertEqual(len(model.non_trainable_weights), 2) + self.assertEqual(len(model.inputs), 2) + self.assertEqual(len(model.outputs), 2) + + @test_util.run_in_graph_and_eager_modes() + def test_updates(self): + # test that updates get run during training + num_samples = 100 + input_dim = 50 + + class BNNet(keras.Model): + + def __init__(self): + super(BNNet, self).__init__() + self.bn = keras.layers.BatchNormalization(beta_initializer='ones', + gamma_initializer='ones') + + def call(self, inputs): + return self.bn(inputs) + + x = np.ones((num_samples, input_dim)) + y = np.ones((num_samples, input_dim)) + + with self.test_session(): + model = BNNet() + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + y_ref = model.predict(x) + + model.train_on_batch(x, y) + y_new = model.predict(x) + self.assertGreater(np.sum(np.abs(y_ref - y_new)), 0.1) + + @test_util.run_in_graph_and_eager_modes() + def test_training_and_inference_behavior(self): + # test that dropout is applied in training and not inference + + num_samples = 100 + input_dim = 50 + + class DPNet(keras.Model): + + def __init__(self): + super(DPNet, self).__init__() + self.dp = keras.layers.Dropout(0.5) + self.dense = keras.layers.Dense(1, + use_bias=False, + kernel_initializer='ones') + + def call(self, inputs): + x = self.dp(inputs) + return self.dense(x) + + with self.test_session(): + model = DPNet() + x = np.ones((num_samples, input_dim)) + y = model.predict(x) + self.assertEqual(np.sum(y), np.sum(x)) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + loss = model.train_on_batch(x, y) + self.assertGreater(loss, 0.1) + + @test_util.run_in_graph_and_eager_modes() + def test_training_methods(self): + # test fit, train_on_batch + # on different input types: list, dict + + num_classes = (2, 3) + num_samples = 100 + input_dim = 50 + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + model.fit({'input_1': x1, 'input_2': x2}, + {'output_1': y1, 'output_2': y2}, + epochs=2, batch_size=32) + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0, + validation_data=([x1, x2], [y1, y2])) + + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.train_on_batch([x1, x2], [y1, y2]) + model.train_on_batch({'input_1': x1, 'input_2': x2}, + {'output_1': y1, 'output_2': y2}) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def test_inference_methods(self): + # test predict, evaluate, test_on_batch, predict_on_batch + # on different input types: list, dict + num_classes = (2, 3) + num_samples = 100 + input_dim = 50 + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.evaluate([x1, x2], [y1, y2]) + model.test_on_batch([x1, x2], [y1, y2]) + + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.predict([x1, x2]) + + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.predict_on_batch([x1, x2]) + + @test_util.run_in_graph_and_eager_modes() + def test_trainable_mutation(self): + # test that you can change `trainable` on a model or layer, and that + # it freezes the model state during training + # TODO(fchollet): add test after we unify BN behavior in eager and symbolic. + pass + + @test_util.run_in_graph_and_eager_modes() + def test_saving(self): + if h5py is None: + return # Skip test if models cannot be saved. + + num_classes = (2, 3) + num_samples = 100 + input_dim = 50 + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + y_ref_1, y_ref_2 = model.predict([x1, x2]) + + fd, fname = tempfile.mkstemp('.h5') + model.save_weights(fname) + + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + # need to build the model before loading weights + # (otherwise no weights to load) + model._set_inputs([x1, x2]) + model.load_weights(fname) + + y1, y2 = model.predict([x1, x2]) + self.assertAllClose(y_ref_1, y1, atol=1e-5) + self.assertAllClose(y_ref_2, y2, atol=1e-5) + os.close(fd) + os.remove(fname) + + @test_util.run_in_graph_and_eager_modes() + def test_summary(self): + + class ToString(object): + + def __init__(self): + self.contents = '' + + def __call__(self, msg): + self.contents += msg + '\n' + + # Single-io + model = SimpleTestModel(num_classes=4, use_bn=True, use_dp=True) + model._set_inputs(np.ones((3, 4))) # need to build model first + print_fn = ToString() + model.summary(print_fn=print_fn) + self.assertTrue('Trainable params: 356' in print_fn.contents) + + # Multi-io + model = MultiIOTestModel(num_classes=(5, 6), use_bn=True, use_dp=True) + model._set_inputs([np.ones((3, 4)), + np.ones((3, 4))]) # need to build model first + print_fn = ToString() + model.summary(print_fn=print_fn) + self.assertTrue('Trainable params: 587' in print_fn.contents) + + @test_util.run_in_graph_and_eager_modes() + def test_subclass_nested_in_subclass(self): + num_classes = 2 + num_samples = 100 + input_dim = 50 + + with self.test_session(): + model = NestedTestModel1(num_classes=num_classes) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) + + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) + + self.assertEqual(len(model.weights), 8 + len(model.test_net.weights)) + self.assertEqual(len(model.non_trainable_weights), + 2 + len(model.test_net.non_trainable_weights)) + self.assertEqual(len(model.trainable_weights), + 6 + len(model.test_net.trainable_weights)) + + @test_util.run_in_graph_and_eager_modes() + def test_graph_nested_in_subclass(self): + num_classes = 2 + num_samples = 100 + input_dim = 50 + + with self.test_session(): + model = NestedTestModel2(num_classes=num_classes) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) + + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) + + self.assertEqual(len(model.weights), 8 + len(model.test_net.weights)) + self.assertEqual(len(model.non_trainable_weights), + 2 + len(model.test_net.non_trainable_weights)) + self.assertEqual(len(model.trainable_weights), + 6 + len(model.test_net.trainable_weights)) + + @test_util.run_in_graph_and_eager_modes() + def test_subclass_nested_in_graph(self): + num_classes = 2 + num_samples = 100 + input_dim = 50 + + with self.test_session(): + model = get_nested_model_3(input_dim=input_dim, num_classes=num_classes) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) + + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) + + self.assertEqual(len(model.weights), 16) + self.assertEqual( + len(model.non_trainable_weights), 4) + self.assertEqual(len(model.trainable_weights), 12) + + @test_util.run_in_graph_and_eager_modes() + def test_support_for_manual_training_arg(self): + # In most cases, the `training` argument is left unspecified, in which + # case it defaults to value corresponding to the Model method being used + # (fit -> True, predict -> False, etc). + # If the user writes their model `call` method to take + # an explicit `training` argument, we must check that the correct value + # is being passed to the model for each method call. + + class DPNet(keras.Model): + + def __init__(self): + super(DPNet, self).__init__() + self.dp = keras.layers.Dropout(0.5) + self.dense = keras.layers.Dense(1, + use_bias=False, + kernel_initializer='ones') + + def call(self, inputs, training=False): + x = self.dp(inputs, training=training) + return self.dense(x) + + with self.test_session(): + model = DPNet() + x = np.ones((10, 10)) + y = model.predict(x) + self.assertEqual(np.sum(y), np.sum(x)) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + loss = model.train_on_batch(x, y) + self.assertGreater(loss, 0.1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/models.py b/tensorflow/python/keras/_impl/keras/models.py index e262cc8c8e9d728c1e7f504ffaf543faa1f3db50..4c3ec7dbe458bfb78d38950b1bad7a474bb55ad3 100644 --- a/tensorflow/python/keras/_impl/keras/models.py +++ b/tensorflow/python/keras/_impl/keras/models.py @@ -38,6 +38,7 @@ from tensorflow.python.keras._impl.keras.engine.training import Model from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export # pylint: disable=g-import-not-at-top @@ -53,6 +54,7 @@ except ImportError: # pylint: enable=g-import-not-at-top +@tf_export('keras.models.save_model') def save_model(model, filepath, overwrite=True, include_optimizer=True): """Save a model to a HDF5 file. @@ -183,6 +185,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True): f.flush() +@tf_export('keras.models.load_model') def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin """Loads a model saved via `save_model`. @@ -302,6 +305,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable= return model +@tf_export('keras.models.model_from_config') def model_from_config(config, custom_objects=None): """Instantiates a Keras model from its config. @@ -324,6 +328,7 @@ def model_from_config(config, custom_objects=None): return layer_module.deserialize(config, custom_objects=custom_objects) +@tf_export('keras.models.model_from_yaml') def model_from_yaml(yaml_string, custom_objects=None): """Parses a yaml model configuration file and returns a model instance. @@ -345,6 +350,7 @@ def model_from_yaml(yaml_string, custom_objects=None): return layer_module.deserialize(config, custom_objects=custom_objects) +@tf_export('keras.models.model_from_json') def model_from_json(json_string, custom_objects=None): """Parses a JSON model configuration file and returns a model instance. @@ -361,6 +367,7 @@ def model_from_json(json_string, custom_objects=None): return layer_module.deserialize(config, custom_objects=custom_objects) +@tf_export('keras.models.Sequential', 'keras.Sequential') class Sequential(Model): """Linear stack of layers. @@ -399,7 +406,9 @@ class Sequential(Model): """ def __init__(self, layers=None, name=None): - self.layers = [] # Stack of layers. + self._is_graph_network = True + self._is_compiled = False + self._layers = [] # Stack of layers. self.model = None # Internal Model instance. self.inputs = [] # List of input tensors self.outputs = [] # List of length 1: the output tensor (unique). @@ -421,8 +430,6 @@ class Sequential(Model): # Used by Layer base class. self._dtype = None self._activity_regularizer = None - self._per_input_losses = {} - self._per_input_updates = {} # The following properties are not actually used by Keras; # they exist for compatibility with TF's variable scoping mechanism. @@ -492,13 +499,13 @@ class Sequential(Model): # to the input layer we just created. layer(x) - if len(layer.inbound_nodes[-1].output_tensors) != 1: + if len(layer._inbound_nodes[-1].output_tensors) != 1: raise ValueError('All layers in a Sequential model ' 'should have a single output tensor. ' 'For multi-output layers, ' 'use the functional API.') - self.outputs = [layer.inbound_nodes[-1].output_tensors[0]] + self.outputs = [layer._inbound_nodes[-1].output_tensors[0]] self.inputs = topology.get_source_inputs(self.outputs[0]) # We create an input node, which we will keep updated @@ -522,7 +529,7 @@ class Sequential(Model): self._inbound_nodes[0].output_tensors = self.outputs self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])] - self.layers.append(layer) + self._layers.append(layer) self.built = False def pop(self): @@ -636,34 +643,6 @@ class Sequential(Model): return trainable_weights + weights return weights - @property - def updates(self): - if not self.built: - self.build() - return self.model.updates - - @property - def state_updates(self): - if not self.built: - self.build() - return self.model.state_updates - - def get_updates_for(self, inputs): - if not self.built: - self.build() - return self.model.get_updates_for(inputs) - - @property - def losses(self): - if not self.built: - self.build() - return self.model.losses - - def get_losses_for(self, inputs): - if not self.built: - self.build() - return self.model.get_losses_for(inputs) - @property def regularizers(self): if not self.built: diff --git a/tensorflow/python/keras/_impl/keras/models_test.py b/tensorflow/python/keras/_impl/keras/models_test.py index edfc0ce0ebc0321589a452e7357c517feeb626cf..04017e4b28b27e52f88a7746fc44510c29edffce 100644 --- a/tensorflow/python/keras/_impl/keras/models_test.py +++ b/tensorflow/python/keras/_impl/keras/models_test.py @@ -340,6 +340,35 @@ class TestSequential(test.TestCase): inner_model.trainable = True self.assertEqual(len(model.trainable_weights), 4) + def test_sequential_update_disabling(self): + val_a = np.random.random((10, 4)) + val_out = np.random.random((10, 4)) + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.BatchNormalization(input_shape=(4,))) + + model.trainable = False + assert not model.updates + + model.compile('sgd', 'mse') + assert not model.updates + assert not model.model.updates + + x1 = model.predict(val_a) + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + self.assertAllClose(x1, x2, atol=1e-7) + + model.trainable = True + model.compile('sgd', 'mse') + assert model.updates + assert model.model.updates + + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + assert np.abs(np.sum(x1 - x2)) > 1e-5 + class TestModelCloning(test.TestCase): diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py index a08073fa86442e0564aa63052bb87b92dc64cdf6..76a97156ed7d9ca89b0d94f31bed3a23eca9609d 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers.py +++ b/tensorflow/python/keras/_impl/keras/optimizers.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras optimizer classes (will eventually be replaced with core optimizers). +# pylint: disable=invalid-name +"""Built-in optimizer classes. """ from __future__ import absolute_import from __future__ import division @@ -31,6 +32,7 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_ke from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer as tf_optimizer_module +from tensorflow.python.util.tf_export import tf_export def clip_norm(g, c, n): @@ -64,6 +66,7 @@ def clip_norm(g, c, n): return g +@tf_export('keras.optimizers.Optimizer') class Optimizer(object): """Abstract optimizer base class. @@ -121,9 +124,9 @@ class Optimizer(object): param_values = K.batch_get_value(params) for pv, p, w in zip(param_values, params, weights): if pv.shape != w.shape: - raise ValueError('Optimizer weight shape ' + str(pv.shape) + - ' not compatible with ' - 'provided weight shape ' + str(w.shape)) + raise ValueError( + 'Optimizer weight shape ' + str(pv.shape) + ' not compatible with ' + 'provided weight shape ' + str(w.shape)) weight_value_tuples.append((p, w)) K.batch_set_value(weight_value_tuples) @@ -148,6 +151,7 @@ class Optimizer(object): return cls(**config) +@tf_export('keras.optimizers.SGD') class SGD(Optimizer): """Stochastic gradient descent optimizer. @@ -156,7 +160,8 @@ class SGD(Optimizer): Arguments: lr: float >= 0. Learning rate. - momentum: float >= 0. Parameter updates momentum. + momentum: float >= 0. Parameter that accelerates SGD + in the relevant direction and dampens oscillations. decay: float >= 0. Learning rate decay over each update. nesterov: boolean. Whether to apply Nesterov momentum. """ @@ -177,9 +182,9 @@ class SGD(Optimizer): lr = self.lr if self.initial_decay > 0: - lr *= (1. / (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) - + lr = lr * (1. / # pylint: disable=g-no-augmented-assignment + (1. + self.decay * K.cast(self.iterations, + K.dtype(self.decay)))) # momentum shapes = [K.int_shape(p) for p in params] moments = [K.zeros(shape) for shape in shapes] @@ -211,6 +216,7 @@ class SGD(Optimizer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.optimizers.RMSprop') class RMSprop(Optimizer): """RMSProp optimizer. @@ -224,32 +230,34 @@ class RMSprop(Optimizer): Arguments: lr: float >= 0. Learning rate. rho: float >= 0. - epsilon: float >= 0. Fuzz factor. + epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. decay: float >= 0. Learning rate decay over each update. + """ - def __init__(self, lr=0.001, rho=0.9, epsilon=1e-8, decay=0., **kwargs): + def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., **kwargs): super(RMSprop, self).__init__(**kwargs) with K.name_scope(self.__class__.__name__): self.lr = K.variable(lr, name='lr') self.rho = K.variable(rho, name='rho') self.decay = K.variable(decay, name='decay') self.iterations = K.variable(0, dtype='int64', name='iterations') + if epsilon is None: + epsilon = K.epsilon() self.epsilon = epsilon self.initial_decay = decay def get_updates(self, loss, params): grads = self.get_gradients(loss, params) - accumulators = [ - K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params - ] + accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] self.weights = accumulators self.updates = [K.update_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: - lr *= (1. / (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * (1. / # pylint: disable=g-no-augmented-assignment + (1. + self.decay * K.cast(self.iterations, + K.dtype(self.decay)))) for p, g, a in zip(params, grads, accumulators): # update accumulator @@ -275,6 +283,7 @@ class RMSprop(Optimizer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.optimizers.Adagrad') class Adagrad(Optimizer): """Adagrad optimizer. @@ -283,20 +292,19 @@ class Adagrad(Optimizer): Arguments: lr: float >= 0. Learning rate. - epsilon: float >= 0. + epsilon: float >= 0. If `None`, defaults to `K.epsilon()`. decay: float >= 0. Learning rate decay over each update. - References: - - [Adaptive Subgradient Methods for Online Learning and Stochastic - Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) """ - def __init__(self, lr=0.01, epsilon=1e-8, decay=0., **kwargs): + def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs): super(Adagrad, self).__init__(**kwargs) with K.name_scope(self.__class__.__name__): self.lr = K.variable(lr, name='lr') self.decay = K.variable(decay, name='decay') self.iterations = K.variable(0, dtype='int64', name='iterations') + if epsilon is None: + epsilon = K.epsilon() self.epsilon = epsilon self.initial_decay = decay @@ -309,8 +317,9 @@ class Adagrad(Optimizer): lr = self.lr if self.initial_decay > 0: - lr *= (1. / (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * (1. / # pylint: disable=g-no-augmented-assignment + (1. + self.decay * K.cast(self.iterations, + K.dtype(self.decay)))) for p, g, a in zip(params, grads, accumulators): new_a = a + K.square(g) # update accumulator @@ -334,6 +343,7 @@ class Adagrad(Optimizer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.optimizers.Adadelta') class Adadelta(Optimizer): """Adadelta optimizer. @@ -344,20 +354,19 @@ class Adadelta(Optimizer): lr: float >= 0. Learning rate. It is recommended to leave it at the default value. rho: float >= 0. - epsilon: float >= 0. Fuzz factor. + epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. decay: float >= 0. Learning rate decay over each update. - References: - - [Adadelta - an adaptive learning rate - method](http://arxiv.org/abs/1212.5701) """ - def __init__(self, lr=1.0, rho=0.95, epsilon=1e-8, decay=0., **kwargs): + def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., **kwargs): super(Adadelta, self).__init__(**kwargs) with K.name_scope(self.__class__.__name__): self.lr = K.variable(lr, name='lr') self.decay = K.variable(decay, name='decay') self.iterations = K.variable(0, dtype='int64', name='iterations') + if epsilon is None: + epsilon = K.epsilon() self.rho = rho self.epsilon = epsilon self.initial_decay = decay @@ -372,8 +381,9 @@ class Adadelta(Optimizer): lr = self.lr if self.initial_decay > 0: - lr *= (1. / (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * (1. / # pylint: disable=g-no-augmented-assignment + (1. + self.decay * K.cast(self.iterations, + K.dtype(self.decay)))) for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators): # update accumulator @@ -406,6 +416,7 @@ class Adadelta(Optimizer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.optimizers.Adam') class Adam(Optimizer): """Adam optimizer. @@ -415,20 +426,21 @@ class Adam(Optimizer): lr: float >= 0. Learning rate. beta_1: float, 0 < beta < 1. Generally close to 1. beta_2: float, 0 < beta < 1. Generally close to 1. - epsilon: float >= 0. Fuzz factor. + epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. decay: float >= 0. Learning rate decay over each update. + amsgrad: boolean. Whether to apply the AMSGrad variant of this + algorithm from the paper "On the Convergence of Adam and + Beyond". - References: - - [Adam - A Method for Stochastic - Optimization](http://arxiv.org/abs/1412.6980v8) """ def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, - epsilon=1e-8, + epsilon=None, decay=0., + amsgrad=False, **kwargs): super(Adam, self).__init__(**kwargs) with K.name_scope(self.__class__.__name__): @@ -437,8 +449,11 @@ class Adam(Optimizer): self.beta_1 = K.variable(beta_1, name='beta_1') self.beta_2 = K.variable(beta_2, name='beta_2') self.decay = K.variable(decay, name='decay') + if epsilon is None: + epsilon = K.epsilon() self.epsilon = epsilon self.initial_decay = decay + self.amsgrad = amsgrad def get_updates(self, loss, params): grads = self.get_gradients(loss, params) @@ -446,21 +461,31 @@ class Adam(Optimizer): lr = self.lr if self.initial_decay > 0: - lr *= (1. / (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * (1. / # pylint: disable=g-no-augmented-assignment + (1. + self.decay * K.cast(self.iterations, + K.dtype(self.decay)))) t = K.cast(self.iterations, K.floatx()) + 1 - lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / - (1. - K.pow(self.beta_1, t))) + lr_t = lr * ( + K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] - self.weights = [self.iterations] + ms + vs + if self.amsgrad: + vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] + else: + vhats = [K.zeros(1) for _ in params] + self.weights = [self.iterations] + ms + vs + vhats - for p, g, m, v in zip(params, grads, ms, vs): + for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) - p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) + if self.amsgrad: + vhat_t = K.maximum(vhat, v_t) + p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon) + self.updates.append(K.update(vhat, vhat_t)) + else: + p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) @@ -479,12 +504,14 @@ class Adam(Optimizer): 'beta_1': float(K.get_value(self.beta_1)), 'beta_2': float(K.get_value(self.beta_2)), 'decay': float(K.get_value(self.decay)), - 'epsilon': self.epsilon + 'epsilon': self.epsilon, + 'amsgrad': self.amsgrad } base_config = super(Adam, self).get_config() return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.optimizers.Adamax') class Adamax(Optimizer): """Adamax optimizer from Adam paper's Section 7. @@ -494,19 +521,16 @@ class Adamax(Optimizer): Arguments: lr: float >= 0. Learning rate. beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1. - epsilon: float >= 0. Fuzz factor. + epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. decay: float >= 0. Learning rate decay over each update. - References: - - [Adam - A Method for Stochastic - Optimization](http://arxiv.org/abs/1412.6980v8) """ def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999, - epsilon=1e-8, + epsilon=None, decay=0., **kwargs): super(Adamax, self).__init__(**kwargs) @@ -516,6 +540,8 @@ class Adamax(Optimizer): self.beta_1 = K.variable(beta_1, name='beta_1') self.beta_2 = K.variable(beta_2, name='beta_2') self.decay = K.variable(decay, name='decay') + if epsilon is None: + epsilon = K.epsilon() self.epsilon = epsilon self.initial_decay = decay @@ -525,8 +551,9 @@ class Adamax(Optimizer): lr = self.lr if self.initial_decay > 0: - lr *= (1. / (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + lr = lr * (1. / # pylint: disable=g-no-augmented-assignment + (1. + self.decay * K.cast(self.iterations, + K.dtype(self.decay)))) t = K.cast(self.iterations, K.floatx()) + 1 lr_t = lr / (1. - K.pow(self.beta_1, t)) @@ -567,6 +594,7 @@ class Adamax(Optimizer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.optimizers.Nadam') class Nadam(Optimizer): """Nesterov Adam optimizer. @@ -580,19 +608,15 @@ class Nadam(Optimizer): Arguments: lr: float >= 0. Learning rate. beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1. - epsilon: float >= 0. Fuzz factor. + epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. - References: - - [Nadam report](http://cs229.stanford.edu/proj2015/054_report.pdf) - - [On the importance of initialization and momentum in deep - learning](http://www.cs.toronto.edu/~fritz/absps/momentum.pdf) """ def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999, - epsilon=1e-8, + epsilon=None, schedule_decay=0.004, **kwargs): super(Nadam, self).__init__(**kwargs) @@ -602,12 +626,15 @@ class Nadam(Optimizer): self.lr = K.variable(lr, name='lr') self.beta_1 = K.variable(beta_1, name='beta_1') self.beta_2 = K.variable(beta_2, name='beta_2') + if epsilon is None: + epsilon = K.epsilon() self.epsilon = epsilon self.schedule_decay = schedule_decay def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] + t = K.cast(self.iterations, K.floatx()) + 1 # Due to the recommendations in [2], i.e. warming momentum schedule @@ -670,6 +697,12 @@ class TFOptimizer(Optimizer): with K.name_scope(self.__class__.__name__): self.iterations = K.variable(0, dtype='int64', name='iterations') + def apply_gradients(self, grads): + self.optimizer.apply_gradients(grads) + + def get_grads(self, loss, params): + return self.optimizer.compute_gradients(loss, params) + def get_updates(self, loss, params): grads = self.optimizer.compute_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] @@ -691,7 +724,6 @@ class TFOptimizer(Optimizer): # Aliases. -# pylint: disable=invalid-name sgd = SGD rmsprop = RMSprop adagrad = Adagrad @@ -700,13 +732,13 @@ adam = Adam adamax = Adamax nadam = Nadam -# pylint: enable=invalid-name - +@tf_export('keras.optimizers.serialize') def serialize(optimizer): return serialize_keras_object(optimizer) +@tf_export('keras.optimizers.deserialize') def deserialize(config, custom_objects=None): """Inverse of the `serialize` function. @@ -740,6 +772,7 @@ def deserialize(config, custom_objects=None): printable_module_name='optimizer') +@tf_export('keras.optimizers.get') def get(identifier): """Retrieves a Keras Optimizer instance. diff --git a/tensorflow/python/keras/_impl/keras/optimizers_test.py b/tensorflow/python/keras/_impl/keras/optimizers_test.py index 6e9e4e6c99a6ffb0684d20ca001bba98b0d799bc..57636afbf089f27c00cc56c46fdb3ea50f89cc6b 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers_test.py +++ b/tensorflow/python/keras/_impl/keras/optimizers_test.py @@ -102,6 +102,7 @@ class KerasOptimizersTest(test.TestCase): with self.test_session(): _test_optimizer(keras.optimizers.Adam()) _test_optimizer(keras.optimizers.Adam(decay=1e-3)) + _test_optimizer(keras.optimizers.Adam(amsgrad=True)) def test_adamax(self): with self.test_session(): diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image.py b/tensorflow/python/keras/_impl/keras/preprocessing/image.py index 82441de5925cac0d66af95202c613b3e5e9aeb79..d12f10863921ee7d635930f34e8bc701c89864e8 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/image.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=g-import-not-at-top """Fairly basic set of tools for real-time data augmentation on image data. Can easily be extended to include new transformations, @@ -28,25 +29,23 @@ import re import threading import numpy as np -from six.moves import range # pylint: disable=redefined-builtin - from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export - -# pylint: disable=g-import-not-at-top -try: - from PIL import Image as pil_image -except ImportError: - pil_image = None try: from scipy import linalg import scipy.ndimage as ndi except ImportError: linalg = None ndi = None -# pylint: enable=g-import-not-at-top + + +try: + from PIL import Image as pil_image +except ImportError: + pil_image = None if pil_image is not None: _PIL_INTERPOLATION_METHODS = { @@ -64,6 +63,7 @@ if pil_image is not None: _PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS +@tf_export('keras.preprocessing.image.random_rotation') def random_rotation(x, rg, row_axis=1, @@ -88,7 +88,7 @@ def random_rotation(x, Returns: Rotated Numpy image tensor. """ - theta = np.pi / 180 * np.random.uniform(-rg, rg) + theta = np.deg2rad(np.random.uniform(-rg, rg)) rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) @@ -98,6 +98,7 @@ def random_rotation(x, return x +@tf_export('keras.preprocessing.image.random_shift') def random_shift(x, wrg, hrg, @@ -134,6 +135,7 @@ def random_shift(x, return x +@tf_export('keras.preprocessing.image.random_shear') def random_shear(x, intensity, row_axis=1, @@ -145,7 +147,7 @@ def random_shear(x, Arguments: x: Input tensor. Must be 3D. - intensity: Transformation intensity. + intensity: Transformation intensity in degrees. row_axis: Index of axis for rows in the input tensor. col_axis: Index of axis for columns in the input tensor. channel_axis: Index of axis for channels in the input tensor. @@ -158,7 +160,7 @@ def random_shear(x, Returns: Sheared Numpy image tensor. """ - shear = np.random.uniform(-intensity, intensity) + shear = np.deg2rad(np.random.uniform(-intensity, intensity)) shear_matrix = np.array([[1, -np.sin(shear), 0], [0, np.cos(shear), 0], [0, 0, 1]]) @@ -168,6 +170,7 @@ def random_shear(x, return x +@tf_export('keras.preprocessing.image.random_zoom') def random_zoom(x, zoom_range, row_axis=1, @@ -188,8 +191,10 @@ def random_zoom(x, (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). cval: Value used for points outside the boundaries of the input if `mode='constant'`. + Returns: Zoomed Numpy image tensor. + Raises: ValueError: if `zoom_range` isn't a tuple. """ @@ -209,6 +214,7 @@ def random_zoom(x, return x +@tf_export('keras.preprocessing.image.random_channel_shift') def random_channel_shift(x, intensity, channel_axis=0): x = np.rollaxis(x, channel_axis, 0) min_x, max_x = np.min(x), np.max(x) @@ -230,6 +236,7 @@ def transform_matrix_offset_center(matrix, x, y): return transform_matrix +@tf_export('keras.preprocessing.image.apply_transform') def apply_transform(x, transform_matrix, channel_axis=0, @@ -267,6 +274,7 @@ def apply_transform(x, return x +@tf_export('keras.preprocessing.image.flip_axis') def flip_axis(x, axis): x = np.asarray(x).swapaxes(axis, 0) x = x[::-1, ...] @@ -274,6 +282,7 @@ def flip_axis(x, axis): return x +@tf_export('keras.preprocessing.image.array_to_img') def array_to_img(x, data_format=None, scale=True): """Converts a 3D Numpy array to a PIL Image instance. @@ -324,6 +333,7 @@ def array_to_img(x, data_format=None, scale=True): raise ValueError('Unsupported channel number: ', x.shape[2]) +@tf_export('keras.preprocessing.image.img_to_array') def img_to_array(img, data_format=None): """Converts a PIL Image instance to a Numpy array. @@ -358,6 +368,7 @@ def img_to_array(img, data_format=None): return x +@tf_export('keras.preprocessing.image.load_img') def load_img(path, grayscale=False, target_size=None, interpolation='nearest'): """Loads an image into PIL format. @@ -366,7 +377,7 @@ def load_img(path, grayscale=False, target_size=None, interpolation='nearest'): grayscale: Boolean, whether to load the image as grayscale. target_size: Either `None` (default to original size) or tuple of ints `(img_height, img_width)`. - interpolation: Interpolation method used to resample the image if the + interpolation: Interpolation method used to resample the image if the target size is different from that of the loaded image. Supported methods are "nearest", "bilinear", and "bicubic". If PIL version 1.1.3 or newer is installed, "lanczos" is also @@ -394,11 +405,9 @@ def load_img(path, grayscale=False, target_size=None, interpolation='nearest'): width_height_tuple = (target_size[1], target_size[0]) if img.size != width_height_tuple: if interpolation not in _PIL_INTERPOLATION_METHODS: - raise ValueError( - 'Invalid interpolation method {} specified. Supported ' - 'methods are {}'.format( - interpolation, - ', '.join(_PIL_INTERPOLATION_METHODS.keys()))) + raise ValueError('Invalid interpolation method {} specified. Supported ' + 'methods are {}'.format(interpolation, ', '.join( + _PIL_INTERPOLATION_METHODS.keys()))) resample = _PIL_INTERPOLATION_METHODS[interpolation] img = img.resize(width_height_tuple, resample) return img @@ -407,11 +416,13 @@ def load_img(path, grayscale=False, target_size=None, interpolation='nearest'): def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'): return [ os.path.join(root, f) - for root, _, files in os.walk(directory) for f in files + for root, _, files in os.walk(directory) + for f in files if re.match(r'([\w]+\.(?:' + ext + '))', f) ] +@tf_export('keras.preprocessing.image.ImageDataGenerator') class ImageDataGenerator(object): """Generate minibatches of image data with real-time data augmentation. @@ -423,9 +434,9 @@ class ImageDataGenerator(object): zca_whitening: apply ZCA whitening. zca_epsilon: epsilon for ZCA whitening. Default is 1e-6. rotation_range: degrees (0 to 180). - width_shift_range: fraction of total width. - height_shift_range: fraction of total height. - shear_range: shear intensity (shear angle in radians). + width_shift_range: fraction of total width, if < 1, or pixels if >= 1. + height_shift_range: fraction of total height, if < 1, or pixels if >= 1. + shear_range: shear intensity (shear angle in degrees). zoom_range: amount of zoom. if scalar z, zoom will be randomly picked in the range [1-z, 1+z]. A sequence of two can be passed instead to select this range. @@ -433,6 +444,12 @@ class ImageDataGenerator(object): fill_mode: points outside the boundaries are filled according to the given mode ('constant', 'nearest', 'reflect' or 'wrap'). Default is 'nearest'. + Points outside the boundaries of the input are filled according to the + given mode: + 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k) + 'nearest': aaaaaaaa|abcd|dddddddd + 'reflect': abcddcba|abcd|dcbaabcd + 'wrap': abcdabcd|abcd|abcdabcd cval: value used for points outside the boundaries when fill_mode is 'constant'. Default is 0. horizontal_flip: whether to randomly flip images horizontally. @@ -522,6 +539,32 @@ class ImageDataGenerator(object): raise ValueError('`zoom_range` should be a float or ' 'a tuple or list of two floats. ' 'Received arg: ', zoom_range) + if zca_whitening: + if not featurewise_center: + self.featurewise_center = True + logging.warning('This ImageDataGenerator specifies ' + '`zca_whitening`, which overrides ' + 'setting of `featurewise_center`.') + if featurewise_std_normalization: + self.featurewise_std_normalization = False + logging.warning('This ImageDataGenerator specifies ' + '`zca_whitening` ' + 'which overrides setting of' + '`featurewise_std_normalization`.') + if featurewise_std_normalization: + if not featurewise_center: + self.featurewise_center = True + logging.warning('This ImageDataGenerator specifies ' + '`featurewise_std_normalization`, ' + 'which overrides setting of ' + '`featurewise_center`.') + if samplewise_std_normalization: + if not samplewise_center: + self.samplewise_center = True + logging.warning('This ImageDataGenerator specifies ' + '`samplewise_std_normalization`, ' + 'which overrides setting of ' + '`samplewise_center`.') def flow(self, x, @@ -591,7 +634,7 @@ class ImageDataGenerator(object): if self.samplewise_center: x -= np.mean(x, keepdims=True) if self.samplewise_std_normalization: - x /= np.std(x, keepdims=True) + 1e-7 + x /= (np.std(x, keepdims=True) + K.epsilon()) if self.featurewise_center: if self.mean is not None: @@ -603,7 +646,7 @@ class ImageDataGenerator(object): 'first by calling `.fit(numpy_data)`.') if self.featurewise_std_normalization: if self.std is not None: - x /= (self.std + 1e-7) + x /= (self.std + K.epsilon()) else: logging.warning('This ImageDataGenerator specifies ' '`featurewise_std_normalization`, but it hasn\'t ' @@ -636,7 +679,6 @@ class ImageDataGenerator(object): """ if ndi is None: raise ImportError('Scipy is required for image transformations.') - # x is a single image, so it doesn't have image number at index 0 img_row_axis = self.row_axis - 1 img_col_axis = self.col_axis - 1 @@ -648,25 +690,27 @@ class ImageDataGenerator(object): # use composition of homographies # to generate final transform that needs to be applied if self.rotation_range: - theta = np.pi / 180 * np.random.uniform(-self.rotation_range, - self.rotation_range) + theta = np.deg2rad( + np.random.uniform(-self.rotation_range, self.rotation_range)) else: theta = 0 if self.height_shift_range: - tx = np.random.uniform(-self.height_shift_range, - self.height_shift_range) * x.shape[img_row_axis] + tx = np.random.uniform(-self.height_shift_range, self.height_shift_range) + if self.height_shift_range < 1: + tx *= x.shape[img_row_axis] else: tx = 0 if self.width_shift_range: - ty = np.random.uniform(-self.width_shift_range, - self.width_shift_range) * x.shape[img_col_axis] + ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) + if self.width_shift_range < 1: + ty *= x.shape[img_col_axis] else: ty = 0 if self.shear_range: - shear = np.random.uniform(-self.shear_range, self.shear_range) + shear = np.deg2rad(np.random.uniform(-self.shear_range, self.shear_range)) else: shear = 0 @@ -744,7 +788,7 @@ class ImageDataGenerator(object): if x.ndim != 4: raise ValueError('Input to `.fit()` should have rank 4. ' 'Got array with shape: ' + str(x.shape)) - if x.shape[self.channel_axis] not in {3, 4}: + if x.shape[self.channel_axis] not in {1, 3, 4}: logging.warning( 'Expected input to be images (as Numpy array) ' 'following the data format convention "' + self.data_format + '" ' @@ -784,12 +828,15 @@ class ImageDataGenerator(object): raise ImportError('Scipy is required for zca_whitening.') flat_x = np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) - sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0] - u, s, _ = linalg.svd(sigma) - self.principal_components = np.dot( - np.dot(u, np.diag(1. / np.sqrt(s + self.zca_epsilon))), u.T) + num_examples = flat_x.shape[0] + _, s, vt = linalg.svd(flat_x / np.sqrt(num_examples)) + s_expand = np.hstack( + (s, np.zeros(vt.shape[0] - num_examples, dtype=flat_x.dtype))) + self.principal_components = ( + vt.T / np.sqrt(s_expand**2 + self.zca_epsilon)).dot(vt) +@tf_export('keras.preprocessing.image.Iterator') class Iterator(Sequence): """Base class for image data iterators. @@ -797,10 +844,10 @@ class Iterator(Sequence): method. Arguments: - n: Integer, total number of samples in the dataset to loop over. - batch_size: Integer, size of a batch. - shuffle: Boolean, whether to shuffle the data between epochs. - seed: Random seeding for data shuffling. + n: Integer, total number of samples in the dataset to loop over. + batch_size: Integer, size of a batch. + shuffle: Boolean, whether to shuffle the data between epochs. + seed: Random seeding for data shuffling. """ def __init__(self, n, batch_size, shuffle, seed): @@ -823,15 +870,14 @@ class Iterator(Sequence): if idx >= len(self): raise ValueError('Asked to retrieve element {idx}, ' 'but the Sequence ' - 'has length {length}'.format(idx=idx, - length=len(self))) + 'has length {length}'.format(idx=idx, length=len(self))) if self.seed is not None: np.random.seed(self.seed + self.total_batches_seen) self.total_batches_seen += 1 if self.index_array is None: self._set_index_array() - index_array = self.index_array[self.batch_size * idx:self.batch_size * - (idx + 1)] + index_array = self.index_array[self.batch_size * idx:self.batch_size * ( + idx + 1)] return self._get_batches_of_transformed_samples(index_array) def __len__(self): @@ -873,12 +919,14 @@ class Iterator(Sequence): Arguments: index_array: array of sample indices to include in batch. + Returns: A batch of transformed samples. """ raise NotImplementedError +@tf_export('keras.preprocessing.image.NumpyArrayIterator') class NumpyArrayIterator(Iterator): """Iterator yielding data from a Numpy array. @@ -948,8 +996,8 @@ class NumpyArrayIterator(Iterator): seed) def _get_batches_of_transformed_samples(self, index_array): - batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]), - dtype=K.floatx()) + batch_x = np.zeros( + tuple([len(index_array)] + list(self.x.shape)[1:]), dtype=K.floatx()) for i, j in enumerate(index_array): x = self.x[j] x = self.image_data_generator.random_transform(x.astype(K.floatx())) @@ -959,7 +1007,9 @@ class NumpyArrayIterator(Iterator): for i, j in enumerate(index_array): img = array_to_img(batch_x[i], self.data_format, scale=True) fname = '{prefix}_{index}_{hash}.{format}'.format( - prefix=self.save_prefix, index=j, hash=np.random.randint(1e4), + prefix=self.save_prefix, + index=j, + hash=np.random.randint(1e4), format=self.save_format) img.save(os.path.join(self.save_to_dir, fname)) if self.y is None: @@ -984,10 +1034,11 @@ class NumpyArrayIterator(Iterator): def _count_valid_files_in_directory(directory, white_list_formats, follow_links): - """Count files with extension in `white_list_formats` in a directory. + """Count files with extension in `white_list_formats` contained in directory. Arguments: - directory: absolute path to the directory containing files to be counted + directory: absolute path to the directory + containing files to be counted white_list_formats: set of strings containing allowed extensions for the files to be counted. follow_links: boolean. @@ -1003,7 +1054,7 @@ def _count_valid_files_in_directory(directory, white_list_formats, samples = 0 for _, _, files in _recursive_list(directory): - for fname in sorted(files): + for fname in files: is_valid = False for extension in white_list_formats: if fname.lower().endswith('.' + extension): @@ -1043,7 +1094,7 @@ def _list_valid_filenames_in_directory(directory, white_list_formats, subdir = os.path.basename(directory) basedir = os.path.dirname(directory) for root, _, files in _recursive_list(directory): - for fname in files: + for fname in sorted(files): is_valid = False for extension in white_list_formats: if fname.lower().endswith('.' + extension): @@ -1057,6 +1108,7 @@ def _list_valid_filenames_in_directory(directory, white_list_formats, return classes, filenames +@tf_export('keras.preprocessing.image.DirectoryIterator') class DirectoryIterator(Iterator): """Iterator capable of reading images from a directory on disk. @@ -1167,8 +1219,8 @@ class DirectoryIterator(Iterator): white_list_formats=white_list_formats, follow_links=follow_links) self.samples = sum( - pool.map(function_partial, (os.path.join(directory, subdir) - for subdir in classes))) + pool.map(function_partial, + (os.path.join(directory, subdir) for subdir in classes))) print('Found %d images belonging to %d classes.' % (self.samples, self.num_classes)) @@ -1181,8 +1233,9 @@ class DirectoryIterator(Iterator): i = 0 for dirpath in (os.path.join(directory, subdir) for subdir in classes): results.append( - pool.apply_async(_list_valid_filenames_in_directory, ( - dirpath, white_list_formats, self.class_indices, follow_links))) + pool.apply_async( + _list_valid_filenames_in_directory, + (dirpath, white_list_formats, self.class_indices, follow_links))) for res in results: classes, filenames = res.get() self.classes[i:i + len(classes)] = classes @@ -1199,10 +1252,11 @@ class DirectoryIterator(Iterator): # build batch of image data for i, j in enumerate(index_array): fname = self.filenames[j] - img = load_img(os.path.join(self.directory, fname), - grayscale=grayscale, - target_size=self.target_size, - interpolation=self.interpolation) + img = load_img( + os.path.join(self.directory, fname), + grayscale=grayscale, + target_size=self.target_size, + interpolation=self.interpolation) x = img_to_array(img, data_format=self.data_format) x = self.image_data_generator.random_transform(x) x = self.image_data_generator.standardize(x) @@ -1212,7 +1266,9 @@ class DirectoryIterator(Iterator): for i, j in enumerate(index_array): img = array_to_img(batch_x[i], self.data_format, scale=True) fname = '{prefix}_{index}_{hash}.{format}'.format( - prefix=self.save_prefix, index=j, hash=np.random.randint(1e7), + prefix=self.save_prefix, + index=j, + hash=np.random.randint(1e7), format=self.save_format) img.save(os.path.join(self.save_to_dir, fname)) # build batch of labels @@ -1241,4 +1297,3 @@ class DirectoryIterator(Iterator): # The transformation of images is not under thread lock # so it can be done in parallel return self._get_batches_of_transformed_samples(index_array) - diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py index 642f4f2face5bd56cdc1ed7b4f6d6621c6d1b210..a423d96d3d8578df347b7ee36fb53dfd335e0d65 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Preprocessing utilities for sequence data. +"""Utilities for preprocessing sequence data. """ from __future__ import absolute_import from __future__ import division @@ -22,8 +22,10 @@ import random import numpy as np from six.moves import range # pylint: disable=redefined-builtin +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.preprocessing.sequence.pad_sequences') def pad_sequences(sequences, maxlen=None, dtype='int32', @@ -104,6 +106,7 @@ def pad_sequences(sequences, return x +@tf_export('keras.preprocessing.sequence.make_sampling_table') def make_sampling_table(size, sampling_factor=1e-5): """Generates a word rank-based probabilistic sampling table. @@ -129,7 +132,7 @@ def make_sampling_table(size, sampling_factor=1e-5): is the probability that a word of rank i should be sampled. """ gamma = 0.577 - rank = np.array(list(range(size))) + rank = np.arange(size) rank[0] = 1 inv_fq = rank * (np.log(rank) + gamma) + 0.5 - 1. / (12. * rank) f = sampling_factor * inv_fq @@ -137,6 +140,7 @@ def make_sampling_table(size, sampling_factor=1e-5): return np.minimum(1., f / np.sqrt(f)) +@tf_export('keras.preprocessing.sequence.skipgrams') def skipgrams(sequence, vocabulary_size, window_size=4, @@ -170,7 +174,7 @@ def skipgrams(sequence, if True labels will be categorical eg. [[1,0],[0,1],[0,1] .. ] sampling_table: 1D array of size `vocabulary_size` where the entry i encodes the probability to sample a word of rank i. - seed: Random seed. + seed: random seed. Returns: couples, labels: where `couples` are int pairs and @@ -224,3 +228,22 @@ def skipgrams(sequence, random.shuffle(labels) return couples, labels + + +def _remove_long_seq(maxlen, seq, label): + """Removes sequences that exceed the maximum length. + + Arguments: + maxlen: int, maximum length + seq: list of lists where each sublist is a sequence + label: list where each element is an integer + + Returns: + new_seq, new_label: shortened lists for `seq` and `label`. + """ + new_seq, new_label = [], [] + for x, y in zip(seq, label): + if len(x) < maxlen: + new_seq.append(x) + new_label.append(y) + return new_seq, new_label diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text.py b/tensorflow/python/keras/_impl/keras/preprocessing/text.py index 47e5aa064fd806196fc9457fc90bc1a26e55ebf3..1e3828ccf1e3bf9c443691e1c1da5697bedb4653 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/text.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/text.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """Utilities for text input preprocessing. - -May benefit from a fast Cython rewrite. """ from __future__ import absolute_import from __future__ import division @@ -29,12 +27,17 @@ import numpy as np from six.moves import range # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export + + if sys.version_info < (3,): maketrans = string.maketrans else: maketrans = str.maketrans +@tf_export('keras.preprocessing.text.text_to_word_sequence') def text_to_word_sequence(text, filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', lower=True, @@ -63,11 +66,27 @@ def text_to_word_sequence(text, return [i for i in seq if i] +@tf_export('keras.preprocessing.text.one_hot') def one_hot(text, n, filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', lower=True, split=' '): + """One-hot encodes a text into a list of word indexes of size n. + + This is a wrapper to the `hashing_trick` function using `hash` as the + hashing function; unicity of word to index mapping non-guaranteed. + + Arguments: + text: Input text (string). + n: Dimension of the hashing space. + filters: Sequence of characters to filter out. + lower: Whether to convert the input to lowercase. + split: Sentence split marker (string). + + Returns: + A list of integer word indices (unicity non-guaranteed). + """ return hashing_trick( text, n, hash_function=hash, filters=filters, lower=lower, split=split) @@ -99,6 +118,10 @@ def hashing_trick(text, Two or more words may be assigned to the same index, due to possible collisions by the hashing function. + The + probability + of a collision is in relation to the dimension of the hashing space and + the number of distinct objects. """ if hash_function is None: hash_function = hash @@ -109,6 +132,7 @@ def hashing_trick(text, return [(hash_function(w) % (n - 1) + 1) for w in seq] +@tf_export('keras.preprocessing.text.Tokenizer') class Tokenizer(object): """Text tokenization utility class. @@ -127,6 +151,8 @@ class Tokenizer(object): lower: boolean. Whether to convert the texts to lowercase. split: character or string to use for token splitting. char_level: if True, every character will be treated as a token. + oov_token: if given, it will be added to word_index and used to + replace out-of-vocabulary words during text_to_sequence calls By default, all punctuation is removed, turning the texts into space-separated sequences of words @@ -141,7 +167,17 @@ class Tokenizer(object): filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', lower=True, split=' ', - char_level=False): + char_level=False, + oov_token=None, + **kwargs): + # Legacy support + if 'nb_words' in kwargs: + logging.warning('The `nb_words` argument in `Tokenizer` ' + 'has been renamed `num_words`.') + num_words = kwargs.pop('nb_words') + if kwargs: + raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) + self.word_counts = OrderedDict() self.word_docs = {} self.filters = filters @@ -150,6 +186,7 @@ class Tokenizer(object): self.num_words = num_words self.document_count = 0 self.char_level = char_level + self.oov_token = oov_token def fit_on_texts(self, texts): """Updates internal vocabulary based on a list of texts. @@ -181,7 +218,13 @@ class Tokenizer(object): sorted_voc = [wc[0] for wc in wcounts] # note that index 0 is reserved, never assigned to an existing word self.word_index = dict( - list(zip(sorted_voc, list(range(1, len(sorted_voc) + 1))))) + list(zip(sorted_voc, list(range(1, + len(sorted_voc) + 1))))) + + if self.oov_token is not None: + i = self.word_index.get(self.oov_token) + if i is None: + self.word_index[self.oov_token] = len(self.word_index) + 1 self.index_docs = {} for w, c in list(self.word_docs.items()): @@ -248,6 +291,10 @@ class Tokenizer(object): continue else: vect.append(i) + elif self.oov_token is not None: + i = self.word_index.get(self.oov_token) + if i is not None: + vect.append(i) yield vect def texts_to_matrix(self, texts, mode='binary'): diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py index 17ab48ba3fc9dfd553f8f425579c0a37ff42eb84..a934e331c4a14d9bd170258b6b6183e6a15bb561 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py @@ -76,6 +76,22 @@ class TestText(test.TestCase): self.assertLessEqual(np.max(encoded), 4) self.assertGreaterEqual(np.min(encoded), 1) + def test_tokenizer_oov_flag(self): + x_train = ['This text has only known words'] + x_test = ['This text has some unknown words'] # 2 OOVs: some, unknown + + # Defalut, without OOV flag + tokenizer = keras.preprocessing.text.Tokenizer() + tokenizer.fit_on_texts(x_train) + x_test_seq = tokenizer.texts_to_sequences(x_test) + assert len(x_test_seq[0]) == 4 # discards 2 OOVs + + # With OOV feature + tokenizer = keras.preprocessing.text.Tokenizer(oov_token='') + tokenizer.fit_on_texts(x_train) + x_test_seq = tokenizer.texts_to_sequences(x_test) + assert len(x_test_seq[0]) == 6 # OOVs marked in place + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/regularizers.py b/tensorflow/python/keras/_impl/keras/regularizers.py index 161ff9bf5bf12b3521fe444f1d68bd62b6e8c71d..2c30844647acdb78d1ca31d052ec7e5ecc6dcc2a 100644 --- a/tensorflow/python/keras/_impl/keras/regularizers.py +++ b/tensorflow/python/keras/_impl/keras/regularizers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras built-in regularizers. +"""Built-in regularizers. """ from __future__ import absolute_import from __future__ import division @@ -23,8 +23,10 @@ import six from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.util.tf_export import tf_export +@tf_export('keras.regularizers.Regularizer') class Regularizer(object): """Regularizer base class. """ @@ -37,6 +39,7 @@ class Regularizer(object): return cls(**config) +@tf_export('keras.regularizers.L1L2') class L1L2(Regularizer): """Regularizer for L1 and L2 regularization. @@ -64,22 +67,27 @@ class L1L2(Regularizer): # Aliases. +@tf_export('keras.regularizers.l1') def l1(l=0.01): return L1L2(l1=l) +@tf_export('keras.regularizers.l2') def l2(l=0.01): return L1L2(l2=l) +@tf_export('keras.regularizers.l1_l2') def l1_l2(l1=0.01, l2=0.01): # pylint: disable=redefined-outer-name return L1L2(l1=l1, l2=l2) +@tf_export('keras.regularizers.serialize') def serialize(regularizer): return serialize_keras_object(regularizer) +@tf_export('keras.regularizers.deserialize') def deserialize(config, custom_objects=None): return deserialize_keras_object( config, @@ -88,6 +96,7 @@ def deserialize(config, custom_objects=None): printable_module_name='regularizer') +@tf_export('keras.regularizers.get') def get(identifier): if identifier is None: return None diff --git a/tensorflow/python/keras/_impl/keras/testing_utils.py b/tensorflow/python/keras/_impl/keras/testing_utils.py index b889e311b37d48732641205a90ca83af34ea4489..fa1ee2fa3da3fbc7650ee80960b00907013cc37c 100644 --- a/tensorflow/python/keras/_impl/keras/testing_utils.py +++ b/tensorflow/python/keras/_impl/keras/testing_utils.py @@ -105,8 +105,14 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, # test in functional API x = keras.layers.Input(shape=input_shape[1:], dtype=input_dtype) y = layer(x) - assert keras.backend.dtype(y) == expected_output_dtype - + if keras.backend.dtype(y) != expected_output_dtype: + raise AssertionError('When testing layer %s, for input %s, found output ' + 'dtype=%s but expected to find %s.\nFull kwargs: %s' % + (layer_cls.__name__, + x, + keras.backend.dtype(y), + expected_output_dtype, + kwargs)) # check shape inference model = keras.models.Model(x, y) expected_output_shape = tuple( @@ -117,7 +123,15 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, for expected_dim, actual_dim in zip(expected_output_shape, actual_output_shape): if expected_dim is not None: - assert expected_dim == actual_dim + if expected_dim != actual_dim: + raise AssertionError( + 'When testing layer %s, for input %s, found output_shape=' + '%s but expected to find %s.\nFull kwargs: %s' % + (layer_cls.__name__, + x, + actual_output_shape, + expected_output_shape, + kwargs)) if expected_output is not None: np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3) @@ -146,7 +160,15 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, for expected_dim, actual_dim in zip(expected_output_shape, actual_output_shape): if expected_dim is not None: - assert expected_dim == actual_dim + if expected_dim != actual_dim: + raise AssertionError( + 'When testing layer %s, for input %s, found output_shape=' + '%s but expected to find %s.\nFull kwargs: %s' % + (layer_cls.__name__, + x, + actual_output_shape, + expected_output_shape, + kwargs)) if expected_output is not None: np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3) diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils.py b/tensorflow/python/keras/_impl/keras/utils/data_utils.py index d9e8f37e36cff0723c02820e16cc502bb0aea294..e87c8f48ef0967d561db1ab841a669d783f9b1ec 100644 --- a/tensorflow/python/keras/_impl/keras/utils/data_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/data_utils.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=g-import-not-at-top """Utilities for file download and caching.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from abc import abstractmethod +from contextlib import closing import hashlib import multiprocessing from multiprocessing.pool import ThreadPool @@ -40,10 +42,11 @@ from six.moves.urllib.request import urlopen from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar from tensorflow.python.util.tf_export import tf_export + try: - import queue # pylint:disable=g-import-not-at-top + import queue except ImportError: - import Queue as queue # pylint:disable=g-import-not-at-top + import Queue as queue if sys.version_info[0] == 2: @@ -87,7 +90,7 @@ if sys.version_info[0] == 2: for chunk in chunk_read(response, reporthook=reporthook): fd.write(chunk) else: - from six.moves.urllib.request import urlretrieve # pylint: disable=g-import-not-at-top + from six.moves.urllib.request import urlretrieve def _extract_archive(file_path, path='.', archive_format='auto'): @@ -188,7 +191,7 @@ def get_file(fname, Path to the downloaded file """ if cache_dir is None: - cache_dir = os.path.expanduser(os.path.join('~', '.keras')) + cache_dir = os.path.join(os.path.expanduser('~'), '.keras') if md5_hash is not None and file_hash is None: file_hash = md5_hash hash_algorithm = 'md5' @@ -323,31 +326,41 @@ class Sequence(object): Every `Sequence` must implements the `__getitem__` and the `__len__` methods. If you want to modify your dataset between epochs you may implement - `on_epoch_end`. The method `__getitem__` should return a complete batch. + `on_epoch_end`. + The method `__getitem__` should return a complete batch. + + # Notes - Notes: `Sequence` are a safer way to do multiprocessing. This structure guarantees - that the network will only train once on each sample per epoch which is not - the case with generators. + that the network will only train once + on each sample per epoch which is not the case with generators. + Examples: + ```python from skimage.io import imread from skimage.transform import resize import numpy as np import math + # Here, `x_set` is list of path to the images # and `y_set` are the associated classes. + class CIFAR10Sequence(Sequence): + def __init__(self, x_set, y_set, batch_size): self.x, self.y = x_set, y_set self.batch_size = batch_size + def __len__(self): return math.ceil(len(self.x) / self.batch_size) + def __getitem__(self, idx): batch_x = self.x[idx * self.batch_size:(idx + 1) * - self.batch_size] + self.batch_size] batch_y = self.y[idx * self.batch_size:(idx + 1) * - self.batch_size] + self.batch_size] + return np.array([ resize(imread(file_name), (200, 200)) for file_name in batch_x]), np.array(batch_y) @@ -375,7 +388,6 @@ class Sequence(object): """ raise NotImplementedError - @abstractmethod def on_epoch_end(self): """Method called at the end of every epoch. """ @@ -474,35 +486,36 @@ class OrderedEnqueuer(SequenceEnqueuer): Arguments: sequence: A `keras.utils.data_utils.Sequence` object. - use_multiprocessing: Use multiprocessing if True, otherwise threading - shuffle: Whether to shuffle the data at the beginning of each epoch + use_multiprocessing: use multiprocessing if True, otherwise threading + shuffle: whether to shuffle the data at the beginning of each epoch """ def __init__(self, sequence, use_multiprocessing=False, shuffle=False): self.sequence = sequence self.use_multiprocessing = use_multiprocessing - # Doing Multiprocessing.Value += x is not process-safe. global _SEQUENCE_COUNTER if _SEQUENCE_COUNTER is None: - if self.use_multiprocessing: + try: _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) - else: + except OSError: + # In this case the OS does not allow us to use + # multiprocessing. We resort to an int + # for enqueuer indexing. _SEQUENCE_COUNTER = 0 - if self.use_multiprocessing: + if isinstance(_SEQUENCE_COUNTER, int): + self.uid = _SEQUENCE_COUNTER + _SEQUENCE_COUNTER += 1 + else: + # Doing Multiprocessing.Value += x is not process-safe. with _SEQUENCE_COUNTER.get_lock(): self.uid = _SEQUENCE_COUNTER.value _SEQUENCE_COUNTER.value += 1 - else: - self.uid = _SEQUENCE_COUNTER - if isinstance(_SEQUENCE_COUNTER, int): - _SEQUENCE_COUNTER += 1 - else: - _SEQUENCE_COUNTER.value += 1 + self.shuffle = shuffle self.workers = 0 - self.executor = None + self.executor_fn = None self.queue = None self.run_thread = None self.stop_signal = None @@ -519,9 +532,9 @@ class OrderedEnqueuer(SequenceEnqueuer): (when full, workers could block on `put()`) """ if self.use_multiprocessing: - self.executor = multiprocessing.Pool(workers) + self.executor_fn = lambda: multiprocessing.Pool(workers) else: - self.executor = ThreadPool(workers) + self.executor_fn = lambda: ThreadPool(workers) self.workers = workers self.queue = queue.Queue(max_queue_size) self.stop_signal = threading.Event() @@ -537,24 +550,26 @@ class OrderedEnqueuer(SequenceEnqueuer): return def _run(self): - """Function to submit request to the executor & queue `Future` objects.""" + """Submits request to the executor and queue the `Future` objects.""" sequence = list(range(len(self.sequence))) self._send_sequence() # Share the initial sequence while True: if self.shuffle: random.shuffle(sequence) - for i in sequence: - if self.stop_signal.is_set(): - return - self.queue.put( - self.executor.apply_async(get_index, (self.uid, i)), block=True) - # Done with the current epoch, waiting for the final batches - self._wait_queue() + with closing(self.executor_fn()) as executor: + for i in sequence: + if self.stop_signal.is_set(): + return + self.queue.put( + executor.apply_async(get_index, (self.uid, i)), block=True) - if self.stop_signal.is_set(): - # We're done - return + # Done with the current epoch, waiting for the final batches + self._wait_queue() + + if self.stop_signal.is_set(): + # We're done + return # Call the internal on epoch end. self.sequence.on_epoch_end() @@ -566,8 +581,9 @@ class OrderedEnqueuer(SequenceEnqueuer): Skip the data if it is `None`. Yields: - Tuples (inputs, targets) - or (inputs, targets, sample_weights) + The next element in the queue, i.e. a tuple + `(inputs, targets)` or + `(inputs, targets, sample_weights)`. """ try: while self.is_running(): @@ -581,14 +597,8 @@ class OrderedEnqueuer(SequenceEnqueuer): def _send_sequence(self): """Send current Sequence to all workers.""" - _SHARED_SEQUENCES[ - self.uid] = self.sequence # For new processes that may spawn - - self._close_pool() - if self.use_multiprocessing: - self.executor = multiprocessing.Pool(self.workers) - else: - self.executor = ThreadPool(self.workers) + # For new processes that may spawn + _SHARED_SEQUENCES[self.uid] = self.sequence def stop(self, timeout=None): """Stops running threads and wait for them to exit, if necessary. @@ -603,14 +613,9 @@ class OrderedEnqueuer(SequenceEnqueuer): self.queue.queue.clear() self.queue.unfinished_tasks = 0 self.queue.not_full.notify() - self._close_pool() self.run_thread.join(timeout) _SHARED_SEQUENCES[self.uid] = None - def _close_pool(self): - self.executor.close() - self.executor.join() - @tf_export('keras.utils.GeneratorEnqueuer') class GeneratorEnqueuer(SequenceEnqueuer): @@ -636,26 +641,53 @@ class GeneratorEnqueuer(SequenceEnqueuer): seed=None): self.wait_time = wait_time self._generator = generator - self._use_multiprocessing = use_multiprocessing + if os.name is 'nt' and use_multiprocessing is True: + # On Windows, avoid **SYSTEMATIC** error in `multiprocessing`: + # `TypeError: can't pickle generator objects` + # => Suggest multithreading instead of multiprocessing on Windows + raise ValueError('Using a generator with `use_multiprocessing=True`' + ' is not supported on Windows (no marshalling of' + ' generators across process boundaries). Instead,' + ' use single thread/process or multithreading.') + else: + self._use_multiprocessing = use_multiprocessing self._threads = [] self._stop_event = None self._manager = None self.queue = None self.seed = seed - def start(self, workers=1, max_queue_size=10): - """Kicks off threads which add data from the generator into the queue. - - Arguments: - workers: number of worker threads - max_queue_size: queue size - (when full, threads could block on `put()`) - """ - - def data_generator_task(): + def _data_generator_task(self): + if self._use_multiprocessing is False: + while not self._stop_event.is_set(): + with self.genlock: + try: + if (self.queue is not None and + self.queue.qsize() < self.max_queue_size): + # On all OSes, avoid **SYSTEMATIC** error + # in multithreading mode: + # `ValueError: generator already executing` + # => Serialize calls to + # infinite iterator/generator's next() function + generator_output = next(self._generator) + self.queue.put((True, generator_output)) + else: + time.sleep(self.wait_time) + except StopIteration: + break + except Exception as e: # pylint: disable=broad-except + # Can't pickle tracebacks. + # As a compromise, print the traceback and pickle None instead. + if not hasattr(e, '__traceback__'): + setattr(e, '__traceback__', sys.exc_info()[2]) + self.queue.put((False, e)) + self._stop_event.set() + break + else: while not self._stop_event.is_set(): try: - if self._use_multiprocessing or self.queue.qsize() < max_queue_size: + if (self.queue is not None and + self.queue.qsize() < self.max_queue_size): generator_output = next(self._generator) self.queue.put((True, generator_output)) else: @@ -663,24 +695,34 @@ class GeneratorEnqueuer(SequenceEnqueuer): except StopIteration: break except Exception as e: # pylint: disable=broad-except - # Can't pick tracebacks. + # Can't pickle tracebacks. # As a compromise, print the traceback and pickle None instead. - if self._use_multiprocessing: - traceback.print_exc() - setattr(e, '__traceback__', None) - elif not hasattr(e, '__traceback__'): - setattr(e, '__traceback__', sys.exc_info()[2]) + traceback.print_exc() + setattr(e, '__traceback__', None) self.queue.put((False, e)) self._stop_event.set() break + def start(self, workers=1, max_queue_size=10): + """Kicks off threads which add data from the generator into the queue. + + Arguments: + workers: number of worker threads + max_queue_size: queue size + (when full, threads could block on `put()`) + """ try: + self.max_queue_size = max_queue_size if self._use_multiprocessing: self._manager = multiprocessing.Manager() self.queue = self._manager.Queue(maxsize=max_queue_size) self._stop_event = multiprocessing.Event() else: - self.queue = queue.Queue() + # On all OSes, avoid **SYSTEMATIC** error in multithreading mode: + # `ValueError: generator already executing` + # => Serialize calls to infinite iterator/generator's next() function + self.genlock = threading.Lock() + self.queue = queue.Queue(maxsize=max_queue_size) self._stop_event = threading.Event() for _ in range(workers): @@ -688,12 +730,12 @@ class GeneratorEnqueuer(SequenceEnqueuer): # Reset random seed else all children processes # share the same seed np.random.seed(self.seed) - thread = multiprocessing.Process(target=data_generator_task) + thread = multiprocessing.Process(target=self._data_generator_task) thread.daemon = True if self.seed is not None: self.seed += 1 else: - thread = threading.Thread(target=data_generator_task) + thread = threading.Thread(target=self._data_generator_task) self._threads.append(thread) thread.start() except: @@ -715,11 +757,15 @@ class GeneratorEnqueuer(SequenceEnqueuer): self._stop_event.set() for thread in self._threads: - if thread.is_alive(): - if self._use_multiprocessing: + if self._use_multiprocessing: + if thread.is_alive(): thread.terminate() - else: - thread.join(timeout) + else: + # The thread.is_alive() test is subject to a race condition: + # the thread could terminate right after the test and before the + # join, rendering this test meaningless -> Call thread.join() + # always, which is ok no matter what the status of the thread. + thread.join(timeout) if self._manager: self._manager.shutdown() @@ -734,7 +780,9 @@ class GeneratorEnqueuer(SequenceEnqueuer): Skip the data if it is `None`. Yields: - Data arrays. + The next element in the queue, i.e. a tuple + `(inputs, targets)` or + `(inputs, targets, sample_weights)`. """ while self.is_running(): if not self.queue.empty(): @@ -752,7 +800,7 @@ class GeneratorEnqueuer(SequenceEnqueuer): else: time.sleep(self.wait_time) - # Make sure to rethrow the first exception in the queue, if any + # Make sure to rethrow the first exception in the queue, if any while not self.queue.empty(): success, value = self.queue.get() if not success: diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py index a805315c94628f263dd4ce7a8b0f751cdf685ca0..462d600bf827768b0f2e6265aebdaad48e70fcd9 100644 --- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import binascii import codecs import marshal import os @@ -255,7 +256,10 @@ def func_load(code, defaults=None, closure=None, globs=None): if closure is not None: closure = tuple(ensure_value_to_cell(_) for _ in closure) - raw_code = codecs.decode(code.encode('ascii'), 'base64') + try: + raw_code = codecs.decode(code.encode('ascii'), 'base64') + except (UnicodeEncodeError, binascii.Error): + raw_code = code.encode('raw_unicode_escape') code = marshal.loads(raw_code) if globs is None: globs = globals() @@ -287,55 +291,73 @@ class Progbar(object): Arguments: target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over time. Metrics in this list + will be displayed as-is. All others will be averaged + by the progbar before display. interval: Minimum visual progress update interval (in seconds). """ - def __init__(self, target, width=30, verbose=1, interval=0.05): - self.width = width - if target is None: - target = -1 + def __init__(self, target, width=30, verbose=1, interval=0.05, + stateful_metrics=None): self.target = target - self.sum_values = {} - self.unique_values = [] - self.start = time.time() - self.last_update = 0 - self.interval = interval - self.total_width = 0 - self.seen_so_far = 0 + self.width = width self.verbose = verbose + self.interval = interval + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()) or - 'ipykernel' in sys.modules) - - def update(self, current, values=None, force=False): + 'ipykernel' in sys.modules or + 'posix' in sys.modules) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + + def update(self, current, values=None): """Updates the progress bar. Arguments: current: Index of current step. - values: List of tuples (name, value_for_last_step). - The progress bar will display averages for these values. - force: Whether to force visual progress update. + values: List of tuples: + `(name, value_for_last_step)`. + If `name` is in `stateful_metrics`, + `value_for_last_step` will be displayed as-is. + Else, an average of the metric over time will be displayed. """ values = values or [] for k, v in values: - if k not in self.sum_values: - self.sum_values[k] = [ - v * (current - self.seen_so_far), current - self.seen_so_far - ] - self.unique_values.append(k) + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + if k not in self._values: + self._values[k] = [v * (current - self._seen_so_far), + current - self._seen_so_far] + else: + self._values[k][0] += v * (current - self._seen_so_far) + self._values[k][1] += (current - self._seen_so_far) else: - self.sum_values[k][0] += v * (current - self.seen_so_far) - self.sum_values[k][1] += (current - self.seen_so_far) - self.seen_so_far = current + self._values[k] = v + self._seen_so_far = current now = time.time() - info = ' - %.0fs' % (now - self.start) + info = ' - %.0fs' % (now - self._start) if self.verbose == 1: - if (not force and (now - self.last_update) < self.interval and - current < self.target): + if (now - self._last_update < self.interval and + self.target is not None and current < self.target): return - prev_total_width = self.total_width + prev_total_width = self._total_width if self._dynamic_display: sys.stdout.write('\b' * prev_total_width) sys.stdout.write('\r') @@ -356,22 +378,21 @@ class Progbar(object): bar += '=' bar += ('.' * (self.width - prog_width)) bar += ']' - sys.stdout.write(bar) - self.total_width = len(bar) else: bar = '%7d/Unknown' % current - self.total_width = len(bar) + self._total_width = len(bar) sys.stdout.write(bar) if current: - time_per_unit = (now - self.start) / current + time_per_unit = (now - self._start) / current else: time_per_unit = 0 if self.target is not None and current < self.target: eta = time_per_unit * (self.target - current) if eta > 3600: - eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, + eta_format = '%d:%02d:%02d' % (eta // 3600, + (eta % 3600) // 60, eta % 60) elif eta > 60: eta_format = '%d:%02d' % (eta // 60, eta % 60) @@ -387,35 +408,32 @@ class Progbar(object): else: info += ' %.0fus/step' % (time_per_unit * 1e6) - for k in self.unique_values: + for k in self._values_order: info += ' - %s:' % k - if isinstance(self.sum_values[k], list): - avg = np.mean(self.sum_values[k][0] / max(1, self.sum_values[k][1])) + if isinstance(self._values[k], list): + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) if abs(avg) > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg else: - info += ' %s' % self.sum_values[k] + info += ' %s' % self._values[k] + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += (' ' * (prev_total_width - self._total_width)) - self.total_width += len(info) - if prev_total_width > self.total_width: - info += (' ' * (prev_total_width - self.total_width)) if self.target is not None and current >= self.target: info += '\n' sys.stdout.write(info) sys.stdout.flush() - if current >= self.target: - sys.stdout.write('\n') - elif self.verbose == 2: if self.target is None or current >= self.target: - for k in self.unique_values: + for k in self._values_order: info += ' - %s:' % k - avg = np.mean( - self.sum_values[k][0] / max(1, self.sum_values[k][1])) + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) if avg > 1e-3: info += ' %.4f' % avg else: @@ -425,7 +443,69 @@ class Progbar(object): sys.stdout.write(info) sys.stdout.flush() - self.last_update = now + self._last_update = now def add(self, n, values=None): - self.update(self.seen_so_far + n, values) + self.update(self._seen_so_far + n, values) + + +def make_batches(size, batch_size): + """Returns a list of batch indices (tuples of indices). + + Arguments: + size: Integer, total size of the data to slice into batches. + batch_size: Integer, batch size. + + Returns: + A list of tuples of array indices. + """ + num_batches = int(np.ceil(size / float(batch_size))) + return [(i * batch_size, min(size, (i + 1) * batch_size)) + for i in range(0, num_batches)] + + +def slice_arrays(arrays, start=None, stop=None): + """Slice an array or list of arrays. + + This takes an array-like, or a list of + array-likes, and outputs: + - arrays[start:stop] if `arrays` is an array-like + - [x[start:stop] for x in arrays] if `arrays` is a list + + Can also work on list/array of indices: `slice_arrays(x, indices)` + + Arguments: + arrays: Single array or list of arrays. + start: can be an integer index (start index) + or a list/array of indices + stop: integer (stop index); should be None if + `start` was a list. + + Returns: + A slice of the array(s). + + Raises: + ValueError: If the value of start is a list and stop is not None. + """ + if arrays is None: + return [None] + if isinstance(start, list) and stop is not None: + raise ValueError('The stop argument has to be None if the value of start is' + 'a list.') + elif isinstance(arrays, list): + if hasattr(start, '__len__'): + # hdf5 datasets only support list objects as indices + if hasattr(start, 'shape'): + start = start.tolist() + return [None if x is None else x[start] for x in arrays] + else: + return [None if x is None else x[start:stop] for x in arrays] + else: + if hasattr(start, '__len__'): + if hasattr(start, 'shape'): + start = start.tolist() + return arrays[start] + elif hasattr(start, '__getitem__'): + return arrays[start:stop] + else: + return [None] diff --git a/tensorflow/python/keras/_impl/keras/utils/io_utils.py b/tensorflow/python/keras/_impl/keras/utils/io_utils.py index e123339f5a7cc629778e2247d985dbe4591da54a..bbf1d2a3d9c3948271780ec3fad3316b4e6d53c3 100644 --- a/tensorflow/python/keras/_impl/keras/utils/io_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/io_utils.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=g-import-not-at-top """Utilities related to disk I/O.""" from __future__ import absolute_import from __future__ import division @@ -25,7 +26,7 @@ from tensorflow.python.util.tf_export import tf_export try: - import h5py # pylint:disable=g-import-not-at-top + import h5py except ImportError: h5py = None @@ -65,11 +66,11 @@ class HDF5Matrix(object): 'HDF5 and h5py installed.') if datapath not in list(self.refs.keys()): - self._f = h5py.File(datapath) - self.refs[datapath] = self._f + f = h5py.File(datapath) + self.refs[datapath] = f else: - self._f = self.refs[datapath] - self.data = self._f[dataset] + f = self.refs[datapath] + self.data = f[dataset] self.start = start if end is None: self.end = self.data.shape[0] @@ -80,9 +81,6 @@ class HDF5Matrix(object): def __len__(self): return self.end - self.start - def __del__(self): - self._f.close() - def __getitem__(self, key): if isinstance(key, slice): start, stop = key.start, key.stop diff --git a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py index 30af285cbfb8b8bc38e62d20f0698f9d3c121d10..4c8009dfd80e1aec457fa03687f2840c7fe4607b 100644 --- a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities related to Keras layers. +# pylint: disable=protected-access +"""Utilities related to layer/model functionality. """ from __future__ import absolute_import from __future__ import division @@ -29,10 +30,10 @@ def count_params(weights): """Count the total number of scalars composing the weights. Arguments: - weights: An iterable containing the weights on which to compute params + weights: An iterable containing the weights on which to compute params Returns: - The total number of scalars composing the weights + The total number of scalars composing the weights """ return int(np.sum([K.count_params(p) for p in set(weights)])) @@ -47,24 +48,30 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): terminal window sizes). positions: Relative or absolute positions of log elements in each line. If not provided, defaults to `[.33, .55, .67, 1.]`. - print_fn: Print function to use (defaults to `print`). + print_fn: Print function to use. It will be called on each line of the summary. You can set it to a custom function in order to capture the string summary. + It defaults to `print` (prints to stdout). """ if print_fn is None: print_fn = print if model.__class__.__name__ == 'Sequential': sequential_like = True + elif not model._is_graph_network: + # We treat subclassed models as a simple sequence of layers, for logging + # purposes. + sequential_like = True else: sequential_like = True - nodes_by_depth = model._nodes_by_depth.values() # pylint: disable=protected-access + nodes_by_depth = model._nodes_by_depth.values() nodes = [] for v in nodes_by_depth: if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1): - # If the model has multiple nodes or if the nodes have - # multiple inbound_layers, the model is no longer sequential. + # if the model has multiple nodes + # or if the nodes have multiple inbound_layers + # the model is no longer sequential sequential_like = False break nodes += v @@ -72,7 +79,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): # search for shared layers for layer in model.layers: flag = False - for node in layer.inbound_nodes: + for node in layer._inbound_nodes: if node in nodes: if flag: sequential_like = False @@ -97,7 +104,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): # header names for the different log elements to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to'] relevant_nodes = [] - for v in model._nodes_by_depth.values(): # pylint: disable=protected-access + for v in model._nodes_by_depth.values(): relevant_nodes += v def print_row(fields, positions): @@ -115,17 +122,24 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): print_fn('=' * line_length) def print_layer_summary(layer): + """Prints a summary for a single layer. + + Arguments: + layer: target layer. + """ try: output_shape = layer.output_shape except AttributeError: output_shape = 'multiple' + except RuntimeError: # output_shape unknown in Eager mode. + output_shape = '?' name = layer.name cls_name = layer.__class__.__name__ fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()] print_row(fields, positions) def print_layer_summary_with_connections(layer): - """Prints a summary for a single layer. + """Prints a summary for a single layer (including topological connections). Arguments: layer: target layer. @@ -135,7 +149,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): except AttributeError: output_shape = 'multiple' connections = [] - for node in layer._inbound_nodes: # pylint: disable=protected-access + for node in layer._inbound_nodes: if relevant_nodes and node not in relevant_nodes: # node is not part of the current network continue @@ -143,8 +157,8 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): inbound_layer = node.inbound_layers[i].name inbound_node_index = node.node_indices[i] inbound_tensor_index = node.tensor_indices[i] - connections.append(inbound_layer + '[' + str(inbound_node_index) + '][' - + str(inbound_tensor_index) + ']') + connections.append(inbound_layer + '[' + str(inbound_node_index) + + '][' + str(inbound_tensor_index) + ']') name = layer.name cls_name = layer.__class__.__name__ @@ -173,9 +187,9 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): else: print_fn('_' * line_length) - model._check_trainable_weights_consistency() # pylint: disable=protected-access + model._check_trainable_weights_consistency() if hasattr(model, '_collected_trainable_weights'): - trainable_count = count_params(model._collected_trainable_weights) # pylint: disable=protected-access + trainable_count = count_params(model._collected_trainable_weights) else: trainable_count = count_params(model.trainable_weights) diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils.py b/tensorflow/python/keras/_impl/keras/utils/np_utils.py index 3dddb99191c8a40adf8f39216679a0975d4e830c..a611be08aaed824ebb278b4b28ef52ea1872563b 100644 --- a/tensorflow/python/keras/_impl/keras/utils/np_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/np_utils.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py index 1ec8e3a2bf6d539655b4417cbd413a926978cee2..45c1b92075c50956fee004409e98898411e83d27 100644 --- a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,31 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=protected-access +# pylint: disable=g-import-not-at-top """Utilities related to model visualization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os -import sys from tensorflow.python.util.tf_export import tf_export + try: # pydot-ng is a fork of pydot that is better maintained. - import pydot_ng as pydot # pylint: disable=g-import-not-at-top + import pydot_ng as pydot except ImportError: - # Fall back on pydot if necessary. - # Silence a `print` statement that occurs in case of import error, - # by temporarily replacing sys.stdout. - _stdout = sys.stdout - sys.stdout = sys.stderr + # pydotplus is an improved version of pydot try: - import pydot # pylint: disable=g-import-not-at-top + import pydotplus as pydot except ImportError: - pydot = None - finally: - # Restore sys.stdout. - sys.stdout = _stdout + # Fall back on pydot if necessary. + try: + import pydot + except ImportError: + pydot = None def _check_pydot(): @@ -66,8 +65,8 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): Returns: A `pydot.Dot` instance representing the Keras model. """ - from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper # pylint: disable=g-import-not-at-top - from tensorflow.python.keras._impl.keras.models import Sequential # pylint: disable=g-import-not-at-top + from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper + from tensorflow.python.keras._impl.keras.models import Sequential _check_pydot() dot = pydot.Dot() @@ -119,9 +118,9 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): # Connect nodes with edges. for layer in layers: layer_id = str(id(layer)) - for i, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access + for i, node in enumerate(layer._inbound_nodes): node_key = layer.name + '_ib-' + str(i) - if node_key in model._network_nodes: # pylint: disable=protected-access + if node_key in model._container_nodes: for inbound_layer in node.inbound_layers: inbound_layer_id = str(id(inbound_layer)) layer_id = str(id(layer)) diff --git a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py index bc788d874f663caefd46d56fbf715a802fe08ec1..2884dc84cc5d99511947e6f0f97b0bf8a505221f 100644 --- a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py +++ b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""API wrapper allowing to use certain Keras models with the Scikit-Learn API. +"""Wrapper for using the Scikit-Learn API with Keras models. """ from __future__ import absolute_import from __future__ import division @@ -24,8 +24,9 @@ import types import numpy as np from tensorflow.python.keras._impl.keras.models import Sequential +from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.util import tf_inspect +from tensorflow.python.util.tf_export import tf_export class BaseWrapper(object): @@ -75,7 +76,7 @@ class BaseWrapper(object): self.check_params(sk_params) def check_params(self, params): - """Checks for user typos in "params". + """Checks for user typos in `params`. Arguments: params: dictionary; the parameters to be checked @@ -95,13 +96,11 @@ class BaseWrapper(object): else: legal_params_fns.append(self.build_fn) - legal_params = [] - for fn in legal_params_fns: - legal_params += tf_inspect.getargspec(fn)[0] - legal_params = set(legal_params) - for params_name in params: - if params_name not in legal_params: + for fn in legal_params_fns: + if has_arg(fn, params_name): + break + else: if params_name != 'nb_epoch': raise ValueError('{} is not a legal parameter'.format(params_name)) @@ -136,10 +135,10 @@ class BaseWrapper(object): Arguments: x : array-like, shape `(n_samples, n_features)` - Training samples where n_samples in the number of samples - and n_features is the number of features. + Training samples where `n_samples` is the number of samples + and `n_features` is the number of features. y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` - True labels for X. + True labels for `x`. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.fit` @@ -170,26 +169,26 @@ class BaseWrapper(object): return history def filter_sk_params(self, fn, override=None): - """Filters `sk_params` and return those in `fn`'s arguments. + """Filters `sk_params` and returns those in `fn`'s arguments. Arguments: fn : arbitrary function - override: dictionary, values to override sk_params + override: dictionary, values to override `sk_params` Returns: - res : dictionary dictionary containing variables - in both sk_params and fn's arguments. + res : dictionary containing variables + in both `sk_params` and `fn`'s arguments. """ override = override or {} res = {} - fn_args = tf_inspect.getargspec(fn)[0] for name, value in self.sk_params.items(): - if name in fn_args: + if has_arg(fn, name): res.update({name: value}) res.update(override) return res +@tf_export('keras.wrappers.scikit_learn.KerasClassifier') class KerasClassifier(BaseWrapper): """Implementation of the scikit-learn classifier API for Keras. """ @@ -199,10 +198,10 @@ class KerasClassifier(BaseWrapper): Arguments: x : array-like, shape `(n_samples, n_features)` - Training samples where n_samples in the number of samples - and n_features is the number of features. + Training samples where `n_samples` is the number of samples + and `n_features` is the number of features. y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` - True labels for X. + True labels for `x`. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.fit` @@ -229,8 +228,8 @@ class KerasClassifier(BaseWrapper): Arguments: x: array-like, shape `(n_samples, n_features)` - Test samples where n_samples in the number of samples - and n_features is the number of features. + Test samples where `n_samples` is the number of samples + and `n_features` is the number of features. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.predict_classes`. @@ -248,8 +247,8 @@ class KerasClassifier(BaseWrapper): Arguments: x: array-like, shape `(n_samples, n_features)` - Test samples where n_samples in the number of samples - and n_features is the number of features. + Test samples where `n_samples` is the number of samples + and `n_features` is the number of features. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.predict_classes`. @@ -258,8 +257,8 @@ class KerasClassifier(BaseWrapper): proba: array-like, shape `(n_samples, n_outputs)` Class probability estimates. In the case of binary classification, - tp match the scikit-learn API, - will return an array of shape '(n_samples, 2)' + to match the scikit-learn API, + will return an array of shape `(n_samples, 2)` (instead of `(n_sample, 1)` as in Keras). """ kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs) @@ -276,16 +275,16 @@ class KerasClassifier(BaseWrapper): Arguments: x: array-like, shape `(n_samples, n_features)` - Test samples where n_samples in the number of samples - and n_features is the number of features. + Test samples where `n_samples` is the number of samples + and `n_features` is the number of features. y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` - True labels for x. + True labels for `x`. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.evaluate`. Returns: score: float - Mean accuracy of predictions on X wrt. y. + Mean accuracy of predictions on `x` wrt. `y`. Raises: ValueError: If the underlying model isn't configured to @@ -312,6 +311,7 @@ class KerasClassifier(BaseWrapper): 'the `model.compile()` method.') +@tf_export('keras.wrappers.scikit_learn.KerasRegressor') class KerasRegressor(BaseWrapper): """Implementation of the scikit-learn regressor API for Keras. """ @@ -321,8 +321,8 @@ class KerasRegressor(BaseWrapper): Arguments: x: array-like, shape `(n_samples, n_features)` - Test samples where n_samples in the number of samples - and n_features is the number of features. + Test samples where `n_samples` is the number of samples + and `n_features` is the number of features. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.predict`. @@ -338,16 +338,16 @@ class KerasRegressor(BaseWrapper): Arguments: x: array-like, shape `(n_samples, n_features)` - Test samples where n_samples in the number of samples - and n_features is the number of features. + Test samples where `n_samples` is the number of samples + and `n_features` is the number of features. y: array-like, shape `(n_samples,)` - True labels for X. + True labels for `x`. **kwargs: dictionary arguments Legal arguments are the arguments of `Sequential.evaluate`. Returns: score: float - Mean accuracy of predictions on X wrt. y. + Mean accuracy of predictions on `x` wrt. `y`. """ kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) loss = self.model.evaluate(x, y, **kwargs) diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py index 34f1435ffb6b65ef0e1399fb6893c3b791616f79..fccedf919a7b261bb30f332172b1388db9da1939 100644 --- a/tensorflow/python/keras/applications/__init__.py +++ b/tensorflow/python/keras/applications/__init__.py @@ -18,16 +18,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.keras.applications import densenet from tensorflow.python.keras.applications import inception_resnet_v2 from tensorflow.python.keras.applications import inception_v3 from tensorflow.python.keras.applications import mobilenet +from tensorflow.python.keras.applications import nasnet from tensorflow.python.keras.applications import resnet50 from tensorflow.python.keras.applications import vgg16 from tensorflow.python.keras.applications import vgg19 from tensorflow.python.keras.applications import xception +from tensorflow.python.keras.applications.densenet import DenseNet121 +from tensorflow.python.keras.applications.densenet import DenseNet169 +from tensorflow.python.keras.applications.densenet import DenseNet201 from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2 from tensorflow.python.keras.applications.inception_v3 import InceptionV3 from tensorflow.python.keras.applications.mobilenet import MobileNet +from tensorflow.python.keras.applications.nasnet import NASNetLarge +from tensorflow.python.keras.applications.nasnet import NASNetMobile from tensorflow.python.keras.applications.resnet50 import ResNet50 from tensorflow.python.keras.applications.vgg16 import VGG16 from tensorflow.python.keras.applications.vgg19 import VGG19 diff --git a/tensorflow/python/keras/applications/densenet/__init__.py b/tensorflow/python/keras/applications/densenet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b8ea83920733a3a442171616ab460ffaf831521 --- /dev/null +++ b/tensorflow/python/keras/applications/densenet/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""DenseNet Keras applications.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras._impl.keras.applications.densenet import decode_predictions +from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121 +from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169 +from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201 +from tensorflow.python.keras._impl.keras.applications.densenet import preprocess_input + +del absolute_import +del division +del print_function diff --git a/tensorflow/python/keras/applications/nasnet/__init__.py b/tensorflow/python/keras/applications/nasnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94eb145b85b85b2e52ca37e7aebc681c1f054e16 --- /dev/null +++ b/tensorflow/python/keras/applications/nasnet/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""NASNet Keras applications.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras._impl.keras.applications.nasnet import decode_predictions +from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetLarge +from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetMobile +from tensorflow.python.keras._impl.keras.applications.nasnet import preprocess_input + +del absolute_import +del division +del print_function diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index b94bf8f0f67a7a8ddbb351d13cb17ccdbf283260..84ee5040dcd7b118a5c63b6532135913fe238797 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -30,6 +30,7 @@ from tensorflow.python.keras._impl.keras.layers.advanced_activations import Leak from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU +from tensorflow.python.keras._impl.keras.layers.advanced_activations import Softmax # Convolution layers. from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D @@ -37,6 +38,7 @@ from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose +from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv1D from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D # Convolution layer aliases. @@ -45,6 +47,7 @@ from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose +from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution1D from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D # Image processing layers. diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 8c1d16c2a8fc2ed1130d81c46aa233bf8416caf8..d4ceb2e489c8a20d26eaf9d89b12992d2b8673d7 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1043,6 +1043,7 @@ tf_py_test( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variables", + "//tensorflow/python/eager:function", ], ) @@ -1293,7 +1294,7 @@ cuda_py_test( cuda_py_test( name = "control_flow_ops_py_test", - # TOOD(b/70473603): change this back to "small" once the C API is + # TODO(b/70473603): change this back to "small" once the C API is # permanently enabled size = "medium", srcs = ["control_flow_ops_py_test.py"], @@ -1600,6 +1601,19 @@ cuda_py_test( ], ) +cuda_py_test( + name = "manip_ops_test", + size = "small", + srcs = ["manip_ops_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:manip_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + ], + tags = ["no_windows_gpu"], +) + cuda_py_test( name = "matmul_op_test", size = "small", @@ -2821,7 +2835,7 @@ tf_py_test( "//tensorflow/python:random_ops", "//tensorflow/python:variables", ], - shard_count = 3, + shard_count = 10, tags = ["no_windows_gpu"], ) diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index ec6184aacdb1ee6376944114ace3f1c1c1407aa9..365cf72108de5a1e5e1eb47891a6ad64151add22 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import time +import unittest import numpy as np @@ -82,7 +83,9 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase): matrix_ph = array_ops.placeholder(dtypes.int32) transposed = array_ops.matrix_transpose(matrix_ph) self.assertAllEqual( - expected_transposed, transposed.eval(feed_dict={matrix_ph: matrix})) + expected_transposed, transposed.eval(feed_dict={ + matrix_ph: matrix + })) def testBatchMatrixDynamicallyDefined(self): matrix_0 = [[1, 2, 3], [4, 5, 6]] @@ -96,7 +99,9 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase): transposed = array_ops.matrix_transpose(batch_matrix_ph) self.assertAllEqual( expected_transposed, - transposed.eval(feed_dict={batch_matrix_ph: batch_matrix})) + transposed.eval(feed_dict={ + batch_matrix_ph: batch_matrix + })) def testTensorWithStaticRankLessThanTwoRaisesBecauseNotAMatrix(self): vector = [1, 2, 3] @@ -203,8 +208,10 @@ class BooleanMaskTest(test_util.TensorFlowTestCase): masked_tensor = sess.run( array_ops.boolean_mask(ph_tensor, ph_mask), - feed_dict={ph_tensor: arr, - ph_mask: mask}) + feed_dict={ + ph_tensor: arr, + ph_mask: mask + }) np.testing.assert_allclose(masked_tensor, arr[mask]) def testMaskDimensionsSetToNoneRaises(self): @@ -280,7 +287,8 @@ class ReverseV2Test(test_util.TensorFlowTestCase): for axis_dtype in [dtypes.int32, dtypes.int64]: with self.test_session(use_gpu=use_gpu): x_tf = array_ops.reverse_v2(x_np, - constant_op.constant([0], dtype=axis_dtype)).eval() + constant_op.constant( + [0], dtype=axis_dtype)).eval() self.assertAllEqual(x_tf, np.asarray(x_np)[::-1]) def _reverse2DimAuto(self, np_dtype): @@ -290,16 +298,17 @@ class ReverseV2Test(test_util.TensorFlowTestCase): for use_gpu in [False, True]: for axis_dtype in [dtypes.int32, dtypes.int64]: with self.test_session(use_gpu=use_gpu): - x_tf_1 = reverse_f(x_np, - constant_op.constant([0], dtype=axis_dtype)).eval() - x_tf_2 = reverse_f(x_np, - constant_op.constant([-2], dtype=axis_dtype)).eval() - x_tf_3 = reverse_f(x_np, - constant_op.constant([1], dtype=axis_dtype)).eval() - x_tf_4 = reverse_f(x_np, - constant_op.constant([-1], dtype=axis_dtype)).eval() + x_tf_1 = reverse_f(x_np, constant_op.constant( + [0], dtype=axis_dtype)).eval() + x_tf_2 = reverse_f(x_np, constant_op.constant( + [-2], dtype=axis_dtype)).eval() + x_tf_3 = reverse_f(x_np, constant_op.constant( + [1], dtype=axis_dtype)).eval() + x_tf_4 = reverse_f(x_np, constant_op.constant( + [-1], dtype=axis_dtype)).eval() x_tf_5 = reverse_f(x_np, - constant_op.constant([1, 0], dtype=axis_dtype)).eval() + constant_op.constant([1, 0], + dtype=axis_dtype)).eval() self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :]) self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :]) self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1]) @@ -324,18 +333,16 @@ class ReverseV2Test(test_util.TensorFlowTestCase): def testReverse1DimAuto(self): for dtype in [ - np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, - np.bool, np.float16, np.float32, - np.float64, np.complex64, np.complex128, + np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.bool, + np.float16, np.float32, np.float64, np.complex64, np.complex128, np.array(b"").dtype.type ]: self._reverse1DimAuto(dtype) def testReverse2DimAuto(self): for dtype in [ - np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, - np.bool, np.float16, np.float32, - np.float64, np.complex64, np.complex128, + np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.bool, + np.float16, np.float32, np.float64, np.complex64, np.complex128, np.array(b"").dtype.type ]: self._reverse2DimAuto(dtype) @@ -408,7 +415,7 @@ class MeshgridTest(test_util.TensorFlowTestCase): def _compareDiffType(self, n, np_dtype, use_gpu): inputs = [] for index in ("ij", "xy"): - for i in range(n): + for _ in range(n): x = np.linspace(-10, 10, 5).astype(np_dtype) if np_dtype in (np.complex64, np.complex128): x += 1j @@ -416,8 +423,8 @@ class MeshgridTest(test_util.TensorFlowTestCase): numpy_out = np.meshgrid(*inputs, indexing=index) with self.test_session(use_gpu=use_gpu): tf_out = array_ops.meshgrid(*inputs, indexing=index) - for X, _X in zip(numpy_out, tf_out): - self.assertAllEqual(X, _X.eval()) + for x_np, x_tf in zip(numpy_out, tf_out): + self.assertAllEqual(x_np, x_tf.eval()) def testCompare(self): for t in (np.float16, np.float32, np.float64, np.int32, np.int64, @@ -491,7 +498,7 @@ class StridedSliceTest(test_util.TensorFlowTestCase): def test_basic_slice(self): for tensor_type in STRIDED_SLICE_TYPES: - with self.test_session(use_gpu=True): + with self.test_session(use_gpu=not tensor_type.is_integer): checker = StridedSliceChecker( self, StridedSliceChecker.REF_TENSOR, tensor_type=tensor_type) _ = checker[:, :, :] @@ -711,8 +718,8 @@ class GradSliceChecker(object): slice_val_grad2, = gradients_impl.gradients( slice_val_grad, dy, grad_ys=self.var) self.sess.run(assign) - slice_val_grad_evaled, slice_val_grad2_evaled = (self.sess.run( - [slice_val_grad, slice_val_grad2])) + slice_val_grad_evaled, slice_val_grad2_evaled = ( + self.sess.run([slice_val_grad, slice_val_grad2])) analytic_grad2_evaled = analytic_grad2.eval() self.test.assertAllEqual(slice_val_grad2_evaled, analytic_grad2_evaled) @@ -877,7 +884,8 @@ class StridedSliceAssignChecker(object): if self.tensor_type.is_complex: value -= 1j * value - with self.test.test_session(use_gpu=True) as sess: + with self.test.test_session( + use_gpu=not self.tensor_type.is_integer) as sess: if self._use_resource: var = resource_variable_ops.ResourceVariable(self.x) else: @@ -946,6 +954,30 @@ class SliceAssignTest(test_util.TensorFlowTestCase): v = variables.Variable([1, 2]) sess.run(v[:].assign([1, 2])) + def testTypeError(self): + init_val = constant_op.constant([1, 2], dtype=dtypes.int32) + too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8) + too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64) + v = variables.Variable(init_val) + with self.assertRaises(TypeError): + v[:].assign(too_small_val) + with self.assertRaises(TypeError): + v[:].assign(too_large_val) + + def testTypeErrorResource(self): + init_val = constant_op.constant([1, 2], dtype=dtypes.int32) + too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8) + too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64) + v = resource_variable_ops.ResourceVariable(init_val) + with self.test_session() as sess: + sess.run(v.initializer) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "l-value dtype int32 does not match r-value dtype int64"): + sess.run(v[:].assign(too_large_val)) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(v[:].assign(too_small_val)) + class ShapeSizeRankTest(test_util.TensorFlowTestCase): @@ -983,32 +1015,49 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, "maxlen must be scalar"): array_ops.sequence_mask([10, 20], [10, 20]) - def testOneDimensional(self): + def testOneDimensionalWithMaxlen(self): with self.test_session(): res = array_ops.sequence_mask(constant_op.constant([1, 3, 2]), 5) self.assertAllEqual(res.get_shape(), [3, 5]) - self.assertAllEqual(res.eval(), [[True, False, False, False, False], - [True, True, True, False, False], - [True, True, False, False, False]]) + self.assertAllEqual( + res.eval(), + [[True, False, False, False, False], [True, True, True, False, False], + [True, True, False, False, False]]) + def testOneDimensionalDtypeWithoutMaxlen(self): + with self.test_session(): # test dtype and default maxlen: + res = array_ops.sequence_mask(constant_op.constant([0, 1, 4]), + dtype=dtypes.float32) + if ops._USE_C_API: + self.assertAllEqual(res.get_shape().as_list(), [3, 4]) + else: + self.assertAllEqual(res.get_shape().as_list(), [3, None]) + self.assertAllEqual( + res.eval(), + [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]]) + + def testOneDimensionalWithoutMaxlen(self): + with self.test_session(): res = array_ops.sequence_mask( - constant_op.constant([0, 1, 4]), dtype=dtypes.float32) + constant_op.constant([0, 1, 4])) if ops._USE_C_API: self.assertAllEqual(res.get_shape().as_list(), [3, 4]) else: self.assertAllEqual(res.get_shape().as_list(), [3, None]) - self.assertAllEqual(res.eval(), [[0.0, 0.0, 0.0, - 0.0], [1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0]]) + self.assertAllEqual( + res.eval(), + [[False, False, False, False], + [True, False, False, False], + [True, True, True, True]]) def testTwoDimensional(self): with self.test_session(): res = array_ops.sequence_mask(constant_op.constant([[1, 3, 2]]), 5) self.assertAllEqual(res.get_shape(), [1, 3, 5]) - self.assertAllEqual(res.eval(), [[[True, False, False, False, False], - [True, True, True, False, False], - [True, True, False, False, False]]]) + self.assertAllEqual(res.eval(), [[[True, False, False, False, False], [ + True, True, True, False, False + ], [True, True, False, False, False]]]) # test dtype and default maxlen: res = array_ops.sequence_mask( @@ -1017,12 +1066,15 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): self.assertAllEqual(res.get_shape().as_list(), [2, 3, 4]) else: self.assertAllEqual(res.get_shape().as_list(), [2, 3, None]) - self.assertAllEqual(res.eval(), [[[0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0]], - [[1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 0.0]]]) + self.assertAllEqual( + res.eval(), + [[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]], + [[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 0.0]]]) + + def testUnknownShape(self): + lengths = array_ops.placeholder(dtype=dtypes.int32) + res = array_ops.sequence_mask(lengths) + self.assertEqual(res.shape, None) def testDtypes(self): @@ -1031,9 +1083,10 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): constant_op.constant([1, 3, 2], dtype=lengths_dtype), constant_op.constant(5, dtype=maxlen_dtype)) self.assertAllEqual(res.get_shape(), [3, 5]) - self.assertAllEqual(res.eval(), [[True, False, False, False, False], - [True, True, True, False, False], - [True, True, False, False, False]]) + self.assertAllEqual( + res.eval(), + [[True, False, False, False, False], [True, True, True, False, False], + [True, True, False, False, False]]) with self.test_session(): check_dtypes(dtypes.int32, dtypes.int32) @@ -1088,13 +1141,14 @@ class PadTest(test_util.TensorFlowTestCase): def testEager(self): with context.eager_mode(): t = constant_op.constant([[1, 2, 3], [4, 5, 6]]) - paddings = constant_op.constant([[1, 1,], [2, 2]]) + paddings = constant_op.constant([[ + 1, + 1, + ], [2, 2]]) padded = array_ops.pad(t, paddings, "CONSTANT") self.assertAllEqual(padded.numpy(), - [[0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 2, 3, 0, 0], - [0, 0, 4, 5, 6, 0, 0], - [0, 0, 0, 0, 0, 0, 0]]) + [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0, 0], + [0, 0, 4, 5, 6, 0, 0], [0, 0, 0, 0, 0, 0, 0]]) class InvertPermutationTest(test_util.TensorFlowTestCase): @@ -1108,6 +1162,29 @@ class InvertPermutationTest(test_util.TensorFlowTestCase): self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1]) +class UnravelIndexTest(test_util.TensorFlowTestCase): + + # TODO(b/73086570): Reenable test. + @unittest.skip("Test does not pass internally.") + def testUnravelIndex(self): + with self.test_session(): + for dtype in [dtypes.int32, dtypes.int64]: + indices_1 = constant_op.constant(1621, dtype=dtype) + dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype) + out_1 = array_ops.unravel_index(indices_1, dims_1) + self.assertAllEqual(out_1.eval(), [3, 1, 4, 1]) + + indices_2 = constant_op.constant([1621], dtype=dtype) + dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype) + out_2 = array_ops.unravel_index(indices_2, dims_2) + self.assertAllEqual(out_2.eval(), [[3], [1], [4], [1]]) + + indices_3 = constant_op.constant([22, 41, 37], dtype=dtype) + dims_3 = constant_op.constant([7, 6], dtype=dtype) + out_3 = array_ops.unravel_index(indices_3, dims_3) + self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]]) + + class GuaranteeConstOpTest(test_util.TensorFlowTestCase): def testSimple(self): diff --git a/tensorflow/python/kernel_tests/atrous_convolution_test.py b/tensorflow/python/kernel_tests/atrous_convolution_test.py index 04248fb2bab4333ed164f7871d2e9d5002dc52ad..2d1b3d9b7e836591646a2d0e59742bf6139446d1 100644 --- a/tensorflow/python/kernel_tests/atrous_convolution_test.py +++ b/tensorflow/python/kernel_tests/atrous_convolution_test.py @@ -81,6 +81,7 @@ class AtrousConvolutionTest(test.TestCase): otherwise, it's delayed after the context. """ checks = [] + def add_check(check, *args, **kwargs): if context.in_eager_mode(): args_val, kwargs_val = self.evaluate([args, kwargs]) @@ -96,12 +97,12 @@ class AtrousConvolutionTest(test.TestCase): def _test_atrous_convolution(self, add_check, input_shape, filter_shape, dilation_rate, **kwargs): - filters = np.arange(np.prod(filter_shape), - dtype=np.float32).reshape(filter_shape) + filters = np.arange( + np.prod(filter_shape), dtype=np.float32).reshape(filter_shape) filters_upsampled = upsample_filters(filters, dilation_rate) x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - y1 = nn_ops.convolution(input=x, filter=filters, - dilation_rate=dilation_rate, **kwargs) + y1 = nn_ops.convolution( + input=x, filter=filters, dilation_rate=dilation_rate, **kwargs) y2 = nn_ops.convolution(input=x, filter=filters_upsampled, **kwargs) def check(y1_eval, y2_eval): @@ -112,13 +113,15 @@ class AtrousConvolutionTest(test.TestCase): def test_unknown_spatial_dims_for_channel_last_format(self): x = array_ops.placeholder(dtypes.float32, [1, None, None, 10]) w = array_ops.zeros([3, 3, 10, 20]) - y = nn_ops.convolution(x, w, "VALID", dilation_rate=[2, 2], data_format="NHWC") + y = nn_ops.convolution( + x, w, "VALID", dilation_rate=[2, 2], data_format="NHWC") self.assertEqual(y.shape.as_list(), [1, None, None, 20]) def test_unknown_spatial_dims_for_channel_first_format(self): x = array_ops.placeholder(dtypes.float32, [1, 10, None, None]) w = array_ops.zeros([3, 3, 10, 20]) - y = nn_ops.convolution(x, w, "VALID", dilation_rate=[2, 2], data_format="NCHW") + y = nn_ops.convolution( + x, w, "VALID", dilation_rate=[2, 2], data_format="NCHW") self.assertEqual(y.shape.as_list(), [1, 20, None, None]) @test_util.run_in_graph_and_eager_modes() @@ -215,28 +218,35 @@ class AtrousConvolutionTest(test.TestCase): def combined_op(converted_input, num_spatial_dims, padding_arg): # pylint: disable=unused-argument # pylint: disable=cell-var-from-loop - result = nn_ops.convolution(input=converted_input, filter=f1, - padding=padding) - result = nn_ops.convolution(input=result, filter=f2, - padding=padding) + result = nn_ops.convolution( + input=converted_input, filter=f1, padding=padding) + result = nn_ops.convolution( + input=result, filter=f2, padding=padding) # pylint: enable=cell-var-from-loop return result for rate_height in range(2, 4): for rate_width in range(2, 4): dilation_rate = [rate_height, rate_width] - y1 = nn_ops.convolution(input=x, filter=f1, padding=padding, - dilation_rate=dilation_rate) - y1 = nn_ops.convolution(input=y1, filter=f2, - padding=padding, - dilation_rate=dilation_rate) + y1 = nn_ops.convolution( + input=x, + filter=f1, + padding=padding, + dilation_rate=dilation_rate) + y1 = nn_ops.convolution( + input=y1, + filter=f2, + padding=padding, + dilation_rate=dilation_rate) y2 = nn_ops.with_space_to_batch( - input=x, dilation_rate=dilation_rate, op=combined_op, + input=x, + dilation_rate=dilation_rate, + op=combined_op, padding="VALID") def check(y1_eval, y2_eval): - self.assertAllClose(y1_eval, y2_eval, rtol=1e-2, - atol=1e-2) + self.assertAllClose(y1_eval, y2_eval, rtol=1e-2, atol=1e-2) + add_check(check, y1, y2) def _test_gradient(self, x_shape, f_shape, dilation_rate, padding): diff --git a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py index 88b3f20469a6a8d8e8181e8d5a3876ae22fb9c06..28b3dc45e9c5fd9aee0b4b7f71a5dc1b93c057ed 100644 --- a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py +++ b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py @@ -80,7 +80,7 @@ class RangeSamplerOpsTest(test.TestCase): with self.test_session(): true_classes = constant_op.constant( [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64) - _, _, sampled_expected_count = candidate_sampling_ops.all_candidate_sampler( + _, _, sampled_expected_count = candidate_sampling_ops.all_candidate_sampler( # pylint: disable=line-too-long true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True) sampled_log_expected_count = math_ops.log(sampled_expected_count) result = sampled_log_expected_count.eval() diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index a5fd3bc3345f41d9d3f07278dc7979c1103b597f..127bc6bb20ae6b415da94672de68cc4b8ceaa287 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -495,9 +495,9 @@ class ConcatOpTest(test.TestCase): p = [] shape = np.array([7, 13]) if test.is_gpu_available(): - num_tensors = 10000 + num_tensors = 5000 else: - num_tensors = 1000 + num_tensors = 500 for i in np.arange(num_tensors): input_shape = shape placeholder = array_ops.placeholder(dtypes.float32, shape=input_shape) diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 576bb68ba49cf5a5c7618131ad8a567672cb08d8..16e56349c45dd56a335f6f881826d975e24bd110 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -465,9 +465,8 @@ class ZerosLikeTest(test.TestCase): def testZerosLikeGPU(self): for dtype in [ dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64, - dtypes_lib.int32, dtypes_lib.int64, - dtypes_lib.complex64, dtypes_lib.complex128, - dtypes_lib.bool + dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.complex64, + dtypes_lib.complex128, dtypes_lib.bool ]: self._compareZeros(dtype, fully_defined_shape=False, use_gpu=True) self._compareZeros(dtype, fully_defined_shape=True, use_gpu=True) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 6e18ed132cd6337378fdb8ec774f7946da8d61ed..15ff0ec09b65a8ba242473fb7b25ee00424e0926 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_logging_ops from tensorflow.python.ops import gen_state_ops @@ -143,7 +144,7 @@ class ControlFlowTest(test.TestCase): enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True) nine = constant_op.constant(9) - enter_nine = control_flow_ops.enter(nine, "foo_1") + enter_nine = gen_control_flow_ops._enter(nine, "foo_1") op = state_ops.assign(enter_v, enter_nine) v2 = control_flow_ops.with_dependencies([op], enter_v) v3 = control_flow_ops.exit(v2) @@ -163,9 +164,9 @@ class ControlFlowTest(test.TestCase): def testEnterMulExit(self): with self.test_session(): data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") - enter_data = control_flow_ops.enter(data, "foo_1", False) + enter_data = gen_control_flow_ops._enter(data, "foo_1", False) five = constant_op.constant(5) - enter_five = control_flow_ops.enter(five, "foo_1", False) + enter_five = gen_control_flow_ops._enter(five, "foo_1", False) mul_op = math_ops.multiply(enter_data, enter_five) exit_op = control_flow_ops.exit(mul_op) @@ -177,12 +178,13 @@ class ControlFlowTest(test.TestCase): v = variables.Variable([0.0, 0.0], dtype=dtypes.float32) # If is_constant=True, the shape information should be propagated. - enter_v_constant = control_flow_ops.enter(v, "frame1", is_constant=True) + enter_v_constant = gen_control_flow_ops._enter( + v, "frame1", is_constant=True) self.assertEqual(enter_v_constant.shape, [2]) # Otherwise, the shape should be unknown. - enter_v_non_constant = control_flow_ops.enter(v, "frame2", - is_constant=False) + enter_v_non_constant = gen_control_flow_ops._enter( + v, "frame2", is_constant=False) self.assertEqual(enter_v_non_constant.shape, None) def testSwitchMergeIndexedSlices(self): @@ -255,8 +257,8 @@ class ControlFlowTest(test.TestCase): false = ops.convert_to_tensor(False) n = constant_op.constant(10) - enter_false = control_flow_ops.enter(false, "foo_1", False) - enter_n = control_flow_ops.enter(n, "foo_1", False) + enter_false = gen_control_flow_ops._enter(false, "foo_1", False) + enter_n = gen_control_flow_ops._enter(n, "foo_1", False) merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0] switch_n = control_flow_ops.switch(merge_n, enter_false) @@ -273,9 +275,9 @@ class ControlFlowTest(test.TestCase): one = constant_op.constant(1) n = constant_op.constant(10) - enter_i = control_flow_ops.enter(zero, "foo", False) - enter_one = control_flow_ops.enter(one, "foo", True) - enter_n = control_flow_ops.enter(n, "foo", True) + enter_i = gen_control_flow_ops._enter(zero, "foo", False) + enter_one = gen_control_flow_ops._enter(one, "foo", True) + enter_n = gen_control_flow_ops._enter(n, "foo", True) with ops.device(test.gpu_device_name()): merge_i = control_flow_ops.merge([enter_i, enter_i])[0] @@ -299,9 +301,9 @@ class ControlFlowTest(test.TestCase): one = constant_op.constant(1) n = constant_op.constant(10) - enter_i = control_flow_ops.enter(zero, "foo", False) - enter_one = control_flow_ops.enter(one, "foo", True) - enter_n = control_flow_ops.enter(n, "foo", True) + enter_i = gen_control_flow_ops._enter(zero, "foo", False) + enter_one = gen_control_flow_ops._enter(one, "foo", True) + enter_n = gen_control_flow_ops._enter(n, "foo", True) merge_i = control_flow_ops.merge([enter_i, enter_i])[0] @@ -322,8 +324,8 @@ class ControlFlowTest(test.TestCase): def testDifferentFrame(self): with self.test_session(): data = array_ops.placeholder(dtypes.float32, shape=[]) - enter_1 = control_flow_ops.enter(data, "foo_1", False) - enter_2 = control_flow_ops.enter(data, "foo_2", False) + enter_1 = gen_control_flow_ops._enter(data, "foo_1", False) + enter_2 = gen_control_flow_ops._enter(data, "foo_2", False) res = math_ops.add(enter_1, enter_2) with self.assertRaisesOpError("has inputs from different frames"): res.eval(feed_dict={data: 1.0}) @@ -702,6 +704,36 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) self.assertEqual(10000, r.eval()) + def testWhileExternalControlDependencies(self): + with self.test_session(): + v = variables.Variable(0.0) + v.initializer.run() + increment = v.assign_add(1.0) + + def body_fn(i): + with ops.control_dependencies([increment]): + return i + i + + result = control_flow_ops.while_loop(cond=lambda i: i < 1, + body=body_fn, loop_vars=[1]) + result.eval() + self.assertAllEqual(v.eval(), 1.0) + + def testWhileExternalControlDependenciesNoInput(self): + with self.test_session(): + v = variables.Variable(0.0) + v.initializer.run() + increment = v.assign_add(1.0) + + def body_fn(unused_i): + with ops.control_dependencies([increment]): + return constant_op.constant(5, name="five") + + result = control_flow_ops.while_loop(cond=lambda i: i < 5, + body=body_fn, loop_vars=[0]) + result.eval() + self.assertAllEqual(v.eval(), 1.0) + def testWhileWithRefs_1(self): with self.test_session() as sess: x = variables.Variable(0)._ref() # pylint: disable=protected-access @@ -736,24 +768,21 @@ class ControlFlowTest(test.TestCase): with self.test_session(): s = constant_op.constant([1, 2, 3, 4, 5]) r = isum(s, maximum_iterations=3) - self.assertAllEqual([1+3, 2+3, 3+3, 4+3, 5+3], r.eval()) + self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval()) def testWhileWithMaximumIterationsAndSingleArgument(self): with self.test_session(): r = control_flow_ops.while_loop( - lambda i: i < 3, - lambda i: i + 1, - [0], - maximum_iterations=1) + lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1) self.assertEqual(1, r.eval()) def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self): v = constant_op.constant(1.0) + def training_loop_with_gradient(i): out = control_flow_ops.while_loop( lambda i_, _: i_ < 3, - lambda i_, j: [i_ + 1, j * v], - [0, 1.0], + lambda i_, j: [i_ + 1, j * v], [0, 1.0], maximum_iterations=i) g = gradients_impl.gradients(out, v) with ops.control_dependencies(g): @@ -763,8 +792,8 @@ class ControlFlowTest(test.TestCase): xla_context.Enter() # Create training loop, ensure we can call gradient() of # while_loop inside the training loop. - loop = control_flow_ops.while_loop( - lambda i: i < 3, training_loop_with_gradient, [0]) + loop = control_flow_ops.while_loop(lambda i: i < 3, + training_loop_with_gradient, [0]) xla_context.Exit() loop_execute = array_ops.identity(loop) # Because loop is not fetchable. @@ -774,17 +803,18 @@ class ControlFlowTest(test.TestCase): def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self): v = constant_op.constant(1.0) + def inner_body(i, x): out = control_flow_ops.while_loop( lambda i, _: i < 3, - lambda i, j: [i + 1, j * v], - [0, x], + lambda i, j: [i + 1, j * v], [0, x], maximum_iterations=i) return out def create_while_loop(maximum_iterations=None): return control_flow_ops.while_loop( - lambda i, _: i < 3, inner_body, [0, 1.0], + lambda i, _: i < 3, + inner_body, [0, 1.0], maximum_iterations=maximum_iterations) loop_no_xla = create_while_loop(maximum_iterations=5) @@ -819,14 +849,17 @@ class ControlFlowTest(test.TestCase): def create_while_loop(): max_iter_holder = [] + def create_mi(): max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=())) return 1.0 - _ = control_flow_ops.cond(constant_op.constant(True), - create_mi, create_mi) + + _ = control_flow_ops.cond( + constant_op.constant(True), create_mi, create_mi) return control_flow_ops.while_loop( - lambda i, _: i < 3, lambda i, x: (i + 1, v * x), (0, 1.0), + lambda i, _: i < 3, + lambda i, x: (i + 1, v * x), (0, 1.0), maximum_iterations=max_iter_holder[0]) xla_context = control_flow_ops.XLAControlFlowContext() @@ -849,28 +882,32 @@ class ControlFlowTest(test.TestCase): p = array_ops.placeholder(dtype=dtypes.int32) def mid_body_builder(iterations): + def mid_body(i, x): r = control_flow_ops.while_loop( lambda *_: True, - lambda i, x: (i + 1, v * x), - (0, x), - maximum_iterations=iterations, name="inner") + lambda i, x: (i + 1, v * x), (0, x), + maximum_iterations=iterations, + name="inner") return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) + return mid_body def outer_body(i, x): iterations = array_ops.size(p, name="iterations") - return ( - i + 1, - x + control_flow_ops.while_loop( - lambda *_: True, mid_body_builder(iterations), (0, x), - maximum_iterations=iterations, name="mid")[1]) + return (i + 1, x + control_flow_ops.while_loop( + lambda *_: True, + mid_body_builder(iterations), (0, x), + maximum_iterations=iterations, + name="mid")[1]) def create_while_loop(): with ops.device("/cpu:0"): r = control_flow_ops.while_loop( - lambda *_: True, outer_body, (0, 1.0), - maximum_iterations=5, name="outer") + lambda *_: True, + outer_body, (0, 1.0), + maximum_iterations=5, + name="outer") return array_ops.identity(r[1]) xla_context = control_flow_ops.XLAControlFlowContext() @@ -881,18 +918,19 @@ class ControlFlowTest(test.TestCase): final_without_xla_context = create_while_loop() with self.test_session(use_gpu=False) as sess: - opts = config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE) + opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) run_metadata = config_pb2.RunMetadata() final_value_without_xla_context = sess.run( - final_without_xla_context, - feed_dict={p: [0, 0, 0]}) + final_without_xla_context, feed_dict={ + p: [0, 0, 0] + }) final_value_with_xla_context = sess.run( final_with_xla_context, feed_dict={p: [0, 0, 0]}, - options=opts, run_metadata=run_metadata) + options=opts, + run_metadata=run_metadata) node_stats = run_metadata.step_stats.dev_stats[0].node_stats stack_push_count = len( @@ -901,8 +939,8 @@ class ControlFlowTest(test.TestCase): # the last two "3"s comes from size(p), when p == [0, 0, 0]. self.assertEqual(stack_push_count, 5 * 3 * 3) - self.assertAllClose( - final_value_with_xla_context, final_value_without_xla_context) + self.assertAllClose(final_value_with_xla_context, + final_value_without_xla_context) # Have more than 10 parallel iterations and hence exercise k-bound # most of the time. @@ -951,8 +989,7 @@ class ControlFlowTest(test.TestCase): with self.test_session(): def compute(i, c, o): - c = array_ops.strided_slice(x, - array_ops.expand_dims(i, 0), + c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0), [1] + array_ops.expand_dims(i, 0)) o = array_ops.concat([o, c], 0) i = math_ops.add(i, 1) @@ -963,11 +1000,12 @@ class ControlFlowTest(test.TestCase): o = ops.convert_to_tensor([0]) x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6]) s = array_ops.size(x) - r = control_flow_ops.while_loop( - lambda i, c, o: math_ops.less(i, s), compute, [i, c, o], [ - i.get_shape(), tensor_shape.unknown_shape(), - tensor_shape.unknown_shape() - ]) + r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s), + compute, [i, c, o], [ + i.get_shape(), + tensor_shape.unknown_shape(), + tensor_shape.unknown_shape() + ]) result = r[2].eval() self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result) @@ -1033,7 +1071,8 @@ class ControlFlowTest(test.TestCase): return [new_i, new_j] r = control_flow_ops.while_loop( - c, _b, [i, m], [i.get_shape(), tensor_shape.unknown_shape()]) + c, _b, [i, m], + [i.get_shape(), tensor_shape.unknown_shape()]) r = r[1] * array_ops.ones([8, 8]) self.assertAllEqual(np.ones((8, 8)), r.eval()) @@ -1065,7 +1104,8 @@ class ControlFlowTest(test.TestCase): return [new_i, new_j] r = control_flow_ops.while_loop( - c, b, [i, m], [i.get_shape(), tensor_shape.TensorShape([None, 2])]) + c, b, [i, m], + [i.get_shape(), tensor_shape.TensorShape([None, 2])]) self.assertTrue(r[1].get_shape()[0].value is None) self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2)) @@ -1092,20 +1132,22 @@ class ControlFlowTest(test.TestCase): def b(i, x): return [ - i + 1, sparse_tensor.SparseTensor(x.indices, x.values * 2.0, - x.dense_shape) + i + 1, + sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape) ] _, r = control_flow_ops.while_loop(c, b, [i, x]) self.assertEqual(r.dense_shape.get_shape()[0].value, 1) _, r = control_flow_ops.while_loop( - c, b, [i, x], [i.get_shape(), tensor_shape.TensorShape([None])]) + c, b, [i, x], + [i.get_shape(), tensor_shape.TensorShape([None])]) self.assertTrue(r.dense_shape.get_shape()[0].value is None) with self.assertRaisesRegexp(ValueError, "is not compatible with"): _, r = control_flow_ops.while_loop( - c, b, [i, x], [i.get_shape(), tensor_shape.TensorShape([5])]) + c, b, [i, x], + [i.get_shape(), tensor_shape.TensorShape([5])]) def testWhileShapeInferenceIndexedSlices(self): with self.test_session(): @@ -1120,7 +1162,8 @@ class ControlFlowTest(test.TestCase): def b(i, x): return [ - i + 1, ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape) + i + 1, + ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape) ] _, r = control_flow_ops.while_loop(c, b, [i, x]) @@ -1128,14 +1171,16 @@ class ControlFlowTest(test.TestCase): self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2])) _, r = control_flow_ops.while_loop( - c, b, [i, x], [i.get_shape(), tensor_shape.TensorShape([None, 2])]) + c, b, [i, x], + [i.get_shape(), tensor_shape.TensorShape([None, 2])]) self.assertEqual(r.dense_shape.get_shape()[0].value, 2) self.assertTrue(r.values.get_shape()[0].value is None) self.assertEqual(r.values.get_shape()[1].value, 2) with self.assertRaisesRegexp(ValueError, "is not compatible with"): _, r = control_flow_ops.while_loop( - c, b, [i, x], [i.get_shape(), tensor_shape.TensorShape([None, 5])]) + c, b, [i, x], + [i.get_shape(), tensor_shape.TensorShape([None, 5])]) def _testNestedWhile_1(self, use_gpu): with self.test_session(use_gpu=use_gpu): @@ -1276,16 +1321,17 @@ class ControlFlowTest(test.TestCase): "v", [], initializer=init_ops.constant_initializer(2)) i0 = constant_op.constant(0) with ops.control_dependencies([i0]): + def loop_condition(i): return i < 4 def loop_body(i): some_cond = control_flow_ops.cond( constant_op.constant(True), - lambda: state_ops.assign(v, math_ops.square(v)), - lambda: v) + lambda: state_ops.assign(v, math_ops.square(v)), lambda: v) with ops.control_dependencies([some_cond]): return i + 1 + r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,)) variables.global_variables_initializer().run() self.assertEqual(4, r.eval()) @@ -1600,7 +1646,8 @@ class ControlFlowTest(test.TestCase): _, rx = control_flow_ops.while_loop( c1, - b1, [r, x], [r.get_shape(), tensor_shape.unknown_shape()], + b1, [r, x], + [r.get_shape(), tensor_shape.unknown_shape()], parallel_iterations=1) self.assertEqual(45, rx.eval()) @@ -1663,7 +1710,8 @@ class ControlFlowTest(test.TestCase): b = lambda i, v: [i + 1, math_ops.multiply(x, v)] r = control_flow_ops.while_loop( c, - b, [n, v], [n.get_shape(), tensor_shape.unknown_shape()], + b, [n, v], + [n.get_shape(), tensor_shape.unknown_shape()], parallel_iterations=1) r = gradients_impl.gradients(r[1], x)[0] @@ -1797,8 +1845,8 @@ class ControlFlowTest(test.TestCase): named = collections.namedtuple("named", ("a", "b")) loop_vars = [ named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)), - (constant_op.constant(2.0), - constant_op.constant(3.0)), constant_op.constant(4.0) + (constant_op.constant(2.0), constant_op.constant(3.0)), + constant_op.constant(4.0) ] c = lambda lv0, _1, _2: lv0.a < 100.0 @@ -1824,8 +1872,8 @@ class ControlFlowTest(test.TestCase): named = collections.namedtuple("named", ("a", "b")) loop_vars = [ named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)), - (constant_op.constant(2.0), - constant_op.constant(3.0)), constant_op.constant(4.0) + (constant_op.constant(2.0), constant_op.constant(3.0)), + constant_op.constant(4.0) ] c = lambda lv0, _1, _2: lv0.a < 100.0 @@ -2176,7 +2224,8 @@ class ControlFlowTest(test.TestCase): def b(i, x): return [ - i + 1, ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape) + i + 1, + ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape) ] _, r = control_flow_ops.while_loop(c, b, [i, x]) @@ -2197,8 +2246,8 @@ class ControlFlowTest(test.TestCase): def b(i, x): return [ - i + 1, sparse_tensor.SparseTensor(x.indices, x.values * 2.0, - x.dense_shape) + i + 1, + sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape) ] _, r = control_flow_ops.while_loop(c, b, [i, x]) @@ -2220,8 +2269,8 @@ class ControlFlowTest(test.TestCase): x1 = x + gradients_impl.gradients(data, params)[0] return i + 1, x1 - output_grad = control_flow_ops.while_loop(c, b, - [i0, constant_op.constant(0.0)]) + output_grad = control_flow_ops.while_loop( + c, b, [i0, constant_op.constant(0.0)]) self.assertAllClose(600.0, sess.run(output_grad)[1]) def testWhileAndTensorArray(self): @@ -2359,9 +2408,12 @@ class ControlFlowTest(test.TestCase): def testStopGradMultiFlows(self): with self.test_session(): + def body(i, y, r): x = variable_scope.get_variable( - "x", shape=(), dtype=dtypes.float32, + "x", + shape=(), + dtype=dtypes.float32, initializer=init_ops.ones_initializer()) y *= x return [i + 1, y, r + math_ops.reduce_sum(y)] @@ -2773,7 +2825,8 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop( lambda i, v: i < 2, lambda i, v: [i + 1, func(v)], [constant_op.constant(0), x], - [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()]) + [tensor_shape.unknown_shape(), + tensor_shape.unknown_shape()]) self.assertEqual(r[1].eval(), 65536.0) r = gradients_impl.gradients(r, x)[0] @@ -2800,12 +2853,14 @@ class ControlFlowContextCheckTest(test.TestCase): def _getCondTensor(self): cond_tensor = [] + def true_fn(): if not cond_tensor: cond_tensor.append(constant_op.constant(1)) return cond_tensor[0] - control_flow_ops.cond(math_ops.less(1, 2), true_fn, - lambda: constant_op.constant(0)) + + control_flow_ops.cond( + math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0)) return cond_tensor[0] def testInvalidContext(self): @@ -2821,14 +2876,13 @@ class ControlFlowContextCheckTest(test.TestCase): # Accessing a while loop tensor in cond is illegal. while_tensor = self._getWhileTensor() with self.assertRaisesRegexp( - ValueError, - "Cannot use 'while/Const_1' as input to 'cond/Add' because " + ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because " "'while/Const_1' is in a while loop. See info log for more details."): # TODO(skyewm): this passes if we return while_tensor directly instead # of using it as input to another op. - control_flow_ops.cond(math_ops.less(1, 2), - lambda: math_ops.add(1, while_tensor), - lambda: constant_op.constant(0)) + control_flow_ops.cond( + math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor), + lambda: constant_op.constant(0)) def testInvalidContextInWhile(self): # Accessing a while loop tensor in a different while loop is illegal. @@ -2856,6 +2910,7 @@ class ControlFlowContextCheckTest(test.TestCase): # Accessing a tensor from a cond context from the other branch's cond # context is OK (although dangerous). cond_tensor = [] + def branch_fn(): if not cond_tensor: cond_tensor.append(constant_op.constant(1)) @@ -2892,12 +2947,13 @@ class ControlFlowContextCheckTest(test.TestCase): while_tensor = self._getWhileTensor() return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + while_tensor, [0]) + with self.assertRaisesRegexp( ValueError, "Cannot use 'cond/while_1/add' as input to 'cond/while/Const_1' because" " they are in different while loops. See info log for more details."): - control_flow_ops.cond(math_ops.less(1, 2), true_fn, - lambda: constant_op.constant(0)) + control_flow_ops.cond( + math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0)) @test_util.with_c_api @@ -3005,11 +3061,13 @@ class AssertTest(test.TestCase): sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata) guarded_nodestat_names = [ n.node_name - for d in guarded_metadata.step_stats.dev_stats for n in d.node_stats + for d in guarded_metadata.step_stats.dev_stats + for n in d.node_stats ] unguarded_nodestat_names = [ n.node_name - for d in unguarded_metadata.step_stats.dev_stats for n in d.node_stats + for d in unguarded_metadata.step_stats.dev_stats + for n in d.node_stats ] guarded_memcpy_nodestat_names = [ n for n in guarded_nodestat_names if "MEMCPYDtoH" in n @@ -3066,6 +3124,7 @@ class WhileOpBenchmark(test.Benchmark): Returns: The duration of the run in seconds. """ + def loop_body(i, x): with ops.device("/gpu:0"): # Always put loop body on GPU. @@ -3107,7 +3166,7 @@ class WhileOpBenchmark(test.Benchmark): start_time = time.time() for _ in xrange(num_iters): sess.run(r) - return (time.time() - start_time)/num_iters + return (time.time() - start_time) / num_iters def benchmarkWhileOpCrossDevicePlacement(self): iters = 10 @@ -3154,23 +3213,20 @@ class EagerTest(test.TestCase): def testWhileLoop(self): with context.eager_mode(): tensor = constant_op.constant([1, 2, 3, 4, 5]) - self.assertAllEqual(isum(tensor).numpy(), - [46, 47, 48, 49, 50]) + self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50]) def testWhileLoopWithMaxIterations(self): with context.eager_mode(): tensor = constant_op.constant([1, 2, 3, 4, 5]) - self.assertAllEqual(isum(tensor, maximum_iterations=3).numpy(), - [1+3, 2+3, 3+3, 4+3, 5+3]) + self.assertAllEqual( + isum(tensor, maximum_iterations=3).numpy(), + [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3]) def testWhileWithMaximumIterationsAndSingleArgument(self): with context.eager_mode(): tensor = constant_op.constant(0) r = control_flow_ops.while_loop( - lambda i: i < 3, - lambda i: i + 1, - [tensor], - maximum_iterations=1) + lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1) self.assertEqual(1, r.numpy()) def testWithDependencies(self): @@ -3197,8 +3253,8 @@ class EagerTest(test.TestCase): f2 = lambda: constant_op.constant(23) f3 = lambda: constant_op.constant(-1) - r1 = control_flow_ops.case([(x < y, f1), (x > z, f2)], - default=f3, exclusive=True) + r1 = control_flow_ops.case( + [(x < y, f1), (x > z, f2)], default=f3, exclusive=True) self.assertAllEqual(r1.numpy(), 17) diff --git a/tensorflow/python/kernel_tests/control_flow_util_test.py b/tensorflow/python/kernel_tests/control_flow_util_test.py index 39e96f74b0461da0cf499e303b30a4a41aae4899..23185eaeece0d56fd83ecdf9e02c778712420465 100644 --- a/tensorflow/python/kernel_tests/control_flow_util_test.py +++ b/tensorflow/python/kernel_tests/control_flow_util_test.py @@ -41,17 +41,17 @@ class ControlFlowUtilTest(test.TestCase): self.assertFalse(control_flow_util.IsSwitch(test_ops.int_output().op)) def testIsLoopEnter(self): - enter = gen_control_flow_ops.enter(1, frame_name="name").op + enter = gen_control_flow_ops._enter(1, frame_name="name").op self.assertTrue(control_flow_util.IsLoopEnter(enter)) self.assertFalse(control_flow_util.IsLoopConstantEnter(enter)) - ref_enter = gen_control_flow_ops.ref_enter(test_ops.ref_output(), - frame_name="name").op + ref_enter = gen_control_flow_ops._ref_enter(test_ops.ref_output(), + frame_name="name").op self.assertTrue(control_flow_util.IsLoopEnter(ref_enter)) self.assertFalse(control_flow_util.IsLoopConstantEnter(ref_enter)) - const_enter = gen_control_flow_ops.enter(1, frame_name="name", - is_constant=True).op + const_enter = gen_control_flow_ops._enter(1, frame_name="name", + is_constant=True).op self.assertTrue(control_flow_util.IsLoopEnter(const_enter)) self.assertTrue(control_flow_util.IsLoopConstantEnter(const_enter)) diff --git a/tensorflow/python/kernel_tests/conv2d_transpose_test.py b/tensorflow/python/kernel_tests/conv2d_transpose_test.py index 7d0bc54b6993daff0298f9d76e9e67dfcbfa5711..b692d3da609fd97a55b8f5fce3334b8e9d97c827 100644 --- a/tensorflow/python/kernel_tests/conv2d_transpose_test.py +++ b/tensorflow/python/kernel_tests/conv2d_transpose_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.python.client import device_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -175,7 +174,7 @@ class Conv2DTransposeTest(test.TestCase): self.assertLess(err, err_tolerance) def testConv2DTransposeSingleStrideNCHW(self): - # `NCHW` data fomat is only supported for CUDA device. + # `NCHW` data format is only supported for CUDA device. if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): strides = [1, 1, 1, 1] @@ -210,7 +209,7 @@ class Conv2DTransposeTest(test.TestCase): self.assertAllClose(target, value[n, k, h, w]) def testConv2DTransposeSameNCHW(self): - # `NCHW` data fomat is only supported for CUDA device. + # `NCHW` data format is only supported for CUDA device. if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): strides = [1, 1, 2, 2] @@ -246,7 +245,7 @@ class Conv2DTransposeTest(test.TestCase): self.assertAllClose(target, value[n, k, h, w]) def testConv2DTransposeValidNCHW(self): - # `NCHW` data fomat is only supported for CUDA device. + # `NCHW` data format is only supported for CUDA device. if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): strides = [1, 1, 2, 2] diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 3e9bd3dade6d08835780362cd73f5f01368e83ac..27857989167ecd11c33286d7bb6cb068edd12831 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -24,6 +24,7 @@ import time import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import layers from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op @@ -301,25 +302,20 @@ class Conv2DTest(test.TestCase): padding, dilations): expected_results = [] computed_results = [] - default_dilations = (dilations[0] == 1 and dilations[1] == 1) for data_format, use_gpu in GetTestConfigs(): - # If any dilation rate is larger than 1, only do test on the GPU - # because we currently do not have a CPU implementation for arbitrary - # dilation rates. - if default_dilations or use_gpu: - expected, computed = self._ComputeReferenceDilatedConv( - tensor_in_sizes, filter_in_sizes, strides, dilations, padding, - data_format, use_gpu) - expected_results.append(expected) - computed_results.append(computed) - tolerance = 1e-2 if use_gpu else 1e-5 - expected_values = self.evaluate(expected_results) - computed_values = self.evaluate(computed_results) - for e_value, c_value in zip(expected_values, computed_values): - print("expected = ", e_value) - print("actual = ", c_value) - self.assertAllClose( - e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4) + expected, computed = self._ComputeReferenceDilatedConv( + tensor_in_sizes, filter_in_sizes, strides, dilations, padding, + data_format, use_gpu) + expected_results.append(expected) + computed_results.append(computed) + tolerance = 1e-2 if use_gpu else 1e-5 + expected_values = self.evaluate(expected_results) + computed_values = self.evaluate(computed_results) + for e_value, c_value in zip(expected_values, computed_values): + print("expected = ", e_value) + print("actual = ", c_value) + self.assertAllClose( + e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4) def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding, expected): @@ -364,13 +360,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Filter2x1Dilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 4, 4, 1], - filter_in_sizes=[2, 2, 1, 1], - strides=[1, 1], - dilations=[2, 1], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 4, 4, 1], + filter_in_sizes=[2, 2, 1, 1], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2DEmpty(self): @@ -384,13 +379,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2DEmptyDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[0, 2, 3, 3], - filter_in_sizes=[1, 1, 3, 3], - strides=[1, 1], - dilations=[2, 1], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[0, 2, 3, 3], + filter_in_sizes=[1, 1, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Filter(self): @@ -405,13 +399,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2D2x2FilterDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[2, 2, 3, 3], - strides=[1, 1], - dilations=[1, 2], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[2, 2, 3, 3], + strides=[1, 1], + dilations=[1, 2], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2D1x2Filter(self): @@ -429,13 +422,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2D1x2FilterDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[1, 2, 3, 3], - strides=[1, 1], - dilations=[2, 1], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[1, 2, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2D2x2FilterStride2(self): @@ -511,15 +503,14 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2DKernelSizeMatchesInputSizeDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 3, 3, 1], - filter_in_sizes=[2, 2, 1, 2], - strides=[1, 1], - dilations=[2, 2], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 3, 3, 1], + filter_in_sizes=[2, 2, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID") - # TODO this currently fails. + # TODO(yzhwang): this currently fails. # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1], # filter_in_sizes=[2, 2, 1, 1], # strides=[4, 4], padding="SAME", @@ -1537,21 +1528,6 @@ class Conv2DTest(test.TestCase): use_gpu=False) self.evaluate(conv) - def testCPUConv2DDilatedUnimplemented(self): - with self.test_session(use_gpu=False): - with self.assertRaisesRegexp(errors_impl.UnimplementedError, - "dilated rate of 1 for now"): - conv = self._SetupValuesForDevice( - tensor_in_sizes=[1, 4, 4, 1], - filter_in_sizes=[2, 2, 1, 1], - dilations=[2, 1], - strides=[1, 1], - padding="VALID", - data_format="NHWC", - dtype=dtypes.float32, - use_gpu=False) - self.evaluate(conv) - class DepthwiseConv2DTest(test.TestCase): @@ -1886,7 +1862,7 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding, def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding): def Test(self): - if test.is_gpu_available(cuda_only=True) and stride == 1: + if stride == 1: tf_logging.info("Testing InceptionFwd with dilations %s", (input_size, filter_size, stride, padding)) self._VerifyDilatedConvValues( diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index a91917b27faf46710d3f494b76929f4c7b9e9eec..0d9b46c30dbbed20dd940e0427fbf6f6d5415106 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -71,6 +71,7 @@ def _sparsify(x, thresh=0.5, index_dtype=np.int64): return sparse_tensor.SparseTensor( indices=x_indices, values=x_values, dense_shape=x_shape), x_values + def _default_tolerance(dtype): """Returns a sensible default tolerance for comparing results of a given type""" @@ -81,7 +82,7 @@ def _default_tolerance(dtype): elif dtype in (np.float64, np.complex128): return 1e-5 else: - return None # Fail fast for unexpected types + return None # Fail fast for unexpected types class UnaryOpTest(test.TestCase): @@ -233,10 +234,10 @@ class UnaryOpTest(test.TestCase): self._compareBoth(k, np.arccos, math_ops.acos) self._compareBoth(x, np.arctan, math_ops.atan) self._compareBoth(x, np.tan, math_ops.tan) - self._compareBoth( - y, - np.vectorize(self._replace_domain_error_with_inf(math.lgamma)), - math_ops.lgamma) + self._compareBoth(y, + np.vectorize( + self._replace_domain_error_with_inf(math.lgamma)), + math_ops.lgamma) self._compareBoth(x, np.vectorize(math.erf), math_ops.erf) self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc) @@ -298,8 +299,8 @@ class UnaryOpTest(test.TestCase): w = x - x.min() + 1.02 # all greater than 1 y = (x + .5).astype(np.float64) # no zero z = (x + 15.5).astype(np.float64) # all positive - k = np.arange(-0.90, 0.90, 0.35).reshape(1, 3, 2).astype( - np.float64) # between -1 and 1 + k = np.arange(-0.90, 0.90, + 0.35).reshape(1, 3, 2).astype(np.float64) # between -1 and 1 self._compareBoth(x, np.abs, math_ops.abs) self._compareBoth(x, np.abs, _ABS) self._compareBoth(x, np.negative, math_ops.negative) @@ -322,10 +323,10 @@ class UnaryOpTest(test.TestCase): self._compareBoth(y, np.sign, math_ops.sign) self._compareBoth(x, np.sin, math_ops.sin) self._compareBoth(x, np.cos, math_ops.cos) - self._compareBoth( - y, - np.vectorize(self._replace_domain_error_with_inf(math.lgamma)), - math_ops.lgamma) + self._compareBoth(y, + np.vectorize( + self._replace_domain_error_with_inf(math.lgamma)), + math_ops.lgamma) self._compareBoth(x, np.vectorize(math.erf), math_ops.erf) self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc) self._compareBoth(x, np.arctan, math_ops.atan) @@ -362,10 +363,10 @@ class UnaryOpTest(test.TestCase): self._compareBoth(y, np.sign, math_ops.sign) self._compareBoth(x, np.sin, math_ops.sin) self._compareBoth(x, np.cos, math_ops.cos) - self._compareBoth( - y, - np.vectorize(self._replace_domain_error_with_inf(math.lgamma)), - math_ops.lgamma) + self._compareBoth(y, + np.vectorize( + self._replace_domain_error_with_inf(math.lgamma)), + math_ops.lgamma) self._compareBoth(x, np.vectorize(math.erf), math_ops.erf) self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc) @@ -406,8 +407,8 @@ class UnaryOpTest(test.TestCase): self._compareBothSparse(x, np.sign, math_ops.sign) def testComplex64Basic(self): - x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, - 2).astype(np.complex64) + x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype( + np.complex64) y = x + np.complex(0.5, 0.5) # no zeros self._compareBoth(x, np.abs, math_ops.abs) self._compareBoth(x, np.abs, _ABS) @@ -450,8 +451,8 @@ class UnaryOpTest(test.TestCase): self._compareBothSparse(y, complex_sign, math_ops.sign) def testComplex128Basic(self): - x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, - 2).astype(np.complex128) + x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype( + np.complex128) y = x + np.complex(0.5, 0.5) # no zeros self._compareBoth(x, np.abs, math_ops.abs) self._compareBoth(x, np.abs, _ABS) @@ -805,10 +806,10 @@ class BinaryOpTest(test.TestCase): self._compareBoth(x, y, np.mod, _MOD) def testComplex64Basic(self): - x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape( - 1, 3, 2).astype(np.complex64) - y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape( - 1, 3, 2).astype(np.complex64) + x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype( + np.complex64) + y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype( + np.complex64) self._compareBoth(x, y, np.add, math_ops.add) self._compareBoth(x, y, np.subtract, math_ops.subtract) self._compareBoth(x, y, np.multiply, math_ops.multiply) @@ -819,10 +820,10 @@ class BinaryOpTest(test.TestCase): self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV) def testComplex128Basic(self): - x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape( - 1, 3, 2).astype(np.complex128) - y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape( - 1, 3, 2).astype(np.complex128) + x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype( + np.complex128) + y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype( + np.complex128) self._compareBoth(x, y, np.add, math_ops.add) self._compareBoth(x, y, np.subtract, math_ops.subtract) self._compareBoth(x, y, np.multiply, math_ops.multiply) @@ -1127,8 +1128,8 @@ class BinaryOpTest(test.TestCase): def testMismatchedDimensions(self): for func in [ - math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.div, - _ADD, _SUB, _MUL, _TRUEDIV, _FLOORDIV + math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.div, _ADD, + _SUB, _MUL, _TRUEDIV, _FLOORDIV ]: with self.assertRaisesWithPredicateMatch( ValueError, lambda e: "Dimensions must" in str(e)): @@ -1161,8 +1162,8 @@ class BinaryOpTest(test.TestCase): (1.2345, float("inf")), (1.2345, -float("inf")), (-4.321, float("inf")), (-4.125, -float("inf")), (float("inf"), float("inf")), (float("inf"), -float("inf")), - (-float("inf"), float("inf")), (-float("inf"), - -float("inf"))) + (-float("inf"), float("inf")), + (-float("inf"), -float("inf"))) for dtype in np.float32, np.float64: x1 = np.array(x1l).astype(dtype) x2 = np.array(x2l).astype(dtype) @@ -1213,22 +1214,22 @@ class ComparisonOpTest(test.TestCase): for x in data: for y in data: self.assertEqual(self._compareScalar(math_ops.less, x, y, t), x < y) - self.assertEqual(self._compareScalar(math_ops.less_equal, x, y, t), - x <= y) - self.assertEqual(self._compareScalar(math_ops.greater, x, y, t), - x > y) + self.assertEqual( + self._compareScalar(math_ops.less_equal, x, y, t), x <= y) + self.assertEqual( + self._compareScalar(math_ops.greater, x, y, t), x > y) self.assertEqual( self._compareScalar(math_ops.greater_equal, x, y, t), x >= y) self.assertEqual(self._compareScalar(math_ops.equal, x, y, t), x == y) - self.assertEqual(self._compareScalar(math_ops.not_equal, x, y, t), - x != y) + self.assertEqual( + self._compareScalar(math_ops.not_equal, x, y, t), x != y) data = [-1, 0, 1, -1j, 1j, 1 + 1j, 1 - 1j] for t in [np.complex64, np.complex128]: for x in data: for y in data: self.assertEqual(self._compareScalar(math_ops.equal, x, y, t), x == y) - self.assertEqual(self._compareScalar(math_ops.not_equal, x, y, t), - x != y) + self.assertEqual( + self._compareScalar(math_ops.not_equal, x, y, t), x != y) def _compare(self, x, y, np_func, tf_func): np_ans = np_func(x, y) @@ -1311,8 +1312,8 @@ class ComparisonOpTest(test.TestCase): self._testBCastByFunc(np.equal, math_ops.equal, include_complex=True) def testBCastNotEqual(self): - self._testBCastByFunc(np.not_equal, math_ops.not_equal, - include_complex=True) + self._testBCastByFunc( + np.not_equal, math_ops.not_equal, include_complex=True) def testShapeMismatch(self): dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64] @@ -1771,9 +1772,8 @@ class MathOpsOverloadTest(test.TestCase): def _compareUnary(self, x, dtype, np_func, tf_func): np_ans = np_func(x).astype(dtype.as_numpy_dtype) with self.test_session(use_gpu=False): - self.assertAllClose( - np_ans, tf_func(ops.convert_to_tensor( - x, dtype=dtype)).eval()) + self.assertAllClose(np_ans, + tf_func(ops.convert_to_tensor(x, dtype=dtype)).eval()) def testOverload(self): dtypes = [ @@ -1795,8 +1795,8 @@ class MathOpsOverloadTest(test.TestCase): ] for dtype in dtypes: for np_func, tf_func in funcs: - if dtype in (dtypes_lib.complex64, dtypes_lib.complex128 - ) and tf_func == _FLOORDIV: + if dtype in (dtypes_lib.complex64, + dtypes_lib.complex128) and tf_func == _FLOORDIV: continue # floordiv makes no sense for complex self._compareBinary(10, 5, dtype, np_func, tf_func) # Mod only works for int32 and int64. @@ -2008,7 +2008,8 @@ class ComplexMakeRealImagTest(test.TestCase): # self._compareAngle(cplx, use_gpu=True) def testRealReal(self): - for dtype in dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.float32, dtypes_lib.float64: + for dtype in (dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.float32, + dtypes_lib.float64): x = array_ops.placeholder(dtype) y = math_ops.real(x) self.assertEqual(x, y) @@ -2037,15 +2038,16 @@ class ComplexMakeRealImagTest(test.TestCase): self._compareConj(cplx, use_gpu=True) def testConjReal(self): - for dtype in dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.float16, dtypes_lib.float32, dtypes_lib.float64: + for dtype in (dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.float16, + dtypes_lib.float32, dtypes_lib.float64): x = array_ops.placeholder(dtype) y = math_ops.conj(x) self.assertEqual(x, y) def testConjString(self): x = array_ops.placeholder(dtypes_lib.string) - with self.assertRaisesRegexp( - TypeError, r"Expected numeric or variant tensor"): + with self.assertRaisesRegexp(TypeError, + r"Expected numeric or variant tensor"): math_ops.conj(x) def _compareGradient(self, x): @@ -2060,8 +2062,9 @@ class ComplexMakeRealImagTest(test.TestCase): real, imag = array_ops.reshape(real, [-1]), array_ops.reshape(imag, [-1]) cplx = math_ops.complex(real, imag) cplx = math_ops.conj(cplx) - loss = math_ops.reduce_sum(math_ops.square(math_ops.real( - cplx))) + math_ops.reduce_sum(math_ops.square(math_ops.imag(cplx))) + loss = math_ops.reduce_sum(math_ops.square( + math_ops.real(cplx))) + math_ops.reduce_sum( + math_ops.square(math_ops.imag(cplx))) epsilon = 1e-3 jacob_t, jacob_n = gradient_checker.compute_gradient( inx, list(x.shape), loss, [1], x_init_value=x, delta=epsilon) @@ -2125,8 +2128,8 @@ class AccumulateTest(test.TestCase): np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20) ] random_tensors = [ - ops.convert_to_tensor( - x, dtype=dtypes_lib.float32) for x in random_arrays + ops.convert_to_tensor(x, dtype=dtypes_lib.float32) + for x in random_arrays ] tf_val = math_ops.accumulate_n(random_tensors) np_val = random_arrays[0] diff --git a/tensorflow/python/kernel_tests/decode_bmp_op_test.py b/tensorflow/python/kernel_tests/decode_bmp_op_test.py index c67c26b7be0777587eb6d7c49119ad6cd2e22953..35f8f76991a679e4164da4c63bacbe79fb5cd2c2 100644 --- a/tensorflow/python/kernel_tests/decode_bmp_op_test.py +++ b/tensorflow/python/kernel_tests/decode_bmp_op_test.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops from tensorflow.python.ops import image_ops from tensorflow.python.platform import test diff --git a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py index ead55cd03b656a18d622b9d35c1b94f9cf2f5107..510daf79dc4252c3e2943e2ba23c1012370bf456 100644 --- a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py +++ b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import time +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py index 0c7025f54e672bb09e601715a58864673a670d12..122a9ed46967fc9c02c59ea3047216cb73a72293 100644 --- a/tensorflow/python/kernel_tests/decode_raw_op_test.py +++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import numpy as np -import sys from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py index 6cfa9b37fe0e40f4f0e5e2ad2686819e5f6d4f12..0825d8fc6bea008532fd7428236dfb569f2a471e 100644 --- a/tensorflow/python/kernel_tests/diag_op_test.py +++ b/tensorflow/python/kernel_tests/diag_op_test.py @@ -84,11 +84,8 @@ class MatrixSetDiagTest(test.TestCase): def testSquare(self): with self.test_session(use_gpu=True): v = np.array([1.0, 2.0, 3.0]) - mat = np.array([[0.0, 1.0, 0.0], - [1.0, 0.0, 1.0], - [1.0, 1.0, 1.0]]) - mat_set_diag = np.array([[1.0, 1.0, 0.0], - [1.0, 2.0, 1.0], + mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]]) + mat_set_diag = np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]]) output = array_ops.matrix_set_diag(mat, v) self.assertEqual((3, 3), output.get_shape()) @@ -135,19 +132,12 @@ class MatrixSetDiagTest(test.TestCase): def testRectangularBatch(self): with self.test_session(use_gpu=True): - v_batch = np.array([[-1.0, -2.0], - [-4.0, -5.0]]) - mat_batch = np.array( - [[[1.0, 0.0, 3.0], - [0.0, 2.0, 0.0]], - [[4.0, 0.0, 4.0], - [0.0, 5.0, 0.0]]]) - - mat_set_diag_batch = np.array( - [[[-1.0, 0.0, 3.0], - [0.0, -2.0, 0.0]], - [[-4.0, 0.0, 4.0], - [0.0, -5.0, 0.0]]]) + v_batch = np.array([[-1.0, -2.0], [-4.0, -5.0]]) + mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]], + [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]]) + + mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]], + [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]]) output = array_ops.matrix_set_diag(mat_batch, v_batch) self.assertEqual((2, 2, 3), output.get_shape()) self.assertAllEqual(mat_set_diag_batch, output.eval()) @@ -178,10 +168,14 @@ class MatrixSetDiagTest(test.TestCase): np.random.rand(*diag_shape), dtype=dtypes_lib.float32) y = array_ops.matrix_set_diag(x, x_diag) error_x = gradient_checker.compute_gradient_error( - x, x.get_shape().as_list(), y, y.get_shape().as_list()) + x, + x.get_shape().as_list(), y, + y.get_shape().as_list()) self.assertLess(error_x, 1e-4) error_x_diag = gradient_checker.compute_gradient_error( - x_diag, x_diag.get_shape().as_list(), y, y.get_shape().as_list()) + x_diag, + x_diag.get_shape().as_list(), y, + y.get_shape().as_list()) self.assertLess(error_x_diag, 1e-4) def testGradWithNoShapeInformation(self): @@ -192,12 +186,13 @@ class MatrixSetDiagTest(test.TestCase): output = array_ops.matrix_set_diag(mat, v) grads = gradients_impl.gradients(output, [mat, v], grad_ys=grad_input) grad_input_val = np.random.rand(3, 3).astype(np.float32) - grad_vals = sess.run(grads, - feed_dict={ - v: 2 * np.ones(3), - mat: np.ones((3, 3)), - grad_input: grad_input_val - }) + grad_vals = sess.run( + grads, + feed_dict={ + v: 2 * np.ones(3), + mat: np.ones((3, 3)), + grad_input: grad_input_val + }) self.assertAllEqual(np.diag(grad_input_val), grad_vals[1]) self.assertAllEqual(grad_input_val - np.diag(np.diag(grad_input_val)), grad_vals[0]) @@ -242,13 +237,9 @@ class MatrixDiagPartTest(test.TestCase): def testRectangularBatch(self): with self.test_session(use_gpu=True): - v_batch = np.array([[1.0, 2.0], - [4.0, 5.0]]) - mat_batch = np.array( - [[[1.0, 0.0, 0.0], - [0.0, 2.0, 0.0]], - [[4.0, 0.0, 0.0], - [0.0, 5.0, 0.0]]]) + v_batch = np.array([[1.0, 2.0], [4.0, 5.0]]) + mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]], + [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0]]]) self.assertEqual(mat_batch.shape, (2, 2, 3)) mat_batch_diag = array_ops.matrix_diag_part(mat_batch) self.assertEqual((2, 2), mat_batch_diag.get_shape()) @@ -301,19 +292,13 @@ class DiagTest(test.TestCase): def testRankOneIntTensor(self): x = np.array([1, 2, 3]) - expected_ans = np.array( - [[1, 0, 0], - [0, 2, 0], - [0, 0, 3]]) + expected_ans = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]]) self.diagOp(x, np.int32, expected_ans) self.diagOp(x, np.int64, expected_ans) def testRankOneFloatTensor(self): x = np.array([1.1, 2.2, 3.3]) - expected_ans = np.array( - [[1.1, 0, 0], - [0, 2.2, 0], - [0, 0, 3.3]]) + expected_ans = np.array([[1.1, 0, 0], [0, 2.2, 0], [0, 0, 3.3]]) self.diagOp(x, np.float32, expected_ans) self.diagOp(x, np.float64, expected_ans) @@ -321,123 +306,105 @@ class DiagTest(test.TestCase): for dtype in [np.complex64, np.complex128]: x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype=dtype) expected_ans = np.array( - [[1.1 + 1.1j, 0 + 0j, 0 + 0j], - [0 + 0j, 2.2 + 2.2j, 0 + 0j], - [0 + 0j, 0 + 0j, 3.3 + 3.3j]], dtype=dtype) + [[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 2.2 + 2.2j, 0 + 0j], + [0 + 0j, 0 + 0j, 3.3 + 3.3j]], + dtype=dtype) self.diagOp(x, dtype, expected_ans) def testRankTwoIntTensor(self): x = np.array([[1, 2, 3], [4, 5, 6]]) - expected_ans = np.array( - [[[[1, 0, 0], [0, 0, 0]], - [[0, 2, 0], [0, 0, 0]], - [[0, 0, 3], [0, 0, 0]]], - [[[0, 0, 0], [4, 0, 0]], - [[0, 0, 0], [0, 5, 0]], - [[0, 0, 0], [0, 0, 6]]]]) + expected_ans = np.array([[[[1, 0, 0], [0, 0, 0]], [[0, 2, 0], [0, 0, 0]], + [[0, 0, 3], [0, 0, 0]]], + [[[0, 0, 0], [4, 0, 0]], [[0, 0, 0], [0, 5, 0]], + [[0, 0, 0], [0, 0, 6]]]]) self.diagOp(x, np.int32, expected_ans) self.diagOp(x, np.int64, expected_ans) def testRankTwoFloatTensor(self): x = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]]) expected_ans = np.array( - [[[[1.1, 0, 0], [0, 0, 0]], - [[0, 2.2, 0], [0, 0, 0]], - [[0, 0, 3.3], [0, 0, 0]]], - [[[0, 0, 0], [4.4, 0, 0]], - [[0, 0, 0], [0, 5.5, 0]], - [[0, 0, 0], [0, 0, 6.6]]]]) + [[[[1.1, 0, 0], [0, 0, 0]], [[0, 2.2, 0], [0, 0, 0]], + [[0, 0, 3.3], [0, 0, 0]]], [[[0, 0, 0], [4.4, 0, 0]], + [[0, 0, 0], [0, 5.5, 0]], [[0, 0, 0], + [0, 0, 6.6]]]]) self.diagOp(x, np.float32, expected_ans) self.diagOp(x, np.float64, expected_ans) def testRankTwoComplexTensor(self): for dtype in [np.complex64, np.complex128]: - x = np.array([[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], - [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]], dtype=dtype) + x = np.array( + [[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], + [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]], + dtype=dtype) expected_ans = np.array( - [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], - [[0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], - [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]], - [[[0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]], - [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]], - [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]], - dtype=dtype) + [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], [ + [0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j] + ], [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]], [[ + [0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j] + ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j] + ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]], + dtype=dtype) self.diagOp(x, dtype, expected_ans) def testRankThreeFloatTensor(self): - x = np.array([[[1.1, 2.2], [3.3, 4.4]], - [[5.5, 6.6], [7.7, 8.8]]]) - expected_ans = np.array( - [[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]], - [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]], - [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]], - [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]], - [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]], - [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]], - [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]], - [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]]) + x = np.array([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]]) + expected_ans = np.array([[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]], + [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]], + [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]], + [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]], + [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]], + [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]], + [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]], + [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]]) self.diagOp(x, np.float32, expected_ans) self.diagOp(x, np.float64, expected_ans) def testRankThreeComplexTensor(self): for dtype in [np.complex64, np.complex128]: - x = np.array([[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]], - [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]], - dtype=dtype) + x = np.array( + [[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]], + [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]], + dtype=dtype) expected_ans = np.array( - [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]], - [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]], - [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]], - [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]], - [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]], - [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]], - [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]], - [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]]], - [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], - [[5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]]], - [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], - [[0 + 0j, 6.6 + 6.6j], [0 + 0j, 0 + 0j]]]], - [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], - [[0 + 0j, 0 + 0j], [7.7 + 7.7j, 0 + 0j]]], - [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], - [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]], + [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [ + 0 + 0j, 0 + 0j + ]]], [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [ + 0 + 0j, 0 + 0j + ]]]], [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]], [[0 + 0j, 0 + 0j], [ + 0 + 0j, 0 + 0j + ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]], [[0 + 0j, 0 + 0j], [ + 0 + 0j, 0 + 0j + ]]]]], [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [ + [5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j] + ]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 6.6 + 6.6j], [ + 0 + 0j, 0 + 0j + ]]]], [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [ + 7.7 + 7.7j, 0 + 0j + ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], + [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]], dtype=dtype) self.diagOp(x, dtype, expected_ans) def testRankFourNumberTensor(self): for dtype in [np.float32, np.float64, np.int64, np.int32]: # Input with shape [2, 1, 2, 3] - x = np.array([[[[ 1, 2, 3], - [ 4, 5, 6]]], - [[[ 7, 8, 9], - [10, 11, 12]]]], dtype=dtype) + x = np.array( + [[[[1, 2, 3], [4, 5, 6]]], [[[7, 8, 9], [10, 11, 12]]]], dtype=dtype) # Output with shape [2, 1, 2, 3, 2, 1, 2, 3] expected_ans = np.array( - [[[[[[[[1, 0, 0], [0, 0, 0]]], - [[[0, 0, 0], [0, 0, 0]]]], - [[[[0, 2, 0], [0, 0, 0]]], - [[[0, 0, 0], [0, 0, 0]]]], - [[[[0, 0, 3], [0, 0, 0]]], - [[[0, 0, 0], [0, 0, 0]]]]], - [[[[[0, 0, 0], [4, 0, 0]]], - [[[0, 0, 0], [0, 0, 0]]]], - [[[[0, 0, 0], [0, 5, 0]]], - [[[0, 0, 0], [0, 0, 0]]]], - [[[[0, 0, 0], [0, 0, 6]]], - [[[0, 0, 0], [0, 0, 0]]]]]]], - - [[[[[[[0, 0, 0], [0, 0, 0]]], - [[[7, 0, 0], [0, 0, 0]]]], - [[[[0, 0, 0], [0, 0, 0]]], - [[[0, 8, 0], [0, 0, 0]]]], - [[[[0, 0, 0], [0, 0, 0]]], - [[[0, 0, 9], [0, 0, 0]]]]], - [[[[[0, 0, 0], [0, 0, 0]]], - [[[0, 0, 0], [10, 0, 0]]]], - [[[[0, 0, 0], [0, 0, 0]]], - [[[0, 0, 0], [0, 11, 0]]]], - [[[[0, 0, 0], [0, 0, 0]]], - [[[0, 0, 0], [0, 0, 12]]]]]]]], dtype=dtype) + [[[[[[[[1, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [ + [[[0, 2, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]] + ], [[[[0, 0, 3], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]]], [[ + [[[0, 0, 0], [4, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]] + ], [[[[0, 0, 0], [0, 5, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [ + [[[0, 0, 0], [0, 0, 6]]], [[[0, 0, 0], [0, 0, 0]]] + ]]]], [[[[[[[0, 0, 0], [0, 0, 0]]], [[[7, 0, 0], [0, 0, 0]]]], [ + [[[0, 0, 0], [0, 0, 0]]], [[[0, 8, 0], [0, 0, 0]]] + ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 9], [0, 0, 0]]]]], [[ + [[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [10, 0, 0]]] + ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 11, 0]]] + ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 12]]]]]]]], + dtype=dtype) self.diagOp(x, dtype, expected_ans) def testInvalidRank(self): @@ -537,7 +504,9 @@ class DiagGradOpTest(test.TestCase): x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype) y = array_ops.diag(x1) error = gradient_checker.compute_gradient_error( - x1, x1.get_shape().as_list(), y, y.get_shape().as_list()) + x1, + x1.get_shape().as_list(), y, + y.get_shape().as_list()) tf_logging.info("error = %f", error) self.assertLess(error, 1e-4) @@ -555,7 +524,9 @@ class DiagGradPartOpTest(test.TestCase): x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype) y = array_ops.diag_part(x1) error = gradient_checker.compute_gradient_error( - x1, x1.get_shape().as_list(), y, y.get_shape().as_list()) + x1, + x1.get_shape().as_list(), y, + y.get_shape().as_list()) tf_logging.info("error = %f", error) self.assertLess(error, 1e-4) diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py index a269d722737866fa5e6ae9feee919be0db71bcf1..09812db8166567403dc966ac9cb4304be0740e50 100644 --- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py +++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py @@ -25,7 +25,6 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bernoulli from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.platform import test @@ -291,12 +290,6 @@ class BernoulliTest(test.TestCase): [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], dtype=np.float32)) - def testBernoulliWithSigmoidProbs(self): - p = np.array([8.3, 4.2]) - dist = bernoulli.BernoulliWithSigmoidProbs(logits=p) - with self.test_session(): - self.assertAllClose(math_ops.sigmoid(p).eval(), dist.probs.eval()) - def testBernoulliBernoulliKL(self): with self.test_session() as sess: batch_size = 6 diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py index 91a451f033ffbb01d54c3dacce952b406564b7b4..ab5041a6eb477ce231acbd1e6041c354ee17409b 100644 --- a/tensorflow/python/kernel_tests/distributions/beta_test.py +++ b/tensorflow/python/kernel_tests/distributions/beta_test.py @@ -107,8 +107,10 @@ class BetaTest(test.TestCase): dist.prob([-1., 0.1, 0.5]).eval() with self.assertRaisesOpError("sample must be positive"): dist.prob([0., 0.1, 0.5]).eval() - with self.assertRaisesOpError("sample must be no larger than `1`"): + with self.assertRaisesOpError("sample must be less than `1`"): dist.prob([.1, .2, 1.2]).eval() + with self.assertRaisesOpError("sample must be less than `1`"): + dist.prob([.1, .2, 1.0]).eval() def testPdfTwoBatches(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py index cf723f5eec3c31c93d67fd6a34a21c8377b74c84..a4b30e4319527c6f3354ac83bf0e3a5114eb45e8 100644 --- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -48,8 +48,10 @@ class DynamicStitchTestBase(object): def testShapeInferenceForScalarWithNonConstantIndices(self): with self.test_session(use_gpu=True): - indices = [array_ops.placeholder(dtype=dtypes.int32), - constant_op.constant(1)] + indices = [ + array_ops.placeholder(dtype=dtypes.int32), + constant_op.constant(1) + ] data = [constant_op.constant(40), constant_op.constant(60)] for step in -1, 1: stitched_t = self.stitch_op(indices[::step], data) @@ -61,7 +63,8 @@ class DynamicStitchTestBase(object): def testSimpleOneDimensional(self): with self.test_session(use_gpu=True): indices = [ - constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5]) + constant_op.constant([0, 4, 7]), + constant_op.constant([1, 6, 2, 3, 5]) ] data = [ constant_op.constant([0, 40, 70]), @@ -86,7 +89,8 @@ class DynamicStitchTestBase(object): def testSimpleTwoDimensional(self): with self.test_session(use_gpu=True): indices = [ - constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]), + constant_op.constant([0, 4, 7]), + constant_op.constant([1, 6]), constant_op.constant([2, 3, 5]) ] data = [ @@ -104,7 +108,8 @@ class DynamicStitchTestBase(object): def testHigherRank(self): with self.test_session(use_gpu=True) as sess: indices = [ - constant_op.constant(6), constant_op.constant([4, 1]), + constant_op.constant(6), + constant_op.constant([4, 1]), constant_op.constant([[5, 2], [0, 3]]) ] data = [ @@ -127,7 +132,8 @@ class DynamicStitchTestBase(object): def testErrorIndicesMultiDimensional(self): indices = [ - constant_op.constant([0, 4, 7]), constant_op.constant([[1, 6, 2, 3, 5]]) + constant_op.constant([0, 4, 7]), + constant_op.constant([[1, 6, 2, 3, 5]]) ] data = [ constant_op.constant([[0, 40, 70]]), @@ -138,7 +144,8 @@ class DynamicStitchTestBase(object): def testErrorDataNumDimsMismatch(self): indices = [ - constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5]) + constant_op.constant([0, 4, 7]), + constant_op.constant([1, 6, 2, 3, 5]) ] data = [ constant_op.constant([0, 40, 70]), @@ -149,7 +156,8 @@ class DynamicStitchTestBase(object): def testErrorDataDimSizeMismatch(self): indices = [ - constant_op.constant([0, 4, 5]), constant_op.constant([1, 6, 2, 3]) + constant_op.constant([0, 4, 5]), + constant_op.constant([1, 6, 2, 3]) ] data = [ constant_op.constant([[0], [40], [70]]), @@ -160,7 +168,8 @@ class DynamicStitchTestBase(object): def testErrorDataAndIndicesSizeMismatch(self): indices = [ - constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5]) + constant_op.constant([0, 4, 7]), + constant_op.constant([1, 6, 2, 3, 5]) ] data = [ constant_op.constant([0, 40, 70]), @@ -235,13 +244,15 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase): def testHigherRankGPU(self): with self.test_session() as sess: indices = [ - constant_op.constant(6), constant_op.constant([4, 1]), + constant_op.constant(6), + constant_op.constant([4, 1]), constant_op.constant([[5, 2], [0, 3]]) ] data = [ constant_op.constant([61, 62], dtype=dtypes.float32), constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32), - constant_op.constant([[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32) + constant_op.constant( + [[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32) ] stitched_t = data_flow_ops.dynamic_stitch(indices, data) stitched_val = stitched_t.eval() diff --git a/tensorflow/python/kernel_tests/extract_image_patches_op_test.py b/tensorflow/python/kernel_tests/extract_image_patches_op_test.py index 5c7624f1f6be4da91ca74d4ef2ed81a21890b35c..6ea9f1badc3b8fac06fe6328f95714b93de97c0e 100644 --- a/tensorflow/python/kernel_tests/extract_image_patches_op_test.py +++ b/tensorflow/python/kernel_tests/extract_image_patches_op_test.py @@ -84,7 +84,7 @@ class ExtractImagePatches(test.TestCase): patches=patches) def testKsize2x2Stride1x1Rate1x1Valid(self): - """Test for 1x1 kernel .""" + """Test for 2x2 kernel with VALID padding.""" # [1, 2, 2, 1] image = [[[[1], [2]], [[3], [4]]]] # [1, 1, 1, 4] @@ -98,7 +98,7 @@ class ExtractImagePatches(test.TestCase): patches=patches) def testKsize2x2Stride1x1Rate1x1Same(self): - """Test for 1x1 kernel .""" + """Test for 2x2 kernel with SAME padding.""" # [1, 2, 2, 1] image = [[[[1], [2]], [[3], [4]]]] # [1, 2, 2, 4] @@ -111,6 +111,20 @@ class ExtractImagePatches(test.TestCase): padding="SAME", patches=patches) + def testKsize2x2Stride1x1Rate2x2Valid(self): + """Test for 2x2 kernel with 2x2 dilation.""" + # [1, 2, 2, 1] + image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32) + # [1, 2, 2, 4] + patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]], + [[4, 6, 12, 14], [5, 7, 13, 15]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[2, 2], + padding="VALID", + patches=patches) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index 748135440ec5e8ad387f910e1433f638abf2260a..ce73e7ad3e5f822363c697609dfa163b6f13751a 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import random -import re import time import numpy as np diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py index f91875c6f0c1a7bfa388ec1b1a58f06b65889c3e..61944f7e3197844d00cbc001459e48b50c9003b4 100644 --- a/tensorflow/python/kernel_tests/io_ops_test.py +++ b/tensorflow/python/kernel_tests/io_ops_test.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 4e18eaa4e8281c799e4669b2d6083c00bc1e2863..fd1b5bab6f5aa072c8821eb053bd8d39391be4d4 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -39,6 +39,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + shard_count = 5, tags = ["noasan"], # times out b/63678675 ) @@ -57,6 +58,7 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], + shard_count = 5, ) cuda_py_test( @@ -73,6 +75,7 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], + shard_count = 5, ) cuda_py_test( @@ -88,6 +91,7 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], + shard_count = 5, ) cuda_py_test( @@ -134,6 +138,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + shard_count = 5, ) filegroup( diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 81af3a0887d09a7736a145a5b3c99c9391691724..197dbf44afaea2cfaf5a1ffebb6ac0a6be09d165 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -953,14 +953,14 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): # Compute the expected loss 'manually'. total = np.zeros((batch_size,)) for b in range(batch_size): - for i in range(dims): - for j in range(dims): + for i in range(dims - 1): + for j in range(i + 1, dims): x = self._predictions[b, i].item() - self._predictions[b, j].item() y = self._labels[b, i].item() - self._labels[b, j].item() diff = (x - y) total[b] += (diff * diff) - self._expected_losses = np.divide(total, 9.0) + self._expected_losses = np.divide(total, 3.0) def testValueErrorThrownWhenWeightIsNone(self): with self.test_session(): @@ -1059,8 +1059,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): [[4, 8, 12], [1, 2, 3], [4, 5, 6]], [[8, 1, 3], [7, 8, 9], [10, 11, 12]], ]) - self._test_valid_weights( - labels, predictions, expected_loss=122.22222) + self._test_valid_weights(labels, predictions, expected_loss=137.5) def test3dWeightedScalar(self): labels = np.array([ @@ -1073,8 +1072,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): ]) weight = 3.0 self._test_valid_weights( - labels, predictions, expected_loss=weight * 122.22222, - weights=weight) + labels, predictions, expected_loss=weight * 137.5, weights=weight) def _test_invalid_weights( self, labels, predictions, weights=1.0): @@ -1124,7 +1122,9 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): ]) self._test_valid_weights( # TODO(ptucker): This doesn't look right. - labels, predictions, expected_loss=9 * 122.22222, + labels, + predictions, + expected_loss=9 * 137.5, weights=np.ones((2, 3, 3))) def testLossWithAllZeroBatchSpecificWeights(self): @@ -1345,6 +1345,34 @@ class ComputeWeightedLossTest(test.TestCase): self.assertAllClose( np.mean(self._raw_losses), unweighted_loss.eval()) + def testUnweightedFromPlaceholder(self): + for reduction in losses.Reduction.all(): + with ops.Graph().as_default() as g: + self.assertEqual(0, len(util.get_losses())) + raw_losses = array_ops.placeholder(dtype=dtypes.float32) + feed_dict = {raw_losses: self._raw_losses} + unweighted_losses = ( + losses.compute_weighted_loss(raw_losses, reduction=reduction), + losses.compute_weighted_loss( + raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction), + losses.compute_weighted_loss( + raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction), + ) + self.assertEqual(3, len(util.get_losses())) + with self.test_session(g): + for unweighted_loss in unweighted_losses: + if reduction == losses.Reduction.NONE: + self.assertAllClose( + self._raw_losses, unweighted_loss.eval(feed_dict)) + elif reduction == losses.Reduction.SUM: + self.assertAllClose( + np.sum(self._raw_losses), unweighted_loss.eval(feed_dict)) + else: + # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS, + # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE. + self.assertAllClose( + np.mean(self._raw_losses), unweighted_loss.eval(feed_dict)) + def testScalarWeight(self): with ops.Graph().as_default(): self.assertEqual(0, len(util.get_losses())) diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b8200ac0cb1e4315a56181779c70da1126d8fc15 --- /dev/null +++ b/tensorflow/python/kernel_tests/manip_ops_test.py @@ -0,0 +1,138 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for manip_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import manip_ops +from tensorflow.python.platform import test as test_lib + +# pylint: disable=g-import-not-at-top +try: + from distutils.version import StrictVersion as Version + # numpy.roll for multiple shifts was introduced in numpy version 1.12.0 + NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version("1.12.0") +except ImportError: + NP_ROLL_CAN_MULTISHIFT = False +# pylint: enable=g-import-not-at-top + + +class RollTest(test_util.TensorFlowTestCase): + + def _testRoll(self, np_input, shift, axis): + expected_roll = np.roll(np_input, shift, axis) + with self.test_session(): + roll = manip_ops.roll(np_input, shift, axis) + self.assertAllEqual(roll.eval(), expected_roll) + + def _testGradient(self, np_input, shift, axis): + with self.test_session(): + inx = constant_op.constant(np_input.tolist()) + xs = list(np_input.shape) + y = manip_ops.roll(inx, shift, axis) + # Expected y's shape to be the same + ys = xs + jacob_t, jacob_n = gradient_checker.compute_gradient( + inx, xs, y, ys, x_init_value=np_input) + self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5) + + def _testAll(self, np_input, shift, axis): + self._testRoll(np_input, shift, axis) + if np_input.dtype == np.float32: + self._testGradient(np_input, shift, axis) + + def testIntTypes(self): + for t in [np.int32, np.int64]: + self._testAll(np.random.randint(-100, 100, (5)).astype(t), 3, 0) + if NP_ROLL_CAN_MULTISHIFT: + self._testAll( + np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -2, 3], + [0, 1, 2]) + self._testAll( + np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2], + [1, 2, 3]) + + def testFloatTypes(self): + for t in [np.float32, np.float64]: + self._testAll(np.random.rand(5).astype(t), 2, 0) + if NP_ROLL_CAN_MULTISHIFT: + self._testAll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0]) + self._testAll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2]) + + def testComplexTypes(self): + for t in [np.complex64, np.complex128]: + x = np.random.rand(4, 4).astype(t) + self._testAll(x + 1j * x, 2, 0) + if NP_ROLL_CAN_MULTISHIFT: + x = np.random.rand(2, 5).astype(t) + self._testAll(x + 1j * x, [1, 2], [1, 0]) + x = np.random.rand(3, 2, 1, 1).astype(t) + self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2]) + + def testRollInputMustVectorHigherRaises(self): + tensor = 7 + shift = 1 + axis = 0 + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "input must be 1-D or higher"): + manip_ops.roll(tensor, shift, axis).eval() + + def testRollAxisMustBeScalarOrVectorRaises(self): + tensor = [[1, 2], [3, 4]] + shift = 1 + axis = [[0, 1]] + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "axis must be a scalar or a 1-D vector"): + manip_ops.roll(tensor, shift, axis).eval() + + def testRollShiftMustBeScalarOrVectorRaises(self): + tensor = [[1, 2], [3, 4]] + shift = [[0, 1]] + axis = 1 + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "shift must be a scalar or a 1-D vector"): + manip_ops.roll(tensor, shift, axis).eval() + + def testRollShiftAndAxisMustBeSameSizeRaises(self): + tensor = [[1, 2], [3, 4]] + shift = [1] + axis = [0, 1] + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "shift and axis must have the same size"): + manip_ops.roll(tensor, shift, axis).eval() + + def testRollAxisOutOfRangeRaises(self): + tensor = [1, 2] + shift = 1 + axis = 1 + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "is out of range"): + manip_ops.roll(tensor, shift, axis).eval() + + +if __name__ == "__main__": + test_lib.main() diff --git a/tensorflow/python/kernel_tests/map_stage_op_test.py b/tensorflow/python/kernel_tests/map_stage_op_test.py index 8b669450590f1fce0f14a9e5d64e1055dbe23f4e..acfafde9e0f74d4e3ad6f2ee8ada9da3df94f5b9 100644 --- a/tensorflow/python/kernel_tests/map_stage_op_test.py +++ b/tensorflow/python/kernel_tests/map_stage_op_test.py @@ -26,6 +26,7 @@ from tensorflow.python.platform import test TIMEOUT = 1 + class MapStageTest(test.TestCase): def testSimple(self): @@ -83,7 +84,7 @@ class MapStageTest(test.TestCase): [dtypes.float32, dtypes.float32], shapes=[[], [128, 128]], names=['x', 'v']) - stage = stager.put(pi,{'x': x, 'v': v}) + stage = stager.put(pi, {'x': x, 'v': v}) key, ret = stager.get(gi) z = ret['x'] y = ret['v'] @@ -128,8 +129,11 @@ class MapStageTest(test.TestCase): gi = array_ops.placeholder(dtypes.int64) p = array_ops.placeholder(dtypes.int32, name='p') with ops.device(test.gpu_device_name()): - stager = data_flow_ops.MapStagingArea([dtypes.int32, ], shapes=[[]]) - stage = stager.put(pi,[x], [0]) + stager = data_flow_ops.MapStagingArea( + [ + dtypes.int32, + ], shapes=[[]]) + stage = stager.put(pi, [x], [0]) peek = stager.peek(gi) size = stager.size() @@ -158,7 +162,7 @@ class MapStageTest(test.TestCase): [dtypes.float32, dtypes.float32], shapes=[[], [128, 128]], names=['x', 'v']) - stage = stager.put(pi,{'x': x, 'v': v}) + stage = stager.put(pi, {'x': x, 'v': v}) size = stager.size() clear = stager.clear() @@ -172,7 +176,6 @@ class MapStageTest(test.TestCase): sess.run(clear) self.assertEqual(sess.run(size), 0) - def testCapacity(self): capacity = 3 @@ -182,8 +185,10 @@ class MapStageTest(test.TestCase): pi = array_ops.placeholder(dtypes.int64, name='pi') gi = array_ops.placeholder(dtypes.int64, name='gi') with ops.device(test.gpu_device_name()): - stager = data_flow_ops.MapStagingArea([dtypes.int32, ], - capacity=capacity, shapes=[[]]) + stager = data_flow_ops.MapStagingArea( + [ + dtypes.int32, + ], capacity=capacity, shapes=[[]]) stage = stager.put(pi, [x], [0]) get = stager.get() @@ -222,9 +227,8 @@ class MapStageTest(test.TestCase): self.fail("Expected to timeout on iteration '{}' " "but instead timed out on iteration '{}' " "Staging Area size is '{}' and configured " - "capacity is '{}'.".format(capacity, i, - sess.run(size), - capacity)) + "capacity is '{}'.".format(capacity, i, sess.run(size), + capacity)) # Should have capacity elements in the staging area self.assertTrue(sess.run(size) == capacity) @@ -236,8 +240,8 @@ class MapStageTest(test.TestCase): self.assertTrue(sess.run(size) == 0) def testMemoryLimit(self): - memory_limit = 512*1024 # 512K - chunk = 200*1024 # 256K + memory_limit = 512 * 1024 # 512K + chunk = 200 * 1024 # 256K capacity = memory_limit // chunk with ops.Graph().as_default() as G: @@ -246,8 +250,8 @@ class MapStageTest(test.TestCase): pi = array_ops.placeholder(dtypes.int64, name='pi') gi = array_ops.placeholder(dtypes.int64, name='gi') with ops.device(test.gpu_device_name()): - stager = data_flow_ops.MapStagingArea([dtypes.uint8], - memory_limit=memory_limit, shapes=[[]]) + stager = data_flow_ops.MapStagingArea( + [dtypes.uint8], memory_limit=memory_limit, shapes=[[]]) stage = stager.put(pi, [x], [0]) get = stager.get() size = stager.size() @@ -287,9 +291,8 @@ class MapStageTest(test.TestCase): self.fail("Expected to timeout on iteration '{}' " "but instead timed out on iteration '{}' " "Staging Area size is '{}' and configured " - "capacity is '{}'.".format(capacity, i, - sess.run(size), - capacity)) + "capacity is '{}'.".format(capacity, i, sess.run(size), + capacity)) # Should have capacity elements in the staging area self.assertTrue(sess.run(size) == capacity) @@ -310,8 +313,10 @@ class MapStageTest(test.TestCase): pi = array_ops.placeholder(dtypes.int64, name='pi') gi = array_ops.placeholder(dtypes.int64, name='gi') with ops.device(test.gpu_device_name()): - stager = data_flow_ops.MapStagingArea([dtypes.int32, ], - shapes=[[]], ordered=True) + stager = data_flow_ops.MapStagingArea( + [ + dtypes.int32, + ], shapes=[[]], ordered=True) stage = stager.put(pi, [x], [0]) get = stager.get() size = stager.size() @@ -349,7 +354,7 @@ class MapStageTest(test.TestCase): stager = data_flow_ops.MapStagingArea( [dtypes.float32, dtypes.float32, dtypes.float32], names=['x', 'v', 'f']) - stage_xf = stager.put(pi,{'x': x, 'f': f}) + stage_xf = stager.put(pi, {'x': x, 'f': f}) stage_v = stager.put(pi, {'v': v}) key, ret = stager.get(gi) size = stager.size() @@ -373,12 +378,13 @@ class MapStageTest(test.TestCase): self.assertTrue(sess.run([size, isize]) == [1, 1]) # We can now obtain tuple associated with key 0 self.assertTrue( - sess.run([key, ret], - feed_dict={gi: 0}) == [0, { - 'x': 1, - 'f': 2, - 'v': 1 - }]) + sess.run([key, ret], feed_dict={ + gi: 0 + }) == [0, { + 'x': 1, + 'f': 2, + 'v': 1 + }]) # 0 complete and 1 incomplete entry self.assertTrue(sess.run([size, isize]) == [0, 1]) @@ -386,12 +392,13 @@ class MapStageTest(test.TestCase): sess.run(stage_v, feed_dict={pi: 1, v: 3}) # We can now obtain tuple associated with key 1 self.assertTrue( - sess.run([key, ret], - feed_dict={gi: 1}) == [1, { - 'x': 1, - 'f': 2, - 'v': 3 - }]) + sess.run([key, ret], feed_dict={ + gi: 1 + }) == [1, { + 'x': 1, + 'f': 2, + 'v': 3 + }]) def testPartialIndexInsert(self): with ops.Graph().as_default() as G: @@ -450,7 +457,7 @@ class MapStageTest(test.TestCase): stager = data_flow_ops.MapStagingArea( [dtypes.float32, dtypes.float32, dtypes.float32], names=['x', 'v', 'f']) - stage_xf = stager.put(pi,{'x': x, 'f': f}) + stage_xf = stager.put(pi, {'x': x, 'f': f}) stage_v = stager.put(pi, {'v': v}) peek_xf = stager.peek(pei, ['x', 'f']) peek_v = stager.peek(pei, ['v']) @@ -487,11 +494,12 @@ class MapStageTest(test.TestCase): # We can now obtain 'x' and 'f' values associated with key 0 self.assertTrue( - sess.run([key_xf, get_xf], - feed_dict={gi: 0}) == [0, { - 'x': 1, - 'f': 2 - }]) + sess.run([key_xf, get_xf], feed_dict={ + gi: 0 + }) == [0, { + 'x': 1, + 'f': 2 + }]) # Still have 1 complete and 1 incomplete entry self.assertTrue(sess.run([size, isize]) == [1, 1]) @@ -499,14 +507,15 @@ class MapStageTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError) as cm: sess.run([key_xf, get_xf], feed_dict={gi: 0}) - exc_str = ("Tensor at index '0' for key '0' " - "has already been removed.") + exc_str = ("Tensor at index '0' for key '0' " 'has already been removed.') self.assertTrue(exc_str in cm.exception.message) # Obtain 'v' value associated with key 0 self.assertTrue( - sess.run([key_v, get_v], feed_dict={gi: 0}) == [0, { + sess.run([key_v, get_v], feed_dict={ + gi: 0 + }) == [0, { 'v': 1 }]) # 0 complete and 1 incomplete entry @@ -523,7 +532,9 @@ class MapStageTest(test.TestCase): self.assertTrue(sess.run([size, isize]) == [1, 0]) # We can now obtain 'x' and 'f' values associated with key 1 self.assertTrue( - sess.run([pop_key_v, pop_v], feed_dict={pi: 1}) == [1, { + sess.run([pop_key_v, pop_v], feed_dict={ + pi: 1 + }) == [1, { 'v': 1 }]) # Nothing is left @@ -557,18 +568,20 @@ class MapStageTest(test.TestCase): self.assertTrue(sess.run([size, isize]) == [1, 0]) # Partial get using indices - self.assertTrue(sess.run([key_xf, get_xf], - feed_dict={gi: 0}) == [0, [1, 2]]) + self.assertTrue( + sess.run([key_xf, get_xf], feed_dict={ + gi: 0 + }) == [0, [1, 2]]) # Still some of key 0 left self.assertTrue(sess.run([size, isize]) == [1, 0]) # Partial get of remaining index - self.assertTrue(sess.run([key_v, get_v], - feed_dict={gi: 0}) == [0, [3]]) + self.assertTrue(sess.run([key_v, get_v], feed_dict={gi: 0}) == [0, [3]]) # All gone self.assertTrue(sess.run([size, isize]) == [0, 0]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py index 317b8dc05beac7642c384bf89e6d154be50f6992..68d626de2c5cdd91ee332247c05ddce2a558a35e 100644 --- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py @@ -21,6 +21,7 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -54,9 +55,13 @@ def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_): band_np = np.tril(band_np, upper) if batch_shape_ is not (): band_np = np.tile(band_np, batch_shape_ + (1, 1)) - with self.test_session(use_gpu=False): - band = array_ops.matrix_band_part(batch_mat, lower, upper) - self.assertAllEqual(band_np, band.eval()) + for index_dtype in [dtypes_lib.int32, dtypes_lib.int64]: + with self.test_session(use_gpu=False): + band = array_ops.matrix_band_part( + batch_mat, + constant_op.constant(lower, index_dtype), + constant_op.constant(upper, index_dtype)) + self.assertAllEqual(band_np, band.eval()) return Test diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index e0e752147cdf8690d22fa782aca2561b2935fa8e..fd78c026c273da1ffecf9e1dfe8c9e6042a4be69 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -1105,9 +1105,9 @@ class AUCTest(test.TestCase): auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.54166, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.54166, auc.eval(), delta=1e-3) def testAnotherAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1119,9 +1119,9 @@ class AUCTest(test.TestCase): auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3) def testThirdAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1133,9 +1133,26 @@ class AUCTest(test.TestCase): auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3) + + def testFourthAUCPRSpecialCase(self): + # Create the labels and data. + labels = np.array([ + 0, 0, 0, 0, 0, 0, 0, 1, 0, 1]) + predictions = np.array([ + 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35]) + + with self.test_session() as sess: + auc, _ = metrics.auc( + labels, predictions, curve='PR', num_thresholds=11) + + sess.run(variables.local_variables_initializer()) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(0.0, auc.eval(), delta=0.001) def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -1161,16 +1178,16 @@ class AUCTest(test.TestCase): self.assertAlmostEqual(1, auc.eval(), 6) - def testRecallOneAndPrecisionOneGivesOnePRAUC(self): + def testRecallOneAndPrecisionOne(self): with self.test_session() as sess: predictions = array_ops.ones([4], dtype=dtypes_lib.float32) labels = array_ops.ones([4]) auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(1, sess.run(update_op), 6) + self.assertAlmostEqual(0.5, sess.run(update_op), 6) - self.assertAlmostEqual(1, auc.eval(), 6) + self.assertAlmostEqual(0.5, auc.eval(), 6) def np_auc(self, predictions, labels, weights): """Computes the AUC explicitly using Numpy. diff --git a/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py index 30795eed8a063076a69ec2ec7851788775fe4dc6..d8ce9fffbd2bc0d18033339a02e0ad84f8f4c952 100644 --- a/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py @@ -148,7 +148,7 @@ class DepthwiseConv2DTest(test.TestCase): print("depthwise conv_2d: ", tensor_in_sizes, "*", filter_in_sizes, ", stride:", stride, ", padding: ", padding, ", max diff: ", np.amax(np.absolute(native_result - interface_result))) - self.assertArrayNear( + self.assertAllClose( np.ravel(native_result), np.ravel(interface_result), 1e-5) self.assertShapeEqual(native_result, conv_native) self.assertShapeEqual(native_result, conv_interface) @@ -213,7 +213,7 @@ class DepthwiseConv2DTest(test.TestCase): t1, t2, strides=[1, stride, stride, 1], padding=padding) value = sess.run(conv) print("value = ", value) - self.assertArrayNear(expected, np.ravel(value), 1e-5) + self.assertAllClose(expected, np.ravel(value), 1e-5) self.assertShapeEqual(value, conv) def testConv2D2x2Filter(self): diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py index 56a07cb012f08dec750c5ee18cc73b3b127ef5dd..f5c6255c346961fec7245889229ea1c4b89fa388 100644 --- a/tensorflow/python/kernel_tests/partitioned_variables_test.py +++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py @@ -50,8 +50,7 @@ class PartitionerCreatorsTest(test.TestCase): with self.test_session(): partitioner = partitioned_variables.fixed_size_partitioner(4, axis=0) with variable_scope.variable_scope("root", partitioner=partitioner): - v0 = variable_scope.get_variable( - "v0", dtype=dtypes.int64, shape=[20]) + v0 = variable_scope.get_variable("v0", dtype=dtypes.int64, shape=[20]) v0_list = v0._get_variable_list() self.assertEqual(len(v0_list), 4) @@ -169,8 +168,10 @@ class PartitionerCreatorsTest(test.TestCase): max_shards=2) # Use the partitioner with strings - partitioner_axis3_str = partitioned_variables.variable_axis_size_partitioner( - axis=3, max_shard_bytes=32768, bytes_per_string_element=8) + partitioner_axis3_str = partitioned_variables.variable_axis_size_partitioner( # pylint: disable=line-too-long + axis=3, + max_shard_bytes=32768, + bytes_per_string_element=8) with variable_scope.variable_scope( "root", partitioner=partitioner_axis3_str): @@ -423,8 +424,7 @@ class PartitionedVariablesTestCase(test.TestCase): def testRandomInitUnevenPartitions(self): with self.test_session(): rnd = variables.Variable( - random_ops.random_uniform( - [20, 43], dtype=dtypes.float64)) + random_ops.random_uniform([20, 43], dtype=dtypes.float64)) var_lists = [ partitioned_variables.create_partitioned_variables( rnd.get_shape(), [1, i], rnd.initialized_value()) diff --git a/tensorflow/python/kernel_tests/pool_test.py b/tensorflow/python/kernel_tests/pool_test.py index 63848976336f5487cf2a44f7cf62ea316c40d7c8..6ede654aadc7d0d78bc18f13c2d4b3d47fef0402 100644 --- a/tensorflow/python/kernel_tests/pool_test.py +++ b/tensorflow/python/kernel_tests/pool_test.py @@ -96,7 +96,7 @@ def pool_direct_single_axis( def pool_direct( - input, + input, # pylint: disable=redefined-builtin window_shape, pooling_type, padding, # pylint: disable=redefined-builtin diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 5c0ea8ec8edbd1a1f523630f61afbe28adf77a19..4466beeec96509b3761e34d885276e1510c62d10 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -159,8 +159,10 @@ class PoolingTest(test.TestCase): elif data_format == "NCHW": t = test_util.NCHWToNHWC(t) if v2: - actual = t.eval(feed_dict={ksize_placeholder: ksize, - strides_placeholder: strides}) + actual = t.eval(feed_dict={ + ksize_placeholder: ksize, + strides_placeholder: strides + }) else: actual = t.eval() self.assertShapeEqual(actual, t) @@ -195,8 +197,15 @@ class PoolingTest(test.TestCase): self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, data_format, dtypes.float16, expected, use_gpu, v2) - def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding, - expected, use_gpu, v2=False): + def _VerifyValues(self, + pool_func, + input_sizes, + ksize, + strides, + padding, + expected, + use_gpu, + v2=False): """Verifies the output values of the pooling function. Args: @@ -1148,16 +1157,16 @@ class PoolingTest(test.TestCase): def _testMaxPoolGradSamePadding3_1(self, data_format, use_gpu): for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]: self._ConstructAndTestGradient( - pool_func, - input_sizes=[1, 7, 7, 1], - output_sizes=[1, 7, 7, 1], - window_rows=3, - window_cols=3, - row_stride=1, - col_stride=1, - padding="SAME", - data_format=data_format, - use_gpu=use_gpu) + pool_func, + input_sizes=[1, 7, 7, 1], + output_sizes=[1, 7, 7, 1], + window_rows=3, + window_cols=3, + row_stride=1, + col_stride=1, + padding="SAME", + data_format=data_format, + use_gpu=use_gpu) def testMaxPoolGrad(self): for (data_format, use_gpu) in GetTestConfigs(): @@ -1202,17 +1211,14 @@ class PoolingTest(test.TestCase): pool_func = gen_nn_ops._max_pool_v2 if v2 else nn_ops.max_pool with self.test_session(use_gpu=use_gpu): input_tensor = constant_op.constant(input_data, shape=input_sizes) - output_tensor = pool_func(input_tensor, - [1, window_rows, window_cols, 1], + output_tensor = pool_func(input_tensor, [1, window_rows, window_cols, 1], [1, row_stride, col_stride, 1], padding) output_backprop_tensor = constant_op.constant( output_backprop, shape=output_sizes) - input_backprop_tensor = self._MaxPoolGrad(input_tensor, output_tensor, - output_backprop_tensor, - window_rows, window_cols, - row_stride, col_stride, - padding, v2) + input_backprop_tensor = self._MaxPoolGrad( + input_tensor, output_tensor, output_backprop_tensor, window_rows, + window_cols, row_stride, col_stride, padding, v2) actual_input_backprop = input_backprop_tensor.eval() self.assertShapeEqual(actual_input_backprop, input_backprop_tensor) @@ -1414,13 +1420,15 @@ class PoolingTest(test.TestCase): def _testMaxPoolGradDirectWithNans2_2(self): input_data = [float("nan")] * 16 output_backprop = [ - float("nan"), 12.0, 13.0, 15.0, float("nan"), 17.0, 19.0, 20.0, + float("nan"), 12.0, 13.0, 15.0, + float("nan"), 17.0, 19.0, 20.0, float("nan") ] # Test the CPU implementation, which propagates diffs in case of NaN expected_input_backprop_tf_cpu = [ - float("nan"), 12.0, 13.0, 0.0, 15.0, float("nan"), 17.0, 0.0, 19.0, - 20.0, float("nan"), 0.0, 0.0, 0.0, 0.0, 0.0 + float("nan"), 12.0, 13.0, 0.0, 15.0, + float("nan"), 17.0, 0.0, 19.0, 20.0, + float("nan"), 0.0, 0.0, 0.0, 0.0, 0.0 ] for v2 in [True, False]: self._testMaxPoolGradDirect( @@ -1636,10 +1644,9 @@ class PoolingTest(test.TestCase): Returns: A Tensor. """ - return gen_nn_ops._max_pool_grad_grad(orig_input, orig_output, grad, - [1, window_rows, window_cols, - 1], [1, row_stride, col_stride, - 1], padding) + return gen_nn_ops._max_pool_grad_grad( + orig_input, orig_output, grad, [1, window_rows, window_cols, 1], + [1, row_stride, col_stride, 1], padding) def testAvgPoolGrad(self): for (data_format, use_gpu) in GetTestConfigs(): @@ -1793,8 +1800,7 @@ class PoolingTest(test.TestCase): ]: with self.assertRaises(ValueError): pool_func( - array_ops.placeholder( - dtypes.float32, shape=[1, 3]), + array_ops.placeholder(dtypes.float32, shape=[1, 3]), ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1], padding="SAME") @@ -1805,30 +1811,29 @@ class PoolingTest(test.TestCase): if test.is_gpu_available(): pool_funcs.append(nn_ops.max_pool_with_argmax) for pool_func in pool_funcs: - # Illegal strides. - with self.assertRaisesRegexp( - errors_impl.UnimplementedError, - "Pooling is not yet supported on the batch"): - sess.run( - pool_func( - array_ops.placeholder(dtypes.float32), - ksize=[1, 1, 1, 1], - strides=[2, 1, 1, 1], - padding="SAME")) + if pool_func != nn_ops.max_pool: + # Illegal strides. + with self.assertRaisesRegexp( + errors_impl.UnimplementedError, + "Pooling is not yet supported on the batch"): + sess.run( + pool_func( + array_ops.placeholder(dtypes.float32), + ksize=[1, 1, 1, 1], + strides=[2, 1, 1, 1], + padding="SAME")) # Filter too large. with self.assertRaisesRegexp(ValueError, "Negative dimension size"): sess.run( pool_func( - array_ops.placeholder( - dtypes.float32, shape=[32, 20, 20, 3]), + array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]), ksize=[1, 20, 21, 1], strides=[1, 1, 1, 1], padding="VALID")) with self.assertRaisesRegexp(ValueError, "Negative dimension size"): pool_func( - array_ops.placeholder( - dtypes.float32, shape=[32, 20, 20, 3]), + array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]), ksize=[1, 21, 20, 1], strides=[1, 1, 1, 1], padding="VALID") diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 92fb68820e04c3db1385296d91d956134b8ff2d4..61fb3f12e45ea5ae3bc4f0a26c2116b54c003624 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -212,6 +213,16 @@ class PyFuncTest(test.TestCase): value.op.run() self.assertAllEqual(np_array, [1.0, 2.0]) + def testReturnUnicodeString(self): + with self.test_session(): + correct = u"你好 世界" + + def unicode_string(): + return correct + + z, = script_ops.py_func(unicode_string, [], [dtypes.string]) + self.assertEqual(z.eval(), correct.encode("utf8")) + def testBadNumpyReturnType(self): with self.test_session(): @@ -396,66 +407,66 @@ class PyFuncTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testEagerSingleOutputFloat32(self): - a = array_ops.ones((3, 3), dtype=dtypes.float32) - x = array_ops.ones((3, 1), dtype=dtypes.float32) - output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) - with self.test_session(): + with test_util.device(use_gpu=True): + a = array_ops.ones((3, 3), dtype=dtypes.float32) + x = array_ops.ones((3, 1), dtype=dtypes.float32) + output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) ret = self.evaluate(output) self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) @test_util.run_in_graph_and_eager_modes() def testEagerArrayOutput(self): - a = array_ops.ones((3, 3), dtype=dtypes.int32) - x = array_ops.ones((3, 1), dtype=dtypes.int32) - output = script_ops.eager_py_func( - lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.int32]) - - with self.test_session(): + with test_util.device(use_gpu=True): + a = array_ops.ones((3, 3), dtype=dtypes.float32) + x = array_ops.ones((3, 1), dtype=dtypes.float32) + output = script_ops.eager_py_func( + lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32]) ret = self.evaluate(output) - self.assertAllEqual(ret, [[[3], [3], [3]]]) + self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]]) @test_util.run_in_graph_and_eager_modes() def testEagerReturnNone(self): + with test_util.device(use_gpu=True): + def no_return_value(): + return - def no_return_value(): - return - - output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) - ret = self.evaluate(output) - if context.in_eager_mode(): - self.assertEquals(len(ret), 0) - else: - self.assertIsNone(ret) + output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) + ret = self.evaluate(output) + if context.in_eager_mode(): + self.assertEquals(len(ret), 0) + else: + self.assertIsNone(ret) @test_util.run_in_graph_and_eager_modes() def testEagerPyFuncInDefun(self): + with test_util.device(use_gpu=True): + def wrapper(): + a = array_ops.ones((3, 3), dtype=dtypes.float32) + x = array_ops.ones((3, 1), dtype=dtypes.float32) + return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) - def wrapper(): - a = array_ops.ones((3, 3), dtype=dtypes.int32) - x = array_ops.ones((3, 1), dtype=dtypes.int32) - return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32) - - wrapped = function.defun(wrapper) - ret = self.evaluate(wrapped()) - self.assertAllEqual(ret, [[3], [3], [3]]) + wrapped = function.defun(wrapper) + ret = self.evaluate(wrapped()) + self.assertAllEqual(ret, [[3.0], [3.0], [3.0]]) @test_util.run_in_graph_and_eager_modes() def testEagerExceptionHandling(self): - self._testExceptionHandling( - ValueError, errors.InvalidArgumentError, eager=True) - self._testExceptionHandling( - TypeError, errors.InvalidArgumentError, eager=True) - self._testExceptionHandling( - StopIteration, errors.OutOfRangeError, eager=True) - self._testExceptionHandling( - MemoryError, errors.ResourceExhaustedError, eager=True) - self._testExceptionHandling( - NotImplementedError, errors.UnimplementedError, eager=True) - - class WeirdError(Exception): - pass - - self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) + with test_util.device(use_gpu=True): + self._testExceptionHandling( + ValueError, errors.InvalidArgumentError, eager=True) + self._testExceptionHandling( + TypeError, errors.InvalidArgumentError, eager=True) + self._testExceptionHandling( + StopIteration, errors.OutOfRangeError, eager=True) + self._testExceptionHandling( + MemoryError, errors.ResourceExhaustedError, eager=True) + self._testExceptionHandling( + NotImplementedError, errors.UnimplementedError, eager=True) + + class WeirdError(Exception): + pass + + self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py index 5a2903a4234202c828168b6538baf320b961c776..df37dd98ece57ae7c3835ab63b720b29fc19c975 100644 --- a/tensorflow/python/kernel_tests/random/random_ops_test.py +++ b/tensorflow/python/kernel_tests/random/random_ops_test.py @@ -203,7 +203,8 @@ class RandomUniformTest(test.TestCase): return func def testRange(self): - for dt in dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64: + for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64): sampler = self._Sampler(1000, minv=-2, maxv=8, dtype=dt, use_gpu=True) x = sampler() self.assertTrue(-2 <= np.min(x)) @@ -213,7 +214,8 @@ class RandomUniformTest(test.TestCase): # to see the same sequence of values. Will catch buggy # implementations which uses the same random number seed. def testDistinct(self): - for dt in dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64: + for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64): maxv = 1.0 if dt.is_floating else 1 << 30 sampler = self._Sampler(1000, minv=0, maxv=maxv, dtype=dt, use_gpu=True) x = sampler() @@ -251,7 +253,8 @@ class RandomUniformTest(test.TestCase): # Checks that the CPU and GPU implementation returns the same results, # given the same random seed def testCPUGPUMatch(self): - for dt in dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64: + for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64): maxv = 1.0 if dt.is_floating else 17 results = {} for use_gpu in False, True: @@ -261,7 +264,8 @@ class RandomUniformTest(test.TestCase): self.assertAllEqual(results[False], results[True]) def testSeed(self): - for dt in dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64: + for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64): for seed in [345, 2**100, -2**100]: sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=True, seed=seed) sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=True, seed=seed) @@ -285,8 +289,7 @@ class RandomShapeTest(test.TestCase): self.assertEqual([1, 2, 3], rnd1.get_shape()) # Partially known shape. rnd2 = random_ops.truncated_normal( - array_ops.placeholder( - dtypes.int32, shape=(3,))) + array_ops.placeholder(dtypes.int32, shape=(3,))) self.assertEqual([None, None, None], rnd2.get_shape().as_list()) # Unknown shape. rnd3 = random_ops.truncated_normal(array_ops.placeholder(dtypes.int32)) @@ -298,8 +301,7 @@ class RandomShapeTest(test.TestCase): self.assertEqual([1, 2, 3], rnd1.get_shape()) # Partially known shape. rnd2 = random_ops.random_normal( - array_ops.placeholder( - dtypes.int32, shape=(3,))) + array_ops.placeholder(dtypes.int32, shape=(3,))) self.assertEqual([None, None, None], rnd2.get_shape().as_list()) # Unknown shape. rnd3 = random_ops.random_normal(array_ops.placeholder(dtypes.int32)) @@ -311,8 +313,7 @@ class RandomShapeTest(test.TestCase): self.assertEqual([1, 2, 3], rnd1.get_shape()) # Partially known shape. rnd2 = random_ops.random_uniform( - array_ops.placeholder( - dtypes.int32, shape=(3,))) + array_ops.placeholder(dtypes.int32, shape=(3,))) self.assertEqual([None, None, None], rnd2.get_shape().as_list()) # Unknown shape. rnd3 = random_ops.random_uniform(array_ops.placeholder(dtypes.int32)) diff --git a/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py b/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py index c4e16ff6280cc7ce121955474fe8ec45acd57f95..b7a79f239cee04b191b78affd002f687b7de851a 100644 --- a/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py +++ b/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import random -import re import time import numpy as np diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index 223a4b2c8726d957f014e65ea9f87c0fb61e65bb..82a27eebeef16c9dacaf1b900f0398a56533cd2d 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -428,7 +428,7 @@ class FixedLengthRecordReaderTest(test.TestCase): for i in range(self._num_files): fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) filenames.append(fn) - with open(fn+".tmp", "wb") as f: + with open(fn + ".tmp", "wb") as f: f.write(b"H" * self._header_bytes) if num_records > 0: f.write(self._Record(i, 0)) @@ -437,7 +437,7 @@ class FixedLengthRecordReaderTest(test.TestCase): f.write(b"G" * gap_bytes) f.write(self._Record(i, j)) f.write(b"F" * self._footer_bytes) - with open(fn+".tmp", "rb") as f: + with open(fn + ".tmp", "rb") as f: cdata = zlib.compress(f.read()) with open(fn, "wb") as zf: zf.write(cdata) @@ -455,7 +455,7 @@ class FixedLengthRecordReaderTest(test.TestCase): all_records_str = "".join([ str(i)[0] for i in range(self._record_bytes + self._hop_bytes * - (num_overlapped_records - 1)) + (num_overlapped_records - 1)) ]) f.write(compat.as_bytes(all_records_str)) f.write(b"F" * self._footer_bytes) @@ -467,7 +467,7 @@ class FixedLengthRecordReaderTest(test.TestCase): fn = os.path.join(self.get_temp_dir(), "fixed_length_overlapped_record.%d.txt" % i) filenames.append(fn) - with open(fn+".tmp", "wb") as f: + with open(fn + ".tmp", "wb") as f: f.write(b"H" * self._header_bytes) if num_overlapped_records > 0: all_records_str = "".join([ @@ -477,7 +477,7 @@ class FixedLengthRecordReaderTest(test.TestCase): ]) f.write(compat.as_bytes(all_records_str)) f.write(b"F" * self._footer_bytes) - with open(fn+".tmp", "rb") as f: + with open(fn + ".tmp", "rb") as f: cdata = zlib.compress(f.read()) with open(fn, "wb") as zf: zf.write(cdata) @@ -509,7 +509,10 @@ class FixedLengthRecordReaderTest(test.TestCase): "\\(requested 1, current size 0\\)"): k, v = sess.run([key, value]) - def _TestOneEpochWithHopBytes(self, files, num_overlapped_records, encoding=None): + def _TestOneEpochWithHopBytes(self, + files, + num_overlapped_records, + encoding=None): with self.test_session() as sess: reader = io_ops.FixedLengthRecordReader( header_bytes=self._header_bytes, @@ -565,13 +568,15 @@ class FixedLengthRecordReaderTest(test.TestCase): def testGzipOneEpochWithHopBytes(self): for num_overlapped_records in [0, 2]: - files = self._CreateGzipOverlappedRecordFiles(num_overlapped_records, ) - self._TestOneEpochWithHopBytes(files, num_overlapped_records, encoding="GZIP") + files = self._CreateGzipOverlappedRecordFiles(num_overlapped_records,) + self._TestOneEpochWithHopBytes( + files, num_overlapped_records, encoding="GZIP") def testZlibOneEpochWithHopBytes(self): for num_overlapped_records in [0, 2]: files = self._CreateZlibOverlappedRecordFiles(num_overlapped_records) - self._TestOneEpochWithHopBytes(files, num_overlapped_records, encoding="ZLIB") + self._TestOneEpochWithHopBytes( + files, num_overlapped_records, encoding="ZLIB") class TFRecordReaderTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index 4231a79b2dcef951048ca54e8c8df2f42b44b1a1..531478162971575739bbe37abfc57ca427ae22ae 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -110,10 +110,10 @@ class ReductionUnknownShape(test.TestCase): class BaseReductionTest(test.TestCase): - def _tf_reduce(self, x, reduction_axes, keep_dims): + def _tf_reduce(self, x, reduction_axes, keepdims): raise NotImplementedError() - def _np_reduce(self, x, reduction_axes, keep_dims): + def _np_reduce(self, x, reduction_axes, keepdims): raise NotImplementedError() def _makeIncremental(self, shape, dtype): @@ -128,10 +128,10 @@ class BaseReductionTest(test.TestCase): data -= 2j * data return data - def _compare(self, x, reduction_axes, keep_dims, feed_dict=None): - np_ans = self._np_reduce(x, reduction_axes, keep_dims) + def _compare(self, x, reduction_axes, keepdims, feed_dict=None): + np_ans = self._np_reduce(x, reduction_axes, keepdims) with self.test_session(use_gpu=True) as sess: - tf_ans = self._tf_reduce(x, reduction_axes, keep_dims) + tf_ans = self._tf_reduce(x, reduction_axes, keepdims) out = sess.run(tf_ans, feed_dict) self.assertAllClose(np_ans, out) self.assertShapeEqual(np_ans, tf_ans) @@ -140,8 +140,8 @@ class BaseReductionTest(test.TestCase): if reduction_axes is not None and np.shape(reduction_axes) == (1,): # Test scalar reduction_axes argument self._compareAll(x, reduction_axes[0]) - self._compare(x, reduction_axes, keep_dims=False, feed_dict=feed_dict) - self._compare(x, reduction_axes, keep_dims=True, feed_dict=feed_dict) + self._compare(x, reduction_axes, keepdims=False, feed_dict=feed_dict) + self._compare(x, reduction_axes, keepdims=True, feed_dict=feed_dict) def _compareAllAxes(self, x, feed_dict=None): self._compareAll(x, None) @@ -171,14 +171,14 @@ class BaseReductionTest(test.TestCase): class SumReductionTest(BaseReductionTest): - def _tf_reduce(self, x, reduction_axes, keep_dims): - return math_ops.reduce_sum(x, reduction_axes, keep_dims) + def _tf_reduce(self, x, reduction_axes, keepdims): + return math_ops.reduce_sum(x, reduction_axes, keepdims) - def _np_reduce(self, x, reduction_axes, keep_dims): + def _np_reduce(self, x, reduction_axes, keepdims): if isinstance(reduction_axes, list) or isinstance(reduction_axes, np.ndarray): reduction_axes = tuple(reduction_axes) - return np.sum(x, axis=reduction_axes, keepdims=keep_dims) + return np.sum(x, axis=reduction_axes, keepdims=keepdims) def testAxesType(self): for dtype in [dtypes.int64, dtypes.int32]: @@ -298,7 +298,7 @@ class SumReductionTest(BaseReductionTest): c_known_rank = array_ops.placeholder(dtypes.float32) c_known_rank.set_shape(tensor_shape.unknown_shape(ndims=3)) s_known_rank = math_ops.reduce_sum( - c_known_rank, reduction_axes, keep_dims=True) + c_known_rank, reduction_axes, keepdims=True) self.assertEqual(3, s_known_rank.get_shape().ndims) np_input = np.random.randn(3, 3, 3) @@ -308,11 +308,11 @@ class SumReductionTest(BaseReductionTest): unknown_indices = array_ops.placeholder(dtypes.int32) c_unknown_indices = constant_op.constant([[10.0], [20.0]]) s_unknown_indices = math_ops.reduce_sum( - c_unknown_indices, unknown_indices, keep_dims=False) + c_unknown_indices, unknown_indices, keepdims=False) self.assertEqual(tensor_shape.unknown_shape(), s_unknown_indices.get_shape()) s_unknown_indices_keep = math_ops.reduce_sum( - c_unknown_indices, unknown_indices, keep_dims=True) + c_unknown_indices, unknown_indices, keepdims=True) self.assertEqual(2, s_unknown_indices_keep.get_shape().ndims) def testWrongShapeForReductionIndices(self): @@ -372,10 +372,10 @@ class SumReductionTest(BaseReductionTest): class MeanReductionTest(BaseReductionTest): - def _tf_reduce(self, x, reduction_axes, keep_dims): - return math_ops.reduce_mean(x, reduction_axes, keep_dims) + def _tf_reduce(self, x, reduction_axes, keepdims): + return math_ops.reduce_mean(x, reduction_axes, keepdims) - def _np_reduce(self, x, reduction_axes, keep_dims): + def _np_reduce(self, x, reduction_axes, keepdims): if isinstance(reduction_axes, list) or isinstance(reduction_axes, np.ndarray): reduction_axes = tuple(reduction_axes) @@ -389,7 +389,7 @@ class MeanReductionTest(BaseReductionTest): # np.mean automatically converts integer inputs to float, while TensorFlow's # reduce_mean does not. For integer inputs, we emulate TensorFlow's behavior # using np.sum and truncating division. - np_sum = np.sum(x, axis=reduction_axes, keepdims=keep_dims) + np_sum = np.sum(x, axis=reduction_axes, keepdims=keepdims) if np.issubdtype(x.dtype, np.integer): return np_sum // count return np_sum / count @@ -458,14 +458,14 @@ class MeanReductionTest(BaseReductionTest): class ProdReductionTest(BaseReductionTest): - def _tf_reduce(self, x, reduction_axes, keep_dims): - return math_ops.reduce_prod(x, reduction_axes, keep_dims) + def _tf_reduce(self, x, reduction_axes, keepdims): + return math_ops.reduce_prod(x, reduction_axes, keepdims) - def _np_reduce(self, x, reduction_axes, keep_dims): + def _np_reduce(self, x, reduction_axes, keepdims): if isinstance(reduction_axes, list) or isinstance(reduction_axes, np.ndarray): reduction_axes = tuple(reduction_axes) - return np.prod(x, axis=reduction_axes, keepdims=keep_dims) + return np.prod(x, axis=reduction_axes, keepdims=keepdims) def testAxesType(self): for dtype in [dtypes.int64, dtypes.int32]: @@ -549,17 +549,17 @@ class ProdReductionTest(BaseReductionTest): class MinReductionTest(test.TestCase): - def _compare(self, x, reduction_axes, keep_dims, use_gpu=False): + def _compare(self, x, reduction_axes, keepdims, use_gpu=False): np_ans = x if reduction_axes is None: - np_ans = np.amin(np_ans, keepdims=keep_dims) + np_ans = np.amin(np_ans, keepdims=keepdims) else: for ra in reduction_axes[::-1]: - np_ans = np.amin(np_ans, axis=ra, keepdims=keep_dims) + np_ans = np.amin(np_ans, axis=ra, keepdims=keepdims) with self.test_session(use_gpu=use_gpu): if reduction_axes is not None: reduction_axes = np.array(reduction_axes).astype(np.int32) - tf_ans = math_ops.reduce_min(x, reduction_axes, keep_dims) + tf_ans = math_ops.reduce_min(x, reduction_axes, keepdims) out = tf_ans.eval() self.assertAllClose(np_ans, out) self.assertShapeEqual(np_ans, tf_ans) @@ -662,17 +662,17 @@ class MinReductionTest(test.TestCase): class MaxReductionTest(test.TestCase): - def _compare(self, x, reduction_axes, keep_dims, use_gpu=False): + def _compare(self, x, reduction_axes, keepdims, use_gpu=False): np_ans = x if reduction_axes is None: - np_ans = np.amax(np_ans, keepdims=keep_dims) + np_ans = np.amax(np_ans, keepdims=keepdims) else: for ra in reduction_axes[::-1]: - np_ans = np.amax(np_ans, axis=ra, keepdims=keep_dims) + np_ans = np.amax(np_ans, axis=ra, keepdims=keepdims) with self.test_session(use_gpu=use_gpu): if reduction_axes is not None: reduction_axes = np.array(reduction_axes).astype(np.int32) - tf_ans = math_ops.reduce_max(x, reduction_axes, keep_dims) + tf_ans = math_ops.reduce_max(x, reduction_axes, keepdims) out = tf_ans.eval() self.assertAllClose(np_ans, out) self.assertShapeEqual(np_ans, tf_ans) @@ -789,17 +789,17 @@ class MaxReductionTest(test.TestCase): class AllReductionTest(test.TestCase): - def _compare(self, x, reduction_axes, keep_dims, use_gpu=False): + def _compare(self, x, reduction_axes, keepdims, use_gpu=False): np_ans = x if reduction_axes is None: - np_ans = np.all(np_ans, keepdims=keep_dims) + np_ans = np.all(np_ans, keepdims=keepdims) else: for ra in reduction_axes[::-1]: - np_ans = np.all(np_ans, axis=ra, keepdims=keep_dims) + np_ans = np.all(np_ans, axis=ra, keepdims=keepdims) with self.test_session(use_gpu=use_gpu): if reduction_axes is not None: reduction_axes = np.array(reduction_axes).astype(np.int32) - tf_ans = math_ops.reduce_all(x, reduction_axes, keep_dims) + tf_ans = math_ops.reduce_all(x, reduction_axes, keepdims) out = tf_ans.eval() self.assertAllEqual(np_ans, out) self.assertShapeEqual(np_ans, tf_ans) @@ -838,17 +838,17 @@ class AllReductionTest(test.TestCase): class AnyReductionTest(test.TestCase): - def _compare(self, x, reduction_axes, keep_dims, use_gpu=False): + def _compare(self, x, reduction_axes, keepdims, use_gpu=False): np_ans = x if reduction_axes is None: - np_ans = np.any(np_ans, keepdims=keep_dims) + np_ans = np.any(np_ans, keepdims=keepdims) else: for ra in reduction_axes[::-1]: - np_ans = np.any(np_ans, axis=ra, keepdims=keep_dims) + np_ans = np.any(np_ans, axis=ra, keepdims=keepdims) with self.test_session(use_gpu=use_gpu): if reduction_axes is not None: reduction_axes = np.array(reduction_axes).astype(np.int32) - tf_ans = math_ops.reduce_any(x, reduction_axes, keep_dims) + tf_ans = math_ops.reduce_any(x, reduction_axes, keepdims) out = tf_ans.eval() self.assertAllEqual(np_ans, out) self.assertShapeEqual(np_ans, tf_ans) @@ -890,18 +890,18 @@ class CountNonzeroReductionTest(test.TestCase): def _compare(self, x, reduction_axes, - keep_dims, + keepdims, use_gpu=False, feed_dict=None): np_ans = (x != 0).astype(np.int32) if reduction_axes is None: - np_ans = np.sum(np_ans, keepdims=keep_dims) + np_ans = np.sum(np_ans, keepdims=keepdims) else: reduction_axes = np.array(reduction_axes).astype(np.int32) for ra in reduction_axes.ravel()[::-1]: - np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims) + np_ans = np.sum(np_ans, axis=ra, keepdims=keepdims) with self.test_session(use_gpu=use_gpu) as sess: - tf_ans = math_ops.count_nonzero(x, reduction_axes, keep_dims) + tf_ans = math_ops.count_nonzero(x, reduction_axes, keepdims) out = sess.run(tf_ans, feed_dict) self.assertAllClose(np_ans, out) self.assertShapeEqual(np_ans, tf_ans) diff --git a/tensorflow/python/kernel_tests/reduction_ops_test_big.py b/tensorflow/python/kernel_tests/reduction_ops_test_big.py index 0959adb026e3934713442e6f3487b30a0b252943..d70360775a03caa32eab995371d54786c3c0a0d9 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test_big.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test_big.py @@ -27,24 +27,24 @@ from tensorflow.python.platform import test class BaseReductionTest(test.TestCase): - def _tf_reduce(self, x, reduction_axes, keep_dims): + def _tf_reduce(self, x, reduction_axes, keepdims): raise NotImplementedError() class BigReductionTest(BaseReductionTest): """Test reductions for sum and boolean all over a wide range of shapes.""" - def _tf_reduce_max(self, x, reduction_axes, keep_dims): - return math_ops.reduce_max(x, reduction_axes, keep_dims) + def _tf_reduce_max(self, x, reduction_axes, keepdims): + return math_ops.reduce_max(x, reduction_axes, keepdims) - def _tf_reduce_all(self, x, reduction_axes, keep_dims): - return math_ops.reduce_all(x, reduction_axes, keep_dims) + def _tf_reduce_all(self, x, reduction_axes, keepdims): + return math_ops.reduce_all(x, reduction_axes, keepdims) - def _tf_reduce_mean(self, x, reduction_axes, keep_dims): - return math_ops.reduce_mean(x, reduction_axes, keep_dims) + def _tf_reduce_mean(self, x, reduction_axes, keepdims): + return math_ops.reduce_mean(x, reduction_axes, keepdims) - def _tf_reduce_sum(self, x, reduction_axes, keep_dims): - return math_ops.reduce_sum(x, reduction_axes, keep_dims) + def _tf_reduce_sum(self, x, reduction_axes, keepdims): + return math_ops.reduce_sum(x, reduction_axes, keepdims) def testFloat32Sum(self): # make sure we test all possible kernel invocations diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index dd11ba700d518ab230c1160d17f4cc0833a79198..6b4091ae5d3c6e469a9cd5237b978eae4c75485f 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -48,8 +48,8 @@ class ReluTest(test.TestCase): self.assertAllClose( np.array([[0.0, 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]), self._npRelu( - np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9] - ]))) + np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, + 0.9]]))) def _testRelu(self, np_features, use_gpu=False): np_relu = self._npRelu(np_features) @@ -163,8 +163,8 @@ class Relu6Test(test.TestCase): self.assertAllClose( np.array([[0.0, 0.7, 0.0, 0.3, 6.0], [0.1, 0.0, 6.0, 0.0, 0.9]]), self._npRelu6( - np.array([[-0.9, 0.7, -0.5, 0.3, 6.0], [0.1, -0.3, 6.5, -0.7, 0.9] - ]))) + np.array([[-0.9, 0.7, -0.5, 0.3, 6.0], [0.1, -0.3, 6.5, -0.7, + 0.9]]))) def _testRelu6(self, np_features, use_gpu=False): np_relu6 = self._npRelu6(np_features) @@ -231,8 +231,8 @@ class EluTest(test.TestCase): np.array([[-0.59343034025, 0.7, -0.39346934028, 0.3, -0.09516258196], [0.1, -0.25918177931, 0.5, -0.5034146962, 0.9]]), self._npElu( - np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9] - ]))) + np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, + 0.9]]))) def _testElu(self, np_features, use_gpu=False): np_elu = self._npElu(np_features) @@ -330,11 +330,11 @@ class SeluTest(test.TestCase): def testNpSelu(self): self.assertAllClose( - np.array([[-1.0433095, 0.73549069, -0.6917582, 0.3152103 , -0.16730527], - [0.1050701 , -0.45566732, 0.5253505, -0.88505305, 0.9456309]]), + np.array([[-1.0433095, 0.73549069, -0.6917582, 0.3152103, -0.16730527], + [0.1050701, -0.45566732, 0.5253505, -0.88505305, 0.9456309]]), self._npSelu( - np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9] - ]))) + np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, + 0.9]]))) def _testSelu(self, np_features, use_gpu=False): np_selu = self._npSelu(np_features) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index b4b555591d054226210eb6af20036967b240928f..8503f3e0310125bb714942b32bbbf46596f9bddb 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.util import compat @test_util.with_c_api @@ -63,6 +64,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): 0, dtype=dtypes.int32)).run() + def testGPUInt64(self): + if not context.context().num_gpus(): + return + with context.eager_mode(), context.device("gpu:0"): + v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int64) + self.assertAllEqual(1, v.numpy()) + def testEagerNameNotIdentity(self): with context.eager_mode(): v0 = resource_variable_ops.ResourceVariable(1.0, name="a") @@ -161,14 +169,26 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testScatterAdd(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate(resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + self.evaluate(resource_variable_ops.resource_scatter_add( + handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) + + def testScatterUpdateString(self): handle = resource_variable_ops.var_handle_op( - dtype=dtypes.int32, shape=[1, 1]) + dtype=dtypes.string, shape=[1, 1]) self.evaluate(resource_variable_ops.assign_variable_op( - handle, constant_op.constant([[1]], dtype=dtypes.int32))) - self.evaluate(resource_variable_ops.resource_scatter_add( - handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) - read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(self.evaluate(read), [[3]]) + handle, constant_op.constant([["a"]], dtype=dtypes.string))) + self.evaluate(resource_variable_ops.resource_scatter_update( + handle, [0], constant_op.constant([["b"]], dtype=dtypes.string))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) + self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), + compat.as_bytes("b")) # TODO(alive): get this to work in Eager mode. def testGPU(self): @@ -264,6 +284,32 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): v.load(2.0) self.assertEqual(2.0, self.evaluate(v.value())) + def testVariableDefInitializedInstances(self): + with ops.Graph().as_default(), self.test_session() as sess: + v_def = resource_variable_ops.ResourceVariable( + initial_value=constant_op.constant(3.0)).to_proto() + + with ops.Graph().as_default(), self.test_session() as sess: + # v describes a VariableDef-based variable without an initial value. + v = resource_variable_ops.ResourceVariable(variable_def=v_def) + self.assertEqual(3.0, sess.run(v.initialized_value())) + + # initialized_value should not rerun the initializer_op if the variable + # has already been initialized elsewhere. + sess.run(v.assign(1.0)) + self.assertEqual(1.0, v.initialized_value().eval()) + + v_def.ClearField("initial_value_name") + with ops.Graph().as_default(), self.test_session() as sess: + # Restoring a legacy VariableDef proto that does not have + # initial_value_name set should still work. + v = resource_variable_ops.ResourceVariable(variable_def=v_def) + # We should also be able to re-export the variable to a new meta graph. + self.assertProtoEquals(v_def, v.to_proto()) + # But attempts to use initialized_value will result in errors. + with self.assertRaises(ValueError): + sess.run(v.initialized_value()) + @test_util.run_in_graph_and_eager_modes() def testSparseRead(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 0c77d1db921566000c2a52e6ddb9d3dddd9b193c..daa42938e6af205425d7e423ce162294b9002be4 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -23,6 +23,7 @@ import timeit import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import rnn as contrib_rnn from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session diff --git a/tensorflow/python/kernel_tests/scalar_test.py b/tensorflow/python/kernel_tests/scalar_test.py index b34426cc21590d585bf7ef7b24b778adbf0cd084..e65241981eac2d42207c1de7a261f7936e588f2a 100644 --- a/tensorflow/python/kernel_tests/scalar_test.py +++ b/tensorflow/python/kernel_tests/scalar_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import math_ops @@ -30,6 +31,7 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test +@test_util.with_c_api class ScalarTest(test.TestCase): def check(self, op, args, error, correct=None): @@ -51,7 +53,7 @@ class ScalarTest(test.TestCase): # Test various GraphDef versions for version in strict + lenient: with ops.Graph().as_default() as g: - g.graph_def_versions.producer = version + test_util.set_producer_version(g, version) with self.test_session(graph=g) as sess: feed = {} xs = placeholders(args, feed) diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index 5a54f448d092093db668570d055801f9f9cd0f9f..bbce6b7d47325b8209815230426672ec6894147f 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -46,7 +46,8 @@ class SegmentReductionHelper(test.TestCase): return constant_op.constant( np_values, shape=input_shape, dtype=dtype), np_values - def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None): + def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None, + initial_value=0): if not x.size: return np.array([]) indices = np.asarray(indices) @@ -64,13 +65,8 @@ class SegmentReductionHelper(test.TestCase): else: output[index] = x_flat[i] # zero initialize values that are still uncalcuated. - # output = [o if o is not None else np.zeros(slice_shape) for o in output] - if not op1 == np.max: - output = [o if o is not None else np.zeros(slice_shape) for o in output] - else: - zeroslice = np.zeros(slice_shape) - zeroslice.fill(dtype.min) - output = [o if o is not None else zeroslice for o in output] + initial_value_slice = np.ones(slice_shape) * initial_value + output = [o if o is not None else initial_value_slice for o in output] if op2 is not None: output = [op2(o) for o in output] output = [o.reshape(slice_shape) for o in output] @@ -82,6 +78,9 @@ class SegmentReductionHelper(test.TestCase): def _mean_reduce_op(self, x): return x[0] / x[1] if isinstance(x, tuple) else x + def _sqrt_n_reduce_op(self, x): + return x[0] / np.sqrt(x[1]) if isinstance(x, tuple) else x + class SegmentReductionOpTest(SegmentReductionHelper): @@ -244,27 +243,61 @@ class SegmentReductionOpTest(SegmentReductionHelper): self.assertAllClose(jacob_t, jacob_n) -class UnsortedSegmentSumTest(SegmentReductionHelper): +class UnsortedSegmentTest(SegmentReductionHelper): + + def __init__(self, methodName='runTest'): + # Each item is np_op1, np_op2, tf_op, initial_value functor + self.ops_list = [(np.add, None, + math_ops.unsorted_segment_sum, lambda t: 0), + (self._mean_cum_op, self._mean_reduce_op, + math_ops.unsorted_segment_mean, lambda t: 0), + (self._mean_cum_op, self._sqrt_n_reduce_op, + math_ops.unsorted_segment_sqrt_n, lambda t: 0), + (np.ndarray.__mul__, None, + math_ops.unsorted_segment_prod, lambda t: 1), + (np.minimum, None, + math_ops.unsorted_segment_min, lambda t: t.max), + (np.maximum, None, + math_ops.unsorted_segment_max, lambda t: t.min)] + + # A subset of ops has been enabled for complex numbers + self.complex_ops_list = [(np.add, None, + math_ops.unsorted_segment_sum, lambda t: 0)] + self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32, + dtypes_lib.float64] + self.all_dtypes = (self.differentiable_dtypes + + [dtypes_lib.bfloat16, + dtypes_lib.int64, dtypes_lib.int32, + dtypes_lib.complex64, dtypes_lib.complex128]) + super(UnsortedSegmentTest, self).__init__(methodName=methodName) def testValues(self): - dtypes = [ - dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64, - dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128 - ] indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) num_segments = 12 for indices in indices_flat, indices_flat.reshape(5, 2): shape = indices.shape + (2,) - for dtype in dtypes: - with self.test_session(use_gpu=True): - tf_x, np_x = self._input(shape, dtype=dtype) - np_ans = self._segmentReduce( - indices, np_x, np.add, op2=None, num_segments=num_segments) - s = math_ops.unsorted_segment_sum( - data=tf_x, segment_ids=indices, num_segments=num_segments) - tf_ans = s.eval() - self.assertAllClose(np_ans, tf_ans) - self.assertShapeEqual(np_ans, s) + for dtype in self.all_dtypes: + ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list + tf_x, np_x = self._input(shape, dtype=dtype) + for use_gpu in [True, False]: + with self.test_session(use_gpu=True): + for np_op1, np_op2, tf_op, init_op in ops_list: + # sqrt_n doesn't support integers + if (np_op2 == self._sqrt_n_reduce_op and dtype.is_integer): + continue + # todo(philjd): enable this test once real_div supports bfloat16 + if (np_op2 in [self._sqrt_n_reduce_op, self._mean_reduce_op] and + dtype == dtypes_lib.bfloat16): + continue + np_ans = self._segmentReduce( + indices, np_x, np_op1, np_op2, num_segments=num_segments, + initial_value=init_op(dtype)) + s = tf_op(tf_x, segment_ids=indices, num_segments=num_segments) + tf_ans = s.eval() + if dtype is dtypes_lib.bfloat16: + tf_ans = tf_ans.astype(np.float32) + self.assertAllClose(np_ans, tf_ans) + self.assertShapeEqual(np_ans, s) def testNumSegmentsTypes(self): dtypes = [dtypes_lib.int32, dtypes_lib.int64] @@ -287,25 +320,51 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): self.assertAllClose(np_ans, tf_ans) self.assertShapeEqual(np_ans, s) - def testGradientSegmentSum(self): + def testGradients(self): num_cols = 2 - indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) + indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3]) num_segments = max(indices_flat) + 3 - for dtype in [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, - dtypes_lib.complex128]: + for dtype in self.differentiable_dtypes: + ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list for indices in indices_flat, indices_flat.reshape(5, 2): shape = indices.shape + (num_cols,) - with self.test_session(use_gpu=True): - tf_x, np_x = self._input(shape, dtype=dtype) - s = math_ops.unsorted_segment_sum( - data=tf_x, segment_ids=indices, num_segments=num_segments) + # test CPU and GPU as tf.gather behaves differently on each device + for use_gpu in [False, True]: + with self.test_session(use_gpu=use_gpu): + for _, _, tf_op, _ in ops_list: + tf_x, np_x = self._input(shape, dtype=dtype) + s = tf_op(tf_x, indices, num_segments) + jacob_t, jacob_n = gradient_checker.compute_gradient( + tf_x, + shape, + s, [num_segments, num_cols], + x_init_value=np_x, + delta=1) + self.assertAllClose(jacob_t, jacob_n) + + def testProdGrad(self): + # additional test for the prod gradient to ensure correct handling of zeros + values = np.array([0, 0, 1, 0, 2, 2, 3, 3, 3], dtype=np.float32) + indices = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=np.int32) + indices_neg = np.array([-1, 0, 0, -1, 1, 1, -1, 2, 2], dtype=np.int32) + values_tf = constant_op.constant(values) + # ground truth partial derivatives + gradients_indices = np.zeros((9, 3), dtype=np.float32) + gradients_indices_neg = np.zeros((9, 3), dtype=np.float32) + # the derivative w.r.t. to the other segments is zero, so here we only + # explicitly set the grad values for the corresponding segment + gradients_indices[range(9), indices] = [0, 0, 0, 4, 0, 0, 9, 9, 9] + gradients_indices_neg[range(9), indices_neg] = [0, 1, 0, 0, 2, 2, 0, 3, 3] + for use_gpu in [False, True]: + with self.test_session(use_gpu=use_gpu): + for ind, grad_gt in [(indices, gradients_indices), + (indices_neg, gradients_indices_neg)]: + s = math_ops.unsorted_segment_prod(values_tf, + constant_op.constant(ind), 3) jacob_t, jacob_n = gradient_checker.compute_gradient( - tf_x, - shape, - s, [num_segments, num_cols], - x_init_value=np_x, - delta=1) - self.assertAllClose(jacob_t, jacob_n) + values_tf, (9,), s, (3,), x_init_value=values, delta=1) + self.assertAllClose(jacob_t, jacob_n) + self.assertAllClose(jacob_t, grad_gt) def testGradientMatchesSegmentSum(self): # Strategy: compute the gradient for UnsortedSegmentSum and SegmentSum @@ -318,8 +377,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): num_cols = 2 shape = [n, num_cols] num_segments = max(indices) + 1 - for dtype in [dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, - dtypes_lib.complex128]: + for dtype in self.differentiable_dtypes: with self.test_session(use_gpu=True): tf_x, np_x = self._input(shape, dtype=dtype) # Results from UnsortedSegmentSum @@ -353,9 +411,8 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): unsorted.eval() def testEmptySecondDimension(self): - dtypes = [ - np.float32, np.float64, np.int64, np.int32, np.complex64, np.complex128 - ] + dtypes = [np.float16, np.float32, np.float64, np.int64, np.int32, + np.complex64, np.complex128] with self.test_session(use_gpu=True): for dtype in dtypes: for itype in (np.int32, np.int64): @@ -364,36 +421,14 @@ class UnsortedSegmentSumTest(SegmentReductionHelper): unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2) self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype)) - def testGradientSegmentMax(self): - num_cols = 2 - indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) - num_segments = max(indices_flat) + 3 - for indices in indices_flat, indices_flat.reshape(5, 2): - shape = indices.shape + (num_cols,) - with self.test_session(use_gpu=True): - tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64) - s = math_ops.unsorted_segment_max( - data=tf_x, segment_ids=indices, num_segments=num_segments) - jacob_t, jacob_n = gradient_checker.compute_gradient( - tf_x, - shape, - s, - [num_segments, num_cols], - x_init_value=np_x.astype(np.double), delta=1) - self.assertAllClose(jacob_t, jacob_n) - def testDropNegatives(self): # Note: the test is done by replacing segment_ids with 8 to -1 # for index and replace values generated by numpy with 0. - dtypes = [ - dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64, - dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128 - ] indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3]) num_segments = 12 for indices in indices_flat, indices_flat.reshape(5, 2): shape = indices.shape + (2,) - for dtype in dtypes: + for dtype in self.all_dtypes: with self.test_session(use_gpu=True): tf_x, np_x = self._input(shape, dtype=dtype) np_ans = self._segmentReduce( diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py index be72c1940723ea9f1e22a3b81d2b34ad67a57f4f..4d89831aae9a5e95210a8defb180e09c9d38f4d6 100644 --- a/tensorflow/python/kernel_tests/softmax_op_test.py +++ b/tensorflow/python/kernel_tests/softmax_op_test.py @@ -18,18 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test +@test_util.with_c_api class SoftmaxTest(test.TestCase): def _npSoftmax(self, features, dim=-1, log=False): @@ -99,10 +99,10 @@ class SoftmaxTest(test.TestCase): def _testOverflow(self, use_gpu=False): if use_gpu: - type = np.float32 + type = np.float32 # pylint: disable=redefined-builtin else: - type = np.float64 - max = np.finfo(type).max + type = np.float64 # pylint: disable=redefined-builtin + max = np.finfo(type).max # pylint: disable=redefined-builtin features = np.array([[1., 1., 1., 1.], [max, 1., 2., 3.]]).astype(type) with self.test_session(use_gpu=use_gpu): tf_log_softmax = nn_ops.log_softmax(features) @@ -174,8 +174,11 @@ class SoftmaxTest(test.TestCase): def testDimTooLarge(self): with self.test_session(): + # Use placeholder to make sure we get runtime error instead of shape + # inference error. + dim = array_ops.placeholder_with_default(100, shape=[]) with self.assertRaises(errors_impl.InvalidArgumentError): - nn_ops.softmax([1., 2., 3., 4.], dim=100).eval() + nn_ops.softmax([1., 2., 3., 4.], dim=dim).eval() def testLargeDims(self): # Make sure that we properly handle large inputs. See diff --git a/tensorflow/python/kernel_tests/sparse_slice_op_test.py b/tensorflow/python/kernel_tests/sparse_slice_op_test.py index 762e400447c7e6e89ca4c0b480662aa91e287c26..da116601f833cc6b471e383e030c5fbe93b52ac5 100644 --- a/tensorflow/python/kernel_tests/sparse_slice_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_slice_op_test.py @@ -32,11 +32,12 @@ class SparseSliceOpTest(test.TestCase): # [ |11| |13|14| ] # [20| | |23| |25] # [30| |32|33| |35] - ind = np.array([[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4], - [2, 0], [2, 3], [2, 5], [3, 0], [3, 2], [3, 3], - [3, 5]]).astype(np.int64) - val = np.array( - [0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype(np.int64) + ind = np.array([[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, + 4], [2, 0], + [2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype( + np.int64) + val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype( + np.int64) shape = np.array([4, 6]).astype(np.int64) return sparse_tensor.SparseTensor(ind, val, shape) @@ -65,50 +66,49 @@ class SparseSliceOpTest(test.TestCase): # [ |'c1'| |'d1'] # [ | |'e1'| ] ind = np.array([[0, 0, 0], [0, 0, 1], [0, 2, 0], [0, 2, 1], [1, 1, 0], - [1, 1, 1], [1, 3, 0], [1, 3, 1], [2, 2, 0], - [2, 2, 1]]).astype(np.int64) + [1, 1, 1], [1, 3, 0], [1, 3, 1], [2, 2, 0], [2, 2, + 1]]).astype( + np.int64) val = np.array(['a0', 'a1', 'b0', 'b1', 'c0', 'c1', 'd0', 'd1', 'e0', 'e1']) shape = np.array([3, 4, 2]).astype(np.int64) return sparse_tensor.SparseTensorValue(ind, val, shape) def _SparseTensor_3x4x2(self): - return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x4x2( - )) + return sparse_tensor.SparseTensor.from_value( + self._SparseTensorValue_3x4x2()) def testSliceMatrixRows(self): with self.test_session(use_gpu=False): - sp_input=self._SparseTensor_4x6() + sp_input = self._SparseTensor_4x6() sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [2, 6]) sp_tensor1 = sparse_ops.sparse_slice(sp_input, [2, 0], [3, 7]) - self.assertAllEqual(sp_tensor0.indices.eval(), [[0, 0], [0, 2], [0, 4], - [0, 5], [1, 1], [1, 3], - [1, 4]]) + self.assertAllEqual( + sp_tensor0.indices.eval(), + [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4]]) self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 4, 5, 11, 13, 14]) self.assertAllEqual(sp_tensor0.dense_shape.eval(), [2, 6]) - self.assertAllEqual(sp_tensor1.indices.eval(), [[0, 0], [0, 3], [0, 5], - [1, 0], [1, 2], [1, 3], - [1, 5]]) + self.assertAllEqual( + sp_tensor1.indices.eval(), + [[0, 0], [0, 3], [0, 5], [1, 0], [1, 2], [1, 3], [1, 5]]) self.assertAllEqual(sp_tensor1.values.eval(), [20, 23, 25, 30, 32, 33, 35]) self.assertAllEqual(sp_tensor1.dense_shape.eval(), [2, 6]) def testSliceMatrixUnevenCols(self): with self.test_session(use_gpu=False): - sp_input=self._SparseTensor_5x7() + sp_input = self._SparseTensor_5x7() sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [5, 3]) sp_tensor1 = sparse_ops.sparse_slice(sp_input, [0, 3], [5, 2]) sp_tensor2 = sparse_ops.sparse_slice(sp_input, [0, 5], [5, 2]) - self.assertAllEqual(sp_tensor0.indices.eval(), - [[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2], - [4, 1]]) - self.assertAllEqual(sp_tensor0.values.eval(), - [0, 2, 11, 20, 30, 32, 41]) + self.assertAllEqual( + sp_tensor0.indices.eval(), + [[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2], [4, 1]]) + self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 11, 20, 30, 32, 41]) self.assertAllEqual(sp_tensor0.dense_shape.eval(), [5, 3]) self.assertAllEqual(sp_tensor1.indices.eval(), [[0, 1], [1, 0], [1, 1], [2, 0], [3, 0], [4, 1]]) - self.assertAllEqual(sp_tensor1.values.eval(), - [4, 13, 14, 23, 33, 44]) + self.assertAllEqual(sp_tensor1.values.eval(), [4, 13, 14, 23, 33, 44]) self.assertAllEqual(sp_tensor1.dense_shape.eval(), [5, 2]) self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 0], [1, 1], [2, 0], [3, 0], [4, 1]]) @@ -137,7 +137,7 @@ class SparseSliceOpTest(test.TestCase): def testSliceMatrixUnevenRows(self): with self.test_session(use_gpu=False): - sp_input=self._SparseTensor_5x7() + sp_input = self._SparseTensor_5x7() sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [3, 7]) sp_tensor1 = sparse_ops.sparse_slice(sp_input, [3, 0], [3, 7]) self.assertAllEqual(sp_tensor0.indices.eval(), @@ -146,9 +146,9 @@ class SparseSliceOpTest(test.TestCase): self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 4, 5, 11, 13, 14, 16, 20, 23, 25]) self.assertAllEqual(sp_tensor0.dense_shape.eval(), [3, 7]) - self.assertAllEqual(sp_tensor1.indices.eval(), - [[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4], - [1, 6]]) + self.assertAllEqual( + sp_tensor1.indices.eval(), + [[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4], [1, 6]]) self.assertAllEqual(sp_tensor1.values.eval(), [30, 32, 33, 35, 41, 44, 46]) self.assertAllEqual(sp_tensor1.dense_shape.eval(), [2, 7]) @@ -156,9 +156,9 @@ class SparseSliceOpTest(test.TestCase): sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [2, 7]) sp_tensor1 = sparse_ops.sparse_slice(sp_input, [2, 0], [2, 7]) sp_tensor2 = sparse_ops.sparse_slice(sp_input, [4, 0], [2, 7]) - self.assertAllEqual(sp_tensor0.indices.eval(), - [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], - [1, 4], [1, 6]]) + self.assertAllEqual( + sp_tensor0.indices.eval(), + [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4], [1, 6]]) self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 4, 5, 11, 13, 14, 16]) self.assertAllEqual(sp_tensor0.dense_shape.eval(), [2, 7]) @@ -166,45 +166,42 @@ class SparseSliceOpTest(test.TestCase): self.assertAllEqual(sp_tensor1.values.eval(), [20, 23, 25, 30, 32, 33, 35]) self.assertAllEqual(sp_tensor1.dense_shape.eval(), [2, 7]) - self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 1], [0, 4], - [0, 6]]) + self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 1], [0, 4], [0, 6]]) self.assertAllEqual(sp_tensor2.values.eval(), [41, 44, 46]) self.assertAllEqual(sp_tensor2.dense_shape.eval(), [1, 7]) return def testSliceAllRows(self): with self.test_session(use_gpu=False): - sp_input=self._SparseTensor_4x6() + sp_input = self._SparseTensor_4x6() sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [1, 6]) sp_tensor1 = sparse_ops.sparse_slice(sp_input, [1, 0], [1, 6]) sp_tensor2 = sparse_ops.sparse_slice(sp_input, [2, 0], [1, 7]) sp_tensor3 = sparse_ops.sparse_slice(sp_input, [3, 0], [2, 7]) - self.assertAllEqual(sp_tensor0.indices.eval(), [[0, 0], [0, 2], [0, 4], - [0, 5]]) + self.assertAllEqual(sp_tensor0.indices.eval(), + [[0, 0], [0, 2], [0, 4], [0, 5]]) self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 4, 5]) self.assertAllEqual(sp_tensor0.dense_shape.eval(), [1, 6]) - self.assertAllEqual(sp_tensor1.indices.eval(), [[0, 1], [0, 3], [0, - 4]]) + self.assertAllEqual(sp_tensor1.indices.eval(), [[0, 1], [0, 3], [0, 4]]) self.assertAllEqual(sp_tensor1.values.eval(), [11, 13, 14]) self.assertAllEqual(sp_tensor1.dense_shape.eval(), [1, 6]) - self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 0], [0, 3], [0, - 5]]) + self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 0], [0, 3], [0, 5]]) self.assertAllEqual(sp_tensor2.values.eval(), [20, 23, 25]) self.assertAllEqual(sp_tensor2.dense_shape.eval(), [1, 6]) - self.assertAllEqual(sp_tensor3.indices.eval(), [[0, 0], [0, 2], [0, 3], - [0, 5]]) + self.assertAllEqual(sp_tensor3.indices.eval(), + [[0, 0], [0, 2], [0, 3], [0, 5]]) self.assertAllEqual(sp_tensor3.values.eval(), [30, 32, 33, 35]) self.assertAllEqual(sp_tensor3.dense_shape.eval(), [1, 6]) def testSliceColumns(self): with self.test_session(use_gpu=False): - sp_input=self._SparseTensor_4x6() + sp_input = self._SparseTensor_4x6() sparse_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [4, 2]) sparse_tensor1 = sparse_ops.sparse_slice(sp_input, [0, 2], [5, 2]) sparse_tensor2 = sparse_ops.sparse_slice(sp_input, [0, 4], [5, 3]) - self.assertAllEqual(sparse_tensor0.indices.eval(), [[0, 0], [1, 1], - [2, 0], [3, 0]]) + self.assertAllEqual(sparse_tensor0.indices.eval(), + [[0, 0], [1, 1], [2, 0], [3, 0]]) self.assertAllEqual(sparse_tensor0.values.eval(), [0, 11, 20, 30]) self.assertAllEqual(sparse_tensor0.dense_shape.eval(), [4, 2]) self.assertAllEqual(sparse_tensor1.indices.eval(), @@ -218,15 +215,15 @@ class SparseSliceOpTest(test.TestCase): def testSliceAllColumns(self): with self.test_session(use_gpu=False): - sp_input=self._SparseTensor_4x6() + sp_input = self._SparseTensor_4x6() sparse_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [4, 1]) sparse_tensor1 = sparse_ops.sparse_slice(sp_input, [0, 1], [4, 1]) sparse_tensor2 = sparse_ops.sparse_slice(sp_input, [0, 2], [4, 1]) sparse_tensor3 = sparse_ops.sparse_slice(sp_input, [0, 3], [4, 1]) sparse_tensor4 = sparse_ops.sparse_slice(sp_input, [0, 4], [5, 1]) sparse_tensor5 = sparse_ops.sparse_slice(sp_input, [0, 5], [6, 3]) - self.assertAllEqual(sparse_tensor0.indices.eval(), [[0, 0], [2, 0], - [3, 0]]) + self.assertAllEqual(sparse_tensor0.indices.eval(), + [[0, 0], [2, 0], [3, 0]]) self.assertAllEqual(sparse_tensor0.values.eval(), [0, 20, 30]) self.assertAllEqual(sparse_tensor0.dense_shape.eval(), [4, 1]) self.assertAllEqual(sparse_tensor1.indices.eval(), [[1, 0]]) @@ -235,17 +232,18 @@ class SparseSliceOpTest(test.TestCase): self.assertAllEqual(sparse_tensor2.indices.eval(), [[0, 0], [3, 0]]) self.assertAllEqual(sparse_tensor2.values.eval(), [2, 32]) self.assertAllEqual(sparse_tensor2.dense_shape.eval(), [4, 1]) - self.assertAllEqual(sparse_tensor3.indices.eval(), [[1, 0], [2, 0], - [3, 0]]) + self.assertAllEqual(sparse_tensor3.indices.eval(), + [[1, 0], [2, 0], [3, 0]]) self.assertAllEqual(sparse_tensor3.dense_shape.eval(), [4, 1]) self.assertAllEqual(sparse_tensor3.values.eval(), [13, 23, 33]) self.assertAllEqual(sparse_tensor4.indices.eval(), [[0, 0], [1, 0]]) self.assertAllEqual(sparse_tensor4.values.eval(), [4, 14]) self.assertAllEqual(sparse_tensor4.dense_shape.eval(), [4, 1]) - self.assertAllEqual(sparse_tensor5.indices.eval(), [[0, 0], [2, 0], - [3, 0]]) + self.assertAllEqual(sparse_tensor5.indices.eval(), + [[0, 0], [2, 0], [3, 0]]) self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35]) self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/stack_op_test.py b/tensorflow/python/kernel_tests/stack_op_test.py index 347baf81148e9b747a9be4849912d154b220a084..2f27d1839b2218d0cc33d7278116186548ad3420 100644 --- a/tensorflow/python/kernel_tests/stack_op_test.py +++ b/tensorflow/python/kernel_tests/stack_op_test.py @@ -50,7 +50,7 @@ class StackOpTest(test.TestCase): # Convert [data[0], data[1], ...] separately to tensorflow # TODO(irving): Remove list() once we handle maps correctly xs = list(map(constant_op.constant, data)) - # Pack back into a single tensorflow tensor + # Stack back into a single tensorflow tensor c = array_ops.stack(xs) self.assertAllEqual(c.eval(), data) @@ -78,7 +78,7 @@ class StackOpTest(test.TestCase): for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): for dtype in [np.bool, np.float32, np.int32, np.int64]: data = np.random.randn(*shape).astype(dtype) - # Pack back into a single tensorflow tensor directly using np array + # Stack back into a single tensorflow tensor directly using np array c = array_ops.stack(data) # This is implemented via a Const: self.assertEqual(c.op.type, "Const") @@ -223,7 +223,7 @@ class StackOpTest(test.TestCase): array_ops.stack(t, axis=-3) -class AutomaticPackingTest(test.TestCase): +class AutomaticStackingTest(test.TestCase): def testSimple(self): with self.test_session(use_gpu=True): diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py index 64b3388c5c0fd16436fa77ac5d8d0e8f9a859c32..dd06d303912813733886b9cf20590513760e67f1 100644 --- a/tensorflow/python/kernel_tests/stage_op_test.py +++ b/tensorflow/python/kernel_tests/stage_op_test.py @@ -25,8 +25,8 @@ from tensorflow.python.platform import test TIMEOUT = 1 -class StageTest(test.TestCase): +class StageTest(test.TestCase): def testSimple(self): with ops.Graph().as_default() as G: @@ -116,7 +116,10 @@ class StageTest(test.TestCase): x = array_ops.placeholder(dtypes.int32, name='x') p = array_ops.placeholder(dtypes.int32, name='p') with ops.device(test.gpu_device_name()): - stager = data_flow_ops.StagingArea([dtypes.int32, ], shapes=[[]]) + stager = data_flow_ops.StagingArea( + [ + dtypes.int32, + ], shapes=[[]]) stage = stager.put([x]) peek = stager.peek(p) ret = stager.get() @@ -162,8 +165,10 @@ class StageTest(test.TestCase): with ops.device('/cpu:0'): x = array_ops.placeholder(dtypes.int32, name='x') with ops.device(test.gpu_device_name()): - stager = data_flow_ops.StagingArea([dtypes.int32, ], - capacity=capacity, shapes=[[]]) + stager = data_flow_ops.StagingArea( + [ + dtypes.int32, + ], capacity=capacity, shapes=[[]]) stage = stager.put([x]) ret = stager.get() size = stager.size() @@ -201,9 +206,8 @@ class StageTest(test.TestCase): self.fail("Expected to timeout on iteration '{}' " "but instead timed out on iteration '{}' " "Staging Area size is '{}' and configured " - "capacity is '{}'.".format(capacity, i, - sess.run(size), - capacity)) + "capacity is '{}'.".format(capacity, i, sess.run(size), + capacity)) # Should have capacity elements in the staging area self.assertTrue(sess.run(size) == capacity) @@ -216,16 +220,18 @@ class StageTest(test.TestCase): self.assertTrue(sess.run(size) == 0) def testMemoryLimit(self): - memory_limit = 512*1024 # 512K - chunk = 200*1024 # 256K + memory_limit = 512 * 1024 # 512K + chunk = 200 * 1024 # 256K capacity = memory_limit // chunk with ops.Graph().as_default() as G: with ops.device('/cpu:0'): x = array_ops.placeholder(dtypes.uint8, name='x') with ops.device(test.gpu_device_name()): - stager = data_flow_ops.StagingArea([dtypes.uint8, ], - memory_limit=memory_limit, shapes=[[]]) + stager = data_flow_ops.StagingArea( + [ + dtypes.uint8, + ], memory_limit=memory_limit, shapes=[[]]) stage = stager.put([x]) ret = stager.get() size = stager.size() @@ -264,9 +270,8 @@ class StageTest(test.TestCase): self.fail("Expected to timeout on iteration '{}' " "but instead timed out on iteration '{}' " "Staging Area size is '{}' and configured " - "capacity is '{}'.".format(capacity, i, - sess.run(size), - capacity)) + "capacity is '{}'.".format(capacity, i, sess.run(size), + capacity)) # Should have capacity elements in the staging area self.assertTrue(sess.run(size) == capacity) @@ -277,5 +282,6 @@ class StageTest(test.TestCase): self.assertTrue(sess.run(size) == 0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index 8792ab41a07aac2dc8c9fcb956c378054a309a41..a519b69b22cf51ab4f4173b215c21a71d83e9f99 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -17,10 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import traceback from tensorflow.python.client import session from tensorflow.python.eager import context +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -412,6 +414,17 @@ class TemplateTest(test.TestCase): self.assertEqual("s1_1/nested/dummy:0", v3[0].name) self.assertEqual("s1_1/nested_1/dummy:0", v3[1].name) + def test_graph_function_no_name(self): + with context.eager_mode(): + + def f(_, y): + return y + 1 + + partial = functools.partial(f, 1.0) + tmpl = template.make_template_internal( + "a", partial, create_graph_function_=True) + self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0) + @test_util.run_in_graph_and_eager_modes() def test_immediate_scope_creation(self): # Create templates in scope a then call in scope b. make_template should diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index 38205518b528b44313b1de83d06707b4498f061d..8ad29afd0a0f2e7fbaaf2bde956326e578466b1d 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -56,9 +56,11 @@ class TensordotTest(test_lib.TestCase): axes_ph = array_ops.placeholder(dtypes.int32) output = math_ops.tensordot(a_ph, b_ph, axes_ph) _ = sess.run( - [output], feed_dict={a_ph: a, - b_ph: b, - axes_ph: (a_axes, b_axes)}) + [output], feed_dict={ + a_ph: a, + b_ph: b, + axes_ph: (a_axes, b_axes) + }) def test_invalid_axes(self): a = [[1, 2], [3, 4]] @@ -81,28 +83,29 @@ class TensordotTest(test_lib.TestCase): with self.test_session() as sess: with self.assertRaises(errors_impl.InvalidArgumentError): _ = sess.run( - [output], feed_dict={a_ph: a, - b_ph: b, - axes_ph: axes_value}) + [output], feed_dict={ + a_ph: a, + b_ph: b, + axes_ph: axes_value + }) # Test case for 11950 def test_valid_axis(self): for axes_value in [1, 2], [[1], [2]], [[], []], 0: with self.test_session() as sess: - np_a = np.ones((3,3)) + np_a = np.ones((3, 3)) np_b = np.array([2, 3, 1])[None, None] np_ans = np.tensordot(np_a, np_b, axes_value) - tf_a = array_ops.ones((3,3), dtype=dtypes.float32) + tf_a = array_ops.ones((3, 3), dtype=dtypes.float32) tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None] tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value).eval() self.assertAllEqual(tf_ans.shape, np_ans.shape) self.assertAllEqual(tf_ans, np_ans) - def test_partial_shape_inference(self): - for axes in ([1],[0]), 1: + for axes in ([1], [0]), 1: a = array_ops.placeholder(dtypes.float32) b = array_ops.placeholder(dtypes.float32) output = math_ops.tensordot(a, b, axes) @@ -169,9 +172,11 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): axes = array_ops.placeholder(dtypes.int32) c = math_ops.tensordot(a, b, axes) tf_ans = sess.run( - c, feed_dict={a: a_np, - b: b_np, - axes: (a_dims_np, b_dims_np)}) + c, feed_dict={ + a: a_np, + b: b_np, + axes: (a_dims_np, b_dims_np) + }) else: tf_ans = math_ops.tensordot(a_np, b_np, (a_dims_np, b_dims_np)).eval() self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol) diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py index efb5b9f3641ceaebf1fd5285486b4a9bb93615cf..6ab931fdb97a8945ab610fda27a036693f0291e5 100644 --- a/tensorflow/python/kernel_tests/topk_op_test.py +++ b/tensorflow/python/kernel_tests/topk_op_test.py @@ -58,7 +58,7 @@ class TopKTest(test.TestCase): # Do some special casing of equality of indices: if indices # are not the same, but values are floating type, ensure that # the values are within epsilon of each other. - if not np.issubdtype(np_expected_values.dtype, np.float): + if not np.issubdtype(np_expected_values.dtype, np.floating): # Values are not floating point type; check indices exactly self.assertAllEqual(np_expected_indices, indices) else: diff --git a/tensorflow/python/kernel_tests/unstack_op_test.py b/tensorflow/python/kernel_tests/unstack_op_test.py index 84818755766a435c873f30e96dc0080af4f78b84..1ee6e0866a6b1c7a9b641a95403d45213f5dc0b4 100644 --- a/tensorflow/python/kernel_tests/unstack_op_test.py +++ b/tensorflow/python/kernel_tests/unstack_op_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functional tests for Unpack Op.""" +"""Functional tests for Unstack Op.""" from __future__ import absolute_import from __future__ import division @@ -49,7 +49,7 @@ class UnstackOpTest(test.TestCase): data = np.random.randn(*shape).astype(dtype) # Convert data to a single tensorflow tensor x = constant_op.constant(data) - # Unpack into a list of tensors + # Unstack into a list of tensors cs = array_ops.unstack(x, num=shape[0]) self.assertEqual(type(cs), list) self.assertEqual(len(cs), shape[0]) @@ -66,7 +66,7 @@ class UnstackOpTest(test.TestCase): data = np.random.randn(*shape).astype(dtype) # Convert data to a single tensorflow tensor x = constant_op.constant(data) - # Unpack into a list of tensors + # Unstack into a list of tensors cs = array_ops.unstack(x, num=shape[0]) self.assertEqual(type(cs), list) self.assertEqual(len(cs), shape[0]) diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index f60ebf58f6fe81bf75fa4db166449843e5595c7d..b16c8c002c98a0351d1fc55fce061695327a18c9 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -22,6 +22,7 @@ import operator import numpy as np +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -509,6 +510,15 @@ class VariablesTestCase(test.TestCase): "", repr(var)) + def testVariableNamesPreserveNameScopesWithDefun(self): + @function.defun + def create_variable(): + with ops.name_scope("foo"): + v = variables.Variable(0.0, name="bar") + self.assertEqual(v.name, "foo/bar:0") + with ops.get_default_graph().as_default(): + create_variable() + class IsInitializedTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index c6c7c4e26cb5e4eff22d1bb9d3e32c227c1c838f..e152f02d8e983364603053dc5c8d14b5dfaf3605 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -38,9 +38,8 @@ class XentTest(test.TestCase): dim = len(features.shape) - 1 one_only_on_dim = list(features.shape) one_only_on_dim[dim] = 1 - e = np.exp(features - np.reshape( - np.amax( - features, axis=dim), one_only_on_dim)) + e = np.exp( + features - np.reshape(np.amax(features, axis=dim), one_only_on_dim)) probs = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim) bp = (probs - labels) l = -np.sum(labels * np.log(probs + 1.0e-20), axis=dim) @@ -85,10 +84,10 @@ class XentTest(test.TestCase): def testRankTooLarge(self): for dtype in np.float16, np.float32: - np_features = np.array( - [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(dtype) - np_labels = np.array( - [[[0., 0., 0., 1.]], [[0., .5, .5, 0.]]]).astype(dtype) + np_features = np.array([[[1., 1., 1., 1.]], [[1., 2., 3., + 4.]]]).astype(dtype) + np_labels = np.array([[[0., 0., 0., 1.]], [[0., .5, .5, + 0.]]]).astype(dtype) self.assertRaisesRegexp(ValueError, "must be rank 2", gen_nn_ops._softmax_cross_entropy_with_logits, np_features, np_labels) @@ -121,8 +120,8 @@ class XentTest(test.TestCase): # = [1.3862, 1.9401] np_loss, np_backprop = self._npXent(np.array(features), np.array(labels)) self.assertAllClose( - np.array([[0.25, 0.25, 0.25, -0.75], - [0.0321, -0.4129, -0.2632, 0.6439]]), + np.array([[0.25, 0.25, 0.25, -0.75], [0.0321, -0.4129, -0.2632, + 0.6439]]), np_backprop, rtol=1.e-3, atol=1.e-3) @@ -168,15 +167,17 @@ class XentTest(test.TestCase): shape=[3, 4], dtype=dtypes.float64, name="f") - x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f, - name="xent") + x = nn_ops.softmax_cross_entropy_with_logits( + labels=l, logits=f, name="xent") err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3]) # Check that no extra computation performed. When only first derivative is requested, # second derivative must not be computed. So when there is no second derivative, # there is no `BatchMatMul` op in the graph. - op_names = [op.op_def.name for op in sess.graph.get_operations() if op.op_def] - self.assertNotIn('BatchMatMul', op_names) + op_names = [ + op.op_def.name for op in sess.graph.get_operations() if op.op_def + ] + self.assertNotIn("BatchMatMul", op_names) print("cross entropy gradient err = ", err) self.assertLess(err, 5e-8) @@ -193,24 +194,29 @@ class XentTest(test.TestCase): shape=[3, 4], dtype=dtypes.float64, name="f") - x = nn_ops.softmax_cross_entropy_with_logits_v2(labels=l, logits=f, - name="xent") + x = nn_ops.softmax_cross_entropy_with_logits_v2( + labels=l, logits=f, name="xent") err = gradient_checker.compute_gradient_error(l, [3, 4], x, [3]) self.assertLess(err, 5e-8) def testSecondGradient(self): with self.test_session() as sess: - l = constant_op.constant([0.0, 0.0, 1.0/3, 0.0, - 1.0/3, 0.0, 0.0, 0.0, - 0.0, 0.5/3, 0.0, 0.5/3], shape=[12], - dtype=dtypes.float64, name="l") - f = constant_op.constant([0.1, 0.2, 0.3, 0.4, - 0.1, 0.4, 0.9, 1.6, - 0.1, 0.8, 2.7, 6.4], shape=[12], - dtype=dtypes.float64, name="f") - x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f, - name="xent") + l = constant_op.constant( + [ + 0.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, 0.0, 0.0, 0.0, 0.0, 0.5 / 3, 0.0, + 0.5 / 3 + ], + shape=[12], + dtype=dtypes.float64, + name="l") + f = constant_op.constant( + [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4], + shape=[12], + dtype=dtypes.float64, + name="f") + x = nn_ops.softmax_cross_entropy_with_logits( + labels=l, logits=f, name="xent") loss = math_ops.reduce_sum(x) gradients = gradients_impl.gradients(loss, [f])[0] @@ -219,20 +225,23 @@ class XentTest(test.TestCase): # Check that second derivative is calculated. # (it is equivalent to being `BatchMatMul` op in the graph because of implementation of xentropy grad) - op_names = [op.op_def.name for op in sess.graph.get_operations() if op.op_def] - self.assertIn('BatchMatMul', op_names) + op_names = [ + op.op_def.name for op in sess.graph.get_operations() if op.op_def + ] + self.assertIn("BatchMatMul", op_names) print("cross entropy hessian err = ", err) self.assertLess(err, 5e-8) def testWrapper(self): - features = np.array( - [[[1., 1., 1., 1.], [1., 2., 3., 4.]], - [[2., 3., 4., 5.], [6., 7., 8., 9.]], - [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32) + features = np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]], + [[2., 3., 4., 5.], [6., 7., 8., 9.]], + [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype( + np.float32) labels = np.array([[[0., 0., 0., 1.], [0., 1., 0., 0.]], [[0., 0.5, 0.5, 0.], [0.5, 0.5, 0., 0.]], - [[0., 1., 0., 0.], [0., 0., 1., 0.]]]).astype(np.float32) + [[0., 1., 0., 0.], [0., 0., 1., 0.]]]).astype( + np.float32) self._testXentWrapper(features, labels, dim=0, use_gpu=False) self._testXentWrapper(features, labels, dim=0, use_gpu=True) self._testXentWrapper(features, labels, dim=1, use_gpu=False) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 00faf3faa1004ddbb310137500dbec0db4a52196..8314c4aa87a5b54effc44c371703267517ffa07d 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -31,13 +31,16 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import utils as layers_util +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export +@tf_export('layers.Layer') class Layer(object): """Base layer class. @@ -99,8 +102,16 @@ class Layer(object): raise TypeError('Keyword argument not understood:', kwarg) # Mutable properties + # Indicates whether the layer's weights are updated during training + # and whether the layer's updates are run during training self.trainable = trainable + # A stateful layer is a layer whose updates are run during inference too, + # for instance stateful RNNs. + self.stateful = False + # Indicates whether `build` needs to be called upon layer call, to create + # the layer's weights. self.built = False + # Provides information about which inputs are compatible with the layer. self.input_spec = None if activity_regularizer and context.in_eager_mode(): @@ -116,8 +127,6 @@ class Layer(object): self._losses = [] self._reuse = kwargs.get('_reuse') self._graph = ops.get_default_graph() - self._per_input_losses = {} - self._per_input_updates = {} self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name call_fn_args = estimator_util.fn_args(self.call) self._compute_previous_mask = ('mask' in call_fn_args or @@ -131,9 +140,6 @@ class Layer(object): self._init_set_name(name) - # Holds functions for creating regularizer ops. - self._regularizer_factories = [] - # Determine variable scope. scope = kwargs.get('_scope') if scope: @@ -223,6 +229,8 @@ class Layer(object): def updates(self): if context.in_eager_mode(): raise RuntimeError('Layer.updates not supported in Eager mode.') + if not self.trainable and not self.stateful: + return [] return self._updates def add_update(self, updates, inputs=None): @@ -242,39 +250,34 @@ class Layer(object): Arguments: updates: Update op, or list/tuple of update ops. - inputs: Optional input tensor(s) that the update(s) depend on. Must - match the `inputs` argument passed to the `__call__` method at the time - the updates are created. If `None` is passed, the updates are assumed - to be unconditional, and will apply across all dataflows of the layer. + inputs: If anything other than None is passed, it signals the updates + are conditional on some of the layer's inputs, + and thus they should only be run where these inputs are available. + This is the case for BatchNormalization updates, for instance. + If None, the updates will be taken into account unconditionally, + and you are responsible for making sure that any dependency they might + have is available at runtime. + A step counter might fall into this category. """ if context.in_eager_mode(): return # Updates already applied when in eager mode. + updates = _to_list(updates) - if not updates: - return + updates = [x if isinstance(x, ops.Operation) + else ops.convert_to_tensor(x) for x in updates] self._updates += updates - if inputs is not None: - inputs = nest.flatten(inputs) - if not inputs: - inputs = None - if inputs is not None: - # We compute an ID that uniquely identifies the list of tensors. - # This ID is order-sensitive. - inputs_hash = layers_util.object_list_uid(inputs) + if inputs is None: + for u in updates: + u._unconditional_update = True # pylint: disable=protected-access else: - inputs_hash = None - if inputs_hash not in self._per_input_updates: - self._per_input_updates[inputs_hash] = [] - self._per_input_updates[inputs_hash] += updates + for u in updates: + u._unconditional_update = False # pylint: disable=protected-access def get_updates_for(self, inputs): """Retrieves updates relevant to a specific set of inputs. Arguments: inputs: Input tensor or list/tuple of input tensors. - Must match the `inputs` argument passed to the `__call__` method - at the time the updates were created. - If you pass `inputs=None`, unconditional updates are returned. Returns: List of update ops of the layer that depend on `inputs`. @@ -283,32 +286,24 @@ class Layer(object): RuntimeError: If called in Eager mode. """ if context.in_eager_mode(): - raise RuntimeError('Layer.get_updates_for not supported in Eager mode.') - if inputs is not None: - inputs = nest.flatten(inputs) - if not inputs: - inputs = None - if inputs is not None: - inputs_hash = layers_util.object_list_uid(inputs) - else: - inputs_hash = None - return self._per_input_updates.get(inputs_hash, []) - - def _get_regularizer_factories(self): - try: - # Some subclasses of Layer do not use its constructor. - return self._regularizer_factories - except AttributeError: - self._regularizer_factories = [] - return self._regularizer_factories - - def _maybe_create_variable_regularizers(self): - """Creates added but uninstantiated regularizers.""" - factories = self._get_regularizer_factories() - if factories: - for factory in factories: - factory() - factories[:] = [] + raise RuntimeError('`get_updates_for()` not supported in Eager mode.') + + # Updates disabled if layer is not trainable and not explicitly stateful. + if not self.trainable and not self.stateful: + return [] + + if inputs is None: + # Requesting unconditional updates. + return [x for x in self.updates if x._unconditional_update] # pylint: disable=protected-access + + # Requesting input-conditional updates. + inputs = nest.flatten(inputs) + reachable = layers_util.get_reachable_from_inputs(inputs, self.updates) + updates = [] + for update in self.updates: + if update in reachable: + updates.append(update) + return updates @property def losses(self): @@ -321,7 +316,6 @@ class Layer(object): Returns: A list of tensors. """ - self._maybe_create_variable_regularizers() if context.in_eager_mode(): # _losses may only contain variable regularization losses when executing # eagerly, and they have been saved as lambdas to be executed when @@ -349,9 +343,11 @@ class Layer(object): Arguments: losses: Loss tensor, or list/tuple of tensors. - inputs: Optional input tensor(s) that the loss(es) depend on. Must - match the `inputs` argument passed to the `__call__` method at the time - the losses are created. If `None` is passed, the losses are assumed + inputs: If anything other than None is passed, it signals the losses + are conditional on some of the layer's inputs, + and thus they should only be run where these inputs are available. + This is the case for activity regularization losses, for instance. + If `None` is passed, the losses are assumed to be unconditional, and will apply across all dataflows of the layer (e.g. weight regularization losses). @@ -359,24 +355,25 @@ class Layer(object): RuntimeError: If called in Eager mode. """ if context.in_eager_mode(): + # TODO(fchollet): it should be possible (and highly desirable) to support + # `add_loss` in eager mode. This allows great convenience and flexibility + # in defining custom losses on the fly (e.g. in VAEs). + # Simply appending the loss value to `self._losses` + # is the correct behavior. + # The only caveat is that we need to force the user to only call + # `add_loss` from inside a model or Layer's `call` method + # (otherwise the loss computation cannot be backproped through). raise RuntimeError('Layer.add_loss not supported in Eager mode.') + losses = _to_list(losses) - if not losses: - return self._losses += losses - if inputs is not None: - inputs = nest.flatten(inputs) - if not inputs: - inputs = None - if inputs is not None: - # We compute an ID that uniquely identifies the list of tensors. - # This ID is order-sensitive. - inputs_hash = layers_util.object_list_uid(inputs) + if inputs is None: + for loss in losses: + loss._unconditional_loss = True # pylint: disable=protected-access else: - inputs_hash = None - if inputs_hash not in self._per_input_losses: - self._per_input_losses[inputs_hash] = [] - self._per_input_losses[inputs_hash] += losses + for loss in losses: + loss._unconditional_loss = False # pylint: disable=protected-access + # TODO(fchollet): deprecate collection below. _add_elements_to_collection(losses, ops.GraphKeys.REGULARIZATION_LOSSES) def get_losses_for(self, inputs): @@ -384,10 +381,6 @@ class Layer(object): Arguments: inputs: Input tensor or list/tuple of input tensors. - Must match the `inputs` argument passed to the `__call__` - method at the time the losses were created. - If you pass `inputs=None`, unconditional losses are returned, - such as weight regularization losses. Returns: List of loss tensors of the layer that depend on `inputs`. @@ -397,16 +390,23 @@ class Layer(object): """ if context.in_eager_mode(): raise RuntimeError('Layer.get_losses_for not supported in Eager mode.') - if inputs is not None: - inputs = nest.flatten(inputs) - if not inputs: - inputs = None - if inputs is not None: - inputs_hash = layers_util.object_list_uid(inputs) - else: - inputs_hash = None - self._maybe_create_variable_regularizers() - return self._per_input_losses.get(inputs_hash, []) + + if inputs is None: + # Requesting unconditional losses. + return [x for x in self.losses if x._unconditional_loss] # pylint: disable=protected-access + + # Requesting input-conditional losses. + inputs = nest.flatten(inputs) + # Retrieve the set of tensors in the TF graph that depend on `inputs`. + # The losses we want to return will be part of this set. + # To avoid unnecessary work, we stop the search in case all of + # `self.losses` have been retrieved. + reachable = layers_util.get_reachable_from_inputs(inputs, self.losses) + losses = [] + for loss in self.losses: + if loss in reachable: + losses.append(loss) + return losses def build(self, _): """Creates the variables of the layer.""" @@ -500,13 +500,30 @@ class Layer(object): instance is returned. Raises: - RuntimeError: If called in Eager mode with partioned variable - regularization. + RuntimeError: If called with partioned variable regularization and + eager execution is enabled. """ - in_graph_mode = context.in_graph_mode() - if in_graph_mode: - existing_variables = set(tf_variables.global_variables()) + # `init_graph` should point to the graph in which variable initialization + # will occur; it should be None if and only if initialization will take + # place in the eager context. + init_graph = None + if context.in_graph_mode(): + default_graph = ops.get_default_graph() + if default_graph.building_function: + with ops.init_scope(): + # Retrieve the variables from the graph into which variables + # will be lifted; if initialization ops will be lifted into + # the eager context, then there is nothing to retrieve, since variable + # collections are not supported when eager execution is enabled. + if context.in_graph_mode(): + init_graph = ops.get_default_graph() + existing_variables = set(tf_variables.global_variables()) + else: + # Initialization ops will not be lifted out of the default graph. + init_graph = default_graph + existing_variables = set(tf_variables.global_variables()) + if dtype is None: dtype = self.dtype or dtypes.float32 @@ -523,54 +540,51 @@ class Layer(object): trainable=trainable and self.trainable, partitioner=partitioner) - if in_graph_mode: - if (trainable and self.trainable - and variable not in tf_variables.trainable_variables()): - # A custom getter / variable scope overrode the trainable flag. - trainable = False + if init_graph is not None: # pylint: disable=protected-access + # The variable was created and initialized in a graph. + if variable in existing_variables: # To match the behavior of tf.get_variable(), we only apply # regularization if the variable is newly created. return variable - if regularizer: - def regularizer_factory(): - if context.in_graph_mode(): - with vs.variable_scope(scope, reuse=reuse, - auxiliary_name_scope=False): - with ops.name_scope(self._name_scope_name(scope)): - if isinstance(variable, tf_variables.PartitionedVariable): - for v in variable: - with ops.colocate_with(v.op): - with ops.name_scope(name + '/Regularizer'): - regularization = regularizer(v) - if regularization is not None: - self.add_loss(regularization) - else: - with ops.colocate_with(variable.op): - with ops.name_scope(name + '/Regularizer'): - regularization = regularizer(variable) - if regularization is not None: - self.add_loss(regularization) + with init_graph.as_default(): + trainable_variables = tf_variables.trainable_variables() + if (trainable and self.trainable and + variable not in trainable_variables): + # A custom getter / variable scope overrode the trainable flag. + trainable = False + + if regularizer: + if isinstance(variable, tf_variables.PartitionedVariable): + for v in variable: + with ops.colocate_with(v.op): + with ops.name_scope(name + '/Regularizer'): + regularization = regularizer(v) + if regularization is not None: + self.add_loss(regularization) else: - if isinstance(variable, tf_variables.PartitionedVariable): - raise RuntimeError( - 'Partitioned variable regularization is not yet ' - 'supported when executing eagerly. File a feature request' - 'if this is important to you.') - # Save a zero-argument lambda which runs the regularizer on the - # variable, to be executed when `Layer.losses` is requested. - # This makes losses responsive to variable updates when - # executing eagerly. - self._losses.append(lambda: regularizer(variable)) - - if hasattr(self, '_defer_regularizers') and self._defer_regularizers: - # _defer_regularizers exists and is set to True if `build` was - # invoked in `__call__`: deferring regularizer construction - # prevents the regularizer from being created in an `init_scope`. - self._get_regularizer_factories().append(regularizer_factory) - else: - regularizer_factory() + with ops.colocate_with(variable.op): + with ops.name_scope(name + '/Regularizer'): + regularization = regularizer(variable) + if regularization is not None: + self.add_loss(regularization) + elif regularizer: # and initialization took place in an eager context + if isinstance(variable, tf_variables.PartitionedVariable): + raise RuntimeError( + 'Partitioned variable regularization is not yet ' + 'supported when executing eagerly. File a feature request' + 'if this is important to you.') + # Save a zero-argument lambda which runs the regularizer on the + # variable, to be executed when `Layer.losses` is requested. + # This makes losses responsive to variable updates when executing + # eagerly. + # + # TODO(akshayka): Do the same for graphs as well, so that losses + # collected in a while_loop can be run outside its control flow + # context and so that losses won't be swallowed up by graph functions + # (i.e., `.losses()` should always create regularizers). + self._losses.append(lambda: regularizer(variable)) if trainable: self._trainable_weights.append(variable) @@ -644,6 +658,7 @@ class Layer(object): else: scope_context_manager = vs.variable_scope( self._scope, reuse=self._reuse, auxiliary_name_scope=False) + input_shapes = None with scope_context_manager as scope: with ops.name_scope(self._name_scope_name(scope)): if not self.built: @@ -670,15 +685,7 @@ class Layer(object): except AttributeError: pass input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) - - # Signal to `add_variable` that regularizer construction should be - # deferred. - self._defer_regularizers = True - with ops.init_scope(): - self.build(input_shapes) - # Create any regularizers added by `build`. - self._maybe_create_variable_regularizers() - self._defer_regularizers = False + self.build(input_shapes) try: # Note: not all sub-classes of Layer call Layer.__init__ (especially # the ones under tensorflow/python/keras). Hence we recompute this @@ -701,6 +708,9 @@ class Layer(object): else: # Deferred mode behavior: use `compute_output_shape` to # infer the number of outputs of the layer and their shapes. + if input_shapes is None: + input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) + output_shapes = self.compute_output_shape(input_shapes) output_shapes = nest.flatten(output_shapes) outputs = [ @@ -722,12 +732,10 @@ class Layer(object): activity_regularization = self._activity_regularizer(output) self.add_loss(activity_regularization, inputs=inputs) - if not in_deferred_mode: - # TODO(fchollet): consider how masking will work with deferred mode. - # Handle mask computation and propagation to the next layer. + # TODO(fchollet): consider enabling masking for Eager mode. if hasattr(self, 'compute_mask'): output_mask = self.compute_mask(inputs, previous_mask) - if isinstance(outputs, list): + if isinstance(outputs, (list, tuple)): if output_mask is None: output_mask = [None for _ in range(len(outputs))] for x, m in zip(outputs, output_mask): @@ -1226,6 +1234,7 @@ class Layer(object): ', found shape=' + str(shape)) +@tf_export('keras.layers.InputSpec', 'layers.InputSpec') class InputSpec(object): """Specifies the ndim, dtype and shape of every input to a layer. @@ -1263,6 +1272,15 @@ class InputSpec(object): self.min_ndim = min_ndim self.axes = axes or {} + def __repr__(self): + spec = [('dtype=' + str(self.dtype)) if self.dtype else '', + ('shape=' + str(self.shape)) if self.shape else '', + ('ndim=' + str(self.ndim)) if self.ndim else '', + ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '', + ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '', + ('axes=' + str(self.axes)) if self.axes else ''] + return 'InputSpec(%s)' % ', '.join(x for x in spec if x) + class Node(object): """A `Node` describes the connectivity between two layers. @@ -1387,7 +1405,10 @@ class _DeferredTensor(object): def __init__(self, shape, dtype, name=None): self.shape = tensor_shape.TensorShape(shape) - self.dtype = dtypes.as_dtype(dtype) + if dtype is None: + self.dtype = dtypes.as_dtype(np.float32) + else: + self.dtype = dtypes.as_dtype(dtype) self.name = name def get_shape(self): diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 06ba214c0fc60202c773f8f231b17c3b728f5c52..91b8988d31c1f04be8134733e5e919c738ccb74f 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -555,6 +556,93 @@ class BaseLayerTest(test.TestCase): self.assertEqual(len(layer.trainable_variables), 1) self.assertEqual(layer.variables[0].graph, outer_graph) + def testGetUpdateFor(self): + + class MyLayer(base_layers.Layer): + + def build(self, input_shape): + self.a = self.add_variable('a', + (), + dtypes.float32, + trainable=False) + self.b = self.add_variable('b', + (), + dtypes.float32, + trainable=False) + self.add_update(state_ops.assign_add(self.a, 1., name='b_update')) + self.built = True + + def call(self, inputs): + self.add_update(state_ops.assign_add(self.a, inputs, name='a_update'), + inputs=True) + return inputs + 1 + + layer = MyLayer() + inputs = array_ops.placeholder(dtypes.float32, (), 'inputs') + intermediate_inputs = inputs + 1 + outputs = layer.apply(intermediate_inputs) + + self.assertEqual(len(layer.updates), 2) + self.assertEqual(len(layer.get_updates_for(None)), 1) + self.assertEqual(len(layer.get_updates_for([inputs])), 1) + self.assertEqual(len(layer.get_updates_for([intermediate_inputs])), 1) + self.assertEqual(len(layer.get_updates_for([outputs])), 0) + + # Call same layer on new input, creating one more conditional update + inputs = array_ops.placeholder(dtypes.float32, (), 'inputs') + intermediate_inputs = inputs + 1 + outputs = layer.apply(intermediate_inputs) + + self.assertEqual(len(layer.updates), 3) + self.assertEqual(len(layer.get_updates_for(None)), 1) + # Check that we are successfully filtering out irrelevant updates + self.assertEqual(len(layer.get_updates_for([inputs])), 1) + self.assertEqual(len(layer.get_updates_for([intermediate_inputs])), 1) + self.assertEqual(len(layer.get_updates_for([outputs])), 0) + + def testGetLossesFor(self): + + class MyLayer(base_layers.Layer): + + def build(self, input_shape): + self.a = self.add_variable('a', + (), + dtypes.float32, + trainable=False) + self.b = self.add_variable('b', + (), + dtypes.float32, + trainable=False) + self.add_loss(self.a) + self.built = True + + def call(self, inputs): + self.add_loss(inputs, inputs=True) + return inputs + 1 + + layer = MyLayer() + inputs = array_ops.placeholder(dtypes.float32, (), 'inputs') + intermediate_inputs = inputs + 1 + outputs = layer.apply(intermediate_inputs) + + self.assertEqual(len(layer.losses), 2) + self.assertEqual(len(layer.get_losses_for(None)), 1) + self.assertEqual(len(layer.get_losses_for([inputs])), 1) + self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1) + self.assertEqual(len(layer.get_losses_for([outputs])), 0) + + # Call same layer on new input, creating one more conditional loss + inputs = array_ops.placeholder(dtypes.float32, (), 'inputs') + intermediate_inputs = inputs + 1 + outputs = layer.apply(intermediate_inputs) + + self.assertEqual(len(layer.losses), 3) + self.assertEqual(len(layer.get_losses_for(None)), 1) + # Check that we are successfully filtering out irrelevant losses + self.assertEqual(len(layer.get_losses_for([inputs])), 1) + self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1) + self.assertEqual(len(layer.get_losses_for([outputs])), 0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index e8dba3cea321a415b84e1ec89fd7b021e2b272d0..bb10fe5e8bfd26e4877fb6aef73980a30f62bb5d 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops +from tensorflow.python.util.tf_export import tf_export class _Conv(base.Layer): @@ -222,6 +223,7 @@ class _Conv(base.Layer): new_space) +@tf_export('layers.Conv1D') class Conv1D(_Conv): """1D convolution layer (e.g. temporal convolution). @@ -311,6 +313,7 @@ class Conv1D(_Conv): name=name, **kwargs) +@tf_export('layers.conv1d') def conv1d(inputs, filters, kernel_size, @@ -411,6 +414,7 @@ def conv1d(inputs, return layer.apply(inputs) +@tf_export('layers.Conv2D') class Conv2D(_Conv): """2D convolution layer (e.g. spatial convolution over images). @@ -507,6 +511,7 @@ class Conv2D(_Conv): name=name, **kwargs) +@tf_export('layers.conv2d') def conv2d(inputs, filters, kernel_size, @@ -614,6 +619,7 @@ def conv2d(inputs, return layer.apply(inputs) +@tf_export('layers.Conv3D') class Conv3D(_Conv): """3D convolution layer (e.g. spatial convolution over volumes). @@ -711,6 +717,7 @@ class Conv3D(_Conv): name=name, **kwargs) +@tf_export('layers.conv3d') def conv3d(inputs, filters, kernel_size, @@ -980,6 +987,7 @@ class _SeparableConv(_Conv): raise NotImplementedError +@tf_export('layers.SeparableConv1D') class SeparableConv1D(_SeparableConv): """Depthwise separable 1D convolution. @@ -1088,10 +1096,10 @@ class SeparableConv1D(_SeparableConv): def call(self, inputs): if self.data_format == 'channels_last': - strides = (1, 1) + self.strides + (1,) + strides = (1,) + self.strides * 2 + (1,) spatial_start_dim = 1 else: - strides = (1, 1, 1) + self.strides + strides = (1, 1) + self.strides * 2 spatial_start_dim = 2 # Explicitly broadcast inputs and kernels to 4D. @@ -1123,6 +1131,7 @@ class SeparableConv1D(_SeparableConv): return outputs +@tf_export('layers.SeparableConv2D') class SeparableConv2D(_SeparableConv): """Depthwise separable 2D convolution. @@ -1260,6 +1269,7 @@ class SeparableConv2D(_SeparableConv): return outputs +@tf_export('layers.separable_conv1d') def separable_conv1d(inputs, filters, kernel_size, @@ -1376,6 +1386,7 @@ def separable_conv1d(inputs, return layer.apply(inputs) +@tf_export('layers.separable_conv2d') def separable_conv2d(inputs, filters, kernel_size, @@ -1497,6 +1508,7 @@ def separable_conv2d(inputs, return layer.apply(inputs) +@tf_export('layers.Conv2DTranspose') class Conv2DTranspose(Conv2D): """Transposed 2D convolution layer (sometimes called 2D Deconvolution). @@ -1695,6 +1707,7 @@ class Conv2DTranspose(Conv2D): return tensor_shape.TensorShape(output_shape) +@tf_export('layers.conv2d_transpose') def conv2d_transpose(inputs, filters, kernel_size, @@ -1790,6 +1803,7 @@ def conv2d_transpose(inputs, return layer.apply(inputs) +@tf_export('layers.Conv3DTranspose') class Conv3DTranspose(Conv3D): """Transposed 3D convolution layer (sometimes called 3D Deconvolution). @@ -2018,6 +2032,7 @@ class Conv3DTranspose(Conv3D): return tensor_shape.TensorShape(output_shape) +@tf_export('layers.conv3d_transpose') def conv3d_transpose(inputs, filters, kernel_size, diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index e5b93a54f79bef68d96ab7efccc883033e7001c7..6970bf9234f5a31ee8093069ac1c933bcdb6f103 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -36,9 +36,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import standard_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export('layers.Dense') class Dense(base.Layer): """Densely-connected layer class. @@ -49,9 +52,6 @@ class Dense(base.Layer): and `bias` is a bias vector created by the layer (only if `use_bias` is `True`). - Note: if the input to the layer has a rank greater than 2, then it is - flattened prior to the initial matrix multiply by `kernel`. - Arguments: units: Integer or Long, dimensionality of the output space. activation: Activation function (callable). Set it to None to maintain a @@ -176,6 +176,7 @@ class Dense(base.Layer): return input_shape[:-1].concatenate(self.units) +@tf_export('layers.dense') def dense( inputs, units, activation=None, @@ -199,9 +200,6 @@ def dense( and `bias` is a bias vector created by the layer (only if `use_bias` is `True`). - Note: if the `inputs` tensor has a rank greater than 2, then it is - flattened prior to the initial matrix multiply by `kernel`. - Arguments: inputs: Tensor input. units: Integer or Long, dimensionality of the output space. @@ -230,7 +228,8 @@ def dense( by the same name. Returns: - Output tensor. + Output tensor the same shape as `inputs` except the last dimension is of + size `units`. Raises: ValueError: if eager execution is enabled. @@ -253,6 +252,7 @@ def dense( return layer.apply(inputs) +@tf_export('layers.Dropout') class Dropout(base.Layer): """Applies Dropout to the input. @@ -292,13 +292,7 @@ class Dropout(base.Layer): # shapes with dynamically sized inputs. if self.noise_shape is None: return self.noise_shape - - symbolic_shape = array_ops.shape(inputs) - noise_shape = [ - symbolic_shape[axis] if shape is None else shape - for axis, shape in enumerate(self.noise_shape) - ] - return noise_shape + return nn_ops._get_noise_shape(inputs, self.noise_shape) def call(self, inputs, training=False): @@ -314,6 +308,7 @@ class Dropout(base.Layer): return input_shape +@tf_export('layers.dropout') def dropout(inputs, rate=0.5, noise_shape=None, @@ -355,6 +350,7 @@ def dropout(inputs, return layer.apply(inputs, training=training) +@tf_export('layers.Flatten') class Flatten(base.Layer): """Flattens an input tensor while preserving the batch axis (axis 0). @@ -391,6 +387,7 @@ class Flatten(base.Layer): return tensor_shape.TensorShape(output_shape) +@tf_export('layers.flatten') def flatten(inputs, name=None): """Flattens an input tensor while preserving the batch axis (axis 0). diff --git a/tensorflow/python/layers/maxout.py b/tensorflow/python/layers/maxout.py index ed048845a0b88344b357836a838231677cbf40ce..765a1c4fdafdfdc5d3ea6629d4d9290d8b658902 100644 --- a/tensorflow/python/layers/maxout.py +++ b/tensorflow/python/layers/maxout.py @@ -31,15 +31,18 @@ from tensorflow.python.layers import base def maxout(inputs, num_units, axis=-1, name=None): """Adds a maxout op from https://arxiv.org/abs/1302.4389 - "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, + "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron + Courville, Yoshua Bengio - Usually the operation is performed in the filter/channel dimension. This can also be + Usually the operation is performed in the filter/channel dimension. This can + also be used after fully-connected layers to reduce number of features. Arguments: inputs: Tensor input - num_units: Specifies how many features will remain after maxout in the `axis` dimension + num_units: Specifies how many features will remain after maxout in the `axis` + dimension (usually channel). This must be multiple of number of `axis`. axis: The dimension where max pooling will be performed. Default is the last dimension. @@ -57,15 +60,18 @@ def maxout(inputs, num_units, axis=-1, name=None): class MaxOut(base.Layer): """Adds a maxout op from https://arxiv.org/abs/1302.4389 - "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, Yoshua + "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron + Courville, Yoshua Bengio - Usually the operation is performed in the filter/channel dimension. This can also be + Usually the operation is performed in the filter/channel dimension. This can + also be used after fully-connected layers to reduce number of features. Arguments: inputs: Tensor input - num_units: Specifies how many features will remain after maxout in the `axis` dimension + num_units: Specifies how many features will remain after maxout in the + `axis` dimension (usually channel). This must be multiple of number of `axis`. axis: The dimension where max pooling will be performed. Default is the @@ -79,13 +85,8 @@ class MaxOut(base.Layer): ValueError: if num_units is not multiple of number of features. """ - def __init__(self, - num_units, - axis=-1, - name=None, - **kwargs): - super(MaxOut, self).__init__( - name=name, trainable=False, **kwargs) + def __init__(self, num_units, axis=-1, name=None, **kwargs): + super(MaxOut, self).__init__(name=name, trainable=False, **kwargs) self.axis = axis self.num_units = num_units @@ -95,8 +96,8 @@ class MaxOut(base.Layer): num_channels = shape[self.axis] if num_channels % self.num_units: raise ValueError('number of features({}) is not ' - 'a multiple of num_units({})' - .format(num_channels, self.num_units)) + 'a multiple of num_units({})'.format( + num_channels, self.num_units)) shape[self.axis] = -1 shape += [num_channels // self.num_units] @@ -104,6 +105,7 @@ class MaxOut(base.Layer): for i in range(len(shape)): if shape[i] is None: shape[i] = gen_array_ops.shape(inputs)[i] - outputs = math_ops.reduce_max(gen_array_ops.reshape(inputs, shape), -1, keep_dims=False) + outputs = math_ops.reduce_max( + gen_array_ops.reshape(inputs, shape), -1, keepdims=False) return outputs diff --git a/tensorflow/python/layers/network.py b/tensorflow/python/layers/network.py index ade57da411d67241e027e0bb559e49bc3c077e6d..9f16559687c52a1149b78a1ccca796cadd8208d0 100644 --- a/tensorflow/python/layers/network.py +++ b/tensorflow/python/layers/network.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export class InputLayer(base.Layer): @@ -117,6 +118,7 @@ class InputLayer(base.Layer): output_tensors=[input_tensor]) +@tf_export('layers.Input') def Input( # pylint: disable=invalid-name shape=None, batch_size=None, @@ -221,13 +223,36 @@ class GraphNetwork(base.Layer): - get_layer: retrieves a child layer by name or index in the graph. Raises: - RuntimeError: If created in Eager mode. + TypeError: If created when eager execution is enabled, with inputs that + don't come from a call to `Input` or outputs that don't come from layers. """ def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called + if isinstance(inputs, (list, tuple)): + self.inputs = list(inputs) # Tensor or list of tensors. + else: + self.inputs = [inputs] + if isinstance(outputs, (list, tuple)): + self.outputs = list(outputs) + else: + self.outputs = [outputs] + if context.in_eager_mode(): - # TODO(fchollet): check that all inputs and outputs are DeferredTensors. - pass + # Check that all inputs/outputs are DeferredTensors. + for tensor in self.inputs: + if not isinstance(tensor, base._DeferredTensor): # pylint: disable=protected-access + raise TypeError('When eager execution is enabled, ' + 'inputs must come from a call to ' + '`tf.keras.Input` (called after ' + 'tfe.enable_eager_execution()). ' + 'Received invalid input: ' + str(tensor)) + for tensor in self.outputs: + if not isinstance(tensor, base._DeferredTensor): # pylint: disable=protected-access + raise TypeError('When eager execution is enabled, ' + 'outputs must come from a call to ' + 'a layer (called after ' + 'tfe.enable_eager_execution()). ' + 'Received invalid output: ' + str(tensor)) self._init_set_name(name) self._activity_regularizer = None @@ -248,32 +273,22 @@ class GraphNetwork(base.Layer): self.built = True # A GraphNetwork does not create weights of its own, thus has no dtype. self._dtype = None + self._is_graph_network = True # The following are implemented as property functions: # self.trainable_weights # self.non_trainable_weights # self.input_spec # Private attributes to implement compatibility with Layer. - self._per_input_losses = {} - self._per_input_updates = {} self._updates = [] self._losses = [] self._scope = None self._reuse = None self._graph = ops.get_default_graph() - # GraphNetwork-specific properties. - if isinstance(inputs, (list, tuple)): - self.inputs = list(inputs) # Tensor or list of tensors. - else: - self.inputs = [inputs] - if isinstance(outputs, (list, tuple)): - self.outputs = list(outputs) - else: - self.outputs = [outputs] # All layers in order of horizontal graph traversal. # Entries are unique. Includes input and output layers. - self.layers = [] + self._layers = [] # Check for redundancy in inputs. if len(set(self.inputs)) != len(self.inputs): @@ -483,7 +498,7 @@ class GraphNetwork(base.Layer): # here we order them by traversal order. layers_for_depth.sort(key=lambda x: layer_indices[x]) layers.extend(layers_for_depth) - self.layers = layers + self._layers = layers self._layers_by_depth = layers_by_depth # Get sorted list of node depths. @@ -542,6 +557,10 @@ class GraphNetwork(base.Layer): input_tensors=self.inputs, output_tensors=self.outputs) + @property + def layers(self): + return self._layers + def get_layer(self, name=None, index=None): """Retrieves a layer based on either its name (unique) or index. @@ -574,32 +593,86 @@ class GraphNetwork(base.Layer): return layer raise ValueError('No such layer: ' + name) + @property + def stateful(self): + return any([(hasattr(layer, 'stateful') and layer.stateful) + for layer in self.layers]) + @property def updates(self): """Retrieve the network's updates. Will only include updates that are either unconditional, or conditional on inputs to this model - (e.g. will not include updates that depend on tensors - that aren't inputs to this model). + (e.g. will not include updates that were created by layers of this model + outside of the model). + + Effectively, `network.updates` behaves like `layer.updates`. + + Concrete example: + + ```python + bn = keras.layers.BatchNormalization() + x1 = keras.layers.Input(shape=(10,)) + _ = bn(x1) # This creates 2 updates. + + x2 = keras.layers.Input(shape=(10,)) + y2 = bn(x2) # This creates 2 more updates. + + # The BN layer has now 4 updates. + self.assertEqual(len(bn.updates), 4) + + # Let's create a model from x2 to y2. + model = keras.models.Model(x2, y2) + + # The model does not list all updates from its underlying layers, + # but only the updates that are relevant to it. Updates created by layers + # outside of the model are discarded. + self.assertEqual(len(model.updates), 2) + + # If you keep calling the model, you append to its updates, just like + # what happens for a layer. + x3 = keras.layers.Input(shape=(10,)) + y3 = model(x3) + self.assertEqual(len(model.updates), 4) + + # But if you call the inner BN layer independently, you don't affect + # the model's updates. + x4 = keras.layers.Input(shape=(10,)) + _ = bn(x4) + self.assertEqual(len(model.updates), 4) + ``` Returns: A list of update ops. """ + if context.in_eager_mode(): + return [] + + if not self.trainable and not self.stateful: + return [] + updates = [] for layer in self.layers: - if hasattr(layer, 'updates'): - # Collect updates that are dependent on inputs - # that are part of the model. - for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access - node_key = _make_node_key(layer.name, node_index) - if node_key in self._network_nodes: - # The model owns this layer node. - inputs = node.input_tensors - updates += layer.get_updates_for(inputs) - # Collect unconditional updates. - updates += layer.get_updates_for(None) - return updates + updates += layer.updates + + # `updates` might contain irrelevant updates, so it needs to be filtered + # with respect to inputs the model has been called on. + relevant_inputs = self.inputs or [] + for i in range(1, len(self._inbound_nodes)): + inputs = self.get_input_at(i) + if isinstance(inputs, list): + relevant_inputs += inputs + else: + relevant_inputs.append(inputs) + reachable = layers_util.get_reachable_from_inputs(relevant_inputs, updates) + relevant_conditional_updates = [x for x in updates if x in reachable] + unconditional_updates = [ + x for x in updates if x._unconditional_update] # pylint: disable=protected-access + # A layer could be used multiple times in a nested structure, + # so the updates list must be de-duped. + return list(set( + relevant_conditional_updates + unconditional_updates + self._updates)) @property def losses(self): @@ -614,22 +687,24 @@ class GraphNetwork(base.Layer): A list of loss tensors. """ losses = [] - # Retrieve losses for all internal layers. for layer in self.layers: - if hasattr(layer, 'losses'): - # Collect losses that are dependent on inputs - # that are part of the model. - for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access - node_key = _make_node_key(layer.name, node_index) - if node_key in self._network_nodes: - # The model owns this layer node. - inputs = node.input_tensors - losses += layer.get_losses_for(inputs) - # Collect unconditional losses. - losses += layer.get_losses_for(None) - # Add any potential unconditional model-level loss. - losses += self.get_losses_for(None) - return losses + losses += layer.losses + if context.in_eager_mode(): + return losses + + relevant_inputs = self.inputs or [] + for i in range(1, len(self._inbound_nodes)): + inputs = self.get_input_at(i) + if isinstance(inputs, list): + relevant_inputs += inputs + else: + relevant_inputs.append(inputs) + reachable = layers_util.get_reachable_from_inputs(relevant_inputs, losses) + relevant_conditional_losses = [x for x in losses if x in reachable] + unconditional_losses = [ + x for x in losses if x._unconditional_loss] # pylint: disable=protected-access + return list(set( + relevant_conditional_losses + unconditional_losses + self._losses)) @property def trainable_weights(self): @@ -660,6 +735,10 @@ class GraphNetwork(base.Layer): A list of `InputSpec` instances (one per input to the model) or a single instance if the model has only one input. """ + # If not a graph network, can't assume anything. + if not self._is_graph_network: + return None + specs = [] for layer in self._input_layers: if layer.input_spec is None: @@ -710,6 +789,9 @@ class GraphNetwork(base.Layer): return outputs def compute_output_shape(self, input_shape): + if not self._is_graph_network: + raise NotImplementedError + if isinstance(input_shape, list): input_shapes = [] for shape in input_shape: @@ -791,7 +873,6 @@ class GraphNetwork(base.Layer): layer, node_index, tensor_index = self._output_coordinates[i] shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) output_shapes.append(layers_to_output_shapes[shape_key]) - # Store in cache. self._output_shape_cache[cache_key] = output_shapes else: @@ -846,7 +927,6 @@ class GraphNetwork(base.Layer): for node in nodes: # This is always a single layer, never a list. layer = node.outbound_layer - reference_input_tensors = node.input_tensors reference_output_tensors = node.output_tensors @@ -894,26 +974,13 @@ class GraphNetwork(base.Layer): else: output_masks = [None for _ in range(len(output_tensors))] - # Apply activity regularizer if any: - if layer.activity_regularizer is not None: - regularization_losses = [ - layer.activity_regularizer(x) for x in computed_tensors - ] - layer.add_loss(regularization_losses, computed_tensors) - - if context.in_graph_mode(): - # Update model updates and losses: - # Keep track of updates that depend on the inputs - # (e.g. BN updates). - self.add_update(layer.get_updates_for(computed_tensors), inputs) - # Keep track of unconditional updates (e.g. a counter). - self.add_update(layer.get_updates_for(None), None) - # Keep track of losses that depend on the inputs - # (e.g. activity regularizers). - self.add_loss(layer.get_losses_for(computed_tensors), inputs) - # Keep track of unconditional losses - # (e.g. weight regularizers). - self.add_loss(layer.get_losses_for(None), None) + if context.in_graph_mode(): + if layer.activity_regularizer is not None: + regularization_losses = [ + layer.activity_regularizer(x) for x in output_tensors + ] + # Apply activity regularizer if any: + layer.add_loss(regularization_losses, computed_tensors) # Update tensor_map. for x, y, mask in zip(reference_output_tensors, output_tensors, @@ -943,8 +1010,8 @@ class GraphNetwork(base.Layer): cache_key = (layers_util.object_list_uid(inputs) + '_' + layers_util.object_list_uid(masks)) self._output_tensor_cache[cache_key] = output_tensors - if output_masks is not None: - self._output_mask_cache[cache_key] = output_masks + self._output_mask_cache[cache_key] = output_masks + if output_shapes is not None: input_shapes = [layers_util.static_shape(x) for x in inputs] cache_key = layers_util.object_list_uid(input_shapes) diff --git a/tensorflow/python/layers/network_test.py b/tensorflow/python/layers/network_test.py index 7a2c7fb3fc782f6e59b8b483ec43c4abddf4d023..cc6e8ca9f41cd1f6aa0a3f64d7ce11ac24c04967 100644 --- a/tensorflow/python/layers/network_test.py +++ b/tensorflow/python/layers/network_test.py @@ -27,29 +27,137 @@ from tensorflow.python.layers import base as base_layers from tensorflow.python.layers import core as core_layers from tensorflow.python.layers import network as network_layers from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import test class BaseLayerCompatibilityTest(test.TestCase): - def test_get_updates_for(self): - a = network_layers.Input(shape=(2,)) - dense_layer = core_layers.Dense(1) - dense_layer.add_update(0, inputs=a) - dense_layer.add_update(1, inputs=None) + def test_get_updates(self): - self.assertEqual(dense_layer.get_updates_for(a), [0]) - self.assertEqual(dense_layer.get_updates_for(None), [1]) + class MyLayer(base_layers.Layer): - def test_get_losses_for(self): - a = network_layers.Input(shape=(2,)) - dense_layer = core_layers.Dense(1) - dense_layer.add_loss(0, inputs=a) - dense_layer.add_loss(1, inputs=None) + def build(self, input_shape): + self.a = self.add_variable('a', + (1, 1), + 'float32', + trainable=False) + self.b = self.add_variable('b', + (1, 1), + 'float32', + trainable=False) + self.add_update(state_ops.assign_add(self.a, [[1.]])) + self.built = True - self.assertEqual(dense_layer.get_losses_for(a), [0]) - self.assertEqual(dense_layer.get_losses_for(None), [1]) + def call(self, inputs): + self.add_update(state_ops.assign_add(self.a, inputs), + inputs=True) + return inputs + 1 + + x1 = network_layers.Input(shape=(1,)) + layer = MyLayer() + _ = layer.apply(x1) + + self.assertEqual(len(layer.updates), 2) + self.assertEqual(len(layer.get_updates_for(x1)), 1) + self.assertEqual(len(layer.get_updates_for(None)), 1) + + x2 = network_layers.Input(shape=(1,)) + y2 = layer.apply(x2) + + self.assertEqual(len(layer.updates), 3) + self.assertEqual(len(layer.get_updates_for(x1)), 1) + self.assertEqual(len(layer.get_updates_for(x2)), 1) + self.assertEqual(len(layer.get_updates_for(None)), 1) + + network = network_layers.GraphNetwork(x2, y2) + self.assertEqual(len(network.updates), 2) + self.assertEqual(len(network.get_updates_for(x1)), 0) + self.assertEqual(len(network.get_updates_for(x2)), 1) + self.assertEqual(len(network.get_updates_for(None)), 1) + + x3 = network_layers.Input(shape=(1,)) + _ = layer.apply(x3) + self.assertEqual(len(network.updates), 2) + + x4 = network_layers.Input(shape=(1,)) + _ = network(x4) + self.assertEqual(len(network.updates), 3) + self.assertEqual(len(network.get_updates_for(x2)), 1) + self.assertEqual(len(network.get_updates_for(x4)), 1) + self.assertEqual(len(network.get_updates_for(None)), 1) + + network.add_update(state_ops.assign_add(layer.a, [[1]])) + self.assertEqual(len(network.updates), 4) + self.assertEqual(len(network.get_updates_for(None)), 2) + + network.add_update(state_ops.assign_add(layer.a, x4), inputs=True) + self.assertEqual(len(network.updates), 5) + self.assertEqual(len(network.get_updates_for(x4)), 2) + + def test_get_losses(self): + + class MyLayer(base_layers.Layer): + + def build(self, input_shape): + self.a = self.add_variable('a', + (1, 1), + 'float32', + trainable=False) + self.b = self.add_variable('b', + (1, 1), + 'float32', + trainable=False) + self.add_loss(math_ops.reduce_sum(self.a)) + self.built = True + + def call(self, inputs): + self.add_loss(math_ops.reduce_sum(inputs), + inputs=True) + return inputs + 1 + + x1 = network_layers.Input(shape=(1,)) + layer = MyLayer() + _ = layer.apply(x1) + + self.assertEqual(len(layer.losses), 2) + self.assertEqual(len(layer.get_losses_for(x1)), 1) + self.assertEqual(len(layer.get_losses_for(None)), 1) + + x2 = network_layers.Input(shape=(1,)) + y2 = layer.apply(x2) + + self.assertEqual(len(layer.losses), 3) + self.assertEqual(len(layer.get_losses_for(x1)), 1) + self.assertEqual(len(layer.get_losses_for(x2)), 1) + self.assertEqual(len(layer.get_losses_for(None)), 1) + + network = network_layers.GraphNetwork(x2, y2) + self.assertEqual(len(network.losses), 2) + self.assertEqual(len(network.get_losses_for(x1)), 0) + self.assertEqual(len(network.get_losses_for(x2)), 1) + self.assertEqual(len(network.get_losses_for(None)), 1) + + x3 = network_layers.Input(shape=(1,)) + _ = layer.apply(x3) + self.assertEqual(len(network.losses), 2) + + x4 = network_layers.Input(shape=(1,)) + _ = network(x4) + self.assertEqual(len(network.losses), 3) + self.assertEqual(len(network.get_losses_for(x2)), 1) + self.assertEqual(len(network.get_losses_for(x4)), 1) + self.assertEqual(len(network.get_losses_for(None)), 1) + + network.add_loss(math_ops.reduce_sum(layer.a)) + self.assertEqual(len(network.losses), 4) + self.assertEqual(len(network.get_losses_for(None)), 2) + + network.add_loss(math_ops.reduce_sum(x4), inputs=True) + self.assertEqual(len(network.losses), 5) + self.assertEqual(len(network.get_losses_for(x4)), 2) def testTopologicalAttributes(self): # test layer attributes / methods related to cross-layer connectivity. @@ -299,9 +407,10 @@ class NetworkTest(test.TestCase): def testNetworkAttributes(self): x = network_layers.Input(shape=(32,)) - z = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))(x) + layer = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2)) + z = layer(x) dense = core_layers.Dense(2, name='dense') - dense.add_update(1) + dense.add_update(state_ops.assign_add(layer.kernel, layer.kernel * 2.)) y = dense(z) net = network_layers.GraphNetwork(x, y) @@ -421,7 +530,6 @@ class NetworkTest(test.TestCase): self.assertEqual(len(network.layers), 2) self.assertEqual(network.layers[0].sparse, True) - @test_util.run_in_graph_and_eager_modes() def testMaskingSingleInput(self): class MaskedLayer(base_layers.Layer): diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 890c12f6e00daabe7e64c00814fcb3ff8f04ae3a..d83292b80963d942023b5d086a089af53008efe0 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -39,8 +39,10 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import moving_averages +from tensorflow.python.util.tf_export import tf_export +@tf_export('layers.BatchNormalization') class BatchNormalization(base.Layer): """Batch Normalization layer from http://arxiv.org/abs/1502.03167. @@ -92,8 +94,8 @@ class BatchNormalization(base.Layer): and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. - fused: if `True`, use a faster, fused implementation if possible. - If `None`, use the system recommended implementation. + fused: if `None` or `True`, use a faster, fused implementation if possible. + If `False`, use the system recommended implementation. trainable: Boolean, if `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, @@ -491,6 +493,7 @@ class BatchNormalization(base.Layer): return (r, d, new_mean, new_variance) def call(self, inputs, training=False): + in_eager_mode = context.in_eager_mode() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation @@ -593,6 +596,9 @@ class BatchNormalization(base.Layer): axis=1, keep_dims=True) def _do_update(var, value): + if in_eager_mode and not self.trainable: + return + return moving_averages.assign_moving_average( var, value, self.momentum, zero_debias=False) @@ -629,6 +635,7 @@ class BatchNormalization(base.Layer): return input_shape +@tf_export('layers.batch_normalization') def batch_normalization(inputs, axis=-1, momentum=0.99, @@ -722,8 +729,8 @@ def batch_normalization(inputs, and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. - fused: if `True`, use a faster, fused implementation if possible. - If `None`, use the system recommended implementation. + fused: if `None` or `True`, use a faster, fused implementation if possible. + If `False`, use the system recommended implementation. virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, which means batch normalization is performed across the whole batch. When `virtual_batch_size` is not `None`, instead perform "Ghost Batch diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py index ab06a3a40826e7d41c040066fd41c56c1ed84ad2..50503ce093fbc251b11c4d5cbccb2a2683d92e7a 100644 --- a/tensorflow/python/layers/pooling.py +++ b/tensorflow/python/layers/pooling.py @@ -26,6 +26,7 @@ from tensorflow.python.layers import base from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn +from tensorflow.python.util.tf_export import tf_export class _Pooling1D(base.Layer): @@ -96,6 +97,7 @@ class _Pooling1D(base.Layer): return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]]) +@tf_export('layers.AveragePooling1D') class AveragePooling1D(_Pooling1D): """Average Pooling layer for 1D inputs. @@ -127,6 +129,7 @@ class AveragePooling1D(_Pooling1D): **kwargs) +@tf_export('layers.average_pooling1d') def average_pooling1d(inputs, pool_size, strides, padding='valid', data_format='channels_last', name=None): @@ -161,6 +164,7 @@ def average_pooling1d(inputs, pool_size, strides, return layer.apply(inputs) +@tf_export('layers.MaxPooling1D') class MaxPooling1D(_Pooling1D): """Max Pooling layer for 1D inputs. @@ -192,6 +196,7 @@ class MaxPooling1D(_Pooling1D): **kwargs) +@tf_export('layers.max_pooling1d') def max_pooling1d(inputs, pool_size, strides, padding='valid', data_format='channels_last', name=None): @@ -297,6 +302,7 @@ class _Pooling2D(base.Layer): [input_shape[0], rows, cols, input_shape[3]]) +@tf_export('layers.AveragePooling2D') class AveragePooling2D(_Pooling2D): """Average pooling layer for 2D inputs (e.g. images). @@ -328,6 +334,7 @@ class AveragePooling2D(_Pooling2D): padding=padding, data_format=data_format, name=name, **kwargs) +@tf_export('layers.average_pooling2d') def average_pooling2d(inputs, pool_size, strides, padding='valid', data_format='channels_last', @@ -365,6 +372,7 @@ def average_pooling2d(inputs, return layer.apply(inputs) +@tf_export('layers.MaxPooling2D') class MaxPooling2D(_Pooling2D): """Max pooling layer for 2D inputs (e.g. images). @@ -396,6 +404,7 @@ class MaxPooling2D(_Pooling2D): padding=padding, data_format=data_format, name=name, **kwargs) +@tf_export('layers.max_pooling2d') def max_pooling2d(inputs, pool_size, strides, padding='valid', data_format='channels_last', @@ -515,6 +524,7 @@ class _Pooling3D(base.Layer): [input_shape[0], len_dim1, len_dim2, len_dim3, input_shape[4]]) +@tf_export('layers.AveragePooling3D') class AveragePooling3D(_Pooling3D): """Average pooling layer for 3D inputs (e.g. volumes). @@ -548,6 +558,7 @@ class AveragePooling3D(_Pooling3D): padding=padding, data_format=data_format, name=name, **kwargs) +@tf_export('layers.average_pooling3d') def average_pooling3d(inputs, pool_size, strides, padding='valid', data_format='channels_last', @@ -587,6 +598,7 @@ def average_pooling3d(inputs, return layer.apply(inputs) +@tf_export('layers.MaxPooling3D') class MaxPooling3D(_Pooling3D): """Max pooling layer for 3D inputs (e.g. volumes). @@ -620,6 +632,7 @@ class MaxPooling3D(_Pooling3D): padding=padding, data_format=data_format, name=name, **kwargs) +@tf_export('layers.max_pooling3d') def max_pooling3d(inputs, pool_size, strides, padding='valid', data_format='channels_last', diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index e8be347799acf2e92e79ca76f44f25d573489940..79529e86c383ce68b05015010be01df3355df691 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -20,6 +20,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.ops import variables from tensorflow.python.ops import control_flow_ops from tensorflow.python.framework import ops @@ -81,7 +82,7 @@ def normalize_tuple(value, n, name): for single_value in value_tuple: try: int(single_value) - except ValueError: + except (ValueError, TypeError): raise ValueError('The `' + name + '` argument must be a tuple of ' + str(n) + ' integers. Received: ' + str(value) + ' ' 'including element ' + str(single_value) + ' of type' + @@ -178,67 +179,56 @@ def deconv_output_length(input_length, filter_size, padding, stride): return input_length -def smart_cond(pred, fn1, fn2, name=None): - """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`. +def smart_cond(pred, true_fn=None, false_fn=None, name=None): + """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. - If `pred` is a bool or has a constant value, we return either `fn1()` - or `fn2()`, otherwise we use `tf.cond` to dynamically route to both. + If `pred` is a bool or has a constant value, we return either `true_fn()` + or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. Arguments: - pred: A scalar determining whether to return the result of `fn1` or `fn2`. - fn1: The callable to be performed if pred is true. - fn2: The callable to be performed if pred is false. + pred: A scalar determining whether to return the result of `true_fn` or + `false_fn`. + true_fn: The callable to be performed if pred is true. + false_fn: The callable to be performed if pred is false. name: Optional name prefix when using `tf.cond`. Returns: - Tensors returned by the call to either `fn1` or `fn2`. + Tensors returned by the call to either `true_fn` or `false_fn`. Raises: - TypeError: If `fn1` or `fn2` is not callable. + TypeError: If `true_fn` or `false_fn` is not callable. """ - if not callable(fn1): - raise TypeError('`fn1` must be callable.') - if not callable(fn2): - raise TypeError('`fn2` must be callable.') - - pred_value = constant_value(pred) - if pred_value is not None: - if pred_value: - return fn1() - else: - return fn2() - else: - return control_flow_ops.cond(pred, true_fn=fn1, false_fn=fn2, name=name) + if isinstance(pred, variables.Variable): + return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn, + name=name) + return control_flow_ops.smart_cond(pred, true_fn=true_fn, + false_fn=false_fn, name=name) def constant_value(pred): """Return the bool value for `pred`, or None if `pred` had a dynamic value. - Arguments: - pred: A scalar, either a Python bool or a TensorFlow boolean variable - or tensor, or the Python integer 1 or 0. + Arguments: + pred: A scalar, either a Python bool or a TensorFlow boolean variable + or tensor, or the Python integer 1 or 0. - Returns: - True or False if `pred` has a constant boolean value, None otherwise. + Returns: + True or False if `pred` has a constant boolean value, None otherwise. - Raises: - TypeError: If `pred` is not a Variable, Tensor or bool. - """ + Raises: + TypeError: If `pred` is not a Variable, Tensor or bool, or Python + interger 1 or 0. + """ # Allow integer booleans. - if pred == 0: - pred = False - elif pred == 1: - pred = True - - if isinstance(pred, bool): - pred_value = pred - elif isinstance(pred, variables.Variable): - pred_value = None - elif isinstance(pred, ops.Tensor): - pred_value = tensor_util.constant_value(pred) - else: - raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.') - return pred_value + if isinstance(pred, int): + if pred == 1: + pred = True + elif pred == 0: + pred = False + + if isinstance(pred, variables.Variable): + return None + return control_flow_ops.smart_constant_value(pred) def object_list_uid(object_list): @@ -255,3 +245,45 @@ def static_shape(x): return tuple(x.get_shape().as_list()) except ValueError: return None + + +def get_reachable_from_inputs(inputs, targets=None): + """Returns the set of tensors reachable from `inputs`. + + Stops if all targets have been found (target is optional). + + Only valid in Symbolic mode, not Eager mode. + + Args: + inputs: List of tensors. + targets: List of tensors. + + Returns: + A set of tensors reachable from the inputs (includes the inputs themselves). + """ + reachable = set(inputs) + if targets: + targets = set(targets) + queue = inputs[:] + + while queue: + x = queue.pop() + outputs = [] + try: + consumers = x.consumers() + except AttributeError: + # Case where x is a variable type + consumers = [x.op] + for z in consumers: + consumer_outputs = z.outputs + if consumer_outputs: # May be None + outputs += consumer_outputs + + for y in outputs: + if y not in reachable: + reachable.add(y) + queue.insert(0, y) + + if targets and targets.issubset(reachable): + return reachable + return reachable diff --git a/tensorflow/python/layers/utils_test.py b/tensorflow/python/layers/utils_test.py index a560f6b6d21efc0c1070d5a9296a7a8e914e2eb9..c941aad7bc63dbb891fbe78cd2a47dd6805bf231 100644 --- a/tensorflow/python/layers/utils_test.py +++ b/tensorflow/python/layers/utils_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.layers import utils +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -87,5 +88,34 @@ class ConvUtilsTest(test.TestCase): self.assertEqual(3, utils.deconv_output_length(4, 2, 'full', 1)) self.assertEqual(6, utils.deconv_output_length(4, 2, 'full', 2)) + +class GraphUtilsTest(test.TestCase): + + def testGetReachableFromInputs(self): + + with self.test_session(): + pl_1 = array_ops.placeholder(shape=None, dtype='float32') + pl_2 = array_ops.placeholder(shape=None, dtype='float32') + pl_3 = array_ops.placeholder(shape=None, dtype='float32') + x_1 = pl_1 + pl_2 + x_2 = pl_2 * 2 + x_3 = pl_3 + 1 + x_4 = x_1 + x_2 + x_5 = x_3 * pl_1 + + self.assertEqual( + utils.get_reachable_from_inputs([pl_1]), + {pl_1, x_1, x_4, x_5}) + self.assertEqual( + utils.get_reachable_from_inputs([pl_1, pl_2]), + {pl_1, pl_2, x_1, x_2, x_4, x_5}) + self.assertEqual( + utils.get_reachable_from_inputs([pl_3]), + {pl_3, x_3, x_5}) + self.assertEqual( + utils.get_reachable_from_inputs([x_3]), + {x_3, x_5}) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index d3bfa0ee337d1f606e5e994406969685a2986ab4..e0422ef80add42307268be2743e668eb8c8acb68 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -19,6 +19,7 @@ limitations under the License. #include "numpy/arrayobject.h" #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/op_kernel.h" @@ -53,6 +54,12 @@ struct PyCall { // with this "token". string token; + // The device on which Tensors are stored; only used for EagerPyFunc. + Device* device; + + // True if and only if the op has been placed on a GPU. + bool gpu; + // True if the call is associated with an EagerPyFunc. bool eager; @@ -71,7 +78,12 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) { PyObject* arg = nullptr; const Tensor& t = call->ins[i]; if (call->eager) { - arg = EagerTensorFromHandle(TFE_NewTensorHandle(t)); + if (call->gpu) { + arg = EagerTensorFromHandle(new TFE_TensorHandle(t, call->device)); + } else { + // TFE_TensorHandle assumes that CPU is identified by `nullptr`. + arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr)); + } if (arg == nullptr) { return errors::Internal("Unable to procure EagerTensor from Tensor."); } @@ -84,7 +96,8 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) { } PyList_SetItem(lst, i, arg); } - *tuple = Py_BuildValue("(sN)", call->token.c_str(), lst); + *tuple = Py_BuildValue("(sON)", call->token.c_str(), + call->gpu ? Py_True : Py_False, lst); CHECK(*tuple); return Status::OK(); } @@ -150,15 +163,9 @@ bool IsSingleNone(PyObject* obj) { } // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`. -Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, - Tensor* output_tensor, - TF_Status* tf_status) { - // TODO(akshayka): Lift the restriction requiring output tensors to - // lie in host memory; EagerPyFunc should be able to dispatch ops on GPU - // tensors, so we should eventually implement a GPU kernel for EagerPyFunc. - *output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory( - EagerTensor_Handle(eager_tensor), tf_status); - return StatusFromTF_Status(tf_status); +void ExtractTensorFromEagerTensor(const PyObject* eager_tensor, + Tensor* output_tensor) { + *output_tensor = EagerTensor_Handle(eager_tensor)->t; } // Calls the registered py function through the trampoline. @@ -201,15 +208,23 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } // Process the return values and convert them to TF Tensors. - Status s; + Status s = Status::OK(); if (PyList_Check(result)) { + // `result` is a Python list; if this operation is an `EagerPyFunc`, then + // every item in the list must be an `EagerTensor`; otherwise, every element + // must be a NumPy array. call->out.clear(); for (int i = 0; i < PyList_Size(result); ++i) { Tensor t; if (call->eager) { - auto tf_status = tensorflow::make_safe(TF_NewStatus()); - s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t, - tf_status.get()); + const PyObject* item = PyList_GetItem(result, i); + if (EagerTensor_CheckExact(item)) { + ExtractTensorFromEagerTensor(item, &t); + } else { + s = errors::FailedPrecondition( + "Expected EagerTensor, found PyObject of type: ", + Py_TYPE(item)->tp_name); + } } else { s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t); } @@ -220,16 +235,15 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { call->out.push_back(t); } } else if (EagerTensor_CheckExact(result) || result == Py_None) { + // result is an `EagerTensor` or `None`. DCHECK(call->eager); Tensor t; if (result != Py_None) { - auto tf_status = tensorflow::make_safe(TF_NewStatus()); - s = ExtractTensorFromEagerTensor(result, &t, tf_status.get()); - if (s.ok()) { - call->out.push_back(t); - } + ExtractTensorFromEagerTensor(result, &t); + call->out.push_back(t); } } else if (PyArray_Check(result)) { + // `result` is a NumPy array. DCHECK(!call->eager); if (!IsSingleNone(result)) { Tensor t; @@ -239,7 +253,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } } } else { - s = errors::Internal("Unexpected pyobject is returned: ", + s = errors::Internal("Unexpected PyObject was returned: ", Py_TYPE(result)->tp_name); } Py_DECREF(result); @@ -429,12 +443,24 @@ class PyFuncOp : public OpKernel { explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_)); eager_ = type_string() == "EagerPyFunc"; + gpu_ = ctx->device_type().type_string() == DEVICE_GPU; } void Compute(OpKernelContext* ctx) override { PyCall call; call.token = token_; + call.gpu = gpu_; call.eager = eager_; + if (call.eager) { + // Eager's C API uses `Device`, whereas `OpKernelContext` stores a + // `DeviceBase`; attempt to downcast. + call.device = dynamic_cast(ctx->device()); + if (call.device == nullptr) { + ctx->CtxFailureWithWarning( + errors::Internal("Unrecognized device class")); + } + } + for (int i = 0; i < ctx->num_inputs(); ++i) { call.ins.push_back(ctx->input(i)); } @@ -476,6 +502,9 @@ class PyFuncOp : public OpKernel { private: string token_; + // True if and only if this op has been placed on a GPU. + bool gpu_; + // True if and only if this op should execute the python function eagerly, // i.e., if and only if the eager attribute is set. bool eager_; @@ -486,5 +515,6 @@ class PyFuncOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp); REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp); REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp); +REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_GPU), PyFuncOp); } // end namespace tensorflow diff --git a/tensorflow/python/lib/io/file_io.i b/tensorflow/python/lib/io/file_io.i index c0c4e035fc3d6a50334acb9228c13c702ef426c0..891a7b0fd0dc177f5ee439707c9e2c99148e177c 100644 --- a/tensorflow/python/lib/io/file_io.i +++ b/tensorflow/python/lib/io/file_io.i @@ -110,21 +110,15 @@ void RecursivelyCreateDir(const string& dirname, TF_Status* out_status) { } } -void CopyFile(const string& oldpath, const string& newpath, bool overwrite, +void CopyFile(const string& src, const string& target, bool overwrite, TF_Status* out_status) { - // If overwrite is false and the newpath file exists then it's an error. - if (!overwrite && tensorflow::Env::Default()->FileExists(newpath).ok()) { + // If overwrite is false and the target file exists then its an error. + if (!overwrite && tensorflow::Env::Default()->FileExists(target).ok()) { TF_SetStatus(out_status, TF_ALREADY_EXISTS, "file already exists"); return; } - string file_content; - tensorflow::Status status = ReadFileToString(tensorflow::Env::Default(), - oldpath, &file_content); - if (!status.ok()) { - Set_TF_Status_from_Status(out_status, status); - return; - } - status = WriteStringToFile(tensorflow::Env::Default(), newpath, file_content); + tensorflow::Status status = + tensorflow::Env::Default()->CopyFile(src, target); if (!status.ok()) { Set_TF_Status_from_Status(out_status, status); } diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py index 4e3071d8513a28b02b70b290c4987bec92b3c32e..59f5075f177ef5335115cb4f24182d28a9b547c8 100644 --- a/tensorflow/python/lib/io/file_io.py +++ b/tensorflow/python/lib/io/file_io.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import c_api_util from tensorflow.python.framework import errors from tensorflow.python.util import compat from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export class FileIO(object): @@ -235,6 +236,7 @@ class FileIO(object): self._writable_file = None +@tf_export("gfile.Exists") def file_exists(filename): """Determines whether a path exists or not. @@ -256,6 +258,7 @@ def file_exists(filename): return True +@tf_export("gfile.Remove") def delete_file(filename): """Deletes the file located at 'filename'. @@ -306,6 +309,7 @@ def write_string_to_file(filename, file_content): f.write(file_content) +@tf_export("gfile.Glob") def get_matching_files(filename): """Returns a list of files that match the given pattern(s). @@ -336,6 +340,7 @@ def get_matching_files(filename): ] +@tf_export("gfile.MkDir") def create_dir(dirname): """Creates a directory with the name 'dirname'. @@ -353,6 +358,7 @@ def create_dir(dirname): pywrap_tensorflow.CreateDir(compat.as_bytes(dirname), status) +@tf_export("gfile.MakeDirs") def recursive_create_dir(dirname): """Creates a directory and all parent/intermediate directories. @@ -368,6 +374,7 @@ def recursive_create_dir(dirname): pywrap_tensorflow.RecursivelyCreateDir(compat.as_bytes(dirname), status) +@tf_export("gfile.Copy") def copy(oldpath, newpath, overwrite=False): """Copies data from oldpath to newpath. @@ -385,6 +392,7 @@ def copy(oldpath, newpath, overwrite=False): compat.as_bytes(oldpath), compat.as_bytes(newpath), overwrite, status) +@tf_export("gfile.Rename") def rename(oldname, newname, overwrite=False): """Rename or move a file / directory. @@ -426,6 +434,7 @@ def atomic_write_string_to_file(filename, contents, overwrite=True): raise +@tf_export("gfile.DeleteRecursively") def delete_recursively(dirname): """Deletes everything under dirname recursively. @@ -439,6 +448,7 @@ def delete_recursively(dirname): pywrap_tensorflow.DeleteRecursively(compat.as_bytes(dirname), status) +@tf_export("gfile.IsDirectory") def is_directory(dirname): """Returns whether the path is a directory or not. @@ -452,6 +462,7 @@ def is_directory(dirname): return pywrap_tensorflow.IsDirectory(compat.as_bytes(dirname), status) +@tf_export("gfile.ListDirectory") def list_directory(dirname): """Returns a list of entries contained within a directory. @@ -479,6 +490,7 @@ def list_directory(dirname): ] +@tf_export("gfile.Walk") def walk(top, in_order=True): """Recursive directory tree generator for directories. @@ -522,6 +534,7 @@ def walk(top, in_order=True): yield here +@tf_export("gfile.Stat") def stat(filename): """Returns file statistics for a given path. diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index df190100689bd864de78f5a2cf52b1ade081a789..48ea107a146c2714f7b59f53abbcd8b60dbf2fd4 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -22,8 +22,10 @@ from __future__ import print_function from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import errors from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export +@tf_export("python_io.TFRecordCompressionType") class TFRecordCompressionType(object): """The type of compression for the record.""" NONE = 0 @@ -33,6 +35,7 @@ class TFRecordCompressionType(object): # NOTE(vrv): This will eventually be converted into a proto. to match # the interface used by the C++ RecordWriter. +@tf_export("python_io.TFRecordOptions") class TFRecordOptions(object): """Options used for manipulating TFRecord files.""" compression_type_map = { @@ -51,6 +54,7 @@ class TFRecordOptions(object): return cls.compression_type_map[options.compression_type] +@tf_export("python_io.tf_record_iterator") def tf_record_iterator(path, options=None): """An iterator that read the records from a TFRecords file. @@ -81,6 +85,7 @@ def tf_record_iterator(path, options=None): reader.Close() +@tf_export("python_io.TFRecordWriter") class TFRecordWriter(object): """A class to write records to a TFRecords file. diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 55cae0bcbfca8a9cacfe525fe3b69c7fb232acd3..9745d38dc23dba806a2d0dd2ef588a5a950aa05c 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Gradients for operators defined in array_ops.py.""" from __future__ import absolute_import @@ -28,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops @@ -116,6 +116,19 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): non_neg_concat_dim) out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) else: + if constant_op.is_constant(concat_dim): + # If concat_dim is a constant defined in a different context, + # then we duplicate it in the current context to avoid passing it + # through an Enter node. + # This is a small optimization in general, but it is required when + # compiling with XLA, as XLA needs the concat input to be folded into a + # constant. + grad_context = control_flow_util.GetOutputContext(grad.op) + dim_context = control_flow_util.GetOutputContext(concat_dim.op) + if dim_context != grad_context: + value = tensor_util.constant_value(concat_dim) + concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype) + # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) @@ -131,8 +144,8 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): # extract the size of each input along the concat dimension sizes = array_ops.squeeze( array_ops.slice( - array_ops.stack( - sizes, axis=1), [non_neg_concat_dim, 0], [1, -1])) + array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0], + [1, -1])) out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) else: offset = gen_array_ops._concat_offset(non_neg_concat_dim, sizes) @@ -167,8 +180,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): new_values = array_ops.slice( grad.values, begin, array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0)) - out_grads.append( - ops.IndexedSlices(new_values, grad.indices, size)) + out_grads.append(ops.IndexedSlices(new_values, grad.indices, size)) # Lint complains begin = begin + ... begin = math_ops.add(begin, size * mask) else: @@ -178,30 +190,33 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): for size in sizes: size_concat_dim = array_ops.gather(size, non_neg_concat_dim) if size_concat_dim.dtype != grad.indices.dtype: - size_concat_dim = math_ops.cast(size_concat_dim, - dtype=grad.indices.dtype) + size_concat_dim = math_ops.cast( + size_concat_dim, dtype=grad.indices.dtype) end = start + size_concat_dim # Compute the 1-D Tensor of indices relevant for this input. indices_to_select = array_ops.squeeze( - array_ops.where(math_ops.logical_and(grad.indices >= start, - grad.indices < end)), + array_ops.where( + math_ops.logical_and(grad.indices >= start, + grad.indices < end)), squeeze_dims=[1]) new_indices = array_ops.gather(grad.indices, indices_to_select) - start new_values = array_ops.gather(grad.values, indices_to_select) - out_grads.append( - ops.IndexedSlices(new_values, new_indices, size)) + out_grads.append(ops.IndexedSlices(new_values, new_indices, size)) start = end else: raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad)) - return (out_grads + [None] if end_value_index <= dim_index - else [None] + out_grads) + return (out_grads + [None] + if end_value_index <= dim_index else [None] + out_grads) @ops.RegisterGradient("Concat") def _ConcatGrad(op, grad): return _ConcatGradHelper( - op, grad, start_value_index=1, end_value_index=len(op.inputs), + op, + grad, + start_value_index=1, + end_value_index=len(op.inputs), dim_index=0) @@ -287,9 +302,13 @@ def _SplitGrad(op, *grads): @ops.RegisterGradient("SplitV") def _SplitVGrad(op, *grads): returnval = array_ops.concat(list(grads), op.inputs[2]) - returnval = [returnval] + [None,] * (len(op.inputs) - 1) + returnval = [returnval] + [ + None, + ] * ( + len(op.inputs) - 1) return returnval + ops.NotDifferentiable("Const") @@ -334,9 +353,9 @@ def _MatrixSetDiagGrad(op, grad): matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2]) min_dim = math_ops.reduce_min(matrix_shape) diag_shape = array_ops.concat([batch_shape, [min_dim]], 0) - grad_input = array_ops.matrix_set_diag( - grad, array_ops.zeros( - diag_shape, dtype=grad.dtype)) + grad_input = array_ops.matrix_set_diag(grad, + array_ops.zeros( + diag_shape, dtype=grad.dtype)) grad_diag = array_ops.matrix_diag_part(grad) return (grad_input, grad_diag) @@ -444,8 +463,8 @@ def _GatherV2Grad(op, grad): values_transpose = array_ops.transpose(values, transpose_dims) num_segments = params_shape[axis] - params_grad = math_ops.unsorted_segment_sum( - values_transpose, indices, num_segments) + params_grad = math_ops.unsorted_segment_sum(values_transpose, indices, + num_segments) # Inverts the above transpose by moving dimension 0 back to its original # position. @@ -536,13 +555,10 @@ def _ConjugateTransposeGrad(op, grad): ops.NotDifferentiable("Shape") - ops.NotDifferentiable("ShapeN") - ops.NotDifferentiable("Rank") - ops.NotDifferentiable("Size") @@ -590,6 +606,7 @@ def _PadGrad(op, grad): else: return x_grad, None + ops.RegisterGradient("Pad")(_PadGrad) ops.RegisterGradient("PadV2")(_PadGrad) @@ -625,30 +642,34 @@ def _ReverseV2Grad(op, grad): def _SpaceToBatchGrad(op, grad): # Its gradient is the opposite op: BatchToSpace. block_size = op.get_attr("block_size") - return [array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), - None] + return [ + array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None + ] @ops.RegisterGradient("SpaceToBatchND") def _SpaceToBatchNDGrad(op, grad): # Its gradient is the opposite op: BatchToSpaceND. - return [array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), - None, None] + return [ + array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None + ] @ops.RegisterGradient("BatchToSpace") def _BatchToSpaceGrad(op, grad): # Its gradient is the opposite op: SpaceToBatch. block_size = op.get_attr("block_size") - return [array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), - None] + return [ + array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None + ] @ops.RegisterGradient("BatchToSpaceND") def _BatchToSpaceNDGrad(op, grad): # Its gradient is the opposite op: SpaceToBatchND. - return [array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), - None, None] + return [ + array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None + ] @ops.RegisterGradient("SpaceToDepth") @@ -712,30 +733,28 @@ def _QuantizeAndDequantizeV3Grad(_, grad): def _ExtractImagePatchesGrad(op, grad): batch_size, rows_in, cols_in, channels = [ - dim.value for dim in op.inputs[0].get_shape() + dim.value for dim in op.inputs[0].get_shape() ] input_bhwc = array_ops.shape(op.inputs[0]) batch_size = input_bhwc[0] channels = input_bhwc[3] - _, rows_out, cols_out, _ = [ - dim.value for dim in op.outputs[0].get_shape() - ] - _, ksize_r, ksize_c, _ = op.get_attr('ksizes') - _, stride_r, stride_h, _ = op.get_attr('strides') - _, rate_r, rate_c, _ = op.get_attr('rates') - padding = op.get_attr('padding') + _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].get_shape()] + _, ksize_r, ksize_c, _ = op.get_attr("ksizes") + _, stride_r, stride_h, _ = op.get_attr("strides") + _, rate_r, rate_c, _ = op.get_attr("rates") + padding = op.get_attr("padding") ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1) ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1) - if padding == b'SAME': + if padding == b"SAME": rows_out = int(ceil(rows_in / stride_r)) cols_out = int(ceil(cols_in / stride_h)) pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2 pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2 - elif padding == b'VALID': + elif padding == b"VALID": rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r)) cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h)) pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in @@ -744,10 +763,9 @@ def _ExtractImagePatchesGrad(op, grad): pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols) grad_expanded = array_ops.transpose( - array_ops.reshape(grad, (batch_size, rows_out, - cols_out, ksize_r, ksize_c, channels)), - (1, 2, 3, 4, 0, 5) - ) + array_ops.reshape( + grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), + (1, 2, 3, 4, 0, 5)) grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) row_steps = range(0, rows_out * stride_r, stride_r) @@ -759,29 +777,21 @@ def _ExtractImagePatchesGrad(op, grad): r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff - idx.extend([(r * (cols_in) + c, - i * (cols_out * ksize_r * ksize_c) + - j * (ksize_r * ksize_c) + - ri * (ksize_c) + ci) + idx.extend([(r * (cols_in) + c, i * (cols_out * ksize_r * ksize_c) + j * + (ksize_r * ksize_c) + ri * (ksize_c) + ci) for (ri, r) in enumerate(range(r_low, r_high, rate_r)) for (ci, c) in enumerate(range(c_low, c_high, rate_c)) - if 0 <= r and r < rows_in and 0 <= c and c < cols_in - ]) + if 0 <= r and r < rows_in and 0 <= c and c < cols_in]) - sp_shape = (rows_in * cols_in, - rows_out * cols_out * ksize_r * ksize_c) + sp_shape = (rows_in * cols_in, rows_out * cols_out * ksize_r * ksize_c) sp_mat = sparse_tensor.SparseTensor( - array_ops.constant(idx, dtype=ops.dtypes.int64), - array_ops.ones((len(idx),), dtype=ops.dtypes.float32), - sp_shape - ) + array_ops.constant(idx, dtype=ops.dtypes.int64), + array_ops.ones((len(idx),), dtype=ops.dtypes.float32), sp_shape) jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) - grad_out = array_ops.reshape( - jac, (rows_in, cols_in, batch_size, channels) - ) + grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels)) grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) return [grad_out] diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 24a0c186198c7389af9add64ec6466b1f3d2afbd..d63a9ea0ddef619c8a9bd77ac3d66e03b5ad7ec3 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# Tests for this file live in python/kernel_tests/array_ops_test.py """Support for manipulating tensors. See the @{$python/array_ops} guide. @@ -34,6 +35,7 @@ See the @{$python/array_ops} guide. @@reshape @@squeeze @@expand_dims +@@unravel_index @@meshgrid @@slice @@strided_slice @@ -603,7 +605,7 @@ def slice(input_, begin, size, name=None): Note that @{tf.Tensor.__getitem__} is typically a more pythonic way to perform slices, as it allows you to write `foo[3:7, :-2]` instead of - `tf.slice([3, 0], [4, foo.get_shape()[1]-2])`. + `tf.slice(foo, [3, 0], [4, foo.get_shape()[1]-2])`. `begin` is zero-based; `size` is one-based. If `size[i]` is -1, all remaining elements in dimension i are included in the @@ -2450,8 +2452,8 @@ def _all_dimensions(x): r = x.dense_shape.get_shape()[0].value # sparse.dense_shape is 1-D. return constant_op.constant(np.arange(r), dtype=dtypes.int32) - # Otherwise, we rely on Range and Rank to do the right thing at run-time. - return range(0, rank(x)) + # Otherwise, we rely on `range` and `rank` to do the right thing at runtime. + return gen_math_ops._range(0, rank(x), 1) @tf_export("sequence_mask") @@ -2496,7 +2498,7 @@ def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None): maxlen = gen_math_ops._max(lengths, _all_dimensions(lengths)) else: maxlen = ops.convert_to_tensor(maxlen) - if maxlen.get_shape().ndims != 0: + if maxlen.get_shape().ndims is not None and maxlen.get_shape().ndims != 0: raise ValueError("maxlen must be scalar for sequence_mask") # The basic idea is to compare a range row vector of size maxlen: diff --git a/tensorflow/python/ops/bitwise_ops_test.py b/tensorflow/python/ops/bitwise_ops_test.py index f9b025b787e4f49e1dcde6c589f66c59d779fcef..c4cfc0da197edcfd143cfee79fd3c3f9b7a2858b 100644 --- a/tensorflow/python/ops/bitwise_ops_test.py +++ b/tensorflow/python/ops/bitwise_ops_test.py @@ -71,7 +71,7 @@ class BitwiseOpTest(test_util.TensorFlowTestCase): def testInvertOp(self): dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, - dtypes.uint8, dtypes.uint16] + dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64] inputs = [0, 5, 3, 14] with self.test_session(use_gpu=True) as sess: for dtype in dtype_list: diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py index 20445c78a290a4fe67cad668dd714dd2c61c5f3d..220ef1754d2e1a2d54a8962148b47806df48e98f 100644 --- a/tensorflow/python/ops/candidate_sampling_ops.py +++ b/tensorflow/python/ops/candidate_sampling_ops.py @@ -20,9 +20,9 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops +from tensorflow.python.ops import array_ops # pylint: disable=unused-import from tensorflow.python.ops import gen_candidate_sampling_ops -from tensorflow.python.ops import math_ops +from tensorflow.python.ops import math_ops # pylint: disable=unused-import from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 0fd6e29a49c8e4e31e244bfbbfca525d72e4d811..64567ac54ae43acf6f8b674c46525db7a6c4fab7 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -334,9 +334,9 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): @compatibility{eager} returns None Raises: - InvalidArgumentError if the check can be performed immediately and - `x == y` is False. The check can be performed immediately during - eager execution or if `x` and `y` are statically known. + InvalidArgumentError: if the check can be performed immediately and + `x == y` is False. The check can be performed immediately during eager + execution or if `x` and `y` are statically known. """ message = message or '' with ops.name_scope(name, 'assert_equal', [x, y, data]): diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index dd8c33247c2436413ee8c9a3ceeca4d8a493bb4e..49f8c665313562cb20dbe4494103ded16646c741 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -110,7 +110,7 @@ def clip_by_norm(t, clip_norm, axes=None, name=None): t = ops.convert_to_tensor(t, name="t") # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm - l2norm = math_ops.sqrt(math_ops.reduce_sum(t * t, axes, keep_dims=True)) + l2norm = math_ops.sqrt(math_ops.reduce_sum(t * t, axes, keepdims=True)) intermediate = t * clip_norm # Assert that the shape is compatible with the initial shape, # to prevent unintentional broadcasting. diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py index 50690cd891f73df1e345817b834ce6c361bff9e8..b9a93c3bedfff1f398e3b42cedf02a2f0a3ddd5c 100644 --- a/tensorflow/python/ops/confusion_matrix.py +++ b/tensorflow/python/ops/confusion_matrix.py @@ -99,19 +99,16 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, name=None, weights=None): """Computes the confusion matrix from predictions and labels. - Calculate the Confusion Matrix for a pair of prediction and - label 1-D int arrays. - The matrix columns represent the prediction labels and the rows represent the real labels. The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid labels for a given classification task. Both prediction and labels must be 1-D arrays of the same shape in order for this function to work. - If `num_classes` is None, then `num_classes` will be set to the one plus - the maximum value in either predictions or labels. - Class labels are expected to start at 0. E.g., if `num_classes` was - three, then the possible labels would be `[0, 1, 2]`. + If `num_classes` is `None`, then `num_classes` will be set to one plus the + maximum value in either predictions or labels. Class labels are expected to + start at 0. For example, if `num_classes` is 3, then the possible labels + would be `[0, 1, 2]`. If `weights` is not `None`, then each prediction contributes its corresponding weight to the total value of the confusion matrix cell. @@ -119,7 +116,7 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, For example: ```python - tf.contrib.metrics.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> + tf.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> [[0 0 0 0 0] [0 0 1 0 0] [0 0 1 0 0] @@ -141,8 +138,9 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, weights: An optional `Tensor` whose shape matches `predictions`. Returns: - A k X k matrix representing the confusion matrix, where k is the number of - possible labels in the classification task. + A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion + matrix, where `n` is the number of possible labels in the classification + task. Raises: ValueError: If both predictions and labels are not 1-D vectors and have @@ -188,7 +186,7 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, weights = math_ops.cast(weights, dtype) shape = array_ops.stack([num_classes, num_classes]) - indices = array_ops.transpose(array_ops.stack([labels, predictions])) + indices = array_ops.stack([labels, predictions], axis=1) values = (array_ops.ones_like(predictions, dtype) if weights is None else weights) cm_sparse = sparse_tensor.SparseTensor( diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index d379eccc20dcd63255ee8c2dbe3fbd3e6a9077af..4f4a2e35db8de61b5f587f92a833b3094a530ba7 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Control Flow Operations. See the @{$python/control_flow_ops} guide. @@ -24,6 +23,7 @@ See the @{$python/control_flow_ops} guide. @@no_op @@count_up_to @@cond +@@smart_cond @@case @@while_loop @@logical_and @@ -51,12 +51,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import functools import six -from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import control_flow_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -79,12 +80,12 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.gen_control_flow_ops import * # pylint: enable=wildcard-import from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use from tensorflow.python.util.tf_export import tf_export - # We override the 'tuple' for a control flow op, so we keep python's # existing 'tuple' for later use in this module. _basetuple = tuple @@ -156,9 +157,10 @@ def Assert(condition, data, summarize=None, name=None): xs = ops.convert_n_to_tensor(data) data_str = [_summarize_eager(x, summarize) for x in xs] raise errors.InvalidArgumentError( - node_def=None, op=None, - message="Expected '%s' to be true. Summarized data: %s" % ( - condition, "\n".join(data_str))) + node_def=None, + op=None, + message="Expected '%s' to be true. Summarized data: %s" % + (condition, "\n".join(data_str))) return with ops.name_scope(name, "Assert", [condition, data]) as name: @@ -167,17 +169,15 @@ def Assert(condition, data, summarize=None, name=None): # As a simple heuristic, we assume that string and int32 are # on host to avoid the need to use cond. If it is not case, # we will pay the price copying the tensor to host memory. - return gen_logging_ops._assert( - condition, data, summarize, name="Assert") + return gen_logging_ops._assert(condition, data, summarize, name="Assert") else: condition = ops.convert_to_tensor(condition, name="Condition") + def true_assert(): return gen_logging_ops._assert( condition, data, summarize, name="Assert") - guarded_assert = cond( - condition, no_op, true_assert, name="AssertGuard") - if context.in_eager_mode(): - return + + guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard") return guarded_assert.op @@ -215,7 +215,7 @@ def _Identity(data, name=None): def _NextIteration(data, name=None): data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) if isinstance(data, ops.Tensor): - if data.dtype._is_ref_dtype: # pylint: disable=protected-access + if data.dtype._is_ref_dtype: # pylint: disable=protected-access return ref_next_iteration(data, name=name) else: return next_iteration(data, name=name) @@ -234,8 +234,13 @@ def _NextIteration(data, name=None): return sparse_tensor.SparseTensor(indices, values, dense_shape) -def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, - use_ref=True, use_input_shape=True, name=None): +def _Enter(data, + frame_name, + is_constant=False, + parallel_iterations=10, + use_ref=True, + use_input_shape=True, + name=None): """Creates or finds a child frame, and makes `data` available to it. The unique `frame_name` is used by the `Executor` to identify frames. If @@ -257,41 +262,57 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) if isinstance(data, ops.Tensor): if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access - result = ref_enter(data, frame_name, is_constant, parallel_iterations, - name=name) + result = gen_control_flow_ops._ref_enter( + data, frame_name, is_constant, parallel_iterations, name=name) else: - result = enter(data, frame_name, is_constant, parallel_iterations, - name=name) + result = gen_control_flow_ops._enter( + data, frame_name, is_constant, parallel_iterations, name=name) if use_input_shape: result.set_shape(data.get_shape()) return result else: if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)): raise TypeError("Type %s not supported" % type(data)) - values = _Enter(data.values, frame_name, is_constant, - parallel_iterations=parallel_iterations, - use_input_shape=use_input_shape, name=name) - indices = enter(data.indices, frame_name, is_constant, - parallel_iterations, name="indices") + values = _Enter( + data.values, + frame_name, + is_constant, + parallel_iterations=parallel_iterations, + use_input_shape=use_input_shape, + name=name) + indices = gen_control_flow_ops._enter( + data.indices, + frame_name, + is_constant, + parallel_iterations, + name="indices") if use_input_shape: indices.set_shape(data.indices.get_shape()) if isinstance(data, ops.IndexedSlices): dense_shape = data.dense_shape if dense_shape is not None: - dense_shape = enter(dense_shape, frame_name, is_constant, - parallel_iterations, name="dense_shape") + dense_shape = gen_control_flow_ops._enter( + dense_shape, + frame_name, + is_constant, + parallel_iterations, + name="dense_shape") if use_input_shape: dense_shape.set_shape(data.dense_shape.get_shape()) return ops.IndexedSlices(values, indices, dense_shape) else: - dense_shape = enter(data.dense_shape, frame_name, is_constant, - parallel_iterations, name="dense_shape") + dense_shape = gen_control_flow_ops._enter( + data.dense_shape, + frame_name, + is_constant, + parallel_iterations, + name="dense_shape") if use_input_shape: dense_shape.set_shape(data.dense_shape.get_shape()) return sparse_tensor.SparseTensor(indices, values, dense_shape) -def exit(data, name=None): +def exit(data, name=None): # pylint: disable=redefined-builtin """Exits the current frame to its parent frame. Exit makes its input `data` available to the parent frame. @@ -444,8 +465,10 @@ def merge(inputs, name=None): if any([inp is None for inp in inputs]): raise ValueError("At least one of the merge inputs is None: %s" % inputs) with ops.name_scope(name, "Merge", inputs) as name: - inputs = [ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True) - for inp in inputs] + inputs = [ + ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True) + for inp in inputs + ] if all([isinstance(v, ops.Tensor) for v in inputs]): if all([v.dtype._is_ref_dtype for v in inputs]): # pylint: disable=protected-access return gen_control_flow_ops._ref_merge(inputs, name) @@ -475,6 +498,8 @@ def merge(inputs, name=None): else: dense_shape = None return ops.IndexedSlices(values, indices, dense_shape), chosen_index + + # pylint: enable=protected-access @@ -488,7 +513,9 @@ def _convert_tensorarray_to_flow(tensor_or_tensor_array): def _make_tensor_array(ta, t_or_flow): # pylint: disable=protected-access new_ta = tensor_array_ops.TensorArray( - dtype=ta.dtype, handle=ta.handle, flow=t_or_flow, + dtype=ta.dtype, + handle=ta.handle, + flow=t_or_flow, infer_shape=ta._infer_shape, colocate_with_first_write_call=ta._colocate_with_first_write_call) new_ta._colocate_with = ta._colocate_with @@ -500,13 +527,13 @@ def _make_tensor_array(ta, t_or_flow): def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows): if len(tensors_or_tensorarrays) != len(tensors_or_flows): raise ValueError( - "Lengths of original Tensor list and new list do not match: %d vs. %d" - % (len(tensors_or_tensorarrays), len(tensors_or_flows))) + "Lengths of original Tensor list and new list do not match: %d vs. %d" % + (len(tensors_or_tensorarrays), len(tensors_or_flows))) return [ _make_tensor_array(ta, t_or_flow) - if isinstance(ta, tensor_array_ops.TensorArray) - else t_or_flow - for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)] + if isinstance(ta, tensor_array_ops.TensorArray) else t_or_flow + for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows) + ] def _ShapeLessThanOrEqual(shape1, shape2): @@ -545,8 +572,8 @@ def _SetShapeInvariants(input_vars, enter_vars, shapes): raise ValueError( "The shape invariant specified for %s is not compatible with " "the initial shape of the loop variable. It enters the loop " - "with shape %s, but the specified shape invariant is %s." - % (inp.name, inp.get_shape(), shape)) + "with shape %s, but the specified shape invariant is %s." % + (inp.name, inp.get_shape(), shape)) var.set_shape(shape) else: if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)): @@ -557,8 +584,8 @@ def _SetShapeInvariants(input_vars, enter_vars, shapes): "The shape invariant specified for %s is not compatible with " "the initial shape of the values tensor of this IndexedSlices. " "It enters the loop with shape %s, but the specified shape " - "invariant is %s." - % (inp.values.name, inp.values.get_shape(), shape)) + "invariant is %s." % (inp.values.name, inp.values.get_shape(), + shape)) var.values.set_shape(shape) var.indices.set_shape(tensor_shape.TensorShape([shape[0]])) if var.dense_shape is not None: @@ -569,8 +596,8 @@ def _SetShapeInvariants(input_vars, enter_vars, shapes): "The shape invariant specified for %s is not compatible with " "the initial shape of the shape tensor of this SparseTensor. " "It enters the loop with shape %s, but the specified shape " - "invariant is %s." - % (inp.dense_shape.name, inp.dense_shape.get_shape(), shape)) + "invariant is %s." % (inp.dense_shape.name, + inp.dense_shape.get_shape(), shape)) var.values.set_shape(tensor_shape.TensorShape([None])) var.indices.set_shape(tensor_shape.TensorShape([None, shape.ndims])) var.dense_shape.set_shape(shape) @@ -599,8 +626,8 @@ def _EnforceShapeInvariant(merge_var, next_var): "The shape for %s is not an invariant for the loop. It enters " "the loop with shape %s, but has shape %s after one iteration. " "Provide shape invariants using either the `shape_invariants` " - "argument of tf.while_loop or set_shape() on the loop variables." - % (merge_var.name, m_shape, n_shape)) + "argument of tf.while_loop or set_shape() on the loop variables." % + (merge_var.name, m_shape, n_shape)) else: if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)): raise TypeError("Type %s not supported" % type(var)) @@ -623,9 +650,9 @@ def _EnforceShapeInvariant(merge_var, next_var): "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) " "after one iteration. Provide shape invariants using either the " "`shape_invariants` argument of tf.while_loop or set_shape() " - "on the loop variables." - % (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape, - n_values_shape, n_indices_shape, n_shape_shape)) + "on the loop variables." % + (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape, + n_values_shape, n_indices_shape, n_shape_shape)) else: m_values_shape = merge_var.values.get_shape() m_indices_shape = merge_var.indices.get_shape() @@ -637,12 +664,12 @@ def _EnforceShapeInvariant(merge_var, next_var): not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape) or not _ShapeLessThanOrEqual(n_shape_shape, m_shape_shape)): raise ValueError( - "The shape for %s is not an invariant for the loop. It enters " - "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) " - "after one iteration. Provide shape invariants using either " - "the `shape_invariants` argument of tf.while_loop or set_shape() " - "on the loop variables." - % (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape, + "The shape for %s is not an invariant for the loop. It enters " + "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) " + "after one iteration. Provide shape invariants using either " + "the `shape_invariants` argument of tf.while_loop or set_shape() " + "on the loop variables." % + (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape, n_values_shape, n_indices_shape, n_shape_shape)) @@ -657,7 +684,7 @@ def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): # the types don't match. # TODO(skyewm): call this for other cases below (needs testing) _EnforceShapeInvariant(m, v) - m.op._update_input(1, v) # pylint: disable=protected-access + m.op._update_input(1, v) # pylint: disable=protected-access elif isinstance(m, ops.IndexedSlices): # pylint: disable=protected-access v = math_ops._as_indexed_slices(v, optimize=False) @@ -720,8 +747,7 @@ def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt): raise ValueError( "Cannot create a gradient accumulator for tensor '%s' inside " "XLA while_loop because maximum_iterations was not passed to " - "the tf.while_loop call ('%s')." - % (value_name, while_ctxt.name)) + "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name)) # pylint: disable=protected-access max_iter_ctxt = max_iter.op._get_control_flow_context() @@ -742,9 +768,9 @@ def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt): "while_loop. maximum_iterations tensor '%s' for while_loop context " "'%s' must be statically known (e.g. a constant value or known " "shape dimension), or be defined at or outside the while loop " - "context '%s' (currently defined in '%s')." % ( - value_name, max_iter.name, while_ctxt.name, - curr_ctxt_name, max_iter_ctxt.name)) + "context '%s' (currently defined in '%s')." % + (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name, + max_iter_ctxt.name)) max_size *= const_max_iter # Find the next outer WhileContext (or stop if we reach the @@ -808,9 +834,11 @@ class GradLoopState(object): outer_forward_ctxt = forward_ctxt.outer_context # Add the forward loop counter. - if outer_forward_ctxt: outer_forward_ctxt.Enter() + if outer_forward_ctxt: + outer_forward_ctxt.Enter() cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) - if outer_forward_ctxt: outer_forward_ctxt.Exit() + if outer_forward_ctxt: + outer_forward_ctxt.Exit() self._forward_context = forward_ctxt self._forward_index = forward_index @@ -835,7 +863,8 @@ class GradLoopState(object): real_cnt, outer_grad_state) outer_grad_ctxt.Exit() else: - if outer_forward_ctxt: outer_forward_ctxt.Enter() + if outer_forward_ctxt: + outer_forward_ctxt.Enter() self._grad_context = WhileContext( maximum_iterations=forward_ctxt.maximum_iterations, parallel_iterations=forward_ctxt.parallel_iterations, @@ -845,7 +874,8 @@ class GradLoopState(object): grad_state=self) self._grad_index = self._grad_context.AddBackpropLoopCounter( cnt, outer_grad_state) - if outer_forward_ctxt: outer_forward_ctxt.Exit() + if outer_forward_ctxt: + outer_forward_ctxt.Exit() @property def outer_grad_state(self): @@ -973,7 +1003,8 @@ class GradLoopState(object): # curr_ctxt is the context that tf.gradients was called in. curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access with ops.control_dependencies(None): - if curr_ctxt: curr_ctxt.Enter() + if curr_ctxt: + curr_ctxt.Enter() with ops.colocate_with(value): # We only need to pass maximum_iterations to the stack if # we're inside an XLA context. @@ -984,11 +1015,10 @@ class GradLoopState(object): value, self.forward_context) # pylint: disable=protected-access acc = gen_data_flow_ops._stack_v2( - max_size=max_size, - elem_type=value.dtype.base_dtype, - name="f_acc") + max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") # pylint: enable=protected-access - if curr_ctxt: curr_ctxt.Exit() + if curr_ctxt: + curr_ctxt.Exit() # Make acc available in the forward context. enter_acc = self.forward_context.AddValue(acc) @@ -1009,8 +1039,7 @@ class GradLoopState(object): else: # value is in a cond context within the forward context. if not isinstance(value_ctxt, CondContext): - raise TypeError( - "value_ctxt is not a CondContext: %s" % value_ctxt) + raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) if dead_branch: # The special case for creating a zero tensor for a dead # branch of a switch. See ControlFlowState.ZerosLike(). @@ -1134,8 +1163,8 @@ class GradLoopState(object): if real_value is None: # Add the stack pop op in the grad context. - real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value, - cur_value) + real_value = cur_grad_state.AddBackpropAccumulatedValue( + history_value, cur_value) if cur_grad_state != self: real_value = self._grad_context.AddValue(real_value) self._history_map[value.name] = real_value @@ -1154,7 +1183,7 @@ class ControlFlowState(object): """Maintain the mapping from the loops to their grad states.""" def __init__(self): - self._map = {} # maps forward loop context to GradLoopState + self._map = {} # maps forward loop context to GradLoopState def GetGradState(self, op, before): """Return the grad state for this op if it's in a forward loop context.""" @@ -1318,7 +1347,8 @@ class ControlFlowState(object): Returns: A zero tensor of the same shape of op.outputs[index]. """ - if util.IsLoopSwitch(op): return None + if util.IsLoopSwitch(op): + return None dead_branch = util.IsSwitch(op) forward_ctxt = _GetWhileContext(op) grad_state = self._map.get(forward_ctxt) @@ -1361,8 +1391,8 @@ class ControlFlowState(object): grad_state.grad_context.Enter() # Create a zero tensor with the right shape. - shape = grad_state.AddBackpropAccumulatedValue( - history_zeros_shape, zeros_shape, dead_branch) + shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, + zeros_shape, dead_branch) result = array_ops.zeros(shape, val.dtype) return result @@ -1393,12 +1423,14 @@ class ControlFlowState(object): else: # Create a zeros in the outer grad context. outer_grad_ctxt = grad_state.grad_context.outer_context - if outer_grad_ctxt: outer_grad_ctxt.Enter() + if outer_grad_ctxt: + outer_grad_ctxt.Enter() enter_grad_op = b_merge.op.inputs[0].op enter_grad = enter_grad_op.inputs[0] grad_shape = array_ops.shape_internal(enter_grad, optimize=False) grad_val = array_ops.zeros(grad_shape) - if outer_grad_ctxt: outer_grad_ctxt.Exit() + if outer_grad_ctxt: + outer_grad_ctxt.Exit() # Use the zeros for iterations > 0. grad_state.grad_context.Enter() next_grad_val = _NextIteration(grad_val) @@ -1467,11 +1499,13 @@ class ControlFlowContext(object): """ def __init__(self, values_def=None, import_scope=None): + self._nested_contexts = [] self._outer_context = ops.get_default_graph()._get_control_flow_context() + if self._outer_context: + self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access self._context_stack = [] if values_def: - self._init_values_from_proto(values_def, - import_scope=import_scope) + self._init_values_from_proto(values_def, import_scope=import_scope) else: # Values that have been already seen in this context. self._values = set() @@ -1521,7 +1555,17 @@ class ControlFlowContext(object): def back_prop(self): raise NotImplementedError("Abstract method") - def _to_proto(self, export_scope=None): + @abc.abstractmethod + def to_control_flow_context_def(self, context_def, export_scope=None): + """Serializes this into `context_def`. + + Args: + context_def: a `ControlFlowContextDef` protocol buffer. + export_scope: Optional `string`. Name scope to remove. + """ + raise NotImplementedError("Abstract method") + + def _to_values_def(self, export_scope=None): """Converts the values to a `ValuesDef` protocol buffer. Args: @@ -1532,20 +1576,12 @@ class ControlFlowContext(object): """ values_def = control_flow_pb2.ValuesDef() values_def.values.extend( - [ops.strip_name_scope(v, export_scope) - for v in sorted(self._values)]) + [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)]) for k, v in self._external_values.items(): k = ops.strip_name_scope(k, export_scope) - values_def.external_values[k] = ops.strip_name_scope( - v.name, export_scope) + values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope) return values_def - @staticmethod - def _from_proto(values_def, import_scope=None): - """Returns a `ControlFlowContext` created from `values_def`.""" - return ControlFlowContext(values_def=values_def, - import_scope=import_scope) - def AddName(self, name): self._values.add(name) @@ -1595,10 +1631,14 @@ class ControlFlowContext(object): ctxt = util.GetOutputContext(x) if ctxt is not None and ctxt.GetWhileContext() == while_ctxt: internal_control_inputs.append(x) + external_control_inputs = [] if len(internal_control_inputs) != len(op.control_inputs): + external_control_inputs = list(set(op.control_inputs) + - set(internal_control_inputs)) op._remove_all_control_inputs() op._add_control_inputs(internal_control_inputs) - return internal_control_inputs + return internal_control_inputs, external_control_inputs + # pylint: enable=protected-access def AddInnerOp(self, op): @@ -1626,8 +1666,13 @@ class ControlFlowContext(object): class CondContext(ControlFlowContext): """The context for the conditional construct.""" - def __init__(self, pred=None, pivot=None, branch=None, - name="cond_text", context_def=None, import_scope=None): + def __init__(self, + pred=None, + pivot=None, + branch=None, + name="cond_text", + context_def=None, + import_scope=None): """Creates a `CondContext`. Args: @@ -1647,9 +1692,9 @@ class CondContext(ControlFlowContext): else: # Initializes the default fields. ControlFlowContext.__init__(self) - self._pred = pred # The boolean tensor for the cond predicate - self._pivot = pivot # The predicate tensor in this branch - self._branch = branch # 0 or 1 representing this branch + self._pred = pred # The boolean tensor for the cond predicate + self._pivot = pivot # The predicate tensor in this branch + self._branch = branch # 0 or 1 representing this branch # Values considered to have been already seen in this context. self._values.add(pred.name) @@ -1665,15 +1710,14 @@ class CondContext(ControlFlowContext): assert isinstance(context_def, control_flow_pb2.CondContextDef) # Create from context_def. g = ops.get_default_graph() - self._name = ops.prepend_name_scope( - context_def.context_name, import_scope) - self._pred = g.as_graph_element(ops.prepend_name_scope( - context_def.pred_name, import_scope)) - self._pivot = g.as_graph_element(ops.prepend_name_scope( - context_def.pivot_name, import_scope)) + self._name = ops.prepend_name_scope(context_def.context_name, import_scope) + self._pred = g.as_graph_element( + ops.prepend_name_scope(context_def.pred_name, import_scope)) + self._pivot = g.as_graph_element( + ops.prepend_name_scope(context_def.pivot_name, import_scope)) self._branch = context_def.branch - super(CondContext, self).__init__(values_def=context_def.values_def, - import_scope=import_scope) + super(CondContext, self).__init__( + values_def=context_def.values_def, import_scope=import_scope) @property def pred(self): @@ -1711,18 +1755,23 @@ class CondContext(ControlFlowContext): Returns: A `CondContextDef` protocol buffer. """ - if (export_scope is None or - self.name.startswith(export_scope)): + if (export_scope is None or self.name.startswith(export_scope)): context_def = control_flow_pb2.CondContextDef() - context_def.context_name = ops.strip_name_scope( - self.name, export_scope) - context_def.pred_name = ops.strip_name_scope( - self._pred.name, export_scope) - context_def.pivot_name = ops.strip_name_scope( - self._pivot.name, export_scope) + context_def.context_name = ops.strip_name_scope(self.name, export_scope) + context_def.pred_name = ops.strip_name_scope(self._pred.name, + export_scope) + context_def.pivot_name = ops.strip_name_scope(self._pivot.name, + export_scope) context_def.branch = self._branch - context_def.values_def.MergeFrom(super(CondContext, self)._to_proto( + context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def( export_scope)) + # TODO(b/72868227): enable this once the corresponding control_flow.proto + # changes have been checked in (they aren't checked in and this is + # disabled for now to ensure forwards compatibility). + if False: # pylint: disable=using-constant-test + for nested in self._nested_contexts: + nested_def = context_def.nested_contexts.add() + nested.to_control_flow_context_def(nested_def) return context_def else: @@ -1731,8 +1780,21 @@ class CondContext(ControlFlowContext): @staticmethod def from_proto(context_def, import_scope=None): """Returns a `CondContext` object created from `context_def`.""" - return CondContext(context_def=context_def, - import_scope=import_scope) + ret = CondContext(context_def=context_def, + import_scope=import_scope) + + # TODO(b/72868227): remove "if hasattr(...)" once the corresponding + # control_flow.proto changes have been checked in (they aren't checked in + # and this is here for now to ensure forwards compatibility). + if hasattr(context_def, "nested_contexts"): + ret.Enter() + for nested_def in context_def.nested_contexts: + from_control_flow_context_def(nested_def) + ret.Exit() + return ret + + def to_control_flow_context_def(self, context_def, export_scope=None): + context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) def AddValue(self, val): """Add `val` to the current context and its outer context recursively.""" @@ -1846,8 +1908,8 @@ class CondContext(ControlFlowContext): if original_result is None: return no_op(), None else: - original_result = nest.map_structure( - array_ops.identity, original_result) + original_result = nest.map_structure(array_ops.identity, + original_result) if original_result is None: return None, None @@ -1871,11 +1933,15 @@ def _UnpackIfSingleton(res): # pylint: disable=g-doc-args @tf_export("cond") @deprecation.deprecated_args( - None, - "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", + None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", "fn1", "fn2") -def cond(pred, true_fn=None, false_fn=None, strict=False, name=None, - fn1=None, fn2=None): +def cond(pred, + true_fn=None, + false_fn=None, + strict=False, + name=None, + fn1=None, + fn2=None): """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and @@ -2034,9 +2100,15 @@ def cond(pred, true_fn=None, false_fn=None, strict=False, name=None, merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges) - # Add to collections - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) - ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) + # Only add non-nested conds to the collection. Any nested control flow will + # be encapsulated in the root context. + assert context_t.outer_context == context_f.outer_context + # TODO(b/72868227): remove "if True..." once the corresponding + # control_flow.proto changes have been checked in (they aren't checked in + # and this is disabled for now to ensure forwards compatibility). + if True or context_t.outer_context is None: + ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) + ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges) @@ -2044,10 +2116,67 @@ def cond(pred, true_fn=None, false_fn=None, strict=False, name=None, if not strict: merges = _UnpackIfSingleton(merges) return merges + + # pylint: enable=g-doc-args # pylint: enable=redefined-outer-name +def smart_cond(pred, true_fn=None, false_fn=None, name=None): + """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. + + If `pred` is a bool or has a constant value, we return either `true_fn()` + or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. + + Arguments: + pred: A scalar determining whether to return the result of `true_fn` or + `false_fn`. + true_fn: The callable to be performed if pred is true. + false_fn: The callable to be performed if pred is false. + name: Optional name prefix when using `tf.cond`. + + Returns: + Tensors returned by the call to either `true_fn` or `false_fn`. + + Raises: + TypeError: If `true_fn` or `false_fn` is not callable. + """ + if not callable(true_fn): + raise TypeError('`true_fn` must be callable.') + if not callable(false_fn): + raise TypeError('`false_fn` must be callable.') + + pred_value = smart_constant_value(pred) + if pred_value is not None: + if pred_value: + return true_fn() + else: + return false_fn() + else: + return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name) + + +def smart_constant_value(pred): + """Return the bool value for `pred`, or None if `pred` had a dynamic value. + + Arguments: + pred: A scalar, either a Python bool or tensor. + + Returns: + True or False if `pred` has a constant boolean value, None otherwise. + + Raises: + TypeError: If `pred` is not a Tensor or bool. + """ + if isinstance(pred, bool): + pred_value = pred + elif isinstance(pred, ops.Tensor): + pred_value = tensor_util.constant_value(pred) + else: + raise TypeError('`pred` must be a Tensor or a Python bool.') + return pred_value + + def _resource_safe_shape(t): """Returns the shape of t or the variable it points to.""" if t.dtype == dtypes.resource: @@ -2139,8 +2268,7 @@ class WhileContext(ControlFlowContext): assert isinstance(context_def, control_flow_pb2.WhileContextDef) # Create from context_def. g = ops.get_default_graph() - self._name = ops.prepend_name_scope( - context_def.context_name, import_scope) + self._name = ops.prepend_name_scope(context_def.context_name, import_scope) if context_def.maximum_iterations_name: self._maximum_iterations = g.as_graph_element( ops.prepend_name_scope(context_def.maximum_iterations_name, @@ -2150,25 +2278,38 @@ class WhileContext(ControlFlowContext): self._parallel_iterations = context_def.parallel_iterations self._back_prop = context_def.back_prop self._swap_memory = context_def.swap_memory - self._pivot_for_pred = g.as_graph_element(ops.prepend_name_scope( - context_def.pivot_for_pred_name, import_scope)) + self._pivot_for_pred = g.as_graph_element( + ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope)) # We use this node to control constants created by the body lambda. - self._pivot_for_body = g.as_graph_element(ops.prepend_name_scope( - context_def.pivot_for_body_name, import_scope)) + self._pivot_for_body = g.as_graph_element( + ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope)) # The boolean tensor for loop termination condition. Used in code # generation for gradient computation. self._pivot = g.as_graph_element( ops.prepend_name_scope(context_def.pivot_name, import_scope)) # The list of exit tensors for loop variables. - self._loop_exits = [g.as_graph_element( - ops.prepend_name_scope(exit_name, import_scope)) - for exit_name in context_def.loop_exit_names] + self._loop_exits = [ + g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) + for exit_name in context_def.loop_exit_names + ] # The list of enter tensors for loop variables. - self._loop_enters = [g.as_graph_element( - ops.prepend_name_scope(enter_name, import_scope)) - for enter_name in context_def.loop_enter_names] - super(WhileContext, self).__init__(values_def=context_def.values_def, - import_scope=import_scope) + self._loop_enters = [ + g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) + for enter_name in context_def.loop_enter_names + ] + super(WhileContext, self).__init__( + values_def=context_def.values_def, import_scope=import_scope) + + # import_scope causes self.name to be different from the original serialized + # context's name. Rewrite "frame_name" attrs with the new name. + if import_scope: + for tensor_name in self._values: + op = g.as_graph_element(tensor_name).op + if util.IsLoopEnter(op): + # pylint: disable=protected-access + op._set_attr("frame_name", + attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) + # pylint: enable=protected-access @property def maximum_iterations(self): @@ -2219,11 +2360,9 @@ class WhileContext(ControlFlowContext): Returns: A `WhileContextDef` protocol buffer. """ - if (export_scope is None or - self.name.startswith(export_scope)): + if (export_scope is None or self.name.startswith(export_scope)): context_def = control_flow_pb2.WhileContextDef() - context_def.context_name = ops.strip_name_scope( - self.name, export_scope) + context_def.context_name = ops.strip_name_scope(self.name, export_scope) context_def.parallel_iterations = self._parallel_iterations if self._maximum_iterations is not None: context_def.maximum_iterations_name = ops.strip_name_scope( @@ -2234,22 +2373,32 @@ class WhileContext(ControlFlowContext): self._pivot_for_pred.name, export_scope) context_def.pivot_for_body_name = ops.strip_name_scope( self._pivot_for_body.name, export_scope) - context_def.pivot_name = ops.strip_name_scope( - self._pivot.name, export_scope) - context_def.loop_exit_names.extend( - [ops.strip_name_scope(l.name, export_scope) - for l in self._loop_exits]) - context_def.loop_enter_names.extend( - [ops.strip_name_scope(l.name, export_scope) - for l in self._loop_enters]) + context_def.pivot_name = ops.strip_name_scope(self._pivot.name, + export_scope) + context_def.loop_exit_names.extend([ + ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits + ]) + context_def.loop_enter_names.extend([ + ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters + ]) context_def.values_def.MergeFrom( - super(WhileContext, self)._to_proto( + super(WhileContext, self)._to_values_def( export_scope=export_scope)) + # TODO(b/72868227): remove "if True..." once the corresponding + # control_flow.proto changes have been checked in (they aren't checked in + # and this is disabled for now to ensure forwards compatibility). + if False: # pylint: disable=using-constant-test + for nested in self._nested_contexts: + nested_def = context_def.nested_contexts.add() + nested.to_control_flow_context_def(nested_def) return context_def else: return None + def to_control_flow_context_def(self, context_def, export_scope=None): + context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) + @staticmethod def from_proto(context_def, import_scope=None): """Returns a `WhileContext` object created from `context_def`. @@ -2261,8 +2410,17 @@ class WhileContext(ControlFlowContext): Returns: A `WhileContext` Python object. """ - return WhileContext(context_def=context_def, - import_scope=import_scope) + ret = WhileContext(context_def=context_def, + import_scope=import_scope) + # TODO(b/72868227): remove "if hasattr(...)" once the corresponding + # control_flow.proto changes have been checked in (they aren't checked in + # and this is disabled for now to ensure forwards compatibility). + if hasattr(context_def, "nested_contexts"): + ret.Enter() + for nested_def in context_def.nested_contexts: + from_control_flow_context_def(nested_def, import_scope=import_scope) + ret.Exit() + return ret def GetWhileContext(self): return self @@ -2299,8 +2457,11 @@ class WhileContext(ControlFlowContext): result = self._outer_context.AddValue(val) # Create an Enter to make `result` known to this loop context. with ops.control_dependencies(None): - enter = _Enter(result, self._name, is_constant=True, - parallel_iterations=self._parallel_iterations) + enter = _Enter( + result, + self._name, + is_constant=True, + parallel_iterations=self._parallel_iterations) enter.graph.prevent_feeding(enter) if self._outer_context: self._outer_context.AddInnerOp(enter.op) @@ -2340,14 +2501,12 @@ class WhileContext(ControlFlowContext): def _AddOpInternal(self, op): """Add `op` to the current context. - In the case that op has only external data inputs, we remove all of its - external control inputs so all its inputs are in the same while loop - context. This is valid because op now has an Enter input that has all - the right control dependency. + We move any external control dependencies of the op to the loop pivot, to + ensure they get executed. """ if not op.inputs: # Remove any external control dependency on this op - control_inputs = self._RemoveExternalControlEdges(op) + control_inputs, external_inputs = self._RemoveExternalControlEdges(op) # Add a control edge from the control pivot to this op. if not control_inputs: # pylint: disable=protected-access @@ -2360,14 +2519,23 @@ class WhileContext(ControlFlowContext): x = op.inputs[index] real_x = self.AddValue(x) if real_x != x: - op._update_input(index, real_x) + op._update_input(index, real_x) # pylint: disable=protected-access # Remove any external control dependency on this op. - self._RemoveExternalControlEdges(op) + _, external_inputs = self._RemoveExternalControlEdges(op) # Add a control dependency to prevent loop invariants from # enabling ops that should not be executed. self._MaybeAddControlDependency(op) for x in op.outputs: self._values.add(x.name) + if external_inputs: + # Use an identity to pull control inputs as data inputs. Note that we + # ignore ops which don't have outputs. TODO(apassos): fix that + with ops.control_dependencies(None): + self.Enter() + external_inputs = [array_ops.identity(x.outputs[0]).op + for x in external_inputs if x.outputs] + self.Exit() + op._add_control_inputs(external_inputs) # pylint: disable=protected-access if self._outer_context or not util.IsLoopExit(op): op.graph.prevent_fetching(op) for x in op.outputs: @@ -2378,6 +2546,7 @@ class WhileContext(ControlFlowContext): def _MaybeAddControlDependency(self, op): """Add a control input to the op if it only depends on loop invariants.""" + def _IsOpFree(op): """Determines if `op` needs a control dependency.""" if op.control_inputs: @@ -2390,6 +2559,7 @@ class WhileContext(ControlFlowContext): if not util.IsLoopConstantEnter(x.op): return False return True + if _IsOpFree(op): # pylint: disable=protected-access op._add_control_input(self.GetControlPivot().op) @@ -2423,9 +2593,12 @@ class WhileContext(ControlFlowContext): self.Enter() self.AddName(n.name) - enter_n = _Enter(n, self._name, is_constant=False, - parallel_iterations=self._parallel_iterations, - name="f_count") + enter_n = _Enter( + n, + self._name, + is_constant=False, + parallel_iterations=self._parallel_iterations, + name="f_count") self.loop_enters.append(enter_n) merge_n = merge([enter_n, enter_n])[0] @@ -2465,9 +2638,12 @@ class WhileContext(ControlFlowContext): self.Enter() self.AddName(count.name) - enter_count = _Enter(count, self._name, is_constant=False, - parallel_iterations=self._parallel_iterations, - name="b_count") + enter_count = _Enter( + count, + self._name, + is_constant=False, + parallel_iterations=self._parallel_iterations, + name="b_count") self.loop_enters.append(enter_count) merge_count = merge([enter_count, enter_count])[0] @@ -2525,9 +2701,11 @@ class WhileContext(ControlFlowContext): # without running any iterations. shape = grad.get_shape() if shape.is_fully_defined(): - if self.outer_context: self.outer_context.Enter() + if self.outer_context: + self.outer_context.Enter() acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc") - if self.outer_context: self.outer_context.Exit() + if self.outer_context: + self.outer_context.Exit() else: value = op.inputs[0] if (isinstance(self.outer_context, WhileContext) and @@ -2546,16 +2724,21 @@ class WhileContext(ControlFlowContext): acc = array_ops.zeros(real_shape, grad.dtype) self.outer_context.Exit() else: - if self.outer_context: self.outer_context.Enter() + if self.outer_context: + self.outer_context.Enter() zeros_shape = array_ops.shape_internal(value, optimize=False) acc = array_ops.zeros(zeros_shape, grad.dtype) - if self.outer_context: self.outer_context.Exit() + if self.outer_context: + self.outer_context.Exit() self.Enter() self.AddName(acc.name) - enter_acc = _Enter(acc, self._name, is_constant=False, - parallel_iterations=self._parallel_iterations, - name="b_acc") + enter_acc = _Enter( + acc, + self._name, + is_constant=False, + parallel_iterations=self._parallel_iterations, + name="b_acc") self.loop_enters.append(enter_acc) merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0] @@ -2588,14 +2771,17 @@ class WhileContext(ControlFlowContext): dense_shape = grad.dense_shape self.Exit() - if self.outer_context: self.outer_context.Enter() + if self.outer_context: + self.outer_context.Enter() if values.get_shape().is_fully_defined(): values_shape = tensor_shape.TensorShape( [tensor_shape.Dimension(1)] + values.get_shape().dims[1:]) - if self.outer_context: self.outer_context.Enter() - values_acc = constant_op.constant(0, values.dtype, shape=values_shape, - name="b_acc") - if self.outer_context: self.outer_context.Exit() + if self.outer_context: + self.outer_context.Enter() + values_acc = constant_op.constant( + 0, values.dtype, shape=values_shape, name="b_acc") + if self.outer_context: + self.outer_context.Exit() else: values_shape = _resource_safe_shape(op.inputs[0])[1:] values_shape = array_ops.concat([[1], values_shape], 0) @@ -2604,16 +2790,19 @@ class WhileContext(ControlFlowContext): shape_acc = None if dense_shape is not None: if dense_shape.get_shape().is_fully_defined(): - if self.outer_context: self.outer_context.Enter() - shape_acc = constant_op.constant(0, dense_shape.dtype, - shape=dense_shape.get_shape()) - if self.outer_context: self.outer_context.Exit() + if self.outer_context: + self.outer_context.Enter() + shape_acc = constant_op.constant( + 0, dense_shape.dtype, shape=dense_shape.get_shape()) + if self.outer_context: + self.outer_context.Exit() else: shape_acc = array_ops.zeros_like( array_ops.shape_internal(op.inputs[0], optimize=False), optimize=False) - if self.outer_context: self.outer_context.Exit() + if self.outer_context: + self.outer_context.Exit() self.Enter() self.AddName(values_acc.name) @@ -2626,9 +2815,15 @@ class WhileContext(ControlFlowContext): # Set use_input_shape=False since the accumulator tensors will grow in # size. If use_input_shape=True, the _update_input call below will result in # incompatible shapes. - enter_acc = [_Enter(x, self._name, is_constant=False, - parallel_iterations=self._parallel_iterations, - use_input_shape=False, name="b_acc") for x in init_acc] + enter_acc = [ + _Enter( + x, + self._name, + is_constant=False, + parallel_iterations=self._parallel_iterations, + use_input_shape=False, + name="b_acc") for x in init_acc + ] # Manually set appropriate partial shapes. enter_acc[0].set_shape([None]) if values_acc.shape.dims is not None: @@ -2645,8 +2840,7 @@ class WhileContext(ControlFlowContext): ] if shape_acc is not None: # For the shape we just keep the maximum - acc_indexed_slices.append( - math_ops.maximum(dense_shape, switch_acc[2][1])) + acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1])) next_acc = [_NextIteration(x) for x in acc_indexed_slices] for xm, xn in zip(merge_acc, next_acc): @@ -2657,7 +2851,8 @@ class WhileContext(ControlFlowContext): self.ExitResult(exit_acc) return ops.IndexedSlices( - indices=exit_acc[0], values=exit_acc[1], + indices=exit_acc[0], + values=exit_acc[1], dense_shape=exit_acc[2] if shape_acc is not None else None) def _InitializeValues(self, values): @@ -2690,10 +2885,14 @@ class WhileContext(ControlFlowContext): if self._outer_context: real_vars = [self._outer_context.AddValue(x) for x in loop_vars] with ops.control_dependencies(None): - enter_vars = [_Enter(x, self._name, is_constant=False, - parallel_iterations=self._parallel_iterations, - use_input_shape=(shape_invariants is None)) - for x in real_vars] + enter_vars = [ + _Enter( + x, + self._name, + is_constant=False, + parallel_iterations=self._parallel_iterations, + use_input_shape=(shape_invariants is None)) for x in real_vars + ] for x in enter_vars: x.graph.prevent_feeding(x) if self._outer_context: @@ -2754,11 +2953,13 @@ class WhileContext(ControlFlowContext): summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access summary_ref[:] = pre_summaries with ops.control_dependencies(new_summaries): + def map_fn(x): # TODO(apassos) figure out how to trigger with tensor arrays as well if isinstance(x, tensor_array_ops.TensorArray): return x return array_ops.identity(x) + body_result = nest.map_structure(map_fn, body_result) # Compare the structure types of input and output of body. @@ -2815,8 +3016,7 @@ class WhileContext(ControlFlowContext): packed_exit_vars = nest.pack_sequence_as( structure=original_body_result, flat_sequence=exit_vars_with_tensor_arrays) - return (packed_exit_vars[0] if len(exit_vars) == 1 - else packed_exit_vars) + return (packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars) def _FixControlInputsAndContext(self, enters): graph = ops.get_default_graph() @@ -2834,8 +3034,9 @@ class WhileContext(ControlFlowContext): for x in xs: inp_op = x.op.inputs[0].op control_inputs = graph._control_dependencies_for_inputs([inp_op]) - outer_control_inputs = [op for op in control_inputs - if self._IsInOuterContext(op)] + outer_control_inputs = [ + op for op in control_inputs if self._IsInOuterContext(op) + ] x.op._set_control_flow_context(self) x.op._add_control_inputs(outer_control_inputs) graph._record_op_seen_by_control_dependencies(x.op) @@ -2847,9 +3048,15 @@ class WhileContext(ControlFlowContext): # pylint: disable=redefined-outer-name @tf_export("while_loop") -def while_loop(cond, body, loop_vars, shape_invariants=None, - parallel_iterations=10, back_prop=True, swap_memory=False, - name=None, maximum_iterations=None): +def while_loop(cond, + body, + loop_vars, + shape_invariants=None, + parallel_iterations=10, + back_prop=True, + swap_memory=False, + name=None, + maximum_iterations=None): """Repeat `body` while the condition `cond` is true. `cond` is a callable returning a boolean scalar tensor. `body` is a callable @@ -2966,6 +3173,43 @@ def while_loop(cond, body, loop_vars, shape_invariants=None, c, b, loop_vars=[i0, m0], shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) ``` + + Example which demonstrates non-strict semantics: In the following + example, the final value of the counter `i` does not depend on `x`. So + the `while_loop` can increment the counter parallel to updates of `x`. + However, because the loop counter at one loop iteration depends + on the value at the previous iteration, the loop counter itself cannot + be incremented in parallel. Hence if we just want the final value of the + counter (which we print on the line `print(sess.run(i))`), then + `x` will never be incremented, but the counter will be updated on a + single thread. Conversely, if we want the value of the output (which we + print on the line `print(sess.run(out).shape)`), then the counter may be + incremented on its own thread, while `x` can be incremented in + parallel on a separate thread. In the extreme case, it is conceivable + that the thread incrementing the counter runs until completion before + `x` is incremented even a single time. The only thing that can never + happen is that the thread updating `x` can never get ahead of the + counter thread because the thread incrementing `x` depends on the value + of the counter. + ```python + import tensorflow as tf + + n = 10000 + x = tf.constant(list(range(n))) + c = lambda i, x: i < n + b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:")) + i, out = tf.while_loop(c, b, (0, x)) + with tf.Session() as sess: + print(sess.run(i)) # prints [0] ... [9999] + + # The following line may increment the counter and x in parallel. + # The counter thread may get ahead of the other thread, but not the + # other way around. So you may see things like + # [9996] x:[9987] + # meaning that the counter thread is on iteration 9996, + # while the other thread is on iteration 9987 + print(sess.run(out).shape) + ``` """ with ops.name_scope(name, "while", loop_vars): @@ -3018,12 +3262,20 @@ def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory) - ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) + # Only add non-nested loops to the collection. Any nested control flow will + # be encapsulated in the root context. + # TODO(b/72868227): enable condition once the corresponding + # control_flow.proto changes have been checked in (they aren't checked in + # and this is disabled for now to ensure forwards compatibility). + if True or loop_context.outer_context is None: + ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants) if maximum_iterations is not None: return result[1] else: return result + + # pylint: enable=redefined-outer-name @@ -3051,8 +3303,9 @@ def _AsTensorList(x, p): if isinstance(v, ops.Tensor): l.append(array_ops.identity(v)) else: - l.append(ops.IndexedSlices(array_ops.identity(v.values), - array_ops.identity(v.indices))) + l.append( + ops.IndexedSlices( + array_ops.identity(v.values), array_ops.identity(v.indices))) return l @@ -3062,8 +3315,7 @@ def _CheckResults(a, b): for x, y in zip(a, b): assert x.dtype == y.dtype, ( "Values returned by a() [%s] and b() [%s] must have " - "the same type: %s, %s." % - (x.name, y.name, x.dtype.name, y.dtype.name)) + "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name)) def with_dependencies(dependencies, output_tensor, name=None): @@ -3099,9 +3351,9 @@ def with_dependencies(dependencies, output_tensor, name=None): if isinstance(output_tensor, ops.Tensor): return _Identity(output_tensor, name=name) else: - return ops.IndexedSlices(_Identity(output_tensor.values, name=name), - output_tensor.indices, - output_tensor.dense_shape) + return ops.IndexedSlices( + _Identity(output_tensor.values, name=name), output_tensor.indices, + output_tensor.dense_shape) def _GroupControlDeps(dev, deps, name=None): @@ -3173,6 +3425,7 @@ def group(*inputs, **kwargs): def device_key(dev): """A sort key that allows None to be compared to strings.""" return "" if dev is None else dev + for dev in sorted(six.iterkeys(ops_on_device), key=device_key): deps.append(_GroupControlDeps(dev, ops_on_device[dev])) @@ -3181,7 +3434,7 @@ def group(*inputs, **kwargs): @tf_export("tuple") -def tuple(tensors, name=None, control_inputs=None): +def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin """Group tensors together. This creates a tuple of tensors with the same values as the `tensors` @@ -3463,12 +3716,34 @@ class XLAControlFlowContext(ControlFlowContext): return x -ops.register_proto_function(ops.GraphKeys.COND_CONTEXT, - proto_type=control_flow_pb2.CondContextDef, - to_proto=CondContext.to_proto, - from_proto=CondContext.from_proto) +def from_control_flow_context_def(context_def, import_scope=None): + """Deserializes `context_def` into the appropriate ControlFlowContext. + + Args: + context_def: ControlFlowContextDef proto + import_scope: Optional `string`. Name scope to add. + + Returns: + A ControlFlowContext subclass + """ + if context_def.HasField("cond_ctxt"): + return CondContext.from_proto(context_def.cond_ctxt, + import_scope=import_scope) + if context_def.HasField("while_ctxt"): + return WhileContext.from_proto(context_def.while_ctxt, + import_scope=import_scope) + raise NotImplementedError("Unknown ControlFlowContextDef field: %s" + % context_def.WhichOneof("ctxt")) + + +ops.register_proto_function( + ops.GraphKeys.COND_CONTEXT, + proto_type=control_flow_pb2.CondContextDef, + to_proto=CondContext.to_proto, + from_proto=CondContext.from_proto) -ops.register_proto_function(ops.GraphKeys.WHILE_CONTEXT, - proto_type=control_flow_pb2.WhileContextDef, - to_proto=WhileContext.to_proto, - from_proto=WhileContext.from_proto) +ops.register_proto_function( + ops.GraphKeys.WHILE_CONTEXT, + proto_type=control_flow_pb2.WhileContextDef, + to_proto=WhileContext.to_proto, + from_proto=WhileContext.from_proto) diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index cc5a42bf3ddd4b37d037f8d28a2fe6af79f79ba1..7775133348ffb915995970b2942a3b3cf27b8809 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -189,7 +189,7 @@ class SwitchTestCase(test_util.TensorFlowTestCase): zero = constant_op.constant(0) one = constant_op.constant(1) less_op = math_ops.less(zero, one) - switch_false, switch_true = control_flow_ops.switch(data, less_op) + _, switch_true = control_flow_ops.switch(data, less_op) self.assertAllEqual([1, 2, 3], switch_true.values.eval()) self.assertAllEqual([0, 1], switch_true.indices.eval()) @@ -199,16 +199,17 @@ class SwitchTestCase(test_util.TensorFlowTestCase): "embedding_matrix", [5, 5], initializer=init_ops.random_normal_initializer()) - def Cond(it, _): + def cond(it, _): return it < 5 - def Body(it, cost): + def body(it, cost): embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0]) cost += math_ops.reduce_sum(embedding) return it + 1, cost _, cost = control_flow_ops.while_loop( - Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) + cond, body, [constant_op.constant(0), + constant_op.constant(0.0)]) optimizer = momentum.MomentumOptimizer(0.1, 0.9) train_op = optimizer.minimize(cost) with self.test_session() as sess: @@ -223,16 +224,17 @@ class SwitchTestCase(test_util.TensorFlowTestCase): initializer=[[2.0], [3.0]], use_resource=True) - def Cond(it, _): + def cond(it, _): return it < 5 - def Body(it, cost): + def body(it, cost): embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) cost += math_ops.reduce_sum(embedding) return it + 1, cost _, cost = control_flow_ops.while_loop( - Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) + cond, body, [constant_op.constant(0), + constant_op.constant(0.0)]) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) self.assertAllEqual(10.0, cost.eval()) @@ -244,10 +246,10 @@ class SwitchTestCase(test_util.TensorFlowTestCase): initializer=init_ops.random_normal_initializer(), use_resource=use_resource) - def Cond(it, _): + def cond(it, _): return it < 5 - def Body(it, cost): + def body(it, cost): embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) cost = control_flow_ops.cond( math_ops.equal(it, 3), lambda: math_ops.square(cost), @@ -255,7 +257,8 @@ class SwitchTestCase(test_util.TensorFlowTestCase): return it + 1, cost _, cost = control_flow_ops.while_loop( - Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)]) + cond, body, [constant_op.constant(0), + constant_op.constant(0.0)]) dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0] dynamic_grads = math_ops.segment_sum(dynamic_grads.values, @@ -289,15 +292,15 @@ class SwitchTestCase(test_util.TensorFlowTestCase): dtype=dtype, size=num_steps) initial_i = constant_op.constant(0, dtype=dtypes.int32) - def Cond(i, _): + def cond(i, _): return i < num_steps # pylint: disable=cell-var-from-loop - def Body(i, outputs): + def body(i, outputs): x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop outputs = outputs.write(i, x) return i + 1, outputs - _, outputs = control_flow_ops.while_loop(Cond, Body, + _, outputs = control_flow_ops.while_loop(cond, body, [initial_i, initial_outputs]) outputs = math_ops.reduce_sum(outputs.stack()) @@ -316,15 +319,15 @@ class SwitchTestCase(test_util.TensorFlowTestCase): dtype=dtype, dynamic_size=True, size=1) initial_i = constant_op.constant(0, dtype=dtypes.int32) - def Cond(i, _): + def cond(i, _): return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop - def Body(i, outputs): + def body(i, outputs): x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop outputs = outputs.write(i, x) return i + 1, outputs - _, outputs = control_flow_ops.while_loop(Cond, Body, + _, outputs = control_flow_ops.while_loop(cond, body, [initial_i, initial_outputs]) outputs = math_ops.reduce_sum(outputs.stack()) @@ -346,6 +349,44 @@ class SwitchTestCase(test_util.TensorFlowTestCase): self.assertEquals(grad_x_false.eval(), 0.) +@test_util.with_c_api +class SmartCondTest(test_util.TensorFlowTestCase): + + def testSmartCondTrue(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(2) + y = constant_op.constant(5) + z = control_flow_ops.smart_cond( + True, lambda: math_ops.multiply(x, 16), + lambda: math_ops.multiply(y, 5)) + self.assertEqual(z.eval(), 32) + + def testSmartCondFalse(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(4) + y = constant_op.constant(3) + z = control_flow_ops.smart_cond( + False, lambda: math_ops.multiply(x, 16), + lambda: math_ops.multiply(y, 3)) + self.assertEqual(z.eval(), 9) + + def testSmartCondMissingArg1(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.smart_cond(True, false_fn=lambda: x) + + def testSmartCondMissingArg2(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.smart_cond(True, lambda: x) + + @test_util.with_c_api class CondTest(test_util.TensorFlowTestCase): @@ -460,11 +501,12 @@ class ContextTest(test_util.TensorFlowTestCase): control_flow_ops.while_loop( c, b, [i], maximum_iterations=maximum_iterations) for op in sess.graph.get_operations(): - context = op._get_control_flow_context() - if context: - self.assertProtoEquals(context.to_proto(), - control_flow_ops.WhileContext.from_proto( - context.to_proto()).to_proto()) + control_flow_context = op._get_control_flow_context() + if control_flow_context: + self.assertProtoEquals( + control_flow_context.to_proto(), + control_flow_ops.WhileContext.from_proto( + control_flow_context.to_proto()).to_proto()) def testWhileContext(self): self._testWhileContextHelper() @@ -483,8 +525,8 @@ class ContextTest(test_util.TensorFlowTestCase): c._values = ["a", "b"] c._external_values = {"a": b1} - c_with_scope = control_flow_ops.ControlFlowContext._from_proto( - c._to_proto(), import_scope="test_scope") + c_with_scope = control_flow_ops.ControlFlowContext( + values_def=c._to_values_def(), import_scope="test_scope") # _values and _external_values should be have scope prepended. self.assertEquals( @@ -494,12 +536,13 @@ class ContextTest(test_util.TensorFlowTestCase): # Calling _to_proto() with export_scope should remove "test_scope". self.assertProtoEquals( - c._to_proto(), - c_with_scope._to_proto(export_scope="test_scope")) + c._to_values_def(), + c_with_scope._to_values_def(export_scope="test_scope")) + +def _get_nested_shape(nested): -def _GetNestedShape(nested): - def _GetShape(tensor): + def _get_shape(tensor): if isinstance(tensor, tensor_array_ops.TensorArray): return tensor_array_ops.TensorArray elif isinstance(tensor, ops.IndexedSlices): @@ -507,10 +550,10 @@ def _GetNestedShape(nested): else: return tensor.get_shape() - return nest.map_structure(_GetShape, nested) + return nest.map_structure(_get_shape, nested) -def _CreateTensorArray(size, shape): +def _create_tensor_array(size, shape): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size, clear_after_read=False) for i in range(size): @@ -518,13 +561,15 @@ def _CreateTensorArray(size, shape): return ta -def _RawNestedShape(nested_shape): - def _RawShape(shape): +def _raw_nested_shape(nested_shape): + + def _raw_shape(shape): if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None: return [x.value for x in shape] else: return None - return nest.map_structure(_RawShape, nested_shape) + + return nest.map_structure(_raw_shape, nested_shape) # TODO(yori): Add tests for indexed slices. @@ -543,13 +588,15 @@ class DataTypesTest(test_util.TensorFlowTestCase): condition = array_ops.placeholder(dtypes.bool) output_cond = control_flow_ops.cond(condition, fn_true, fn_false, strict=strict) - self.assertEqual(_RawNestedShape(_GetNestedShape(output_cond)), - _RawNestedShape(expected_shape)) + self.assertEqual( + _raw_nested_shape(_get_nested_shape(output_cond)), + _raw_nested_shape(expected_shape)) output_case = control_flow_ops.case([(condition, fn_true)], fn_false, strict=strict) - self.assertEqual(_RawNestedShape(_GetNestedShape(output_case)), - _RawNestedShape(expected_shape)) + self.assertEqual( + _raw_nested_shape(_get_nested_shape(output_case)), + _raw_nested_shape(expected_shape)) def _testReturnValues(self, fn_true, fn_false, expected_value_true, expected_value_false, strict=False, @@ -626,45 +673,55 @@ class DataTypesTest(test_util.TensorFlowTestCase): control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none) def test_tensors(self): - def _BuildTrueBranch(dtype): - def _Build(): + + def _build_true_branch(dtype): + + def _build(): return (array_ops.zeros([2, 2], dtype=dtype), array_ops.ones([3, 3], dtype=dtype)) - return _Build - def _BuildFalseBranch(dtype): - def _Build(): + return _build + + def _build_false_branch(dtype): + + def _build(): return (array_ops.ones([2, 2], dtype=dtype), array_ops.zeros([3, 3], dtype=dtype)) - return _Build + + return _build for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): shape = (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([3, 3])) - fn_true = _BuildTrueBranch(dtype) - fn_false = _BuildFalseBranch(dtype) + fn_true = _build_true_branch(dtype) + fn_false = _build_false_branch(dtype) self._testShape(fn_true, fn_false, shape) self._testReturnValues(fn_true, fn_false, (np.zeros([2, 2]), np.ones([3, 3])), (np.ones([2, 2]), np.zeros([3, 3]))) def test_tensors_unknown_shape(self): - def _BuildTrueBranch(dtype): + + def _build_true_branch(dtype): tensor = array_ops.placeholder(dtype=dtype, shape=None) - def _Build(): + + def _build(): return tensor - return _Build, tensor - def _BuildFalseBranch(dtype): + return _build, tensor + + def _build_false_branch(dtype): tensor = array_ops.placeholder(dtype=dtype, shape=None) - def _Build(): + + def _build(): return tensor - return _Build, tensor + + return _build, tensor for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): shape = tensor_shape.TensorShape(None) - fn_true, true_tensor = _BuildTrueBranch(dtype) - fn_false, false_tensor = _BuildFalseBranch(dtype) + fn_true, true_tensor = _build_true_branch(dtype) + fn_false, false_tensor = _build_false_branch(dtype) self._testShape(fn_true, fn_false, shape) self._testReturnValues(fn_true, fn_false, np.zeros([2, 2]), np.ones([2, 2]), @@ -674,11 +731,11 @@ class DataTypesTest(test_util.TensorFlowTestCase): def test_sparse_tensors(self): shape = tensor_shape.TensorShape([None, None]) - def FnTrue(): + def true_fn(): return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])] - def FnFalse(): + def false_fn(): return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]], values=[3, 4], dense_shape=[3, 4])] @@ -686,26 +743,29 @@ class DataTypesTest(test_util.TensorFlowTestCase): values=[1, 2], dense_shape=[3, 4]) value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]], values=[3, 4], dense_shape=[3, 4]) - self._testShape(FnTrue, FnFalse, shape) - self._testReturnValues(FnTrue, FnFalse, value1, value2) - self._testShape(FnTrue, FnFalse, [shape], strict=True) - self._testReturnValues(FnTrue, FnFalse, [value1], [value2], strict=True) + self._testShape(true_fn, false_fn, shape) + self._testReturnValues(true_fn, false_fn, value1, value2) + self._testShape(true_fn, false_fn, [shape], strict=True) + self._testReturnValues(true_fn, false_fn, [value1], [value2], strict=True) def test_tensors_with_partially_specified_shapes(self): - def _BuildBranch(dtype, shape): + + def _build_branch(dtype, shape): a = array_ops.placeholder(dtype=dtype, shape=shape[0]) b = array_ops.placeholder(dtype=dtype, shape=shape[1]) c = array_ops.placeholder(dtype=dtype, shape=shape[2]) - def _Build(): + + def _build(): return a, b, c - return _Build, (a, b, c) + + return _build, (a, b, c) for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): shape = (tensor_shape.TensorShape([None, 2]), tensor_shape.TensorShape([None]), tensor_shape.TensorShape([3, None])) - fn_true, true_tensors = _BuildBranch(dtype, shape) - fn_false, false_tensors = _BuildBranch(dtype, shape) + fn_true, true_tensors = _build_branch(dtype, shape) + fn_false, false_tensors = _build_branch(dtype, shape) self._testShape(fn_true, fn_false, shape) self._testReturnValues(fn_true, fn_false, (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])), @@ -719,8 +779,8 @@ class DataTypesTest(test_util.TensorFlowTestCase): def test_tensor_arrays(self): element_shape = tensor_shape.TensorShape([2]) - ta1 = _CreateTensorArray(4, element_shape) - ta2 = _CreateTensorArray(4, element_shape) + ta1 = _create_tensor_array(4, element_shape) + ta2 = _create_tensor_array(4, element_shape) shape = tensor_array_ops.TensorArray fn_true = lambda: ta1 fn_false = lambda: ta2 @@ -728,7 +788,7 @@ class DataTypesTest(test_util.TensorFlowTestCase): def test_tensor_array_reads(self): shape = tensor_shape.TensorShape([2]) - ta = _CreateTensorArray(4, shape) + ta = _create_tensor_array(4, shape) fn_true = lambda: ta.read(0) fn_false = lambda: ta.read(1) self._testShape(fn_true, fn_false, shape) @@ -827,23 +887,26 @@ class DataTypesTest(test_util.TensorFlowTestCase): tensor_shape.TensorShape([5, 5]), tensor_shape.TensorShape([])] - def FnTrue(): + def true_fn(): return [constant_op.constant(1), TestTuple(constant_op.constant(2), [3, 4]), array_ops.zeros([5, 5]), 6] - def FnFalse(): + def false_fn(): return [constant_op.constant(11), TestTuple(constant_op.constant(12), [13, 14]), array_ops.ones([5, 5]), 16] - self._testShape(FnTrue, FnFalse, shape) - self._testReturnValues(FnTrue, FnFalse, - [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6], - [11, TestTuple(12, [13, 14]), np.ones([5, 5]), 16]) + self._testShape(true_fn, false_fn, shape) + self._testReturnValues( + true_fn, false_fn, + [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6], + [11, TestTuple(12, [13, 14]), + np.ones([5, 5]), 16]) def test_cond_inside_while_loop(self): - def Body(i, matrix): + + def body(i, matrix): result_tuple, unused_matrix = control_flow_ops.cond( constant_op.constant(True), lambda: (TestTuple(matrix * 2, matrix * 4), matrix), @@ -852,8 +915,9 @@ class DataTypesTest(test_util.TensorFlowTestCase): iteration, matrix = control_flow_ops.while_loop( lambda i, matrix: i < 10, - Body, - loop_vars=[constant_op.constant(0), array_ops.ones([2, 2])]) + body, + loop_vars=[constant_op.constant(0), + array_ops.ones([2, 2])]) self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([])) self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2])) diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 34f0bf7b78a75533cb89ed549afad90f3c066b94..03ed537cfcf27151a0200d7a17f63b1a2bc7ba1a 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. #============================================================================== - """Data Flow Operations.""" # pylint: disable=g-bad-name from __future__ import absolute_import @@ -40,6 +39,7 @@ from tensorflow.python.ops import math_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_data_flow_ops import * from tensorflow.python.util.tf_export import tf_export + # pylint: enable=wildcard-import @@ -54,17 +54,19 @@ def _as_type_list(dtypes): return list(dtypes) -def _as_shape_list(shapes, dtypes, unknown_dim_allowed=False, +def _as_shape_list(shapes, + dtypes, + unknown_dim_allowed=False, unknown_rank_allowed=False): """Convert shapes to a list of tuples of int (or None).""" del dtypes if unknown_dim_allowed: - if (not isinstance(shapes, collections.Sequence) - or not shapes - or any(shape is None or isinstance(shape, int) for shape in shapes)): + if (not isinstance(shapes, collections.Sequence) or not shapes or + any(shape is None or isinstance(shape, int) for shape in shapes)): raise ValueError( "When providing partial shapes, a list of shapes must be provided.") - if shapes is None: return None + if shapes is None: + return None if isinstance(shapes, tensor_shape.TensorShape): shapes = [shapes] if not isinstance(shapes, (tuple, list)): @@ -103,7 +105,8 @@ def _shape_common(s1, s2): return tensor_shape.unknown_shape() d = [ d1 if d1 is not None and d1 == d2 else None - for (d1, d2) in zip(s1.as_list(), s2.as_list())] + for (d1, d2) in zip(s1.as_list(), s2.as_list()) + ] return tensor_shape.TensorShape(d) @@ -195,8 +198,7 @@ class QueueBase(object): TypeError: When `queues` is not a list of `QueueBase` objects, or when the data types of `queues` are not all the same. """ - if ((not queues) or - (not isinstance(queues, list)) or + if ((not queues) or (not isinstance(queues, list)) or (not all(isinstance(x, QueueBase) for x in queues))): raise TypeError("A list of queues expected") @@ -210,12 +212,16 @@ class QueueBase(object): queue_shapes = [q.shapes for q in queues] reduced_shapes = [ - six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)] + six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes) + ] queue_refs = array_ops.stack([x.queue_ref for x in queues]) selected_queue = array_ops.gather(queue_refs, index) - return QueueBase(dtypes=dtypes, shapes=reduced_shapes, names=names, - queue_ref=selected_queue) + return QueueBase( + dtypes=dtypes, + shapes=reduced_shapes, + names=names, + queue_ref=selected_queue) @property def queue_ref(self): @@ -282,8 +288,8 @@ class QueueBase(object): tensors = [] for i, (val, dtype) in enumerate(zip(vals, self._dtypes)): - tensors.append(ops.convert_to_tensor(val, dtype=dtype, - name="component_%d" % i)) + tensors.append( + ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) return tensors @@ -468,7 +474,7 @@ class QueueBase(object): name: A name for the operation (optional). Returns: - The tuple of concatenated tensors that was dequeued. + The list of concatenated tensors that was dequeued. """ if name is None: name = "%s_DequeueMany" % self._name @@ -555,11 +561,13 @@ class QueueBase(object): name = "%s_Close" % self._name if self._queue_ref.dtype == _dtypes.resource: return gen_data_flow_ops._queue_close_v2( - self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues, + self._queue_ref, + cancel_pending_enqueues=cancel_pending_enqueues, name=name) else: return gen_data_flow_ops._queue_close( - self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues, + self._queue_ref, + cancel_pending_enqueues=cancel_pending_enqueues, name=name) def is_closed(self, name=None): @@ -577,9 +585,9 @@ class QueueBase(object): if name is None: name = "%s_Is_Closed" % self._name if self._queue_ref.dtype == _dtypes.resource: - return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref,name=name) + return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name) else: - return gen_data_flow_ops.queue_is_closed_(self._queue_ref,name=name) + return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name) def size(self, name=None): """Compute the number of elements in this queue. @@ -611,8 +619,14 @@ class RandomShuffleQueue(QueueBase): @end_compatibility """ - def __init__(self, capacity, min_after_dequeue, dtypes, shapes=None, - names=None, seed=None, shared_name=None, + def __init__(self, + capacity, + min_after_dequeue, + dtypes, + shapes=None, + names=None, + seed=None, + shared_name=None, name="random_shuffle_queue"): """Create a queue that dequeues elements in a random order. @@ -670,9 +684,14 @@ class RandomShuffleQueue(QueueBase): string = (str(seed1) + shared_name).encode("utf-8") seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF queue_ref = gen_data_flow_ops._random_shuffle_queue_v2( - component_types=dtypes, shapes=shapes, capacity=capacity, - min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2, - shared_name=shared_name, name=name) + component_types=dtypes, + shapes=shapes, + capacity=capacity, + min_after_dequeue=min_after_dequeue, + seed=seed1, + seed2=seed2, + shared_name=shared_name, + name=name) super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -690,8 +709,13 @@ class FIFOQueue(QueueBase): @end_compatibility """ - def __init__(self, capacity, dtypes, shapes=None, names=None, - shared_name=None, name="fifo_queue"): + def __init__(self, + capacity, + dtypes, + shapes=None, + names=None, + shared_name=None, + name="fifo_queue"): """Creates a queue that dequeues elements in a first-in first-out order. A `FIFOQueue` has bounded capacity; supports multiple concurrent @@ -725,8 +749,11 @@ class FIFOQueue(QueueBase): shapes = _as_shape_list(shapes, dtypes) names = _as_name_list(names, dtypes) queue_ref = gen_data_flow_ops._fifo_queue_v2( - component_types=dtypes, shapes=shapes, capacity=capacity, - shared_name=shared_name, name=name) + component_types=dtypes, + shapes=shapes, + capacity=capacity, + shared_name=shared_name, + name=name) super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -747,7 +774,12 @@ class PaddingFIFOQueue(QueueBase): @end_compatibility """ - def __init__(self, capacity, dtypes, shapes, names=None, shared_name=None, + def __init__(self, + capacity, + dtypes, + shapes, + names=None, + shared_name=None, name="padding_fifo_queue"): """Creates a queue that dequeues elements in a first-in first-out order. @@ -792,12 +824,15 @@ class PaddingFIFOQueue(QueueBase): names = _as_name_list(names, dtypes) if len(dtypes) != len(shapes): raise ValueError("Shapes must be provided for all components, " - "but received %d dtypes and %d shapes." - % (len(dtypes), len(shapes))) + "but received %d dtypes and %d shapes." % (len(dtypes), + len(shapes))) queue_ref = gen_data_flow_ops._padding_fifo_queue_v2( - component_types=dtypes, shapes=shapes, capacity=capacity, - shared_name=shared_name, name=name) + component_types=dtypes, + shapes=shapes, + capacity=capacity, + shared_name=shared_name, + name=name) super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -815,7 +850,12 @@ class PriorityQueue(QueueBase): @end_compatibility """ - def __init__(self, capacity, types, shapes=None, names=None, shared_name=None, + def __init__(self, + capacity, + types, + shapes=None, + names=None, + shared_name=None, name="priority_queue"): """Creates a queue that dequeues elements in a first-in first-out order. @@ -856,14 +896,17 @@ class PriorityQueue(QueueBase): shapes = _as_shape_list(shapes, types) queue_ref = gen_data_flow_ops._priority_queue_v2( - component_types=types, shapes=shapes, capacity=capacity, - shared_name=shared_name, name=name) + component_types=types, + shapes=shapes, + capacity=capacity, + shared_name=shared_name, + name=name) priority_dtypes = [_dtypes.int64] + types priority_shapes = [()] + shapes if shapes else shapes - super(PriorityQueue, self).__init__( - priority_dtypes, priority_shapes, names, queue_ref) + super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names, + queue_ref) # TODO(josh11b): class BatchQueue(QueueBase): @@ -943,8 +986,10 @@ class Barrier(object): self._shapes = [tensor_shape.unknown_shape() for _ in self._types] self._barrier_ref = gen_data_flow_ops._barrier( - component_types=self._types, shapes=self._shapes, - shared_name=shared_name, name=name) + component_types=self._types, + shapes=self._shapes, + shared_name=shared_name, + name=name) if context.in_graph_mode(): self._name = self._barrier_ref.op.name.split("/")[-1] else: @@ -1028,12 +1073,13 @@ class Barrier(object): """ if name is None: name = "%s_BarrierTakeMany" % self._name - ret = gen_data_flow_ops._barrier_take_many(self._barrier_ref, - num_elements, - self._types, - allow_small_batch, - timeout, - name=name) + ret = gen_data_flow_ops._barrier_take_many( + self._barrier_ref, + num_elements, + self._types, + allow_small_batch, + timeout, + name=name) # NOTE(mrry): Not using a shape function because we need access to # the Barrier object. @@ -1048,8 +1094,7 @@ class Barrier(object): op.outputs[1].set_shape(tensor_shape.vector(batch_dim)) # keys for output, shape in zip(op.outputs[2:], self._shapes): # value_list output.set_shape( - tensor_shape.TensorShape([batch_dim]).concatenate( - shape)) + tensor_shape.TensorShape([batch_dim]).concatenate(shape)) return ret @@ -1298,8 +1343,8 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase): name="sparse_conditional_accumulator"): accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator( dtype=dtype, shape=shape, shared_name=shared_name, name=name) - super(SparseConditionalAccumulator, - self).__init__(dtype, shape, accumulator_ref) + super(SparseConditionalAccumulator, self).__init__(dtype, shape, + accumulator_ref) def apply_indexed_slices_grad(self, grad, local_step=0, name=None): """Attempts to apply a gradient to the accumulator. @@ -1368,8 +1413,8 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase): local_step=local_step, gradient_indices=math_ops.to_int64(grad_indices), gradient_values=grad_values, - gradient_shape=math_ops.to_int64([] if grad_shape is None else - grad_shape), + gradient_shape=math_ops.to_int64([] + if grad_shape is None else grad_shape), has_known_shape=(grad_shape is not None), name=name) @@ -1431,11 +1476,16 @@ class BaseStagingArea(object): _identifier = 0 _lock = threading.Lock() - def __init__(self, dtypes, shapes=None, names=None, shared_name=None, - capacity=0, memory_limit=0): + def __init__(self, + dtypes, + shapes=None, + names=None, + shared_name=None, + capacity=0, + memory_limit=0): if shared_name is None: - self._name = (ops.get_default_graph() - .unique_name(self.__class__.__name__)) + self._name = ( + ops.get_default_graph().unique_name(self.__class__.__name__)) elif isinstance(shared_name, six.string_types): self._name = shared_name else: @@ -1532,8 +1582,9 @@ class BaseStagingArea(object): (sorted(vals.keys()), sorted(self._names))) # The order of values in `self._names` indicates the order in which the # tensors in the dictionary `vals` must be listed. - vals, indices, n = zip(*[(vals[k], i, k) for i, k in enumerate(self._names) - if k in vals]) + vals, indices, n = zip(*[(vals[k], i, k) + for i, k in enumerate(self._names) + if k in vals]) else: if self._names: raise ValueError("You must enqueue a dictionary in a staging area " @@ -1541,7 +1592,7 @@ class BaseStagingArea(object): if indices is None: raise ValueError("Indices must be supplied when inserting a list " - "of tensors") + "of tensors") if len(indices) != len(vals): raise ValueError("Number of indices '%s' doesn't match " @@ -1553,8 +1604,8 @@ class BaseStagingArea(object): # Sanity check number of values if not len(vals) <= len(self._dtypes): - raise ValueError("Unexpected number of inputs '%s' vs '%s'" % ( - len(vals), len(self._dtypes))) + raise ValueError("Unexpected number of inputs '%s' vs '%s'" % + (len(vals), len(self._dtypes))) tensors = [] @@ -1562,14 +1613,14 @@ class BaseStagingArea(object): dtype, shape = self._dtypes[i], self._shapes[i] # Check dtype if not val.dtype == dtype: - raise ValueError("Datatypes do not match. '%s' != '%s'" %( - str(val.dtype), str(dtype))) + raise ValueError("Datatypes do not match. '%s' != '%s'" % + (str(val.dtype), str(dtype))) # Check shape val.get_shape().assert_is_compatible_with(shape) - tensors.append(ops.convert_to_tensor(val, dtype=dtype, - name="component_%d" % i)) + tensors.append( + ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) return tensors, indices @@ -1632,6 +1683,7 @@ class BaseStagingArea(object): else: return [vals] + class StagingArea(BaseStagingArea): """Class for staging inputs. No ordering guarantees. @@ -1666,8 +1718,13 @@ class StagingArea(BaseStagingArea): """ - def __init__(self, dtypes, shapes=None, names=None, shared_name=None, - capacity=0, memory_limit=0): + def __init__(self, + dtypes, + shapes=None, + names=None, + shared_name=None, + capacity=0, + memory_limit=0): """Constructs a staging area object. The two optional lists, `shapes` and `names`, must be of the same length @@ -1702,9 +1759,8 @@ class StagingArea(BaseStagingArea): ValueError: If one of the arguments is invalid. """ - super(StagingArea, self).__init__(dtypes, shapes, - names, shared_name, - capacity, memory_limit) + super(StagingArea, self).__init__(dtypes, shapes, names, shared_name, + capacity, memory_limit) def put(self, values, name=None): """Create an op that places a value into the staging area. @@ -1726,14 +1782,18 @@ class StagingArea(BaseStagingArea): self._scope_vals(values)) as scope: # Hard-code indices for this staging area - indices = (list(six.moves.range(len(values))) - if isinstance(values, (list, tuple)) else None) + indices = ( + list(six.moves.range(len(values))) + if isinstance(values, (list, tuple)) else None) vals, _ = self._check_put_dtypes(values, indices) with ops.colocate_with(self._coloc_op): - op = gen_data_flow_ops.stage(values=vals, shared_name=self._name, - name=scope, capacity=self._capacity, - memory_limit=self._memory_limit) + op = gen_data_flow_ops.stage( + values=vals, + shared_name=self._name, + name=scope, + capacity=self._capacity, + memory_limit=self._memory_limit) return op @@ -1741,7 +1801,7 @@ class StagingArea(BaseStagingArea): with ops.colocate_with(self._coloc_op): ret = get_fn() - indices = list(six.moves.range(len(self._dtypes))) # Hard coded + indices = list(six.moves.range(len(self._dtypes))) # Hard coded return self._get_return_value(ret, indices) def get(self, name=None): @@ -1769,10 +1829,12 @@ class StagingArea(BaseStagingArea): if name is None: name = "%s_get" % self._name + # pylint: disable=bad-continuation fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes, shared_name=self._name, name=name, capacity=self._capacity, memory_limit=self._memory_limit) + # pylint: enable=bad-continuation return self.__internal_get(fn, name) @@ -1797,10 +1859,12 @@ class StagingArea(BaseStagingArea): if name is None: name = "%s_peek" % self._name + # pylint: disable=bad-continuation fn = lambda: gen_data_flow_ops.stage_peek(index, dtypes=self._dtypes, shared_name=self._name, name=name, capacity=self._capacity, memory_limit=self._memory_limit) + # pylint: enable=bad-continuation return self.__internal_get(fn, name) @@ -1816,9 +1880,12 @@ class StagingArea(BaseStagingArea): if name is None: name = "%s_size" % self._name - return gen_data_flow_ops.stage_size(name=name, shared_name=self._name, - dtypes=self._dtypes, capacity=self._capacity, - memory_limit=self._memory_limit) + return gen_data_flow_ops.stage_size( + name=name, + shared_name=self._name, + dtypes=self._dtypes, + capacity=self._capacity, + memory_limit=self._memory_limit) def clear(self, name=None): """Clears the staging area. @@ -1832,14 +1899,16 @@ class StagingArea(BaseStagingArea): if name is None: name = "%s_clear" % self._name - return gen_data_flow_ops.stage_clear(name=name, shared_name=self._name, - dtypes=self._dtypes, capacity=self._capacity, - memory_limit=self._memory_limit) + return gen_data_flow_ops.stage_clear( + name=name, + shared_name=self._name, + dtypes=self._dtypes, + capacity=self._capacity, + memory_limit=self._memory_limit) + class MapStagingArea(BaseStagingArea): - """ - A `MapStagingArea` is a TensorFlow data structure that stores tensors across - multiple steps, and exposes operations that can put and get tensors. + """A `MapStagingArea` is a TensorFlow data structure that stores tensors across multiple steps, and exposes operations that can put and get tensors. Each `MapStagingArea` element is a (key, value) pair. Only int64 keys are supported, other types should be @@ -1852,7 +1921,8 @@ class MapStagingArea(BaseStagingArea): It supports multiple concurrent producers and consumers; and provides exactly-once delivery. - Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors whose + Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors + whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. @@ -1896,10 +1966,16 @@ class MapStagingArea(BaseStagingArea): associated with it are removed. """ - def __init__(self, dtypes, shapes=None, names=None, shared_name=None, - ordered=False, capacity=0, memory_limit=0): - """ - Args: + def __init__(self, + dtypes, + shapes=None, + names=None, + shared_name=None, + ordered=False, + capacity=0, + memory_limit=0): + """Args: + dtypes: A list of types. The length of dtypes must equal the number of tensors in each element. capacity: (Optional.) Maximum number of elements. @@ -1925,9 +2001,8 @@ class MapStagingArea(BaseStagingArea): """ - super(MapStagingArea, self).__init__(dtypes, shapes, - names, shared_name, - capacity, memory_limit) + super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name, + capacity, memory_limit) # Defer to different methods depending if the map is ordered self._ordered = ordered @@ -1950,8 +2025,7 @@ class MapStagingArea(BaseStagingArea): self._clear_fn = gen_data_flow_ops.map_clear def put(self, key, vals, indices=None, name=None): - """ - Create an op that stores the (key, vals) pair in the staging area. + """Create an op that stores the (key, vals) pair in the staging area. Incomplete puts are possible, preferably using a dictionary for vals as the appropriate dtypes and shapes can be inferred from the value names @@ -1973,7 +2047,8 @@ class MapStagingArea(BaseStagingArea): The created op Raises: - ValueError: If the number or type of inputs don't match the staging area. + ValueError: If the number or type of inputs don't match the staging + area. """ with ops.name_scope(name, "%s_put" % self._name, @@ -1982,10 +2057,15 @@ class MapStagingArea(BaseStagingArea): vals, indices = self._check_put_dtypes(vals, indices) with ops.colocate_with(self._coloc_op): - op = self._put_fn(key, indices, vals, dtypes=self._dtypes, - shared_name=self._name, name=scope, - capacity=self._capacity, - memory_limit=self._memory_limit) + op = self._put_fn( + key, + indices, + vals, + dtypes=self._dtypes, + shared_name=self._name, + name=scope, + capacity=self._capacity, + memory_limit=self._memory_limit) return op def _get_indices_and_dtypes(self, indices=None): @@ -2001,13 +2081,13 @@ class MapStagingArea(BaseStagingArea): if all(isinstance(i, str) for i in indices): if self._names is None: raise ValueError("String indices provided '%s', but this Staging Area " - "was not created with names." % indices) + "was not created with names." % indices) try: indices = [self._names.index(n) for n in indices] except ValueError: raise ValueError("Named index '%s' not in " - "Staging Area names '%s'" % (n, self._names)) + "Staging Area names '%s'" % (n, self._names)) elif all(isinstance(i, int) for i in indices): pass else: @@ -2018,10 +2098,8 @@ class MapStagingArea(BaseStagingArea): return indices, dtypes - def peek(self, key, indices=None, name=None): - """ - Peeks at staging area data associated with the key. + """Peeks at staging area data associated with the key. If the key is not in the staging area, it will block until the associated (key, value) is inserted. @@ -2044,22 +2122,22 @@ class MapStagingArea(BaseStagingArea): indices, dtypes = self._get_indices_and_dtypes(indices) with ops.colocate_with(self._coloc_op): - result = self._peek_fn(key, shared_name=self._name, - indices=indices, - dtypes=dtypes, - name=name, - capacity=self._capacity, - memory_limit=self._memory_limit) + result = self._peek_fn( + key, + shared_name=self._name, + indices=indices, + dtypes=dtypes, + name=name, + capacity=self._capacity, + memory_limit=self._memory_limit) return self._get_return_value(result, indices) def get(self, key=None, indices=None, name=None): - """ - If the key is provided, the associated (key, value) - is returned from the staging area. If the key is not - in the staging area, this method will block until - the associated (key, value) is inserted. + """If the key is provided, the associated (key, value) is returned from the staging area. + If the key is not in the staging area, this method will block until + the associated (key, value) is inserted. If no key is provided and the staging area is ordered, the (key, value) with the smallest key will be returned. Otherwise, a random (key, value) will be returned. @@ -2084,12 +2162,10 @@ class MapStagingArea(BaseStagingArea): return self._pop(key, indices=indices, name=name) def _pop(self, key, indices=None, name=None): - """ - Remove and return the associated (key, value) - is returned from the staging area. If the key is not - in the staging area, this method will block until - the associated (key, value) is inserted. + """Remove and return the associated (key, value) is returned from the staging area. + If the key is not in the staging area, this method will block until + the associated (key, value) is inserted. Args: key: Key associated with the required data indices: Partial list of tensors to retrieve (optional). @@ -2107,21 +2183,21 @@ class MapStagingArea(BaseStagingArea): indices, dtypes = self._get_indices_and_dtypes(indices) with ops.colocate_with(self._coloc_op): - result = self._pop_fn(key, shared_name=self._name, - indices=indices, - dtypes=dtypes, - name=name, - capacity=self._capacity, - memory_limit=self._memory_limit) + result = self._pop_fn( + key, + shared_name=self._name, + indices=indices, + dtypes=dtypes, + name=name, + capacity=self._capacity, + memory_limit=self._memory_limit) return key, self._get_return_value(result, indices) def _popitem(self, indices=None, name=None): - """ - If the staging area is ordered, - the (key, value) with the smallest key will be returned. - Otherwise, a random (key, value) will be returned. + """If the staging area is ordered, the (key, value) with the smallest key will be returned. + Otherwise, a random (key, value) will be returned. If the staging area is empty when this operation executes, it will block until there is an element to dequeue. @@ -2142,12 +2218,13 @@ class MapStagingArea(BaseStagingArea): indices, dtypes = self._get_indices_and_dtypes(indices) with ops.colocate_with(self._coloc_op): - key, result = self._popitem_fn(shared_name=self._name, - indices=indices, - dtypes=dtypes, - name=name, - capacity=self._capacity, - memory_limit=self._memory_limit) + key, result = self._popitem_fn( + shared_name=self._name, + indices=indices, + dtypes=dtypes, + name=name, + capacity=self._capacity, + memory_limit=self._memory_limit) # Separate keys and results out from # underlying namedtuple @@ -2157,8 +2234,7 @@ class MapStagingArea(BaseStagingArea): return key, result def size(self, name=None): - """ - Returns the number of elements in the staging area. + """Returns the number of elements in the staging area. Args: name: A name for the operation (optional) @@ -2169,14 +2245,15 @@ class MapStagingArea(BaseStagingArea): if name is None: name = "%s_size" % self._name - return self._size_fn(shared_name=self._name, - name=name, dtypes=self._dtypes, - capacity=self._capacity, - memory_limit=self._memory_limit) + return self._size_fn( + shared_name=self._name, + name=name, + dtypes=self._dtypes, + capacity=self._capacity, + memory_limit=self._memory_limit) def incomplete_size(self, name=None): - """ - Returns the number of incomplete elements in the staging area. + """Returns the number of incomplete elements in the staging area. Args: name: A name for the operation (optional) @@ -2187,16 +2264,15 @@ class MapStagingArea(BaseStagingArea): if name is None: name = "%s_incomplete_size" % self._name - return self._incomplete_size_fn(shared_name=self._name, - name=name, dtypes=self._dtypes, - capacity=self._capacity, - memory_limit=self._memory_limit) - - + return self._incomplete_size_fn( + shared_name=self._name, + name=name, + dtypes=self._dtypes, + capacity=self._capacity, + memory_limit=self._memory_limit) def clear(self, name=None): - """ - Clears the staging area. + """Clears the staging area. Args: name: A name for the operation (optional) @@ -2207,10 +2283,12 @@ class MapStagingArea(BaseStagingArea): if name is None: name = "%s_clear" % self._name - return self._clear_fn(shared_name=self._name, - name=name, dtypes=self._dtypes, - capacity=self._capacity, - memory_limit=self._memory_limit) + return self._clear_fn( + shared_name=self._name, + name=name, + dtypes=self._dtypes, + capacity=self._capacity, + memory_limit=self._memory_limit) class RecordInput(object): diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py index 1f300b7147be505a316c38ae57cadeae2bd7ea10..68aaf3815e7e2b21c9550562aa49195569c8ea43 100644 --- a/tensorflow/python/ops/distributions/bernoulli.py +++ b/tensorflow/python/ops/distributions/bernoulli.py @@ -22,7 +22,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops @@ -137,21 +136,12 @@ class Bernoulli(distribution.Distribution): return (array_ops.ones_like(event) * logits, array_ops.ones_like(logits) * event) - # First check static shape. - if (event.get_shape().is_fully_defined() and - logits.get_shape().is_fully_defined()): - if event.get_shape() != logits.get_shape(): - logits, event = _broadcast(logits, event) - else: - logits, event = control_flow_ops.cond( - distribution_util.same_dynamic_shape(logits, event), - lambda: (logits, event), - lambda: _broadcast(logits, event)) + if not (event.get_shape().is_fully_defined() and + logits.get_shape().is_fully_defined() and + event.get_shape() == logits.get_shape()): + logits, event = _broadcast(logits, event) return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits) - def _prob(self, event): - return math_ops.exp(self._log_prob(event)) - def _entropy(self): return (-self.logits * (math_ops.sigmoid(self.logits) - 1) + nn.softplus(-self.logits)) @@ -167,26 +157,6 @@ class Bernoulli(distribution.Distribution): return math_ops.cast(self.probs > 0.5, self.dtype) -class BernoulliWithSigmoidProbs(Bernoulli): - """Bernoulli with `probs = nn.sigmoid(logits)`.""" - - def __init__(self, - logits=None, - dtype=dtypes.int32, - validate_args=False, - allow_nan_stats=True, - name="BernoulliWithSigmoidProbs"): - parameters = locals() - with ops.name_scope(name): - super(BernoulliWithSigmoidProbs, self).__init__( - probs=nn.sigmoid(logits, name="sigmoid_probs"), - dtype=dtype, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, - name=name) - self._parameters = parameters - - @kullback_leibler.RegisterKL(Bernoulli, Bernoulli) def _kl_bernoulli_bernoulli(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a and b Bernoulli. diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py index 6d6b40b04557a4483f60d8c06c35f937d38a24b9..469bcadb8ea3a0ec2a85d3a72c0ca5ba08796856 100644 --- a/tensorflow/python/ops/distributions/beta.py +++ b/tensorflow/python/ops/distributions/beta.py @@ -304,12 +304,11 @@ class Beta(distribution.Distribution): if not self.validate_args: return x return control_flow_ops.with_dependencies([ - check_ops.assert_positive( - x, - message="sample must be positive"), + check_ops.assert_positive(x, message="sample must be positive"), check_ops.assert_less( - x, array_ops.ones([], self.dtype), - message="sample must be no larger than `1`."), + x, + array_ops.ones([], self.dtype), + message="sample must be less than `1`."), ], x) diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py index 44d64070ce48c0c115ea7edb1237124bc6698e90..ed435557fde7a2e8a0a4f7eef4e240daef0565e7 100644 --- a/tensorflow/python/ops/distributions/bijector_impl.py +++ b/tensorflow/python/ops/distributions/bijector_impl.py @@ -114,7 +114,7 @@ class _Mapping(collections.namedtuple( @six.add_metaclass(abc.ABCMeta) @tf_export("distributions.bijectors.Bijector") class Bijector(object): - """Interface for transformations of a `Distribution` sample. + r"""Interface for transformations of a `Distribution` sample. Bijectors can be used to represent any differentiable and injective (one to one) function defined on an open subset of `R^n`. Some non-injective @@ -122,27 +122,24 @@ class Bijector(object): #### Mathematical Details - A `Bijector` implements a - [diffeomorphism](https://en.wikipedia.org/wiki/Diffeomorphism), i.e., a - bijective, differentiable function. A `Bijector` is used by - `TransformedDistribution` but can be generally used for transforming a - `Distribution` generated `Tensor`. A `Bijector` is characterized by three - operations: - - 1. Forward Evaluation + A `Bijector` implements a [smooth covering map]( + https://en.wikipedia.org/wiki/Local_diffeomorphism), i.e., a local + diffeomorphism such that every point in the target has a neighborhood evenly + covered by a map ([see also]( + https://en.wikipedia.org/wiki/Covering_space#Covering_of_a_manifold)). + A `Bijector` is used by `TransformedDistribution` but can be generally used + for transforming a `Distribution` generated `Tensor`. A `Bijector` is + characterized by three operations: + 1. Forward\ Useful for turning one random outcome into another random outcome from a different distribution. - - 2. Inverse Evaluation - + 2. Inverse\ Useful for "reversing" a transformation to compute one probability in terms of another. - - 3. (log o det o Jacobian o inverse)(x) - + 3. `(log o det o Jacobian o inverse)(x)`\ "The log of the determinant of the matrix of all first-order partial - derivatives of the inverse function." + derivatives of the inverse function."\ Useful for inverting a transformation to compute one probability in terms of another. Geometrically, the det(Jacobian) is the volume of the transformation and is used to scale the probability. diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py index 26b5c5aef98fc11b07a8c8357e7ec37819587da9..4ae67a009b0a4052f6e23e2e42262bb7c42f1c14 100644 --- a/tensorflow/python/ops/distributions/multinomial.py +++ b/tensorflow/python/ops/distributions/multinomial.py @@ -238,7 +238,7 @@ class Multinomial(distribution.Distribution): n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) k = self.event_shape_tensor()[0] - # boardcast the total_count and logits to same shape + # broadcast the total_count and logits to same shape n_draws = array_ops.ones_like( self.logits[..., 0], dtype=n_draws.dtype) * n_draws logits = array_ops.ones_like( diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 5bc25128a8d6f77895fc4decc98a8978ae8400f3..0a3000ef5ca0decf8aba641e704406b0cf8780af 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -1041,14 +1041,14 @@ def reduce_weighted_logsumexp( with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]): logx = ops.convert_to_tensor(logx, name="logx") if w is None: - lswe = math_ops.reduce_logsumexp(logx, axis=axis, keep_dims=keep_dims) + lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims) if return_sign: sgn = array_ops.ones_like(lswe) return lswe, sgn return lswe w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w") log_absw_x = logx + math_ops.log(math_ops.abs(w)) - max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keep_dims=True) + max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True) # If the largest element is `-inf` or `inf` then we don't bother subtracting # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That # this is ok follows from the fact that we're actually free to subtract any @@ -1062,7 +1062,7 @@ def reduce_weighted_logsumexp( sum_wx_over_max_absw_x = math_ops.reduce_sum( wx_over_max_absw_x, axis=axis, - keep_dims=keep_dims) + keepdims=keep_dims) if not keep_dims: max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis) sgn = math_ops.sign(sum_wx_over_max_absw_x) @@ -1180,7 +1180,7 @@ def process_quadrature_grid_and_probs( grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype) - probs /= linalg_ops.norm(probs, ord=1, axis=-1, keep_dims=True, + probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs") def _static_event_size(x): diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 7dbccf1caf1486bb247a1bef0ac37c36adbcc53e..ac03d30fcd2e65f032937d9259bc8fff18626619 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -458,7 +458,7 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, For example, if `elems` is `(t1, [t2, t3])` and `initializer` is `[i1, i2]` then an appropriate signature for `fn` in `python2` is: - `fn = lambda (acc_p1, acc_p2), (t1 [t2, t3]):` and `fn` must return a list, + `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the one that works in `python3`, is: `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 5d4b9ecd8bee31c5092b04535e97b036eec9f1be..1418c0b10fb60601e7c3024891b89aadb53e6873 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -35,7 +35,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops +from tensorflow.python.ops import check_ops # pylint: disable=unused-import from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util @@ -44,6 +44,7 @@ from tensorflow.python.ops import image_grad # pylint: disable=unused-import from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import from tensorflow.python.ops import logging_ops # pylint: disable=unused-import +from tensorflow.python.ops import manip_grad # pylint: disable=unused-import from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -52,7 +53,6 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export - # Warn the user if we convert a sparse representation to dense with at # least this number of elements. _LARGE_SPARSE_NUM_ELEMENTS = 100000000 @@ -235,9 +235,10 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops): raise TypeError( "Gradients of complex tensors must set grad_ys (y.dtype = %r)" % y.dtype) - new_grad_ys.append(array_ops.fill( - array_ops.shape(y), constant_op.constant( - 1, dtype=y.dtype, name="grad_ys_%d" % i))) + new_grad_ys.append( + array_ops.fill( + array_ops.shape(y), + constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i))) continue if y.dtype.is_floating or y.dtype.is_integer: if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: @@ -492,11 +493,12 @@ def gradients(ys, name, "gradients", list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope: ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y") - xs = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable) - else x - for x in xs] - xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name="x", - as_ref=True) + xs = [ + x.handle if resource_variable_ops.is_resource_variable(x) else x + for x in xs + ] + xs = ops.internal_convert_n_to_tensor_or_indexed_slices( + xs, name="x", as_ref=True) grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops) # The approach we take here is as follows: Create a list of all ops in the @@ -513,9 +515,8 @@ def gradients(ys, to_ops = [t.op for t in ys] from_ops = [t.op for t in xs] stop_gradient_ops = [t.op for t in stop_gradients] - pending_count, loop_state = _PendingCount(ops.get_default_graph(), to_ops, - from_ops, - colocate_gradients_with_ops) + pending_count, loop_state = _PendingCount( + ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops) # Iterate over the collected ops. # @@ -588,9 +589,8 @@ def gradients(ys, # output, it means that the cost does not depend on output[i], # therefore dC/doutput[i] is 0. for i, out_grad in enumerate(out_grads): - if (not isinstance(out_grad, ops.Tensor) and - not out_grad) and ((not grad_fn and is_func_call) or - _IsTrainable(op.outputs[i])): + if (not isinstance(out_grad, ops.Tensor) and not out_grad) and ( + (not grad_fn and is_func_call) or _IsTrainable(op.outputs[i])): # Only trainable outputs or outputs for a function call that # will use SymbolicGradient get a zero gradient. Gradient # functions should ignore the gradient for other outputs. @@ -607,17 +607,17 @@ def gradients(ys, if grad_fn: # If grad_fn was found, do not use SymbolicGradient even for # functions. - in_grads = _MaybeCompile( - grad_scope, op, func_call, lambda: grad_fn(op, *out_grads)) + in_grads = _MaybeCompile(grad_scope, op, func_call, + lambda: grad_fn(op, *out_grads)) else: # For function call ops, we add a 'SymbolicGradient' # node to the graph to compute gradients. - in_grads = _MaybeCompile( - grad_scope, op, func_call, lambda: _SymGrad(op, out_grads)) + in_grads = _MaybeCompile(grad_scope, op, func_call, + lambda: _SymGrad(op, out_grads)) in_grads = _AsList(in_grads) _VerifyGeneratedGradients(in_grads, op) - if gate_gradients and len( - [x for x in in_grads if x is not None]) > 1: + if gate_gradients and len([x for x in in_grads + if x is not None]) > 1: with ops.device(None): with ops.colocate_with(None, ignore_existing=True): in_grads = control_flow_ops.tuple(in_grads) @@ -637,8 +637,8 @@ def gradients(ys, "Incompatible shapes between op input and calculated " "input gradient. Forward operation: %s. Input index: %d. " "Original input shape: %s. " - "Calculated input gradient shape: %s" - % (op.name, i, t_in.shape, in_grad.shape)) + "Calculated input gradient shape: %s" % + (op.name, i, t_in.shape, in_grad.shape)) _SetGrad(grads, t_in, in_grad) if loop_state: loop_state.ExitGradWhileContext(op, before=False) @@ -670,8 +670,8 @@ def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state): pending_count[x.op._id] -= 1 ready = (pending_count[x.op._id] == 0) if loop_state and not ready: - ready = (pending_count[x.op._id] > 0 and - control_flow_util.IsLoopSwitch(x.op)) + ready = ( + pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op)) # pylint: enable=protected-access if ready: if control_flow_util.IsLoopExit(x.op): @@ -725,8 +725,8 @@ def _GetGrad(grads, t): if not op_grads: return None t_grad = op_grads[t.value_index] - assert not isinstance(t_grad, list), ( - "gradients list should have been aggregated by now.") + assert not isinstance( + t_grad, list), ("gradients list should have been aggregated by now.") return t_grad @@ -745,9 +745,8 @@ def _HandleNestedIndexedSlices(grad): else: assert isinstance(grad.values, ops.IndexedSlices) g = _HandleNestedIndexedSlices(grad.values) - return ops.IndexedSlices(g.values, - array_ops.gather(grad.indices, g.indices), - g.dense_shape) + return ops.IndexedSlices(g.values, array_ops.gather( + grad.indices, g.indices), g.dense_shape) def _AccumulatorShape(inputs): @@ -849,8 +848,8 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None): AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE, AggregationMethod.EXPERIMENTAL_ACCUMULATE_N ]: - raise ValueError("Invalid aggregation_method specified %s." % - aggregation_method) + raise ValueError( + "Invalid aggregation_method specified %s." % aggregation_method) out_grads = _GetGrads(grads, op) for i, out_grad in enumerate(out_grads): if loop_state: @@ -859,7 +858,8 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None): continue # Grads have to be Tensors or IndexedSlices if (isinstance(out_grad, collections.Sequence) and not all([ - isinstance(g, (ops.Tensor, ops.IndexedSlices)) for g in out_grad + isinstance(g, (ops.Tensor, ops.IndexedSlices)) + for g in out_grad if g is not None ])): raise TypeError("gradients have to be either all Tensors " @@ -903,8 +903,8 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None): else: used = "add_n" out_grads[i] = _MultiDeviceAddN(out_grad) - logging.vlog(2, " _AggregatedGrads %d x %s using %s", - len(out_grad), tensor_shape, used) + logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), + tensor_shape, used) else: out_grad = math_ops._as_indexed_slices_list( [g for g in out_grad if g is not None]) @@ -967,7 +967,8 @@ def _hessian_vector_product(ys, xs, v): assert len(grads) == length elemwise_products = [ math_ops.multiply(grad_elem, array_ops.stop_gradient(v_elem)) - for grad_elem, v_elem in zip(grads, v) if grad_elem is not None + for grad_elem, v_elem in zip(grads, v) + if grad_elem is not None ] # Second backprop @@ -975,8 +976,12 @@ def _hessian_vector_product(ys, xs, v): @tf_export("hessians") -def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False, - gate_gradients=False, aggregation_method=None): +def hessians(ys, + xs, + name="hessians", + colocate_gradients_with_ops=False, + gate_gradients=False, + aggregation_method=None): """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`. `hessians()` adds ops to the graph to output the Hessian matrix of `ys` @@ -1004,9 +1009,9 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False, """ xs = _AsList(xs) kwargs = { - 'colocate_gradients_with_ops': colocate_gradients_with_ops, - 'gate_gradients': gate_gradients, - 'aggregation_method': aggregation_method + "colocate_gradients_with_ops": colocate_gradients_with_ops, + "gate_gradients": gate_gradients, + "aggregation_method": aggregation_method } # Compute first-order derivatives and iterate for each x in xs. hessians = [] @@ -1031,8 +1036,7 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False, ) _shape = array_ops.shape(x) - _reshaped_hessian = array_ops.reshape( - hessian.stack(), array_ops.concat((_shape, _shape), 0) - ) + _reshaped_hessian = array_ops.reshape(hessian.stack(), + array_ops.concat((_shape, _shape), 0)) hessians.append(_reshaped_hessian) return hessians diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py index f079e56b10ed484225d8f09c6eaf7cf85a02d12a..6a975160b0698270dfc9ce9140e8b3ff633cdb9e 100644 --- a/tensorflow/python/ops/histogram_ops.py +++ b/tensorflow/python/ops/histogram_ops.py @@ -32,8 +32,10 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import tf_export +@tf_export('histogram_fixed_width_bins') def histogram_fixed_width_bins(values, value_range, nbins=100, diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py index d17f1a87d9759d5e83393f40e9e027dee8c15979..093843cd5bc0b7c2281a0c9ddf52d93ea3faede3 100644 --- a/tensorflow/python/ops/image_grad.py +++ b/tensorflow/python/ops/image_grad.py @@ -61,15 +61,10 @@ def _ResizeBilinearGrad(op, grad): Returns: The gradients w.r.t. the input. """ - allowed_types = [dtypes.float32, dtypes.float64] - grad0 = None - if op.inputs[0].dtype in allowed_types: - # pylint: disable=protected-access - grad0 = gen_image_ops._resize_bilinear_grad( - grad, - op.inputs[0], - align_corners=op.get_attr("align_corners")) - # pylint: enable=protected-access + # pylint: disable=protected-access + grad0 = gen_image_ops._resize_bilinear_grad( + grad, op.inputs[0], align_corners=op.get_attr("align_corners")) + # pylint: enable=protected-access return [grad0, None] diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py index 05e8fa1d72851caee522bba470bb40f430152464..75d00c8ed17c26c2c1acb4d92961a2206d959ebb 100644 --- a/tensorflow/python/ops/image_grad_test.py +++ b/tensorflow/python/ops/image_grad_test.py @@ -142,18 +142,6 @@ class ResizeBilinearOpTest(test.TestCase): input_tensor, in_shape, resize_out, out_shape, x_init_value=x) self.assertLess(err, 1e-3) - def testGradOnUnsupportedType(self): - in_shape = [1, 4, 6, 1] - out_shape = [1, 2, 3, 1] - - x = np.arange(0, 24).reshape(in_shape).astype(np.uint8) - - with self.test_session(): - input_tensor = constant_op.constant(x, shape=in_shape) - resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) - grad = gradients_impl.gradients(input_tensor, [resize_out]) - self.assertEqual([None], grad) - def testCompareGpuVsCpu(self): in_shape = [2, 4, 6, 3] out_shape = [2, 8, 16, 3] @@ -172,6 +160,26 @@ class ResizeBilinearOpTest(test.TestCase): self.assertAllClose(grad[False], grad[True], rtol=1e-4, atol=1e-4) + def testTypes(self): + in_shape = [1, 4, 6, 1] + out_shape = [1, 2, 3, 1] + x = np.arange(0, 24).reshape(in_shape) + + with self.test_session() as sess: + for dtype in [np.float16, np.float32, np.float64]: + input_tensor = constant_op.constant(x.astype(dtype), shape=in_shape) + resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) + grad = sess.run(gradients_impl.gradients(resize_out, input_tensor))[0] + self.assertAllEqual(in_shape, grad.shape) + # Not using gradient_checker.compute_gradient as I didn't work out + # the changes required to compensate for the lower precision of + # float16 when computing the numeric jacobian. + # Instead, we just test the theoretical jacobian. + self.assertAllEqual([[[[1.], [0.], [1.], [0.], [1.], [0.]], [[0.], [ + 0. + ], [0.], [0.], [0.], [0.]], [[1.], [0.], [1.], [0.], [1.], [0.]], + [[0.], [0.], [0.], [0.], [0.], [0.]]]], grad) + class ResizeBicubicOpTest(test.TestCase): diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index de12c5f63f4357e0982dd2e16999caf2de0b30f8..ae52d32fea1c872e588c4122f5e73198e4dfe9ad 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -26,6 +26,7 @@ See the @{$python/image} guide. @@extract_jpeg_shape @@decode_png @@encode_png +@@is_jpeg @@decode_image @@resize_images @@resize_area diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 721efcf78656a8832763a473668c108454bde915..0c0e92d5b00b36f2fbd800afc046faa1fc77b95c 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -166,6 +166,26 @@ def _Assert3DImage(image): return control_flow_ops.with_dependencies( _Check3DImage(image, require_static=False), image) +def _AssertAtLeast3DImage(image): + """Assert that we are working with a properly shaped image. + + Performs the check statically if possible (i.e. if the shape + is statically known). Otherwise adds a control dependency + to an assert op that checks the dynamic shape. + + Args: + image: >= 3-D Tensor of size [*, height, width, depth] + + Raises: + ValueError: if image.shape is not a [>= 3] vector. + + Returns: + If the shape of `image` could be verified statically, `image` is + returned unchanged, otherwise there will be a control dependency + added that asserts the correct dynamic shape. + """ + return control_flow_ops.with_dependencies( + _CheckAtLeast3DImage(image, require_static=False), image) def _CheckAtLeast3DImage(image, require_static=True): """Assert that we are working with properly shaped image. @@ -292,108 +312,184 @@ def random_flip_left_right(image, seed=None): def flip_left_right(image): """Flip an image horizontally (left to right). - Outputs the contents of `image` flipped along the second dimension, which is - `width`. + Outputs the contents of `image` flipped along the width dimension. See also `reverse()`. Args: - image: A 3-D tensor of shape `[height, width, channels].` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. Returns: - A 3-D tensor of the same type and shape as `image`. + A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. """ with ops.name_scope(None, 'flip_left_right', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = _Assert3DImage(image) - return fix_image_flip_shape(image, array_ops.reverse( - image, [1], name=scope)) + image = _AssertAtLeast3DImage(image) + shape = image.get_shape() + if shape.ndims == 3 or shape.ndims is None: + return fix_image_flip_shape(image, array_ops.reverse(image, [1])) + elif shape.ndims == 4: + return array_ops.reverse(image, [2]) + else: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') @tf_export('image.flip_up_down') def flip_up_down(image): """Flip an image vertically (upside down). - Outputs the contents of `image` flipped along the first dimension, which is - `height`. + Outputs the contents of `image` flipped along the height dimension. See also `reverse()`. Args: - image: A 3-D tensor of shape `[height, width, channels].` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. Returns: - A 3-D tensor of the same type and shape as `image`. + A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. """ with ops.name_scope(None, 'flip_up_down', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = _Assert3DImage(image) - return fix_image_flip_shape(image, array_ops.reverse( - image, [0], name=scope)) + image = _AssertAtLeast3DImage(image) + shape = image.get_shape() + if shape.ndims == 3 or shape.ndims is None: + return fix_image_flip_shape(image, array_ops.reverse(image, [0])) + elif shape.ndims == 4: + return array_ops.reverse(image, [1]) + else: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') @tf_export('image.rot90') def rot90(image, k=1, name=None): - """Rotate an image counter-clockwise by 90 degrees. + """Rotate image(s) counter-clockwise by 90 degrees. Args: - image: A 3-D tensor of shape `[height, width, channels]`. + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. k: A scalar integer. The number of times the image is rotated by 90 degrees. name: A name for this operation (optional). Returns: - A rotated 3-D tensor of the same type and shape as `image`. + A rotated tensor of the same type and shape as `image`. + + Raises: + ValueError: if the shape of `image` not supported. """ with ops.name_scope(name, 'rot90', [image, k]) as scope: image = ops.convert_to_tensor(image, name='image') - image = _Assert3DImage(image) + image = _AssertAtLeast3DImage(image) k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k') k.get_shape().assert_has_rank(0) k = math_ops.mod(k, 4) - def _rot90(): - return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2]) + shape = image.get_shape() + if shape.ndims == 3 or shape.ndims is None: + return _rot90_3D(image, k, scope) + elif shape.ndims == 4: + return _rot90_4D(image, k, scope) + else: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') + + +def _rot90_3D(image, k, name_scope): + """Rotate image counter-clockwise by 90 degrees `k` times. + + Args: + image: 3-D Tensor of shape `[height, width, channels]`. + k: A scalar integer. The number of times the image is rotated by 90 degrees. + name_scope: A valid TensorFlow name scope. + + Returns: + A 3-D tensor of the same type and shape as `image`. - def _rot180(): - return array_ops.reverse_v2(image, [0, 1]) + """ + def _rot90(): + return array_ops.transpose(array_ops.reverse_v2(image, [1]), + [1, 0, 2]) + def _rot180(): + return array_ops.reverse_v2(image, [0, 1]) + def _rot270(): + return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), + [1]) + cases = [(math_ops.equal(k, 1), _rot90), + (math_ops.equal(k, 2), _rot180), + (math_ops.equal(k, 3), _rot270)] + + result = control_flow_ops.case(cases, default=lambda: image, exclusive=True, + name=name_scope) + result.set_shape([None, None, image.get_shape()[2]]) + return result - def _rot270(): - return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1]) +def _rot90_4D(images, k, name_scope): + """Rotate batch of images counter-clockwise by 90 degrees `k` times. - cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180), - (math_ops.equal(k, 3), _rot270)] + Args: + images: 4-D Tensor of shape `[height, width, channels]`. + k: A scalar integer. The number of times the images are rotated by 90 + degrees. + name_scope: A valid TensorFlow name scope. - ret = control_flow_ops.case( - cases, default=lambda: image, exclusive=True, name=scope) - ret.set_shape([None, None, image.get_shape()[2]]) - return ret + Returns: + A 4-D tensor of the same type and shape as `images`. + """ + def _rot90(): + return array_ops.transpose(array_ops.reverse_v2(images, [2]), + [0, 2, 1, 3]) + def _rot180(): + return array_ops.reverse_v2(images, [1, 2]) + def _rot270(): + return array_ops.reverse_v2(array_ops.transpose(images, [0, 2, 1, 3]), + [2]) + + cases = [(math_ops.equal(k, 1), _rot90), + (math_ops.equal(k, 2), _rot180), + (math_ops.equal(k, 3), _rot270)] + + result = control_flow_ops.case(cases, default=lambda: images, exclusive=True, + name=name_scope) + shape = result.get_shape() + result.set_shape([shape[0], None, None, shape[3]]) + return result @tf_export('image.transpose_image') def transpose_image(image): - """Transpose an image by swapping the first and second dimension. + """Transpose image(s) by swapping the height and width dimension. See also `transpose()`. Args: - image: 3-D tensor of shape `[height, width, channels]` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. Returns: - A 3-D tensor of shape `[width, height, channels]` + If `image` was 4-D, a 4-D float Tensor of shape + `[batch, width, height, channels]` + If `image` was 3-D, a 3-D float Tensor of shape + `[width, height, channels]` Raises: ValueError: if the shape of `image` not supported. """ with ops.name_scope(None, 'transpose_image', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = _Assert3DImage(image) - return array_ops.transpose(image, [1, 0, 2], name=scope) + image = _AssertAtLeast3DImage(image) + shape = image.get_shape() + if shape.ndims == 3 or shape.ndims is None: + return array_ops.transpose(image, [1, 0, 2], name='transpose_image') + elif shape.ndims == 4: + return array_ops.transpose(image, [0, 2, 1, 3], name='transpose_image') + else: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') @tf_export('image.central_crop') @@ -770,8 +866,9 @@ def resize_images(images, size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The new size for the images. method: ResizeMethod. Defaults to `ResizeMethod.BILINEAR`. - align_corners: bool. If true, exactly align all 4 corners of the input and - output. Defaults to `false`. + align_corners: bool. If True, the centers of the 4 corner pixels of the + input and output tensors are aligned, preserving the values at the + corner pixels. Defaults to `False`. Raises: ValueError: if the shape of `images` is incompatible with the @@ -1025,9 +1122,9 @@ def adjust_contrast(images, contrast_factor): def adjust_gamma(image, gamma=1, gain=1): """Performs Gamma Correction on the input image. - Also known as Power Law Transform. This function transforms the - input image pixelwise according to the equation Out = In**gamma - after scaling each pixel to the range 0 to 1. + Also known as Power Law Transform. This function transforms the + input image pixelwise according to the equation `Out = In**gamma` + after scaling each pixel to the range 0 to 1. Args: image : A Tensor. @@ -1338,6 +1435,26 @@ def adjust_saturation(image, saturation_factor, name=None): orig_dtype) +@tf_export('image.is_jpeg') +def is_jpeg(contents, name=None): + r"""Convenience function to check if the 'contents' encodes a JPEG image. + + Args: + contents: 0-D `string`. The encoded image bytes. + name: A name for the operation (optional) + + Returns: + A scalar boolean tensor indicating if 'contents' may be a JPEG image. + is_jpeg is susceptible to false positives. + """ + # Normal JPEGs start with \xff\xd8\xff\xe0 + # JPEG with EXIF stats with \xff\xd8\xff\xe1 + # Use \xff\xd8\xff to cover both. + with ops.name_scope(name, 'is_jpeg'): + substr = string_ops.substr(contents, 0, 3) + return math_ops.equal(substr, b'\xff\xd8\xff', name=name) + + @tf_export('image.decode_image') def decode_image(contents, channels=None, name=None): """Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`, @@ -1426,8 +1543,8 @@ def decode_image(contents, channels=None, name=None): # Decode normal JPEG images (start with \xff\xd8\xff\xe0) # as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1). - is_jpeg = math_ops.equal(substr, b'\xff\xd8\xff', name='is_jpeg') - return control_flow_ops.cond(is_jpeg, _jpeg, check_png, name='cond_jpeg') + return control_flow_ops.cond( + is_jpeg(contents), _jpeg, check_png, name='cond_jpeg') @tf_export('image.total_variation') @@ -1670,8 +1787,8 @@ def non_max_suppression(boxes, # pylint: enable=protected-access -_rgb_to_yiq_kernel = [[0.299, 0.59590059, 0.2115], - [0.587, -0.27455667, -0.52273617], +_rgb_to_yiq_kernel = [[0.299, 0.59590059, + 0.2115], [0.587, -0.27455667, -0.52273617], [0.114, -0.32134392, 0.31119955]] @@ -1690,13 +1807,13 @@ def rgb_to_yiq(images): images: tensor with the same shape as `images`. """ images = ops.convert_to_tensor(images, name='images') - kernel = ops.convert_to_tensor(_rgb_to_yiq_kernel, dtype=images.dtype, name='kernel') + kernel = ops.convert_to_tensor( + _rgb_to_yiq_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims - return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) -_yiq_to_rgb_kernel = [[1, 1, 1], - [0.95598634, -0.27201283, -1.10674021], +_yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021], [0.6208248, -0.64720424, 1.70423049]] @@ -1716,13 +1833,14 @@ def yiq_to_rgb(images): images: tensor with the same shape as `images`. """ images = ops.convert_to_tensor(images, name='images') - kernel = ops.convert_to_tensor(_yiq_to_rgb_kernel, dtype=images.dtype, name='kernel') + kernel = ops.convert_to_tensor( + _yiq_to_rgb_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims - return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) -_rgb_to_yuv_kernel = [[0.299, -0.14714119, 0.61497538], - [0.587, -0.28886916, -0.51496512], +_rgb_to_yuv_kernel = [[0.299, -0.14714119, + 0.61497538], [0.587, -0.28886916, -0.51496512], [0.114, 0.43601035, -0.10001026]] @@ -1741,13 +1859,13 @@ def rgb_to_yuv(images): images: tensor with the same shape as `images`. """ images = ops.convert_to_tensor(images, name='images') - kernel = ops.convert_to_tensor(_rgb_to_yuv_kernel, dtype=images.dtype, name='kernel') + kernel = ops.convert_to_tensor( + _rgb_to_yuv_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims - return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) -_yuv_to_rgb_kernel = [[1, 1, 1], - [0, -0.394642334, 2.03206185], +_yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185], [1.13988303, -0.58062185, 0]] @@ -1767,7 +1885,7 @@ def yuv_to_rgb(images): images: tensor with the same shape as `images`. """ images = ops.convert_to_tensor(images, name='images') - kernel = ops.convert_to_tensor(_yuv_to_rgb_kernel, dtype=images.dtype, name='kernel') + kernel = ops.convert_to_tensor( + _yuv_to_rgb_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims - return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) - + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 9834384634261e5d99cac6a4d09b0417b9b2f883..dc3b581b223edca5f119d726b9fb73e760807818 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -252,11 +252,11 @@ class AdjustGamma(test_util.TensorFlowTestCase): with self.test_session(): x_data = np.random.uniform(0, 255, (8, 8)) x_np = np.array(x_data, dtype=np.float32) - + x = constant_op.constant(x_np, shape=x_np.shape) - err_msg = 'Gamma should be a non-negative real number.' - + err_msg = "Gamma should be a non-negative real number." + try: image_ops.adjust_gamma(x, gamma=-1) except Exception as e: @@ -270,13 +270,13 @@ class AdjustGamma(test_util.TensorFlowTestCase): with self.test_session(): x_data = np.random.uniform(0, 255, (8, 8)) x_np = np.array(x_data, dtype=np.float32) - + x = constant_op.constant(x_np, shape=x_np.shape) y = constant_op.constant(-1.0, dtype=dtypes.float32) - + image = image_ops.adjust_gamma(x, gamma=y) - - err_msg = 'Gamma should be a non-negative real number.' + + err_msg = "Gamma should be a non-negative real number." try: image.eval() except Exception as e: @@ -284,7 +284,7 @@ class AdjustGamma(test_util.TensorFlowTestCase): raise else: raise AssertionError("Exception not raised: %s" % err_msg) - + def test_adjust_gamma_zero(self): """White image should be returned for gamma equal to zero""" with self.test_session(): @@ -311,13 +311,13 @@ class AdjustGamma(test_util.TensorFlowTestCase): y_tf = np.trunc(y.eval()) y_np = np.array( - [[0, 31, 45, 55, 63, 71, 78, 84], - [90, 95, 100, 105, 110, 115, 119, 123], - [127, 131, 135, 139, 142, 146, 149, 153], - [156, 159, 162, 165, 168, 171, 174, 177], - [180, 183, 186, 188, 191, 194, 196, 199], - [201, 204, 206, 209, 211, 214, 216, 218], - [221, 223, 225, 228, 230, 232, 234, 236], + [[0, 31, 45, 55, 63, 71, 78, 84], [ + 90, 95, 100, 105, 110, 115, 119, 123 + ], [127, 131, 135, 139, 142, 146, 149, 153], [ + 156, 159, 162, 165, 168, 171, 174, 177 + ], [180, 183, 186, 188, 191, 194, 196, 199], [ + 201, 204, 206, 209, 211, 214, 216, 218 + ], [221, 223, 225, 228, 230, 232, 234, 236], [238, 241, 243, 245, 247, 249, 251, 253]], dtype=np.float32) @@ -332,14 +332,12 @@ class AdjustGamma(test_util.TensorFlowTestCase): y_tf = np.trunc(y.eval()) y_np = np.array( - [[0, 0, 0, 0, 1, 1, 2, 3], - [4, 5, 6, 7, 9, 10, 12, 14], - [16, 18, 20, 22, 25, 27, 30, 33], - [36, 39, 42, 45, 49, 52, 56, 60], - [64, 68, 72, 76, 81, 85, 90, 95], - [100, 105, 110, 116, 121, 127, 132, 138], - [144, 150, 156, 163, 169, 176, 182, 189], - [196, 203, 211, 218, 225, 233, 241, 249]], + [[0, 0, 0, 0, 1, 1, 2, 3], [4, 5, 6, 7, 9, 10, 12, 14], [ + 16, 18, 20, 22, 25, 27, 30, 33 + ], [36, 39, 42, 45, 49, 52, 56, 60], [64, 68, 72, 76, 81, 85, 90, 95], + [100, 105, 110, 116, 121, 127, 132, 138], [ + 144, 150, 156, 163, 169, 176, 182, 189 + ], [196, 203, 211, 218, 225, 233, 241, 249]], dtype=np.float32) self.assertAllClose(y_tf, y_np, 1e-6) @@ -483,8 +481,7 @@ class FlipImageBenchmark(test.Benchmark): with session.Session("", graph=ops.Graph(), config=config) as sess: with ops.device(device): inputs = variables.Variable( - random_ops.random_uniform( - image_shape, dtype=dtypes.float32) * 255, + random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255, trainable=False, dtype=dtypes.float32) run_op = image_ops.flip_left_right(inputs) @@ -514,8 +511,7 @@ class FlipImageBenchmark(test.Benchmark): with session.Session("", graph=ops.Graph(), config=config) as sess: with ops.device(device): inputs = variables.Variable( - random_ops.random_uniform( - image_shape, dtype=dtypes.float32) * 255, + random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255, trainable=False, dtype=dtypes.float32) run_op = image_ops.random_flip_left_right(inputs) @@ -566,8 +562,7 @@ class AdjustHueBenchmark(test.Benchmark): with session.Session("", graph=ops.Graph(), config=config) as sess: with ops.device(device): inputs = variables.Variable( - random_ops.random_uniform( - image_shape, dtype=dtypes.float32) * 255, + random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255, trainable=False, dtype=dtypes.float32) delta = constant_op.constant(0.1, dtype=dtypes.float32) @@ -611,8 +606,7 @@ class AdjustSaturationBenchmark(test.Benchmark): with session.Session("", graph=ops.Graph(), config=config) as sess: with ops.device(device): inputs = variables.Variable( - random_ops.random_uniform( - image_shape, dtype=dtypes.float32) * 255, + random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255, trainable=False, dtype=dtypes.float32) delta = constant_op.constant(0.1, dtype=dtypes.float32) @@ -667,10 +661,11 @@ class ResizeBilinearBenchmark(test.Benchmark): results = self.run_op_benchmark( sess, benchmark_op, - name=("resize_bilinear_%s_%s_%s" % - (image_size[0], image_size[1], num_channels))) - print("%s : %.2f ms/img" % (results["name"], 1000 * results["wall_time"] - / (batch_size * num_ops))) + name=("resize_bilinear_%s_%s_%s" % (image_size[0], image_size[1], + num_channels))) + print("%s : %.2f ms/img" % + (results["name"], + 1000 * results["wall_time"] / (batch_size * num_ops))) def benchmarkSimilar3Channel(self): self._benchmarkResize((183, 229), 3) @@ -717,8 +712,9 @@ class ResizeBicubicBenchmark(test.Benchmark): min_iters=20, name=("resize_bicubic_%s_%s_%s" % (image_size[0], image_size[1], num_channels))) - print("%s : %.2f ms/img" % (results["name"], 1000 * results["wall_time"] - / (batch_size * num_ops))) + print("%s : %.2f ms/img" % + (results["name"], + 1000 * results["wall_time"] / (batch_size * num_ops))) def benchmarkSimilar3Channel(self): self._benchmarkResize((183, 229), 3) @@ -754,8 +750,8 @@ class ResizeAreaBenchmark(test.Benchmark): batch_size = 1 num_ops = 1000 img = variables.Variable( - random_ops.random_normal([batch_size, image_size[0], - image_size[1], num_channels]), + random_ops.random_normal( + [batch_size, image_size[0], image_size[1], num_channels]), name="img") deps = [] @@ -768,12 +764,13 @@ class ResizeAreaBenchmark(test.Benchmark): with session.Session() as sess: sess.run(variables.global_variables_initializer()) results = self.run_op_benchmark( - sess, benchmark_op, - name=("resize_area_%s_%s_%s" % - (image_size[0], image_size[1], num_channels))) - print("%s : %.2f ms/img" % ( - results["name"], - 1000*results["wall_time"] / (batch_size * num_ops))) + sess, + benchmark_op, + name=("resize_area_%s_%s_%s" % (image_size[0], image_size[1], + num_channels))) + print("%s : %.2f ms/img" % + (results["name"], + 1000 * results["wall_time"] / (batch_size * num_ops))) def benchmarkSimilar3Channel(self): self._benchmarkResize((183, 229), 3) @@ -847,8 +844,7 @@ class AdjustSaturationTest(test_util.TensorFlowTestCase): flt_image = image_ops.convert_image_dtype(image, dtypes.float32) saturation_adjusted_image = gen_image_ops.adjust_saturation( flt_image, saturation_factor) - return image_ops.convert_image_dtype(saturation_adjusted_image, - orig_dtype) + return image_ops.convert_image_dtype(saturation_adjusted_image, orig_dtype) def testHalfSaturationFused(self): x_shape = [2, 2, 3] @@ -938,7 +934,7 @@ class AdjustSaturationTest(test_util.TensorFlowTestCase): class FlipTransposeRotateTest(test_util.TensorFlowTestCase): - def testIdempotentLeftRight(self): + def testInvolutionLeftRight(self): x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1]) with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) @@ -946,6 +942,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_tf = y.eval() self.assertAllEqual(y_tf, x_np) + def testInvolutionLeftRightWithBatch(self): + x_np = np.array([[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.flip_left_right(image_ops.flip_left_right(x_tf)) + y_tf = y.eval() + self.assertAllEqual(y_tf, x_np) + def testLeftRight(self): x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1]) y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1]) @@ -953,22 +958,36 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) y = image_ops.flip_left_right(x_tf) - self.assertTrue(y.op.name.startswith('flip_left_right')) + self.assertTrue(y.op.name.startswith("flip_left_right")) + y_tf = y.eval() + self.assertAllEqual(y_tf, y_np) + + def testLeftRightWithBatch(self): + x_np = np.array([[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + y_np = np.array([[[3, 2, 1], [3, 2, 1]], [[3, 2, 1], [3, 2, 1]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.flip_left_right(x_tf) y_tf = y.eval() self.assertAllEqual(y_tf, y_np) + def testRandomFlipLeftRight(self): x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1]) y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1]) + seed = 42 with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) y = image_ops.random_flip_left_right(x_tf) - self.assertTrue(y.op.name.startswith('random_flip_left_right')) + self.assertTrue(y.op.name.startswith("random_flip_left_right")) count_flipped = 0 count_unflipped = 0 - for _ in range(50): + for _ in range(100): y_tf = y.eval() if y_tf[0][0] == 1: self.assertAllEqual(y_tf, x_np) @@ -976,10 +995,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): else: self.assertAllEqual(y_tf, y_np) count_flipped += 1 - self.assertGreaterEqual(count_flipped, 1) - self.assertGreaterEqual(count_unflipped, 1) - def testIdempotentUpDown(self): + # 100 trials + # Mean: 50 + # Std Dev: ~5 + # Six Sigma: 50 - (5 * 6) = 20 + self.assertGreaterEqual(count_flipped, 20) + self.assertGreaterEqual(count_unflipped, 20) + + def testInvolutionUpDown(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) with self.test_session(use_gpu=True): @@ -988,6 +1012,16 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_tf = y.eval() self.assertAllEqual(y_tf, x_np) + def testInvolutionUpDownWithBatch(self): + x_np = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.flip_up_down(image_ops.flip_up_down(x_tf)) + y_tf = y.eval() + self.assertAllEqual(y_tf, x_np) + def testUpDown(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1]) @@ -995,7 +1029,19 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) y = image_ops.flip_up_down(x_tf) - self.assertTrue(y.op.name.startswith('flip_up_down')) + self.assertTrue(y.op.name.startswith("flip_up_down")) + y_tf = y.eval() + self.assertAllEqual(y_tf, y_np) + + def testUpDownWithBatch(self): + x_np = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + y_np = np.array([[[4, 5, 6], [1, 2, 3]], [[10, 11, 12], [7, 8, 9]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.flip_up_down(x_tf) y_tf = y.eval() self.assertAllEqual(y_tf, y_np) @@ -1005,11 +1051,11 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) - y = image_ops.random_flip_up_down(x_tf) - self.assertTrue(y.op.name.startswith('random_flip_up_down')) + y = image_ops.random_flip_up_down(x_tf, seed=42) + self.assertTrue(y.op.name.startswith("random_flip_up_down")) count_flipped = 0 count_unflipped = 0 - for _ in range(50): + for _ in range(100): y_tf = y.eval() if y_tf[0][0] == 1: self.assertAllEqual(y_tf, x_np) @@ -1017,10 +1063,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): else: self.assertAllEqual(y_tf, y_np) count_flipped += 1 - self.assertGreaterEqual(count_flipped, 1) - self.assertGreaterEqual(count_unflipped, 1) - def testIdempotentTranspose(self): + # 100 trials + # Mean: 50 + # Std Dev: ~5 + # Six Sigma: 50 - (5 * 6) = 20 + self.assertGreaterEqual(count_flipped, 20) + self.assertGreaterEqual(count_unflipped, 20) + + def testInvolutionTranspose(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) with self.test_session(use_gpu=True): @@ -1029,6 +1080,16 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_tf = y.eval() self.assertAllEqual(y_tf, x_np) + def testInvolutionTransposeWithBatch(self): + x_np = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.transpose_image(image_ops.transpose_image(x_tf)) + y_tf = y.eval() + self.assertAllEqual(y_tf, x_np) + def testTranspose(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) y_np = np.array([[1, 4], [2, 5], [3, 6]], dtype=np.uint8).reshape([3, 2, 1]) @@ -1036,19 +1097,36 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) y = image_ops.transpose_image(x_tf) - self.assertTrue(y.op.name.startswith('transpose_image')) + self.assertTrue(y.op.name.startswith("transpose_image")) + y_tf = y.eval() + self.assertAllEqual(y_tf, y_np) + + def testTransposeWithBatch(self): + x_np = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + + y_np = np.array([[[1, 4], [2, 5], [3, 6]], [[7, 10], [8, 11], [9, 12]]], + dtype=np.uint8).reshape([2, 3, 2, 1]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.transpose_image(x_tf) y_tf = y.eval() self.assertAllEqual(y_tf, y_np) def testPartialShapes(self): p_unknown_rank = array_ops.placeholder(dtypes.uint8) - p_unknown_dims = array_ops.placeholder( + p_unknown_dims_3 = array_ops.placeholder( dtypes.uint8, shape=[None, None, None]) + p_unknown_dims_4 = array_ops.placeholder( + dtypes.uint8, shape=[None, None, None, None]) p_unknown_width = array_ops.placeholder(dtypes.uint8, shape=[64, None, 3]) - + p_unknown_batch = array_ops.placeholder(dtypes.uint8, + shape=[None, 64, 64, 3]) p_wrong_rank = array_ops.placeholder(dtypes.uint8, shape=[None, None]) p_zero_dim = array_ops.placeholder(dtypes.uint8, shape=[64, 0, 3]) + #Ops that support 3D input for op in [ image_ops.flip_left_right, image_ops.flip_up_down, image_ops.random_flip_left_right, image_ops.random_flip_up_down, @@ -1056,16 +1134,34 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): ]: transformed_unknown_rank = op(p_unknown_rank) self.assertEqual(3, transformed_unknown_rank.get_shape().ndims) - transformed_unknown_dims = op(p_unknown_dims) - self.assertEqual(3, transformed_unknown_dims.get_shape().ndims) + transformed_unknown_dims_3 = op(p_unknown_dims_3) + self.assertEqual(3, transformed_unknown_dims_3.get_shape().ndims) transformed_unknown_width = op(p_unknown_width) self.assertEqual(3, transformed_unknown_width.get_shape().ndims) - with self.assertRaisesRegexp(ValueError, "must be three-dimensional"): - op(p_wrong_rank) with self.assertRaisesRegexp(ValueError, "must be > 0"): op(p_zero_dim) + #Ops that support 4D input + for op in [ + image_ops.flip_left_right, image_ops.flip_up_down, + image_ops.transpose_image, image_ops.rot90 + ]: + transformed_unknown_dims_4 = op(p_unknown_dims_4) + self.assertEqual(4, transformed_unknown_dims_4.get_shape().ndims) + transformed_unknown_batch = op(p_unknown_batch) + self.assertEqual(4, transformed_unknown_batch.get_shape().ndims) + with self.assertRaisesRegexp(ValueError, + "must be at least three-dimensional"): + op(p_wrong_rank) + + for op in [ + image_ops.random_flip_left_right, image_ops.random_flip_up_down, + ]: + with self.assertRaisesRegexp(ValueError, "must be three-dimensional"): + op(p_wrong_rank) + + def testRot90GroupOrder(self): image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3]) with self.test_session(use_gpu=True): @@ -1074,6 +1170,14 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): rotated = image_ops.rot90(rotated) self.assertAllEqual(image, rotated.eval()) + def testRot90GroupOrderWithBatch(self): + image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3]) + with self.test_session(use_gpu=True): + rotated = image + for _ in xrange(4): + rotated = image_ops.rot90(rotated) + self.assertAllEqual(image, rotated.eval()) + def testRot90NumpyEquivalence(self): image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3]) with self.test_session(use_gpu=True): @@ -1083,6 +1187,14 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_np = np.rot90(image, k=k) self.assertAllEqual(y_np, y_tf.eval({k_placeholder: k})) + def testRot90NumpyEquivalenceWithBatch(self): + image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3]) + with self.test_session(use_gpu=True): + k_placeholder = array_ops.placeholder(dtypes.int32, shape=[]) + y_tf = image_ops.rot90(image, k_placeholder) + for k in xrange(4): + y_np = np.rot90(image, k=k, axes=(1, 2)) + self.assertAllEqual(y_np, y_tf.eval({k_placeholder: k})) class RandomFlipTest(test_util.TensorFlowTestCase): @@ -1261,7 +1373,7 @@ class PerImageWhiteningTest(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True): x = constant_op.constant(x_np, shape=x_shape) y = image_ops.per_image_standardization(x) - self.assertTrue(y.op.name.startswith('per_image_standardization')) + self.assertTrue(y.op.name.startswith("per_image_standardization")) y_tf = y.eval() self.assertAllClose(y_tf, y_np, atol=1e-4) @@ -1433,9 +1545,10 @@ class CropToBoundingBoxTest(test_util.TensorFlowTestCase): # Each line is a test configuration: # (offset_height, offset_width, target_height, target_width), err_msg - test_config = (([-1, 0, 3, 3], "offset_height must be >= 0"), - ([0, -1, 3, 3], "offset_width must be >= 0"), - ([0, 0, 0, 3], "target_height must be > 0"), + test_config = (([-1, 0, 3, 3], "offset_height must be >= 0"), ([ + 0, -1, 3, 3 + ], "offset_width must be >= 0"), ([0, 0, 0, 3], + "target_height must be > 0"), ([0, 0, 3, 0], "target_width must be > 0"), ([2, 0, 3, 3], "height must be >= target + offset"), ([0, 2, 3, 3], "width must be >= target + offset")) @@ -1446,7 +1559,7 @@ class CropToBoundingBoxTest(test_util.TensorFlowTestCase): def testNameScope(self): image = array_ops.placeholder(dtypes.float32, shape=[55, 66, 3]) y = image_ops.crop_to_bounding_box(image, 0, 0, 55, 66) - self.assertTrue(y.name.startswith('crop_to_bounding_box')) + self.assertTrue(y.name.startswith("crop_to_bounding_box")) class CentralCropTest(test_util.TensorFlowTestCase): @@ -1471,9 +1584,10 @@ class CentralCropTest(test_util.TensorFlowTestCase): def testCropping(self): x_shape = [4, 8, 1] - x_np = np.array([[1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8], - [1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8]], - dtype=np.int32).reshape(x_shape) + x_np = np.array( + [[1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8]], + dtype=np.int32).reshape(x_shape) y_np = np.array([[3, 4, 5, 6], [3, 4, 5, 6]]).reshape([2, 4, 1]) with self.test_session(use_gpu=True): x = constant_op.constant(x_np, shape=x_shape) @@ -1490,7 +1604,7 @@ class CentralCropTest(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True): x = array_ops.placeholder(shape=x_shape, dtype=dtypes.int32) y = image_ops.central_crop(x, 0.33) - y_tf = y.eval(feed_dict={x:x_np}) + y_tf = y.eval(feed_dict={x: x_np}) self.assertAllEqual(y_tf, y_np) self.assertAllEqual(y_tf.shape, y_np.shape) @@ -1529,7 +1643,7 @@ class CentralCropTest(test_util.TensorFlowTestCase): x_np = np.ones(x_shape, dtype=np.float32) with self.test_session(use_gpu=True): y = image_ops.central_crop(x_np, 1.0) - self.assertTrue(y.op.name.startswith('central_crop')) + self.assertTrue(y.op.name.startswith("central_crop")) class PadToBoundingBoxTest(test_util.TensorFlowTestCase): @@ -1602,15 +1716,10 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase): self.assertEqual(y.get_shape().as_list(), post_shape) def testInt64(self): - x = [1, 2, 3, - 4, 5, 6, - 7, 8, 9] + x = [1, 2, 3, 4, 5, 6, 7, 8, 9] x_shape = [3, 3, 1] - y = [0, 0, 0, - 1, 2, 3, - 4, 5, 6, - 7, 8, 9] + y = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] y_shape = [4, 3, 1] x = np.array(x).reshape(x_shape) y = np.array(y).reshape(y_shape) @@ -1627,38 +1736,26 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase): self._assertReturns(x, x_shape, offset_height, offset_width, x, x_shape) def testPadding(self): - x = [1, 2, 3, - 4, 5, 6, - 7, 8, 9] + x = [1, 2, 3, 4, 5, 6, 7, 8, 9] x_shape = [3, 3, 1] offset_height, offset_width = [1, 0] - y = [0, 0, 0, - 1, 2, 3, - 4, 5, 6, - 7, 8, 9] + y = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] y_shape = [4, 3, 1] self._assertReturns(x, x_shape, offset_height, offset_width, y, y_shape) offset_height, offset_width = [0, 1] - y = [0, 1, 2, 3, - 0, 4, 5, 6, - 0, 7, 8, 9] + y = [0, 1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 9] y_shape = [3, 4, 1] self._assertReturns(x, x_shape, offset_height, offset_width, y, y_shape) offset_height, offset_width = [0, 0] - y = [1, 2, 3, - 4, 5, 6, - 7, 8, 9, - 0, 0, 0] + y = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0] y_shape = [4, 3, 1] self._assertReturns(x, x_shape, offset_height, offset_width, y, y_shape) offset_height, offset_width = [0, 0] - y = [1, 2, 3, 0, - 4, 5, 6, 0, - 7, 8, 9, 0] + y = [1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 9, 0] y_shape = [3, 4, 1] self._assertReturns(x, x_shape, offset_height, offset_width, y, y_shape) @@ -1690,9 +1787,7 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase): # Input image has 0-length dimension(s). # Each line is a test configuration: # x_shape, target_height, target_width - test_config = (([0, 2, 2], 2, 2), - ([2, 0, 2], 2, 2), - ([2, 2, 0], 2, 2)) + test_config = (([0, 2, 2], 2, 2), ([2, 0, 2], 2, 2), ([2, 2, 0], 2, 2)) offset_height, offset_width = [0, 0] x = [] @@ -1737,7 +1832,7 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase): def testNameScope(self): image = array_ops.placeholder(dtypes.float32, shape=[55, 66, 3]) y = image_ops.pad_to_bounding_box(image, 0, 0, 55, 66) - self.assertTrue(y.op.name.startswith('pad_to_bounding_box')) + self.assertTrue(y.op.name.startswith("pad_to_bounding_box")) class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): @@ -1750,8 +1845,8 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): (bounding_box[2] - bounding_box[0])) image_size_np = np.array(image.shape, dtype=np.int32) - bounding_box_np = (np.array( - bounding_box, dtype=np.float32).reshape([1, 1, 4])) + bounding_box_np = ( + np.array(bounding_box, dtype=np.float32).reshape([1, 1, 4])) aspect_ratios = [] area_ratios = [] @@ -1796,7 +1891,9 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): y = array_ops.strided_slice(image_tf, begin, begin + size) for _ in xrange(num_iter): - y_tf = y.eval(feed_dict={min_object_covered_placeholder: min_object_covered}) + y_tf = y.eval(feed_dict={ + min_object_covered_placeholder: min_object_covered + }) crop_height = y_tf.shape[0] crop_width = y_tf.shape[1] aspect_ratio = float(crop_width) / float(crop_height) @@ -1888,9 +1985,10 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): image_size = constant_op.constant( [40, 50, 1], shape=[3], dtype=dtypes.int32) bounding_box = constant_op.constant( - [0.0, 0.0, 1.0, 1.0], - shape=[4], - dtype=dtypes.float32,) + [[[0.0, 0.0, 1.0, 1.0]]], + shape=[1, 1, 4], + dtype=dtypes.float32, + ) begin, end, bbox_for_drawing = image_ops.sample_distorted_bounding_box( image_size=image_size, bounding_boxes=bounding_box, @@ -1902,6 +2000,10 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): self.assertAllEqual([3], begin.get_shape().as_list()) self.assertAllEqual([3], end.get_shape().as_list()) self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list()) + # Actual run to make sure shape is correct inside Compute(). + begin = begin.eval() + end = end.eval() + bbox_for_drawing = bbox_for_drawing.eval() begin, end, bbox_for_drawing = image_ops.sample_distorted_bounding_box( image_size=image_size, @@ -1921,9 +2023,10 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): image_size = constant_op.constant( [40, 50, 1], shape=[3], dtype=dtypes.int32) bounding_box = constant_op.constant( - [0.0, 0.0, 1.0, 1.0], - shape=[4], - dtype=dtypes.float32,) + [[[0.0, 0.0, 1.0, 1.0]]], + shape=[1, 1, 4], + dtype=dtypes.float32, + ) begin, end, bbox_for_drawing = image_ops.sample_distorted_bounding_box( image_size=image_size, bounding_boxes=bounding_box, @@ -1933,17 +2036,23 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): self.assertAllEqual([3], begin.get_shape().as_list()) self.assertAllEqual([3], end.get_shape().as_list()) self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list()) + # Actual run to make sure shape is correct inside Compute(). + begin = begin.eval() + end = end.eval() + bbox_for_drawing = bbox_for_drawing.eval() class ResizeImagesTest(test_util.TensorFlowTestCase): - OPTIONS = [image_ops.ResizeMethod.BILINEAR, - image_ops.ResizeMethod.NEAREST_NEIGHBOR, - image_ops.ResizeMethod.BICUBIC, - image_ops.ResizeMethod.AREA] + OPTIONS = [ + image_ops.ResizeMethod.BILINEAR, image_ops.ResizeMethod.NEAREST_NEIGHBOR, + image_ops.ResizeMethod.BICUBIC, image_ops.ResizeMethod.AREA + ] - TYPES = [np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, - np.float16, np.float32, np.float64] + TYPES = [ + np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.float16, + np.float32, np.float64 + ] def _assertShapeInference(self, pre_shape, size, post_shape): # Try single image resize @@ -1971,12 +2080,10 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): single_shape = [6, 4, 1] # This test is also conducted with int8, so 127 is the maximum # value that can be used. - data = [127, 127, 64, 64, - 127, 127, 64, 64, - 64, 64, 127, 127, - 64, 64, 127, 127, - 50, 50, 100, 100, - 50, 50, 100, 100] + data = [ + 127, 127, 64, 64, 127, 127, 64, 64, 64, 64, 127, 127, 64, 64, 127, 127, + 50, 50, 100, 100, 50, 50, 100, 100 + ] target_height = 6 target_width = 4 @@ -2007,12 +2114,10 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): single_shape = [6, 4, 1] # This test is also conducted with int8, so 127 is the maximum # value that can be used. - data = [127, 127, 64, 64, - 127, 127, 64, 64, - 64, 64, 127, 127, - 64, 64, 127, 127, - 50, 50, 100, 100, - 50, 50, 100, 100] + data = [ + 127, 127, 64, 64, 127, 127, 64, 64, 64, 64, 127, 127, 64, 64, 127, 127, + 50, 50, 100, 100, 50, 50, 100, 100 + ] new_size = array_ops.placeholder(dtypes.int32, shape=(2)) img_np = np.array(data, dtype=np.uint8).reshape(img_shape) @@ -2066,8 +2171,10 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): image_ops.ResizeMethod.BILINEAR) def testReturnDtype(self): - target_shapes = [[6, 4], [3, 2], [array_ops.placeholder(dtypes.int32), - array_ops.placeholder(dtypes.int32)]] + target_shapes = [[6, 4], [3, 2], [ + array_ops.placeholder(dtypes.int32), + array_ops.placeholder(dtypes.int32) + ]] for nptype in self.TYPES: image = array_ops.placeholder(nptype, shape=[1, 6, 4, 1]) for opt in self.OPTIONS: @@ -2084,12 +2191,10 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): img_shape = [1, 6, 4, 1] # This test is also conducted with int8, so 127 is the maximum # value that can be used. - data = [127, 127, 64, 64, - 127, 127, 64, 64, - 64, 64, 127, 127, - 64, 64, 127, 127, - 50, 50, 100, 100, - 50, 50, 100, 100] + data = [ + 127, 127, 64, 64, 127, 127, 64, 64, 64, 64, 127, 127, 64, 64, 127, 127, + 50, 50, 100, 100, 50, 50, 100, 100 + ] # Test size where width is specified as a tensor which is a sum # of two tensors. width_1 = constant_op.constant(1) @@ -2111,15 +2216,11 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): def testResizeDown(self): # This test is also conducted with int8, so 127 is the maximum # value that can be used. - data = [127, 127, 64, 64, - 127, 127, 64, 64, - 64, 64, 127, 127, - 64, 64, 127, 127, - 50, 50, 100, 100, - 50, 50, 100, 100] - expected_data = [127, 64, - 64, 127, - 50, 100] + data = [ + 127, 127, 64, 64, 127, 127, 64, 64, 64, 64, 127, 127, 64, 64, 127, 127, + 50, 50, 100, 100, 50, 50, 100, 100 + ] + expected_data = [127, 64, 64, 127, 50, 100] target_height = 3 target_width = 2 @@ -2145,39 +2246,31 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): def testResizeUpAlignCornersFalse(self): img_shape = [1, 3, 2, 1] - data = [64, 32, - 32, 64, - 50, 100] + data = [64, 32, 32, 64, 50, 100] target_height = 6 target_width = 4 expected_data = {} expected_data[image_ops.ResizeMethod.BILINEAR] = [ - 64.0, 48.0, 32.0, 32.0, - 48.0, 48.0, 48.0, 48.0, - 32.0, 48.0, 64.0, 64.0, - 41.0, 61.5, 82.0, 82.0, - 50.0, 75.0, 100.0, 100.0, - 50.0, 75.0, 100.0, 100.0] + 64.0, 48.0, 32.0, 32.0, 48.0, 48.0, 48.0, 48.0, 32.0, 48.0, 64.0, 64.0, + 41.0, 61.5, 82.0, 82.0, 50.0, 75.0, 100.0, 100.0, 50.0, 75.0, 100.0, + 100.0 + ] expected_data[image_ops.ResizeMethod.NEAREST_NEIGHBOR] = [ - 64.0, 64.0, 32.0, 32.0, - 64.0, 64.0, 32.0, 32.0, - 32.0, 32.0, 64.0, 64.0, - 32.0, 32.0, 64.0, 64.0, - 50.0, 50.0, 100.0, 100.0, - 50.0, 50.0, 100.0, 100.0] + 64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 32.0, 32.0, 64.0, 64.0, + 32.0, 32.0, 64.0, 64.0, 50.0, 50.0, 100.0, 100.0, 50.0, 50.0, 100.0, + 100.0 + ] expected_data[image_ops.ResizeMethod.AREA] = [ - 64.0, 64.0, 32.0, 32.0, - 64.0, 64.0, 32.0, 32.0, - 32.0, 32.0, 64.0, 64.0, - 32.0, 32.0, 64.0, 64.0, - 50.0, 50.0, 100.0, 100.0, - 50.0, 50.0, 100.0, 100.0] + 64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 32.0, 32.0, 64.0, 64.0, + 32.0, 32.0, 64.0, 64.0, 50.0, 50.0, 100.0, 100.0, 50.0, 50.0, 100.0, + 100.0 + ] for nptype in self.TYPES: for opt in [ image_ops.ResizeMethod.BILINEAR, - image_ops.ResizeMethod.NEAREST_NEIGHBOR, - image_ops.ResizeMethod.AREA]: + image_ops.ResizeMethod.NEAREST_NEIGHBOR, image_ops.ResizeMethod.AREA + ]: with self.test_session(use_gpu=True): img_np = np.array(data, dtype=nptype).reshape(img_shape) image = constant_op.constant(img_np, shape=img_shape) @@ -2190,41 +2283,29 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): def testResizeUpAlignCornersTrue(self): img_shape = [1, 3, 2, 1] - data = [6, 3, - 3, 6, - 6, 9] + data = [6, 3, 3, 6, 6, 9] target_height = 5 target_width = 4 expected_data = {} expected_data[image_ops.ResizeMethod.BILINEAR] = [ - 6.0, 5.0, 4.0, 3.0, - 4.5, 4.5, 4.5, 4.5, - 3.0, 4.0, 5.0, 6.0, - 4.5, 5.5, 6.5, 7.5, - 6.0, 7.0, 8.0, 9.0 + 6.0, 5.0, 4.0, 3.0, 4.5, 4.5, 4.5, 4.5, 3.0, 4.0, 5.0, 6.0, 4.5, 5.5, + 6.5, 7.5, 6.0, 7.0, 8.0, 9.0 ] expected_data[image_ops.ResizeMethod.NEAREST_NEIGHBOR] = [ - 6.0, 6.0, 3.0, 3.0, - 3.0, 3.0, 6.0, 6.0, - 3.0, 3.0, 6.0, 6.0, - 6.0, 6.0, 9.0, 9.0, - 6.0, 6.0, 9.0, 9.0 + 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, + 9.0, 9.0, 6.0, 6.0, 9.0, 9.0 ] # TODO(b/37749740): Improve alignment of ResizeMethod.AREA when # align_corners=True. expected_data[image_ops.ResizeMethod.AREA] = [ - 6.0, 6.0, 6.0, 3.0, - 6.0, 6.0, 6.0, 3.0, - 3.0, 3.0, 3.0, 6.0, - 3.0, 3.0, 3.0, 6.0, - 6.0, 6.0, 6.0, 9.0 + 6.0, 6.0, 6.0, 3.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 6.0, 3.0, 3.0, + 3.0, 6.0, 6.0, 6.0, 6.0, 9.0 ] for nptype in self.TYPES: for opt in [ image_ops.ResizeMethod.BILINEAR, - image_ops.ResizeMethod.NEAREST_NEIGHBOR, - image_ops.ResizeMethod.AREA + image_ops.ResizeMethod.NEAREST_NEIGHBOR, image_ops.ResizeMethod.AREA ]: with self.test_session(use_gpu=True): img_np = np.array(data, dtype=nptype).reshape(img_shape) @@ -2238,23 +2319,21 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): def testResizeUpBicubic(self): img_shape = [1, 6, 6, 1] - data = [128, 128, 64, 64, 128, 128, 64, 64, - 64, 64, 128, 128, 64, 64, 128, 128, - 50, 50, 100, 100, 50, 50, 100, 100, - 50, 50, 100, 100, 50, 50, 100, 100, - 50, 50, 100, 100] + data = [ + 128, 128, 64, 64, 128, 128, 64, 64, 64, 64, 128, 128, 64, 64, 128, 128, + 50, 50, 100, 100, 50, 50, 100, 100, 50, 50, 100, 100, 50, 50, 100, 100, + 50, 50, 100, 100 + ] img_np = np.array(data, dtype=np.uint8).reshape(img_shape) target_height = 8 target_width = 8 - expected_data = [128, 135, 96, 55, 64, 114, 134, 128, - 78, 81, 68, 52, 57, 118, 144, 136, - 55, 49, 79, 109, 103, 89, 83, 84, - 74, 70, 95, 122, 115, 69, 49, 55, - 100, 105, 75, 43, 50, 89, 105, 100, - 57, 54, 74, 96, 91, 65, 55, 58, - 70, 69, 75, 81, 80, 72, 69, 70, - 105, 112, 75, 36, 45, 92, 111, 105] + expected_data = [ + 128, 135, 96, 55, 64, 114, 134, 128, 78, 81, 68, 52, 57, 118, 144, 136, + 55, 49, 79, 109, 103, 89, 83, 84, 74, 70, 95, 122, 115, 69, 49, 55, 100, + 105, 75, 43, 50, 89, 105, 100, 57, 54, 74, 96, 91, 65, 55, 58, 70, 69, + 75, 81, 80, 72, 69, 70, 105, 112, 75, 36, 45, 92, 111, 105 + ] with self.test_session(use_gpu=True): image = constant_op.constant(img_np, shape=img_shape) @@ -2267,20 +2346,17 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): def testResizeDownArea(self): img_shape = [1, 6, 6, 1] - data = [128, 64, 32, 16, 8, 4, - 4, 8, 16, 32, 64, 128, - 128, 64, 32, 16, 8, 4, - 5, 10, 15, 20, 25, 30, - 30, 25, 20, 15, 10, 5, - 5, 10, 15, 20, 25, 30] + data = [ + 128, 64, 32, 16, 8, 4, 4, 8, 16, 32, 64, 128, 128, 64, 32, 16, 8, 4, 5, + 10, 15, 20, 25, 30, 30, 25, 20, 15, 10, 5, 5, 10, 15, 20, 25, 30 + ] img_np = np.array(data, dtype=np.uint8).reshape(img_shape) target_height = 4 target_width = 4 - expected_data = [73, 33, 23, 39, - 73, 33, 23, 39, - 14, 16, 19, 21, - 14, 16, 19, 21] + expected_data = [ + 73, 33, 23, 39, 73, 33, 23, 39, 14, 16, 19, 21, 14, 16, 19, 21 + ] with self.test_session(use_gpu=True): image = constant_op.constant(img_np, shape=img_shape) @@ -2367,7 +2443,7 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True): single_image = array_ops.placeholder(dtypes.float32, shape=[50, 60, 3]) y = image_ops.resize_images(single_image, [55, 66]) - self.assertTrue(y.op.name.startswith('resize_images')) + self.assertTrue(y.op.name.startswith("resize_images")) class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase): @@ -2440,133 +2516,93 @@ class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase): def testPad(self): # Pad even along col. - x = [1, 2, 3, 4, - 5, 6, 7, 8] + x = [1, 2, 3, 4, 5, 6, 7, 8] x_shape = [2, 4, 1] - y = [0, 1, 2, 3, 4, 0, - 0, 5, 6, 7, 8, 0] + y = [0, 1, 2, 3, 4, 0, 0, 5, 6, 7, 8, 0] y_shape = [2, 6, 1] self._assertReturns(x, x_shape, y, y_shape) # Pad odd along col. - x = [1, 2, 3, 4, - 5, 6, 7, 8] + x = [1, 2, 3, 4, 5, 6, 7, 8] x_shape = [2, 4, 1] - y = [0, 1, 2, 3, 4, 0, 0, - 0, 5, 6, 7, 8, 0, 0] + y = [0, 1, 2, 3, 4, 0, 0, 0, 5, 6, 7, 8, 0, 0] y_shape = [2, 7, 1] self._assertReturns(x, x_shape, y, y_shape) # Pad even along row. - x = [1, 2, 3, 4, - 5, 6, 7, 8] + x = [1, 2, 3, 4, 5, 6, 7, 8] x_shape = [2, 4, 1] - y = [0, 0, 0, 0, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0] + y = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0] y_shape = [4, 4, 1] self._assertReturns(x, x_shape, y, y_shape) # Pad odd along row. - x = [1, 2, 3, 4, - 5, 6, 7, 8] + x = [1, 2, 3, 4, 5, 6, 7, 8] x_shape = [2, 4, 1] - y = [0, 0, 0, 0, - 1, 2, 3, 4, - 5, 6, 7, 8, - 0, 0, 0, 0, - 0, 0, 0, 0] + y = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0] y_shape = [5, 4, 1] self._assertReturns(x, x_shape, y, y_shape) def testCrop(self): # Crop even along col. - x = [1, 2, 3, 4, - 5, 6, 7, 8] + x = [1, 2, 3, 4, 5, 6, 7, 8] x_shape = [2, 4, 1] - y = [2, 3, - 6, 7] + y = [2, 3, 6, 7] y_shape = [2, 2, 1] self._assertReturns(x, x_shape, y, y_shape) # Crop odd along col. - x = [1, 2, 3, 4, 5, 6, - 7, 8, 9, 10, 11, 12] + x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] x_shape = [2, 6, 1] - y = [2, 3, 4, - 8, 9, 10] + y = [2, 3, 4, 8, 9, 10] y_shape = [2, 3, 1] self._assertReturns(x, x_shape, y, y_shape) # Crop even along row. - x = [1, 2, - 3, 4, - 5, 6, - 7, 8] + x = [1, 2, 3, 4, 5, 6, 7, 8] x_shape = [4, 2, 1] - y = [3, 4, - 5, 6] + y = [3, 4, 5, 6] y_shape = [2, 2, 1] self._assertReturns(x, x_shape, y, y_shape) # Crop odd along row. - x = [1, 2, - 3, 4, - 5, 6, - 7, 8, - 9, 10, - 11, 12, - 13, 14, - 15, 16] + x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] x_shape = [8, 2, 1] - y = [3, 4, - 5, 6, - 7, 8, - 9, 10, - 11, 12] + y = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12] y_shape = [5, 2, 1] self._assertReturns(x, x_shape, y, y_shape) def testCropAndPad(self): # Pad along row but crop along col. - x = [1, 2, 3, 4, - 5, 6, 7, 8] + x = [1, 2, 3, 4, 5, 6, 7, 8] x_shape = [2, 4, 1] - y = [0, 0, - 2, 3, - 6, 7, - 0, 0] + y = [0, 0, 2, 3, 6, 7, 0, 0] y_shape = [4, 2, 1] self._assertReturns(x, x_shape, y, y_shape) # Crop along row but pad along col. - x = [1, 2, - 3, 4, - 5, 6, - 7, 8] + x = [1, 2, 3, 4, 5, 6, 7, 8] x_shape = [4, 2, 1] - y = [0, 3, 4, 0, - 0, 5, 6, 0] + y = [0, 3, 4, 0, 0, 5, 6, 0] y_shape = [2, 4, 1] self._assertReturns(x, x_shape, y, y_shape) @@ -2647,7 +2683,7 @@ class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase): def testNameScope(self): image = array_ops.placeholder(dtypes.float32, shape=[50, 60, 3]) y = image_ops.resize_image_with_crop_or_pad(image, 55, 66) - self.assertTrue(y.op.name.startswith('resize_image_with_crop_or_pad')) + self.assertTrue(y.op.name.startswith("resize_image_with_crop_or_pad")) def _SimpleColorRamp(): @@ -2910,20 +2946,9 @@ class PngTest(test_util.TensorFlowTestCase): class GifTest(test_util.TensorFlowTestCase): - def testOptimizedGifErrorString(self): - filename = "tensorflow/core/lib/gif/testdata/optimized.gif" - - with self.test_session(use_gpu=True) as sess: - gif = io_ops.read_file(filename) - image = image_ops.decode_gif(gif) - with self.assertRaisesRegexp( - errors.InvalidArgumentError, "can't process optimized gif"): - gif, image = sess.run([gif, image]) - - def testValid(self): + def _testValid(self, filename): # Read some real GIFs prefix = "tensorflow/core/lib/gif/testdata/" - filename = "scan.gif" WIDTH = 20 HEIGHT = 40 STRIDE = 5 @@ -2950,16 +2975,9 @@ class GifTest(test_util.TensorFlowTestCase): self.assertAllClose(frame, gt) - def testInValid(self): - # Read some real GIFs - prefix = "tensorflow/core/lib/gif/testdata/" - filename = "optimized.gif" - - with self.test_session(use_gpu=True) as sess: - gif0 = io_ops.read_file(prefix + filename) - image0 = image_ops.decode_gif(gif0) - with self.assertRaises(errors.InvalidArgumentError): - gif0, image0 = sess.run([gif0, image0]) + def testValid(self): + self._testValid("scan.gif") + self._testValid("optimized.gif") def testShape(self): with self.test_session(use_gpu=True) as sess: @@ -2979,8 +2997,9 @@ class ConvertImageTest(test_util.TensorFlowTestCase): y = image_ops.convert_image_dtype(image, output_dtype) self.assertTrue(y.dtype == output_dtype) self.assertAllClose(y.eval(), y_np, atol=1e-5) - if output_dtype in [dtypes.float32, dtypes.float64, - dtypes.int32, dtypes.int64]: + if output_dtype in [ + dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64 + ]: y_saturate = image_ops.convert_image_dtype( image, output_dtype, saturate=True) self.assertTrue(y_saturate.dtype == output_dtype) @@ -3000,8 +3019,8 @@ class ConvertImageTest(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True): self._convert([0, 255], dtypes.uint8, dtypes.int16, [0, 255 * 128]) self._convert([0, 32767], dtypes.int16, dtypes.uint8, [0, 255]) - self._convert([0, 2 ** 32], dtypes.int64, dtypes.int32, [0, 1]) - self._convert([0, 1], dtypes.int32, dtypes.int64, [0, 2 ** 32]) + self._convert([0, 2**32], dtypes.int64, dtypes.int32, [0, 1]) + self._convert([0, 1], dtypes.int32, dtypes.int64, [0, 2**32]) def testConvertBetweenFloat(self): # Make sure converting to between float types does nothing interesting @@ -3022,20 +3041,14 @@ class ConvertImageTest(test_util.TensorFlowTestCase): def testConvertBetweenInt16AndInt8(self): with self.test_session(use_gpu=True): # uint8, uint16 - self._convert([0, 255 * 256], dtypes.uint16, dtypes.uint8, - [0, 255]) - self._convert([0, 255], dtypes.uint8, dtypes.uint16, - [0, 255 * 256]) + self._convert([0, 255 * 256], dtypes.uint16, dtypes.uint8, [0, 255]) + self._convert([0, 255], dtypes.uint8, dtypes.uint16, [0, 255 * 256]) # int8, uint16 - self._convert([0, 127 * 2 * 256], dtypes.uint16, dtypes.int8, - [0, 127]) - self._convert([0, 127], dtypes.int8, dtypes.uint16, - [0, 127 * 2 * 256]) + self._convert([0, 127 * 2 * 256], dtypes.uint16, dtypes.int8, [0, 127]) + self._convert([0, 127], dtypes.int8, dtypes.uint16, [0, 127 * 2 * 256]) # int16, uint16 - self._convert([0, 255 * 256], dtypes.uint16, dtypes.int16, - [0, 255 * 128]) - self._convert([0, 255 * 128], dtypes.int16, dtypes.uint16, - [0, 255 * 256]) + self._convert([0, 255 * 256], dtypes.uint16, dtypes.int16, [0, 255 * 128]) + self._convert([0, 255 * 128], dtypes.int16, dtypes.uint16, [0, 255 * 256]) class TotalVariationTest(test_util.TensorFlowTestCase): @@ -3168,20 +3181,17 @@ class TotalVariationTest(test_util.TensorFlowTestCase): # The following are the sum of absolute differences between the pixels. # sum row dif = (4-1) + (7-2) = 3 + 5 = 8 # sum col dif = (2-1) + (7-4) = 1 + 3 = 4 - r = [[1, 2], - [4, 7]] + r = [[1, 2], [4, 7]] # Blue color channel. # sum row dif = 18 + 29 = 47 # sum col dif = 7 + 18 = 25 - g = [[11, 18], - [29, 47]] + g = [[11, 18], [29, 47]] # Green color channel. # sum row dif = 120 + 193 = 313 # sum col dif = 47 + 120 = 167 - b = [[73, 120], - [193, 313]] + b = [[73, 120], [193, 313]] # Combine the 3 color channels into a single 3-dim array. # The shape is (2, 2, 3) corresponding to (height, width and color). @@ -3210,9 +3220,7 @@ class TotalVariationTest(test_util.TensorFlowTestCase): # Combine these 3 images into a single array of shape (3, 2, 2, 3) # where the first dimension is for the image-number. - multi = np.vstack((a[np.newaxis, :], - b[np.newaxis, :], - c[np.newaxis, :])) + multi = np.vstack((a[np.newaxis, :], b[np.newaxis, :], c[np.newaxis, :])) # Check that TensorFlow correctly calculates the total variation # for each image individually and returns the correct array. @@ -3268,6 +3276,49 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase): boxes, scores, max_output_size, iou_threshold).eval() self.assertAllClose(selected_indices, [3, 0, 5]) + def testInvalidShape(self): + # The boxes should be 2D of shape [num_boxes, 4]. + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 2 but is rank 1"): + boxes = constant_op.constant([0.0, 0.0, 1.0, 1.0]) + scores = constant_op.constant([0.9]) + image_ops.non_max_suppression(boxes, scores, 3, 0.5) + + with self.assertRaisesRegexp(ValueError, "Dimension must be 4 but is 3"): + boxes = constant_op.constant([[0.0, 0.0, 1.0]]) + scores = constant_op.constant([0.9]) + image_ops.non_max_suppression(boxes, scores, 3, 0.5) + + # The boxes is of shape [num_boxes, 4], and the scores is + # of shape [num_boxes]. So an error will thrown. + with self.assertRaisesRegexp( + ValueError, 'Dimensions must be equal, but are 1 and 2'): + boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) + scores = constant_op.constant([0.9, 0.75]) + selected_indices = image_ops.non_max_suppression( + boxes, scores, 3, 0.5) + + # The scores should be 1D of shape [num_boxes]. + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 1 but is rank 2"): + boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) + scores = constant_op.constant([[0.9]]) + image_ops.non_max_suppression(boxes, scores, 3, 0.5) + + # The max_output_size should be a scaler (0-D). + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 0 but is rank 1"): + boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) + scores = constant_op.constant([0.9]) + image_ops.non_max_suppression(boxes, scores, [3], 0.5) + + # The iou_threshold should be a scaler (0-D). + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 0 but is rank 2"): + boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) + scores = constant_op.constant([0.9]) + image_ops.non_max_suppression(boxes, scores, 3, [[0.5]]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/linalg/linalg.py b/tensorflow/python/ops/linalg/linalg.py index 5369007a56c89ef8601f8144c2fe18717e2e78fe..14319025ff275944cf34e30128df96254d06072b 100644 --- a/tensorflow/python/ops/linalg/linalg.py +++ b/tensorflow/python/ops/linalg/linalg.py @@ -41,4 +41,5 @@ del gen_linalg_ops del linalg_ops del math_ops del special_math_ops +del tf_export # pylint: enable=undefined-variable diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index db33a08137e1d2508314c2d28bdbbb001198e6c1..d5bd916f80d8a03e5423c43d1ca039bc4dceff5e 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import special_math_ops +from tensorflow.python.util.tf_export import tf_export # Linear algebra ops. band_part = array_ops.matrix_band_part @@ -54,6 +55,7 @@ transpose = array_ops.matrix_transpose triangular_solve = linalg_ops.matrix_triangular_solve +@tf_export('linalg.logdet') def logdet(matrix, name=None): """Computes log of the determinant of a hermitian positive definite matrix. @@ -65,8 +67,8 @@ def logdet(matrix, name=None): ``` Args: - matrix: A `Tensor`. Must be `float32`, `float64`, `complex64`, or - `complex128` with shape `[..., M, M]`. + matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, + or `complex128` with shape `[..., M, M]`. name: A name to give this `Op`. Defaults to `logdet`. Returns: @@ -86,6 +88,7 @@ def logdet(matrix, name=None): reduction_indices=[-1]) +@tf_export('linalg.adjoint') def adjoint(matrix, name=None): """Transposes the last two dimensions of and conjugates tensor `matrix`. @@ -99,8 +102,8 @@ def adjoint(matrix, name=None): # [3 - 3j, 6 - 6j]] Args: - matrix: A `Tensor`. Must be `float32`, `float64`, `complex64`, or - `complex128` with shape `[..., M, M]`. + matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, + or `complex128` with shape `[..., M, M]`. name: A name to give this `Op` (optional). Returns: diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 27e0f17020afa0fd44ec11c49b7a77d4426933dd..957a7959181efe3bbc319e62582053329b763dc3 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -32,11 +32,13 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linalg_impl as linalg from tensorflow.python.ops.linalg import linear_operator_util from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export __all__ = ["LinearOperator"] # TODO(langmore) Use matrix_solve_ls for singular or non-square matrices. +@tf_export("linalg.LinearOperator") class LinearOperator(object): """Base class defining a [batch of] linear operator[s]. @@ -478,7 +480,6 @@ class LinearOperator(object): cond, self._max_condition_number_to_be_non_singular(), message="Singular matrix up to precision epsilon.") - raise NotImplementedError("assert_non_singular is not implemented.") def _max_condition_number_to_be_non_singular(self): """Return the maximum condition number that we consider nonsingular.""" diff --git a/tensorflow/python/ops/linalg/linear_operator_composition.py b/tensorflow/python/ops/linalg/linear_operator_composition.py index 14411291d4fddeb2242e243d9a611e9c2fcd171a..ecd30e4d7e4dd7cfd4b109ad6e60aacb172700f6 100644 --- a/tensorflow/python/ops/linalg/linear_operator_composition.py +++ b/tensorflow/python/ops/linalg/linear_operator_composition.py @@ -25,10 +25,12 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.util.tf_export import tf_export __all__ = ["LinearOperatorComposition"] +@tf_export("linalg.LinearOperatorComposition") class LinearOperatorComposition(linear_operator.LinearOperator): """Composes one or more `LinearOperators`. diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py index a4724d030f388230cf85cc68bf60b6553b409c17..b3ec3d5b7cf45ac0b2672eea9a4586b2c3295897 100644 --- a/tensorflow/python/ops/linalg/linear_operator_diag.py +++ b/tensorflow/python/ops/linalg/linear_operator_diag.py @@ -26,10 +26,12 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linalg_impl as linalg from tensorflow.python.ops.linalg import linear_operator from tensorflow.python.ops.linalg import linear_operator_util +from tensorflow.python.util.tf_export import tf_export __all__ = ["LinearOperatorDiag",] +@tf_export("linalg.LinearOperatorDiag") class LinearOperatorDiag(linear_operator.LinearOperator): """`LinearOperator` acting like a [batch] square diagonal matrix. @@ -121,8 +123,8 @@ class LinearOperatorDiag(linear_operator.LinearOperator): Args: diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. - The diagonal of the operator. Allowed dtypes: `float32`, `float64`, - `complex64`, `complex128`. + The diagonal of the operator. Allowed dtypes: `float16`, `float32`, + `float64`, `complex64`, `complex128`. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. If `diag.dtype` is real, this is auto-set to `True`. @@ -167,7 +169,12 @@ class LinearOperatorDiag(linear_operator.LinearOperator): def _check_diag(self, diag): """Static check of diag.""" allowed_dtypes = [ - dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.complex64, + dtypes.complex128, + ] dtype = diag.dtype if dtype not in allowed_dtypes: diff --git a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py index dd4c7cb0413013f3f54f6085a7adcb523755a603..f979fb37d6c69a2683af08a1f6722b98da0b6650 100644 --- a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py +++ b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py @@ -23,10 +23,12 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.util.tf_export import tf_export __all__ = ["LinearOperatorFullMatrix"] +@tf_export("linalg.LinearOperatorFullMatrix") class LinearOperatorFullMatrix(linear_operator.LinearOperator): """`LinearOperator` that wraps a [batch] matrix. @@ -114,7 +116,8 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): Args: matrix: Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`. - Allowed dtypes: `float32`, `float64`, `complex64`, `complex128`. + Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, + `complex128`. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. @@ -147,7 +150,12 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): def _check_matrix(self, matrix): """Static check of the `matrix` argument.""" allowed_dtypes = [ - dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.complex64, + dtypes.complex128, + ] matrix = ops.convert_to_tensor(matrix, name="matrix") diff --git a/tensorflow/python/ops/linalg/linear_operator_identity.py b/tensorflow/python/ops/linalg/linear_operator_identity.py index 740c6c811f2d98f62c200cda7242c6ad00de499d..50f3d407e85e4cca22ad6326931b5a2a736819a8 100644 --- a/tensorflow/python/ops/linalg/linear_operator_identity.py +++ b/tensorflow/python/ops/linalg/linear_operator_identity.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linalg_impl as linalg from tensorflow.python.ops.linalg import linear_operator from tensorflow.python.ops.linalg import linear_operator_util +from tensorflow.python.util.tf_export import tf_export __all__ = [ "LinearOperatorIdentity", @@ -97,6 +98,7 @@ class BaseLinearOperatorIdentity(linear_operator.LinearOperator): return array_ops.ones(shape=d_shape, dtype=self.dtype) +@tf_export("linalg.LinearOperatorIdentity") class LinearOperatorIdentity(BaseLinearOperatorIdentity): """`LinearOperator` acting like a [batch] square identity matrix. @@ -460,6 +462,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): "%s" % self._batch_shape_static) +@tf_export("linalg.LinearOperatorScaledIdentity") class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): """`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`. diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py index ad3bb2efa94bfa9751c31ff0c704aad8faa58ba7..be911029095920d424ac90b406e7b85b73884b3b 100644 --- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py +++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py @@ -27,12 +27,14 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linear_operator from tensorflow.python.ops.linalg import linear_operator_diag from tensorflow.python.ops.linalg import linear_operator_identity +from tensorflow.python.util.tf_export import tf_export __all__ = [ "LinearOperatorLowRankUpdate", ] +@tf_export("linalg.LinearOperatorLowRankUpdate") class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): """Perturb a `LinearOperator` with a rank `K` update. @@ -150,8 +152,8 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): `is_X` matrix property hints, which will trigger the appropriate code path. Args: - base_operator: Shape `[B1,...,Bb, M, N]` real `float32` or `float64` - `LinearOperator`. This is `L` above. + base_operator: Shape `[B1,...,Bb, M, N]` real `float16`, `float32` or + `float64` `LinearOperator`. This is `L` above. u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`. This is `U` above. diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype` @@ -188,7 +190,11 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): # because if diag has non-zero imaginary part, it will not be # self-adjoint positive definite. dtype = base_operator.dtype - allowed_dtypes = [dtypes.float32, dtypes.float64] + allowed_dtypes = [ + dtypes.float16, + dtypes.float32, + dtypes.float64, + ] if dtype not in allowed_dtypes: raise TypeError( "Argument matrix must have dtype in %s. Found: %s" diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py index 6ea55f0367bd55379b280f81f22df2c3a0dcfb1e..a5130188b681813e1ccd4818dabdffeeb663e20a 100644 --- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py +++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py @@ -26,12 +26,14 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linalg_impl as linalg from tensorflow.python.ops.linalg import linear_operator from tensorflow.python.ops.linalg import linear_operator_util +from tensorflow.python.util.tf_export import tf_export __all__ = [ "LinearOperatorLowerTriangular", ] +@tf_export("linalg.LinearOperatorLowerTriangular") class LinearOperatorLowerTriangular(linear_operator.LinearOperator): """`LinearOperator` acting like a [batch] square lower triangular matrix. @@ -118,7 +120,8 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): Args: tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. The lower triangular part of `tril` defines this operator. The strictly - upper triangle is ignored. Allowed dtypes: `float32`, `float64`. + upper triangle is ignored. Allowed dtypes: `float16`, `float32`, + `float64`. is_non_singular: Expect that this operator is non-singular. This operator is non-singular if and only if its diagonal elements are all non-zero. @@ -164,7 +167,11 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): """Static check of the `tril` argument.""" # TODO(langmore) Add complex types once matrix_triangular_solve works for # them. - allowed_dtypes = [dtypes.float32, dtypes.float64] + allowed_dtypes = [ + dtypes.float16, + dtypes.float32, + dtypes.float64, + ] dtype = tril.dtype if dtype not in allowed_dtypes: raise TypeError( diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 72508eb4350f57bb06b3829890f92554677c98d5..aa74e117640d971e38968efadb3c34d5ce3a6f97 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -28,8 +28,10 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops.losses import util from tensorflow.python.util.deprecation import deprecated_args +from tensorflow.python.util.tf_export import tf_export +@tf_export("losses.Reduction") class Reduction(object): """Types of loss reduction. @@ -142,16 +144,17 @@ def _num_present(losses, weights, per_batch=False): if per_batch: return math_ops.reduce_sum( present, axis=math_ops.range(1, array_ops.rank(present)), - keep_dims=True, name=scope) + keepdims=True, name=scope) return math_ops.reduce_sum(present, name=scope) def _num_elements(losses): """Computes the number of elements in `losses` tensor.""" with ops.name_scope(None, "num_elements", values=[losses]) as scope: - return array_ops.size(losses, name=scope, out_type=losses.dtype) + return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype) +@tf_export("losses.compute_weighted_loss") def compute_weighted_loss( losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -211,6 +214,7 @@ def compute_weighted_loss( return loss +@tf_export("losses.absolute_difference") def absolute_difference( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -258,6 +262,7 @@ def absolute_difference( losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.cosine_distance") @deprecated_args(None, "dim is deprecated, use axis instead", "dim") def cosine_distance( labels, predictions, axis=None, weights=1.0, scope=None, @@ -306,11 +311,12 @@ def cosine_distance( predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) - losses = 1 - math_ops.reduce_sum(radial_diffs, axis=(axis,), keep_dims=True) + losses = 1 - math_ops.reduce_sum(radial_diffs, axis=(axis,), keepdims=True) return compute_weighted_loss( losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.hinge_loss") def hinge_loss(labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -352,6 +358,7 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.huber_loss") def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -370,7 +377,7 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, `weights` acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If `weights` is a tensor of size - [batch_size], then the total loss for each sample of the batch is rescaled + `[batch_size]`, then the total loss for each sample of the batch is rescaled by the corresponding element in the `weights` vector. If the shape of `weights` matches the shape of `predictions`, then the loss of each measurable element of `predictions` is scaled by the corresponding value of @@ -415,11 +422,12 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, # This is necessary to avoid doubling the gradient, since there is already a # nonzero contribution to the gradient from the quadratic term. linear = (abs_error - quadratic) - losses = 0.5 * quadratic**2 + delta * linear + losses = 0.5 * quadratic * quadratic + delta * linear return compute_weighted_loss( losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.log_loss") def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -427,7 +435,7 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, `weights` acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If `weights` is a tensor of size - [batch_size], then the total loss for each sample of the batch is rescaled + `[batch_size]`, then the total loss for each sample of the batch is rescaled by the corresponding element in the `weights` vector. If the shape of `weights` matches the shape of `predictions`, then the loss of each measurable element of `predictions` is scaled by the corresponding value of @@ -471,6 +479,7 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, # TODO(b/37208492): Add reduction arg. +@tf_export("losses.mean_pairwise_squared_error") def mean_pairwise_squared_error( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES): @@ -493,7 +502,7 @@ def mean_pairwise_squared_error( `weights` acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If `weights` is a tensor of size - [batch_size], then the total loss for each sample of the batch is rescaled + `[batch_size]`, then the total loss for each sample of the batch is rescaled by the corresponding element in the `weights` vector. Args: @@ -534,16 +543,17 @@ def mean_pairwise_squared_error( sum_squares_diff_per_batch = math_ops.reduce_sum( math_ops.square(diffs), reduction_indices=reduction_indices, - keep_dims=True) + keepdims=True) num_present_per_batch = _num_present(diffs, weights, per_batch=True) term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, - num_present_per_batch) + num_present_per_batch - 1) sum_diff = math_ops.reduce_sum( - diffs, reduction_indices=reduction_indices, keep_dims=True) - term2 = 2.0 * _safe_div(math_ops.square(sum_diff), - math_ops.square(num_present_per_batch)) + diffs, reduction_indices=reduction_indices, keepdims=True) + term2 = 2.0 * _safe_div( + math_ops.square(sum_diff), + math_ops.multiply(num_present_per_batch, num_present_per_batch - 1)) weighted_losses = math_ops.multiply(term1 - term2, weights) loss = math_ops.reduce_sum(weighted_losses) @@ -557,6 +567,7 @@ def mean_pairwise_squared_error( return mean_loss +@tf_export("losses.mean_squared_error") def mean_squared_error( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -565,7 +576,7 @@ def mean_squared_error( `weights` acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If `weights` is a tensor of size - [batch_size], then the total loss for each sample of the batch is rescaled + `[batch_size]`, then the total loss for each sample of the batch is rescaled by the corresponding element in the `weights` vector. If the shape of `weights` matches the shape of `predictions`, then the loss of each measurable element of `predictions` is scaled by the corresponding value of @@ -604,6 +615,7 @@ def mean_squared_error( losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.sigmoid_cross_entropy") def sigmoid_cross_entropy( multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -662,6 +674,7 @@ def sigmoid_cross_entropy( losses, weights, scope, loss_collection, reduction=reduction) +@tf_export("losses.softmax_cross_entropy") def softmax_cross_entropy( onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -713,9 +726,11 @@ def softmax_cross_entropy( smooth_negatives = label_smoothing / num_classes onehot_labels = onehot_labels * smooth_positives + smooth_negatives - losses = nn.softmax_cross_entropy_with_logits(labels=onehot_labels, - logits=logits, - name="xentropy") + onehot_labels = array_ops.stop_gradient( + onehot_labels, name="labels_stop_gradient") + losses = nn.softmax_cross_entropy_with_logits_v2( + labels=onehot_labels, logits=logits, name="xentropy") + return compute_weighted_loss( losses, weights, scope, loss_collection, reduction=reduction) @@ -771,6 +786,7 @@ def _remove_squeezable_dimensions( return labels, predictions, weights +@tf_export("losses.sparse_softmax_cross_entropy") def sparse_softmax_cross_entropy( labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -779,7 +795,7 @@ def sparse_softmax_cross_entropy( `weights` acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If `weights` is a - tensor of shape [`batch_size`], then the loss weights apply to each + tensor of shape `[batch_size]`, then the loss weights apply to each corresponding sample. Args: diff --git a/tensorflow/python/ops/losses/util.py b/tensorflow/python/ops/losses/util.py index 3718c481c26afdd9f007ffc22a9e6ec44a1eb10e..b835d963869704f053de6c2f8a75ae1fa72e6a5d 100644 --- a/tensorflow/python/ops/losses/util.py +++ b/tensorflow/python/ops/losses/util.py @@ -30,8 +30,10 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("losses.add_loss") def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): """Adds a externally defined loss to the collection of losses. @@ -43,6 +45,7 @@ def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): ops.add_to_collection(loss_collection, loss) +@tf_export("losses.get_losses") def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): """Gets the list of losses from the loss_collection. @@ -56,6 +59,7 @@ def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): return ops.get_collection(loss_collection, scope) +@tf_export("losses.get_regularization_losses") def get_regularization_losses(scope=None): """Gets the list of regularization losses. @@ -68,6 +72,7 @@ def get_regularization_losses(scope=None): return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope) +@tf_export("losses.get_regularization_loss") def get_regularization_loss(scope=None, name="total_regularization_loss"): """Gets the total regularization loss. @@ -85,6 +90,7 @@ def get_regularization_loss(scope=None, name="total_regularization_loss"): return constant_op.constant(0.0) +@tf_export("losses.get_total_loss") def get_total_loss(add_regularization_losses=True, name="total_loss"): """Returns a tensor whose value represents the total loss. diff --git a/tensorflow/contrib/quantize/python/copy_graph.py b/tensorflow/python/ops/manip_grad.py similarity index 63% rename from tensorflow/contrib/quantize/python/copy_graph.py rename to tensorflow/python/ops/manip_grad.py index 0376fcba82b99feabdba3b683f9db9a32db51efb..bb2069359dd6fbe4874e228e6f2f58ea8444744d 100644 --- a/tensorflow/contrib/quantize/python/copy_graph.py +++ b/tensorflow/python/ops/manip_grad.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,21 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utility to copy a tf.Graph.""" +"""Gradients for operators defined in manip_ops.py.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.training import saver as saver_lib +from tensorflow.python.ops import manip_ops -def CopyGraph(graph): - """Return a copy of graph.""" - meta_graph = saver_lib.export_meta_graph( - graph=graph, collection_list=graph.get_all_collection_keys()) - graph_copy = ops.Graph() - with graph_copy.as_default(): - _ = saver_lib.import_meta_graph(meta_graph) - return graph_copy +@ops.RegisterGradient("Roll") +def _RollGrad(op, grad): + # The gradient is just the roll reversed + shift = op.inputs[1] + axis = op.inputs[2] + roll_grad = manip_ops.roll(grad, -shift, axis) + return roll_grad, None, None diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..91e15b47b9400f29425af2f186c7c44ee6a5a622 --- /dev/null +++ b/tensorflow/python/ops/manip_ops.py @@ -0,0 +1,38 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators for manipulating tensors. + +@@roll +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops +from tensorflow.python.util.all_util import remove_undocumented + + +# pylint: disable=protected-access +def roll(input, shift, axis): # pylint: disable=redefined-builtin + return _gen_manip_ops.roll(input, shift, axis) + + +roll.__doc__ = _gen_manip_ops.roll.__doc__ +# pylint: enable=protected-access + +_allowed_symbols = ['roll'] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index bca4c665d27f2513ed0029ae0c674f46a060567f..9e7f37d80fdd71e84516ab450d145d79519ae47a 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -40,15 +40,16 @@ def _SumGrad(op, grad): """Gradient for Sum.""" # Fast path for when reducing to a scalar and ndims is known: adds only # Reshape and Tile ops (and possibly a Shape). - if op.inputs[0].get_shape().ndims is not None: + input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access + if input_0_shape is not None: axes = tensor_util.constant_value(op.inputs[1]) if axes is not None: - rank = op.inputs[0].get_shape().ndims + rank = len(input_0_shape) if np.array_equal(axes, np.arange(rank)): # Reduce all dims. grad = array_ops.reshape(grad, [1] * rank) # If shape is not fully defined (but rank is), we use Shape. - if op.inputs[0].get_shape().is_fully_defined(): - input_shape = op.inputs[0].get_shape().as_list() + if None not in input_0_shape: + input_shape = input_0_shape else: input_shape = array_ops.shape(op.inputs[0]) return [array_ops.tile(grad, input_shape), None] @@ -96,9 +97,12 @@ def _MinGrad(op, grad): def _MeanGrad(op, grad): """Gradient for Mean.""" sum_grad = _SumGrad(op, grad)[0] - input_size = op.inputs[0].get_shape().num_elements() - output_size = op.outputs[0].get_shape().num_elements() - if input_size is not None and output_size is not None: + input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access + output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access + if (input_shape is not None and output_shape is not None and + None not in input_shape and None not in output_shape): + input_size = np.prod(input_shape) + output_size = np.prod(output_shape) factor = input_size // max(output_size, 1) factor = constant_op.constant(factor, dtype=sum_grad.dtype) else: @@ -106,7 +110,7 @@ def _MeanGrad(op, grad): output_shape = array_ops.shape(op.outputs[0]) factor = _safe_shape_div( math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape)) - return sum_grad / math_ops.cast(factor, sum_grad.dtype), None + return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None @ops.RegisterGradient("Prod") @@ -169,8 +173,7 @@ def _SegmentMeanGrad(op, grad): array_ops.shape(op.inputs[1]), array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1) ], 0) - ones = array_ops.fill(ones_shape, - constant_op.constant(1, dtype=grad.dtype)) + ones = array_ops.fill(ones_shape, constant_op.constant(1, dtype=grad.dtype)) scaled_grad = math_ops.div(grad, math_ops.segment_sum(ones, op.inputs[1])) return array_ops.gather(scaled_grad, op.inputs[1]), None @@ -225,53 +228,142 @@ def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad): dim0), None, None, None) -def _SegmentMinOrMaxGrad(op, grad, is_sorted): - """Gradient for SegmentMin and (unsorted) SegmentMax. They share similar code.""" - zeros = array_ops.zeros(array_ops.shape(op.inputs[0]), - dtype=op.inputs[0].dtype) - +def _SegmentMinOrMaxGrad(op, grad): + """ Gradient for SegmentMin and SegmentMax. """ + zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype) # Get the number of selected (minimum or maximum) elements in each segment. gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1]) is_selected = math_ops.equal(op.inputs[0], gathered_outputs) - if is_sorted: - num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype), - op.inputs[1]) - else: - num_selected = math_ops.unsorted_segment_sum( - math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2]) - + num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype), + op.inputs[1]) # Compute the gradient for each segment. The gradient for the ith segment is # divided evenly among the selected elements in that segment. weighted_grads = math_ops.div(grad, num_selected) gathered_grads = array_ops.gather(weighted_grads, op.inputs[1]) - - if is_sorted: - return array_ops.where(is_selected, gathered_grads, zeros), None - else: - return array_ops.where(is_selected, gathered_grads, zeros), None, None + return array_ops.where(is_selected, gathered_grads, zeros), None @ops.RegisterGradient("SegmentMin") def _SegmentMinGrad(op, grad): """Gradient for SegmentMin.""" - return _SegmentMinOrMaxGrad(op, grad, True) + return _SegmentMinOrMaxGrad(op, grad) @ops.RegisterGradient("SegmentMax") def _SegmentMaxGrad(op, grad): """Gradient for SegmentMax.""" - return _SegmentMinOrMaxGrad(op, grad, True) + return _SegmentMinOrMaxGrad(op, grad) + + +def _GatherDropNegatives(params, ids, zero_clipped_indices=None, + is_positive=None): + """ Helper function for unsorted segment ops. Gathers params for + positive segment ids and gathers 0 for inputs with negative segment id. + Also returns the clipped indices and a boolean mask with the same shape + as ids where a positive id is masked as true. With this, the latter two + can be passed as arguments to this function to reuse them. + """ + if zero_clipped_indices is None: + zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids)) + gathered = array_ops.gather(params, zero_clipped_indices) + if is_positive is None: + is_positive = math_ops.greater_equal(ids, 0) + # tf.where(condition, x, y) requires condition to have the same shape as x + # and y. + # todo(philjd): remove this if tf.where supports broadcasting (#9284) + for _ in range(gathered.shape.ndims - is_positive.shape.ndims): + is_positive = array_ops.expand_dims(is_positive, -1) + is_positive = (is_positive & + array_ops.ones_like(gathered, dtype=dtypes.bool)) + # replace gathered params of negative indices with 0 + zero_slice = array_ops.zeros_like(gathered) + return (array_ops.where(is_positive, gathered, zero_slice), + zero_clipped_indices, is_positive) + + +def _UnsortedSegmentMinOrMaxGrad(op, grad): + """ Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """ + # Get the number of selected (minimum or maximum) elements in each segment. + gathered_outputs, zero_clipped_indices, is_positive = \ + _GatherDropNegatives(op.outputs[0], op.inputs[1]) + is_selected = math_ops.equal(op.inputs[0], gathered_outputs) + is_selected = math_ops.logical_and(is_selected, is_positive) + num_selected = math_ops.unsorted_segment_sum( + math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2]) + # Compute the gradient for each segment. The gradient for the ith segment is + # divided evenly among the selected elements in that segment. + weighted_grads = math_ops.div(grad, num_selected) + gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, + zero_clipped_indices, + is_positive) + zeros = array_ops.zeros_like(gathered_grads) + return array_ops.where(is_selected, gathered_grads, zeros), None, None @ops.RegisterGradient("UnsortedSegmentSum") def _UnsortedSegmentSumGrad(op, grad): - """Gradient for SegmentSum.""" - return array_ops.gather(grad, op.inputs[1]), None, None + """Gradient for UnsortedSegmentSum.""" + return _GatherDropNegatives(grad, op.inputs[1])[0], None, None @ops.RegisterGradient("UnsortedSegmentMax") def _UnsortedSegmentMaxGrad(op, grad): - return _SegmentMinOrMaxGrad(op, grad, False) + """ Gradient for UnsortedSegmentMax. """ + return _UnsortedSegmentMinOrMaxGrad(op, grad) + + +@ops.RegisterGradient("UnsortedSegmentMin") +def _UnsortedSegmentMinGrad(op, grad): + """ Gradient for UnsortedSegmentMin. """ + return _UnsortedSegmentMinOrMaxGrad(op, grad) + + +@ops.RegisterGradient("UnsortedSegmentProd") +def _UnsortedSegmentProdGrad(op, grad): + """ Gradient for UnsortedSegmentProd. + The gradient can be expressed for each segment by dividing the segment's + product by each element of the segment input tensor, but this approach can't + deal with zeros in the input. + Unlike reduce_prod we can't use cumsum here as individual segments may have + a different number of elements. Therefore we consider three cases: + 1) A segment input contains no zeros and we can safely divide by the input + tensor. + 2) A segment contains exactly one zero. Then the gradient of each input of + the segment is zero except for the 0-input, there the gradient is + the product of the remaining segment entries. + 3) A segment contains at least two zeros. The gradient is zero for all + segment inputs. + """ + # Note that unsorted_segment_sum will filter out the negative indices, + # so we don't need to do a logical_and with is_positive here + is_zero = math_ops.equal(op.inputs[0], 0) + num_zeros = gen_math_ops.unsorted_segment_sum( + math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2]) + # handle case 3 and set the gradient to 0 for segments with more than one + # 0 as input + grad = array_ops.where(math_ops.greater(num_zeros, 1), + array_ops.zeros_like(grad), grad) + # replace all zeros with ones and compute the unsorted_segment_prod + non_zero_data = array_ops.where(is_zero, array_ops.ones_like(op.inputs[0]), + op.inputs[0]) + non_zero_prod = gen_math_ops.unsorted_segment_prod( + non_zero_data, op.inputs[1], op.inputs[2]) + # clip the indices for gather to be positive + zero_clipped_indices = math_ops.maximum(op.inputs[1], + array_ops.zeros_like(op.inputs[1])) + gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices) + gathered_non_zero_prod = array_ops.gather(non_zero_prod, + zero_clipped_indices) + prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf. + # Now fetch the individual results for segments containing 0 and those that + # don't. is_zero will also fetch results for entries with negative index + # but the following gather_drop_negatives sets the corresponding entry in + # grad to 0 for these + partial_derivative = array_ops.where(is_zero, gathered_non_zero_prod, + prod_divided_by_el) + gathered_grad = _GatherDropNegatives(grad, op.inputs[1], + zero_clipped_indices)[0] + return gathered_grad * partial_derivative, None, None @ops.RegisterGradient("Abs") @@ -330,7 +422,7 @@ def _SquareGrad(op, grad): # Added control dependencies to prevent 2*x from being computed too early. with ops.control_dependencies([grad]): x = math_ops.conj(x) - return grad * (2.0 * x) + return math_ops.multiply(grad, math_ops.multiply(x, 2.0)) @ops.RegisterGradient("Sqrt") @@ -532,8 +624,8 @@ def _IgammaGrad(op, grad): # and Gamma'(a) can grow large. partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a)) # TODO(b/36815900): Mark None return values as NotImplemented - return (None, - array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) + return (None, array_ops.reshape( + math_ops.reduce_sum(partial_x * grad, rx), sx)) @ops.RegisterGradient("Igammac") @@ -559,15 +651,17 @@ def _BetaincGrad(op, grad): # Perform operations in log space before summing, because terms # can grow large. - log_beta = (gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) - - gen_math_ops.lgamma(a + b)) - partial_x = math_ops.exp( - (b - 1) * math_ops.log(1 - x) + (a - 1) * math_ops.log(x) - log_beta) + log_beta = ( + gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) - + gen_math_ops.lgamma(a + b)) + partial_x = math_ops.exp((b - 1) * math_ops.log(1 - x) + + (a - 1) * math_ops.log(x) - log_beta) # TODO(b/36815900): Mark None return values as NotImplemented - return (None, # da - None, # db - array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) + return ( + None, # da + None, # db + array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) @ops.RegisterGradient("Zeta") @@ -731,10 +825,8 @@ def _ShapesFullySpecifiedAndEqual(x, y, grad): y_shape = y._shape_tuple() grad_shape = grad._shape_tuple() # pylint: enable=protected-access - return (x_shape == y_shape and - x_shape == grad_shape and - x_shape is not None and - None not in x_shape) + return (x_shape == y_shape and x_shape == grad_shape and + x_shape is not None and None not in x_shape) @ops.RegisterGradient("Add") @@ -756,8 +848,12 @@ def _AddGrad(op, grad): @ops.RegisterGradient("Sub") def _SubGrad(op, grad): + """Gradient for Sub.""" x = op.inputs[0] y = op.inputs[1] + if (isinstance(grad, ops.Tensor) and + _ShapesFullySpecifiedAndEqual(x, y, grad)): + return grad, -grad sx = array_ops.shape(x) sy = array_ops.shape(y) # pylint: disable=protected-access @@ -781,11 +877,13 @@ def _MulGrad(op, grad): sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - # pylint: enable=protected-access x = math_ops.conj(x) y = math_ops.conj(y) - return (array_ops.reshape(math_ops.reduce_sum(grad * y, rx), sx), - array_ops.reshape(math_ops.reduce_sum(x * grad, ry), sy)) + return (array_ops.reshape( + math_ops.reduce_sum(gen_math_ops._mul(grad, y), rx), sx), + array_ops.reshape( + math_ops.reduce_sum(gen_math_ops._mul(x, grad), ry), sy)) + # pylint: enable=protected-access @ops.RegisterGradient("Div") @@ -848,10 +946,10 @@ def _RealDivGrad(op, grad): x = math_ops.conj(x) y = math_ops.conj(y) return (array_ops.reshape( - math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), - sx), array_ops.reshape( - math_ops.reduce_sum(grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), - ry), sy)) + math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx), + array_ops.reshape( + math_ops.reduce_sum( + grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy)) @ops.RegisterGradient("Pow") @@ -946,8 +1044,8 @@ def _SelectGrad(op, grad): c = op.inputs[0] x = op.inputs[1] zeros = array_ops.zeros_like(x) - return (None, array_ops.where(c, grad, zeros), - array_ops.where(c, zeros, grad)) + return (None, array_ops.where(c, grad, zeros), array_ops.where( + c, zeros, grad)) @ops.RegisterGradient("MatMul") @@ -958,18 +1056,20 @@ def _MatMulGrad(op, grad): t_b = op.get_attr("transpose_b") a = math_ops.conj(op.inputs[0]) b = math_ops.conj(op.inputs[1]) + # pylint: disable=protected-access if not t_a and not t_b: - grad_a = math_ops.matmul(grad, b, transpose_b=True) - grad_b = math_ops.matmul(a, grad, transpose_a=True) + grad_a = gen_math_ops._mat_mul(grad, b, transpose_b=True) + grad_b = gen_math_ops._mat_mul(a, grad, transpose_a=True) elif not t_a and t_b: - grad_a = math_ops.matmul(grad, b) - grad_b = math_ops.matmul(grad, a, transpose_a=True) + grad_a = gen_math_ops._mat_mul(grad, b) + grad_b = gen_math_ops._mat_mul(grad, a, transpose_a=True) elif t_a and not t_b: - grad_a = math_ops.matmul(b, grad, transpose_b=True) - grad_b = math_ops.matmul(a, grad) + grad_a = gen_math_ops._mat_mul(b, grad, transpose_b=True) + grad_b = gen_math_ops._mat_mul(a, grad) elif t_a and t_b: - grad_a = math_ops.matmul(b, grad, transpose_a=True, transpose_b=True) - grad_b = math_ops.matmul(grad, a, transpose_a=True, transpose_b=True) + grad_a = gen_math_ops._mat_mul(b, grad, transpose_a=True, transpose_b=True) + grad_b = gen_math_ops._mat_mul(grad, a, transpose_a=True, transpose_b=True) + # pylint: enable=protected-access return grad_a, grad_b @@ -1009,21 +1109,20 @@ def _SparseMatMulGrad(op, grad): dtype_a = op.inputs[0].dtype dtype_b = op.inputs[1].dtype if not t_a and not t_b: - return (_SparseMatMul( - grad, op.inputs[1], dtype_a, transpose_b=True), _SparseMatMul( - op.inputs[0], grad, dtype_b, transpose_a=True)) + return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True), + _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True)) elif not t_a and t_b: - return (_SparseMatMul(grad, op.inputs[1], dtype_a), _SparseMatMul( - grad, op.inputs[0], dtype_b, transpose_a=True)) + return (_SparseMatMul(grad, op.inputs[1], dtype_a), + _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True)) elif t_a and not t_b: - return (_SparseMatMul( - op.inputs[1], grad, dtype_a, transpose_b=True), + return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True), _SparseMatMul(op.inputs[0], grad, dtype_b)) elif t_a and t_b: return (_SparseMatMul( - op.inputs[1], grad, dtype_a, transpose_a=True, - transpose_b=True), _SparseMatMul( - grad, op.inputs[0], dtype_b, transpose_a=True, transpose_b=True)) + op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True), + _SparseMatMul( + grad, op.inputs[0], dtype_b, transpose_a=True, + transpose_b=True)) @ops.RegisterGradient("Floor") @@ -1127,8 +1226,8 @@ def _ComplexAbsGrad(op, grad): """Returns the gradient of ComplexAbs.""" # TODO(b/27786104): The cast to complex could be removed once arithmetic # supports mixtures of complex64 and real values. - return (math_ops.complex(grad, array_ops.zeros_like(grad)) * - math_ops.sign(op.inputs[0])) + return (math_ops.complex(grad, array_ops.zeros_like(grad)) * math_ops.sign( + op.inputs[0])) @ops.RegisterGradient("Cast") @@ -1158,8 +1257,8 @@ def _CumsumGrad(op, grad): exclusive = op.get_attr("exclusive") reverse = op.get_attr("reverse") return [ - math_ops.cumsum( - grad, axis, exclusive=exclusive, reverse=not reverse), None + math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse), + None ] diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index b8e8207bb24ad64d9e07a4585501a10741f5c9ab..da9957aa2a5463a37bba155597600a340ee4f1e6 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -131,6 +131,9 @@ See the @{$python/math_ops} guide. @@segment_mean @@unsorted_segment_sum @@unsorted_segment_max +@@unsorted_segment_min +@@unsorted_segment_prod +@@unsorted_segment_sqrt_n @@sparse_segment_sum @@sparse_segment_mean @@sparse_segment_sqrt_n @@ -237,7 +240,7 @@ def argmin(input, # pylint: disable=anomalous-backslash-in-string,protected-access # pylint: disable=g-docstring-has-escape @tf_export("abs") -def abs(x, name=None): +def abs(x, name=None): # pylint: disable=redefined-builtin r"""Computes the absolute value of a tensor. Given a tensor `x` of complex numbers, this operation returns a tensor of type @@ -542,7 +545,7 @@ def scalar_mul(scalar, x): @tf_export("pow") -def pow(x, y, name=None): +def pow(x, y, name=None): # pylint: disable=redefined-builtin r"""Computes the power of one value to another. Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for @@ -712,7 +715,7 @@ def angle(input, name=None): @tf_export("round") -def round(x, name=None): +def round(x, name=None): # pylint: disable=redefined-builtin """Rounds the values of a tensor to the nearest integer, element-wise. Rounds half to even. Also known as bankers rounding. If you want to round @@ -1207,7 +1210,7 @@ ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal) @tf_export("range") -def range(start, limit=None, delta=1, dtype=None, name="range"): +def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disable=redefined-builtin """Creates a sequence of numbers. Creates a sequence of numbers that begins at `start` and extends by @@ -1841,12 +1844,11 @@ def reduce_logsumexp(input_tensor, reduce_sum( gen_math_ops.exp(input_tensor - my_max), axis, - keepdims=True, - reduction_indices=reduction_indices)) + my_max + keepdims=keepdims, + reduction_indices=reduction_indices)) if not keepdims: - if isinstance(axis, int): - axis = [axis] - result = array_ops.squeeze(result, axis) + my_max = array_ops.reshape(my_max, array_ops.shape(result)) + result += my_max return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result) @@ -2553,12 +2555,93 @@ def reduced_shape(input_shape, axes): ]) # [1, 1] +def _unsorted_segment_N(data, segment_ids, num_segments): + """ Helper function for unsorted_segment_mean/_sqrtN. Computes the number + of segment entries with 0-entries set to 1 to allow division by N. + """ + # bincount doesn't support negative indices so we use unsorted_segment_sum + ones_tensor = array_ops.ones(segment_ids.shape, dtype=data.dtype) + N = gen_math_ops.unsorted_segment_sum(ones_tensor, segment_ids, num_segments) + # add dimensions for all non-reduced axes + ndims_output = data.shape.ndims - segment_ids.shape.ndims + broadcast_shape = [num_segments] + [1] * ndims_output + N = array_ops.reshape(N, broadcast_shape) + return gen_math_ops.maximum(N, 1) + + +@tf_export("unsorted_segment_mean") +def unsorted_segment_mean(data, segment_ids, num_segments, name=None): + r""" Computes the mean along segments of a tensor. + + Read @{$math_ops#segmentation$the section on segmentation} for an explanation + of segments. + + This operator is similar to the unsorted segment sum operator found + [here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). + Instead of computing the sum over segments, it computes the mean of all + entries belonging to a segment such that: + + \\(output_i = 1/N_i \sum data_j\\) where the sum is over `j` such + that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences + of id \\i\\. + + If there is no entry for a given segment ID `i`, it outputs 0. + + segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s + first dimension. + + output: Has same shape as data, except for dimension 0 which + has size `num_segments`. + """ + with ops.name_scope(name, "UnsortedSegmentMean"): + data = ops.convert_to_tensor(data) + segment_ids = ops.convert_to_tensor(segment_ids) + N = _unsorted_segment_N(data, segment_ids, num_segments) + summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments) + return summed / N + + +@tf_export("unsorted_segment_sqrt_n") +def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None): + r"""Computes the sum along segments of a tensor divided by the sqrt(N). + + Read @{$math_ops#segmentation$the section on segmentation} for an explanation + of segments. + + This operator is similar to the unsorted segment sum operator found + [here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). + Additionally to computing the sum over segments, it divides the results by + sqrt(N). + + \\(output_i = 1/sqrt(N_i) \sum data_j\\) where the sum is over `j` such + that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences + of id \\i\\. + + If there is no entry for a given segment ID `i`, it outputs 0. + + Note that this op only supports floating point and complex dtypes, + due to tf.sqrt only supporting these types. + + segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s + first dimension. + + output: Has same shape as data, except for dimension 0 which + has size `num_segments`. + """ + with ops.name_scope(name, "UnsortedSegmentSqrtN"): + data = ops.convert_to_tensor(data) + segment_ids = ops.convert_to_tensor(segment_ids) + N = _unsorted_segment_N(data, segment_ids, num_segments) + summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments) + return summed / gen_math_ops.sqrt(N) + + @tf_export("sparse_segment_sum") def sparse_segment_sum(data, indices, segment_ids, name=None, num_segments=None): r"""Computes the sum along sparse segments of a tensor. - Read @{$math_ops#segmentation$the section on segmentation} for an explanation + Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of segments. Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first @@ -2633,7 +2716,7 @@ def sparse_segment_mean(data, indices, segment_ids, name=None, num_segments=None): r"""Computes the mean along sparse segments of a tensor. - Read @{$math_ops#segmentation$the section on segmentation} for an explanation + Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of segments. Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index bd26ff66961c858865c8a61469abac0b783ed645..d314124ccd9bc8b7676e6926830a8eb1e0315f5f 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -105,7 +105,7 @@ class LogSumExpTest(test_util.TensorFlowTestCase): for dtype in [np.float16, np.float32, np.double]: x_np = np.random.rand(5, 5).astype(dtype) with self.test_session(use_gpu=True): - y_tf_np = math_ops.reduce_logsumexp(x_np, keep_dims=True).eval() + y_tf_np = math_ops.reduce_logsumexp(x_np, keepdims=True).eval() self.assertEqual(y_tf_np.ndim, x_np.ndim) y_np = log(np.sum(exp(x_np), keepdims=True)) self.assertAllClose(y_tf_np, y_np) diff --git a/tensorflow/python/ops/matmul_benchmark.py b/tensorflow/python/ops/matmul_benchmark.py index f95cf08de1aaa47550fa344dc9f964c4f812cd68..6e5fe74290a219d07945998be2677176ca693cd9 100644 --- a/tensorflow/python/ops/matmul_benchmark.py +++ b/tensorflow/python/ops/matmul_benchmark.py @@ -95,8 +95,8 @@ class MatmulBenchmark(test.Benchmark): num_items = n * m * k * 2 throughput = num_items * num_iters / duration / 1e9 print('%s %s input_info:%s %d %.4fsec, %.4fGitems/s.' % - (device, str(dtype), str(n) + 'x' + str(m) + 'x' + str(k) + ',ta:' - + str(transpose_a) + '.tb:' + str(transpose_b), num_iters, + (device, str(dtype), str(n) + 'x' + str(m) + 'x' + str(k) + + ',ta:' + str(transpose_a) + '.tb:' + str(transpose_b), num_iters, duration, throughput)) name_template = ('matmul_{device}_{dtype}_input_info_{inputinfo}') @@ -112,7 +112,8 @@ class MatmulBenchmark(test.Benchmark): return duration def run_test_gpu(self, n, m, k, transpose_a, transpose_b, dtype, num_iters): - self.run_graph(test.gpu_device_name(), n, m, k, transpose_a, transpose_b, num_iters, dtype) + self.run_graph(test.gpu_device_name(), n, m, k, transpose_a, transpose_b, + num_iters, dtype) def test_round(self, num_iters): dtypes = [np.float32, np.float64] @@ -124,8 +125,8 @@ class MatmulBenchmark(test.Benchmark): self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters) for n, m, k, (transpose_a, transpose_b) in itertools.product( - [200], [1, 8, 20], [10000], [(False, False), (True, False), (False, - True)]): + [200], [1, 8, 20], [10000], [(False, False), (True, False), + (False, True)]): self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters) for (n, m, k), (transpose_a, transpose_b) in itertools.product( diff --git a/tensorflow/python/ops/matmul_benchmark_test.py b/tensorflow/python/ops/matmul_benchmark_test.py index 5a9c0a7a4951bbbc1d201f6fbc557e9a996a3655..3df0c66ef9c50909dd8c03b75654d6cf0fd7d709 100644 --- a/tensorflow/python/ops/matmul_benchmark_test.py +++ b/tensorflow/python/ops/matmul_benchmark_test.py @@ -33,11 +33,11 @@ def BuildGraphTest(n, m, k, transpose_a, transpose_b, dtype): def Test(self): if not googletest.is_gpu_available(): - tf_logging.info("Skipping BuildGraphTest %s", (n, m, k, transpose_a, - transpose_b)) + tf_logging.info("Skipping BuildGraphTest %s", + (n, m, k, transpose_a, transpose_b)) return - tf_logging.info("Testing BuildGraphTest %s", (n, m, k, transpose_a, - transpose_b)) + tf_logging.info("Testing BuildGraphTest %s", + (n, m, k, transpose_a, transpose_b)) self._VerifyBuildGraph(n, m, k, transpose_a, transpose_b, dtype) return Test @@ -47,11 +47,11 @@ def RunGraphTest(n, m, k, transpose_a, transpose_b, dtype): def Test(self): if not googletest.is_gpu_available(): - tf_logging.info("Skipping RunGraphTest %s", (n, m, k, transpose_a, - transpose_b)) + tf_logging.info("Skipping RunGraphTest %s", + (n, m, k, transpose_a, transpose_b)) return - tf_logging.info("Testing RunGraphTest %s", (n, m, k, transpose_a, - transpose_b)) + tf_logging.info("Testing RunGraphTest %s", + (n, m, k, transpose_a, transpose_b)) self._VerifyRunGraph(n, m, k, transpose_a, transpose_b, dtype) return Test @@ -71,40 +71,41 @@ class MatmulBenchmarkTest(googletest.TestCase): def _VerifyBuildGraph(self, n, m, k, transpose_a, transpose_b, dtype): graph = ops.Graph() with graph.as_default(): - matmul_benchmark.build_graph(googletest.gpu_device_name(), n, m, k, transpose_a, transpose_b, - dtype) + matmul_benchmark.build_graph(googletest.gpu_device_name(), n, m, k, + transpose_a, transpose_b, dtype) gd = graph.as_graph_def() - dev=googletest.gpu_device_name() + dev = googletest.gpu_device_name() proto_expected = """ - node { name: "random_uniform/shape" op: "Const" device: \""""+ dev +"""\" } - node { name: "random_uniform/min" op: "Const" device: \""""+ dev +"""\" } - node { name: "random_uniform/max" op: "Const" device: \""""+ dev +"""\" } - node { name: "random_uniform/RandomUniform" op: "RandomUniform" input: "random_uniform/shape" device: \""""+ dev +"""\" } - node { name: "random_uniform/sub" op: "Sub" input: "random_uniform/max" input: "random_uniform/min" device: \""""+ dev +"""\" } - node { name: "random_uniform/mul" op: "Mul" input: "random_uniform/RandomUniform" input: "random_uniform/sub" device: \""""+ dev +"""\" } - node { name: "random_uniform" op: "Add" input: "random_uniform/mul" input: "random_uniform/min" device: \""""+ dev +"""\" } - node { name: "Variable" op: "VariableV2" device: \""""+ dev +"""\" } - node { name: "Variable/Assign" op: "Assign" input: "Variable" input: "random_uniform" device: \""""+ dev +"""\" } - node { name: "Variable/read" op: "Identity" input: "Variable" device: \""""+ dev +"""\" } - node { name: "random_uniform_1/shape" op: "Const" device: \""""+ dev +"""\" } - node { name: "random_uniform_1/min" op: "Const" device: \""""+ dev +"""\" } - node { name: "random_uniform_1/max" op: "Const" device: \""""+ dev +"""\" } - node { name: "random_uniform_1/RandomUniform" op: "RandomUniform" input: "random_uniform_1/shape" device: \""""+ dev +"""\" } - node { name: "random_uniform_1/sub" op: "Sub" input: "random_uniform_1/max" input: "random_uniform_1/min" device: \""""+ dev +"""\" } - node { name: "random_uniform_1/mul" op: "Mul" input: "random_uniform_1/RandomUniform" input: "random_uniform_1/sub" device: \""""+ dev +"""\" } - node { name: "random_uniform_1" op: "Add" input: "random_uniform_1/mul" input: "random_uniform_1/min" device: \""""+ dev +"""\" } - node { name: "Variable_1" op: "VariableV2" device: \""""+ dev +"""\" } - node { name: "Variable_1/Assign" op: "Assign" input: "Variable_1" input: "random_uniform_1" device: \""""+ dev +"""\" } - node { name: "Variable_1/read" op: "Identity" input: "Variable_1" device: \""""+ dev +"""\" } - node { name: "MatMul" op: "MatMul" input: "Variable/read" input: "Variable_1/read" device: \""""+ dev +"""\" } - node { name: "group_deps" op: "NoOp" input: "^MatMul" device: \""""+ dev +"""\" } + node { name: "random_uniform/shape" op: "Const" device: \"""" + dev + """\" } + node { name: "random_uniform/min" op: "Const" device: \"""" + dev + """\" } + node { name: "random_uniform/max" op: "Const" device: \"""" + dev + """\" } + node { name: "random_uniform/RandomUniform" op: "RandomUniform" input: "random_uniform/shape" device: \"""" + dev + """\" } + node { name: "random_uniform/sub" op: "Sub" input: "random_uniform/max" input: "random_uniform/min" device: \"""" + dev + """\" } + node { name: "random_uniform/mul" op: "Mul" input: "random_uniform/RandomUniform" input: "random_uniform/sub" device: \"""" + dev + """\" } + node { name: "random_uniform" op: "Add" input: "random_uniform/mul" input: "random_uniform/min" device: \"""" + dev + """\" } + node { name: "Variable" op: "VariableV2" device: \"""" + dev + """\" } + node { name: "Variable/Assign" op: "Assign" input: "Variable" input: "random_uniform" device: \"""" + dev + """\" } + node { name: "Variable/read" op: "Identity" input: "Variable" device: \"""" + dev + """\" } + node { name: "random_uniform_1/shape" op: "Const" device: \"""" + dev + """\" } + node { name: "random_uniform_1/min" op: "Const" device: \"""" + dev + """\" } + node { name: "random_uniform_1/max" op: "Const" device: \"""" + dev + """\" } + node { name: "random_uniform_1/RandomUniform" op: "RandomUniform" input: "random_uniform_1/shape" device: \"""" + dev + """\" } + node { name: "random_uniform_1/sub" op: "Sub" input: "random_uniform_1/max" input: "random_uniform_1/min" device: \"""" + dev + """\" } + node { name: "random_uniform_1/mul" op: "Mul" input: "random_uniform_1/RandomUniform" input: "random_uniform_1/sub" device: \"""" + dev + """\" } + node { name: "random_uniform_1" op: "Add" input: "random_uniform_1/mul" input: "random_uniform_1/min" device: \"""" + dev + """\" } + node { name: "Variable_1" op: "VariableV2" device: \"""" + dev + """\" } + node { name: "Variable_1/Assign" op: "Assign" input: "Variable_1" input: "random_uniform_1" device: \"""" + dev + """\" } + node { name: "Variable_1/read" op: "Identity" input: "Variable_1" device: \"""" + dev + """\" } + node { name: "MatMul" op: "MatMul" input: "Variable/read" input: "Variable_1/read" device: \"""" + dev + """\" } + node { name: "group_deps" op: "NoOp" input: "^MatMul" device: \"""" + dev + """\" } """ self.assertProtoEquals(str(proto_expected), self._StripGraph(gd)) def _VerifyRunGraph(self, n, m, k, transpose_a, transpose_b, dtype): benchmark_instance = matmul_benchmark.MatmulBenchmark() - duration = benchmark_instance.run_graph(googletest.gpu_device_name(), n, m, k, transpose_a, - transpose_b, 1, dtype) + duration = benchmark_instance.run_graph(googletest.gpu_device_name(), n, m, + k, transpose_a, transpose_b, 1, + dtype) self.assertTrue(duration > 1e-6) @@ -113,8 +114,8 @@ if __name__ == "__main__": index = 0 for _dtype in dtypes: for _n, _m, (_transpose_a, _transpose_b) in itertools.product( - [512, 1024], [1, 8, 16, 128], [(False, False), (True, False), (False, - True)]): + [512, 1024], [1, 8, 16, 128], [(False, False), (True, False), + (False, True)]): _k = _n setattr(MatmulBenchmarkTest, "testBuildGraph_" + str(index), BuildGraphTest(_n, _m, _k, _transpose_a, _transpose_b, _dtype)) diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 7776ff08c4f55c43947010f313d8167596b15db7..44c2f304cf9245539e42da2ce54260990de980e0 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -672,7 +672,7 @@ def auc(labels, x = fp_rate y = rec else: # curve == 'PR'. - prec = math_ops.div(tp + epsilon, tp + fp + epsilon) + prec = math_ops.div(tp, tp + fp + epsilon) x = rec y = prec if summation_method == 'trapezoidal': @@ -923,8 +923,8 @@ def mean_per_class_accuracy(labels, weights = array_ops.reshape(weights, [-1]) weights = math_ops.to_float(weights) - is_correct = is_correct * weights - ones = ones * weights + is_correct *= weights + ones *= weights update_total_op = state_ops.scatter_add(total, labels, ones) update_count_op = state_ops.scatter_add(count, labels, is_correct) diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py index fc013b565b764f0d22df29f99e78cb97498c5ced..eebfb17085a568f48769f6df7dddd3ae2f799efc 100644 --- a/tensorflow/python/ops/nn_batchnorm_test.py +++ b/tensorflow/python/ops/nn_batchnorm_test.py @@ -21,10 +21,8 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.core.framework import graph_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -40,15 +38,6 @@ from tensorflow.python.platform import test @test_util.with_c_api class BatchNormalizationTest(test.TestCase): - def SetProducerVersion(self, graph, producer_version): - # The C API doesn't expose altering GraphDefVersions. We can indirectly set - # it via import_graph_def though. - graph_def = graph_pb2.GraphDef() - graph_def.versions.producer = producer_version - with graph.as_default(): - importer.import_graph_def(graph_def) - assert graph.graph_def_versions.producer, producer_version - def _npBatchNorm(self, x, m, v, beta, gamma, epsilon, scale_after_normalization, shift_after_normalization): y = (x - m) / np.sqrt(v + epsilon) @@ -65,7 +54,7 @@ class BatchNormalizationTest(test.TestCase): def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon, scale_after_normalization): """Original implementation.""" - self.SetProducerVersion(ops.get_default_graph(), 8) + test_util.set_producer_version(ops.get_default_graph(), 8) return gen_nn_ops._batch_norm_with_global_normalization( x, m, v, beta, gamma, epsilon, scale_after_normalization) # pylint: enable=protected-access @@ -233,7 +222,7 @@ class BatchNormalizationTest(test.TestCase): epsilon = 0.001 for scale_after_normalization in [True, False]: # _batch_norm_with_global_normalization_grad is deprecated in v9 - self.SetProducerVersion(ops.get_default_graph(), 8) + test_util.set_producer_version(ops.get_default_graph(), 8) grad = gen_nn_ops._batch_norm_with_global_normalization_grad( x, m, v, gamma, backprop, epsilon, scale_after_normalization) dx, dm, dv, db, dg = grad diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py index 0593ed2cfa64eca59ca02904ca71b4fd4936af1b..a08b836025d12178ab7acfbd70fcc7a47bc99532 100644 --- a/tensorflow/python/ops/nn_fused_batchnorm_test.py +++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py @@ -278,7 +278,8 @@ class BatchNormalizationTest(test.TestCase): epsilon = y.op.get_attr('epsilon') data_format = y.op.get_attr('data_format') grad_vals = sess.run([grad_x, grad_scale, grad_offset]) - grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale, pop_mean, pop_var, epsilon, data_format) + grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale, pop_mean, + pop_var, epsilon, data_format) grad_internal_vals = sess.run(list(grad_internal)) for grad_val, grad_internal_val in zip(grad_vals, grad_internal_vals): self.assertAllClose(grad_val, grad_internal_val, atol=err_tolerance) diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index cfff73774b5e585ed702369b9a74ff34e0a5febb..dc24b821a5580e3581f153f3cbf63ad2868b8a18 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -89,52 +89,63 @@ def _Conv2DBackpropFilterGrad(op, grad): @ops.RegisterGradient("Conv3D") def _Conv3DGrad(op, grad): data_format = op.get_attr("data_format") - return [nn_ops.conv3d_backprop_input_v2(array_ops.shape(op.inputs[0]), - op.inputs[1], - grad, - strides=op.get_attr("strides"), - padding=op.get_attr("padding"), - data_format=data_format), - nn_ops.conv3d_backprop_filter_v2(op.inputs[0], - array_ops.shape(op.inputs[1]), - grad, - strides=op.get_attr("strides"), - padding=op.get_attr("padding"), - data_format=data_format)] + return [ + nn_ops.conv3d_backprop_input_v2( + array_ops.shape(op.inputs[0]), + op.inputs[1], + grad, + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=data_format), + nn_ops.conv3d_backprop_filter_v2( + op.inputs[0], + array_ops.shape(op.inputs[1]), + grad, + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=data_format) + ] @ops.RegisterGradient("Conv3DBackpropInputV2") def _Conv3DBackpropInputGrad(op, grad): data_format = op.get_attr("data_format") - return [None, - nn_ops.conv3d_backprop_filter_v2(grad, - array_ops.shape(op.inputs[1]), - op.inputs[2], - strides=op.get_attr("strides"), - padding=op.get_attr("padding"), - data_format=data_format), - nn_ops.conv3d(grad, - op.inputs[1], - strides=op.get_attr("strides"), - padding=op.get_attr("padding"), - data_format=data_format)] + return [ + None, + nn_ops.conv3d_backprop_filter_v2( + grad, + array_ops.shape(op.inputs[1]), + op.inputs[2], + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=data_format), + nn_ops.conv3d( + grad, + op.inputs[1], + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=data_format) + ] @ops.RegisterGradient("Conv3DBackpropFilterV2") def _Conv3DBackpropFilterGrad(op, grad): data_format = op.get_attr("data_format") - return [nn_ops.conv3d_backprop_input_v2(array_ops.shape(op.inputs[0]), - grad, - op.inputs[2], - strides=op.get_attr("strides"), - padding=op.get_attr("padding"), - data_format=data_format), - None, - nn_ops.conv3d(op.inputs[0], - grad, - strides=op.get_attr("strides"), - padding=op.get_attr("padding"), - data_format=data_format)] + return [ + nn_ops.conv3d_backprop_input_v2( + array_ops.shape(op.inputs[0]), + grad, + op.inputs[2], + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=data_format), None, + nn_ops.conv3d( + op.inputs[0], + grad, + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=data_format) + ] @ops.RegisterGradient("AvgPool3D") @@ -150,12 +161,13 @@ def _AvgPool3DGrad(op, grad): @ops.RegisterGradient("AvgPool3DGrad") def _AvgPool3DGradGrad(op, grad): - return (array_ops.stop_gradient(op.inputs[0]), gen_nn_ops.avg_pool3d( - grad, - op.get_attr("ksize"), - op.get_attr("strides"), - op.get_attr("padding"), - data_format=op.get_attr("data_format"))) + return (array_ops.stop_gradient(op.inputs[0]), + gen_nn_ops.avg_pool3d( + grad, + op.get_attr("ksize"), + op.get_attr("strides"), + op.get_attr("padding"), + data_format=op.get_attr("data_format"))) @ops.RegisterGradient("MaxPool3D") @@ -173,9 +185,9 @@ def _MaxPool3DGrad(op, grad): @ops.RegisterGradient("MaxPool3DGrad") def _MaxPool3DGradGrad(op, grad): return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), - dtype=op.inputs[0].dtype), array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), + array_ops.zeros( + shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), gen_nn_ops._max_pool3d_grad_grad( op.inputs[0], op.inputs[1], @@ -189,9 +201,9 @@ def _MaxPool3DGradGrad(op, grad): @ops.RegisterGradient("MaxPool3DGradGrad") def _MaxPool3DGradGradGrad(op, grad): return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), - dtype=op.inputs[0].dtype), array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), + array_ops.zeros( + shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), gen_nn_ops._max_pool3d_grad( op.inputs[0], op.inputs[1], @@ -272,8 +284,9 @@ def _BiasAddGrad(op, received_grad): data_format = op.get_attr("data_format") except ValueError: data_format = None - return (received_grad, gen_nn_ops.bias_add_grad(out_backprop=received_grad, - data_format=data_format)) + return (received_grad, + gen_nn_ops.bias_add_grad( + out_backprop=received_grad, data_format=data_format)) @ops.RegisterGradient("BiasAddGrad") @@ -346,10 +359,9 @@ def _ReluGrad(op, grad): def _EluGradGrad(op, grad): elu_x = op.inputs[1] return (gen_nn_ops._elu_grad(grad, op.outputs[0]), - array_ops.where(elu_x < 0, - grad * op.inputs[0], - array_ops.zeros(shape=array_ops.shape(elu_x), - dtype=elu_x.dtype))) + array_ops.where(elu_x < 0, grad * op.inputs[0], + array_ops.zeros( + shape=array_ops.shape(elu_x), dtype=elu_x.dtype))) @ops.RegisterGradient("SeluGrad") @@ -357,9 +369,11 @@ def _SeluGradGrad(op, grad): x = op.inputs[1] scale_alpha = 1.7580993408473768599402175208123 return (gen_nn_ops._elu_grad(grad, op.outputs[0]), - array_ops.where( - x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + scale_alpha), - array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))) + array_ops.where(x < 0., + gen_nn_ops._elu_grad(grad, + op.outputs[0] + scale_alpha), + array_ops.zeros( + shape=array_ops.shape(x), dtype=x.dtype))) @ops.RegisterGradient("Relu6") @@ -370,8 +384,8 @@ def _Relu6Grad(op, grad): @ops.RegisterGradient("Relu6Grad") def _Relu6GradGrad(op, grad): x = op.inputs[1] - return (gen_nn_ops._relu6_grad(grad, x), array_ops.zeros( - shape=array_ops.shape(x), dtype=x.dtype)) + return (gen_nn_ops._relu6_grad(grad, x), + array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) @ops.RegisterGradient("Elu") @@ -410,8 +424,8 @@ def _SoftsignGrad(op, grad): @ops.RegisterGradient("ReluGrad") def _ReluGradGrad(op, grad): x = op.inputs[1] - return (gen_nn_ops._relu_grad(grad, x), array_ops.zeros( - shape=array_ops.shape(x), dtype=x.dtype)) + return (gen_nn_ops._relu_grad(grad, x), + array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) def _BroadcastMul(vec, mat): @@ -455,8 +469,8 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): softmax = nn_ops.softmax(logits) grad += ((grad_grad - array_ops.squeeze( - math_ops.matmul(grad_grad[:, None, :], - softmax[:, :, None]), axis=1)) * softmax) + math_ops.matmul(grad_grad[:, None, :], softmax[:, :, None]), axis=1)) * + softmax) return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits)) @@ -473,7 +487,8 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _): # so we make sure we prevent silently incorrect results by raising # an error if the second derivative is requested via prevent_gradient. sparse_softmax_grad_without_gradient = array_ops.prevent_gradient( - op.outputs[1], message="Currently there is no way to take the second " + op.outputs[1], + message="Currently there is no way to take the second " "derivative of sparse_softmax_cross_entropy_with_logits due to the fused " "implementation's interaction with tf.gradients()") return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None @@ -531,14 +546,16 @@ def _DepthwiseConv2dNativeGrad(op, grad): @ops.RegisterGradient("Dilation2D") def _Dilation2DGrad(op, grad): - return [nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad, - op.get_attr("strides"), - op.get_attr("rates"), - op.get_attr("padding")), - nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad, - op.get_attr("strides"), - op.get_attr("rates"), - op.get_attr("padding"))] + return [ + nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad, + op.get_attr("strides"), + op.get_attr("rates"), + op.get_attr("padding")), + nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad, + op.get_attr("strides"), + op.get_attr("rates"), + op.get_attr("padding")) + ] @ops.RegisterGradient("LRN") @@ -547,8 +564,10 @@ def _LRNGrad(op, grad): bias = op.get_attr("bias") alpha = op.get_attr("alpha") beta = op.get_attr("beta") - return [gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, - bias, alpha, beta)] + return [ + gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, + bias, alpha, beta) + ] @ops.RegisterGradient("AvgPool") @@ -564,54 +583,58 @@ def _AvgPoolGrad(op, grad): @ops.RegisterGradient("AvgPoolGrad") def _AvgPoolGradGrad(op, grad): - return (array_ops.stop_gradient(op.inputs[0]), gen_nn_ops._avg_pool( - grad, - op.get_attr("ksize"), - op.get_attr("strides"), - op.get_attr("padding"), - data_format=op.get_attr("data_format"))) + return (array_ops.stop_gradient(op.inputs[0]), + gen_nn_ops._avg_pool( + grad, + op.get_attr("ksize"), + op.get_attr("strides"), + op.get_attr("padding"), + data_format=op.get_attr("data_format"))) @ops.RegisterGradient("MaxPool") def _MaxPoolGrad(op, grad): - return gen_nn_ops._max_pool_grad(op.inputs[0], - op.outputs[0], - grad, - op.get_attr("ksize"), - op.get_attr("strides"), - padding=op.get_attr("padding"), - data_format=op.get_attr("data_format")) + return gen_nn_ops._max_pool_grad( + op.inputs[0], + op.outputs[0], + grad, + op.get_attr("ksize"), + op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=op.get_attr("data_format")) @ops.RegisterGradient("MaxPoolV2") def _MaxPoolGradV2(op, grad): ksize = op.inputs[1] strides = op.inputs[2] - return gen_nn_ops.max_pool_grad_v2(op.inputs[0], - op.outputs[0], - grad, - ksize, - strides, - padding=op.get_attr("padding"), - data_format=op.get_attr("data_format")), None, None + return gen_nn_ops.max_pool_grad_v2( + op.inputs[0], + op.outputs[0], + grad, + ksize, + strides, + padding=op.get_attr("padding"), + data_format=op.get_attr("data_format")), None, None @ops.RegisterGradient("MaxPoolWithArgmax") def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): - return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0], - grad, - op.outputs[1], - op.get_attr("ksize"), - op.get_attr("strides"), - padding=op.get_attr("padding")) + return gen_nn_ops._max_pool_grad_with_argmax( + op.inputs[0], + grad, + op.outputs[1], + op.get_attr("ksize"), + op.get_attr("strides"), + padding=op.get_attr("padding")) @ops.RegisterGradient("MaxPoolGrad") def _MaxPoolGradGrad(op, grad): return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), - dtype=op.inputs[0].dtype), array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), + array_ops.zeros( + shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), gen_nn_ops._max_pool_grad_grad( op.inputs[0], op.inputs[1], @@ -627,9 +650,9 @@ def _MaxPoolGradGradV2(op, grad): ksize = op.inputs[3] strides = op.inputs[4] return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), - dtype=op.inputs[0].dtype), array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), + array_ops.zeros( + shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), gen_nn_ops.max_pool_grad_grad_v2( op.inputs[0], op.inputs[1], @@ -643,9 +666,9 @@ def _MaxPoolGradGradV2(op, grad): @ops.RegisterGradient("MaxPoolGradGrad") def _MaxPoolGradGradGrad(op, grad): return (array_ops.zeros( - shape=array_ops.shape(op.inputs[0]), - dtype=op.inputs[0].dtype), array_ops.zeros( - shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), + array_ops.zeros( + shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), gen_nn_ops._max_pool_grad( op.inputs[0], op.inputs[1], @@ -674,10 +697,9 @@ def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): Input backprop for FractionalMaxPool op. """ # pylint: disable=protected-access - return gen_nn_ops._fractional_max_pool_grad(op.inputs[0], op.outputs[0], - grad_0, op.outputs[1], - op.outputs[2], - op.get_attr("overlapping")) + return gen_nn_ops._fractional_max_pool_grad( + op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2], + op.get_attr("overlapping")) @ops.RegisterGradient("FractionalAvgPool") @@ -761,8 +783,9 @@ def _BaseFusedBatchNormGrad(op, use_v2, *grad): epsilon = op.get_attr("epsilon") data_format = op.get_attr("data_format") is_training = op.get_attr("is_training") - grad_fun = (gen_nn_ops.fused_batch_norm_grad_v2 if use_v2 - else gen_nn_ops.fused_batch_norm_grad) + grad_fun = ( + gen_nn_ops.fused_batch_norm_grad_v2 + if use_v2 else gen_nn_ops.fused_batch_norm_grad) if is_training: return grad_fun( grad_y, @@ -786,7 +809,7 @@ def _BaseFusedBatchNormGrad(op, use_v2, *grad): pop_mean, pop_var, epsilon=epsilon, - data_format='NHWC', + data_format="NHWC", is_training=is_training) if data_format == b"NCHW": dx = array_ops.transpose(dx, [0, 3, 1, 2]) @@ -803,18 +826,28 @@ def _FusedBatchNormV2Grad(op, *grad): return _BaseFusedBatchNormGrad(op, True, *grad) -def _BatchNormGrad(grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training=True): +def _BatchNormGrad(grad_y, + x, + scale, + pop_mean, + pop_var, + epsilon, + data_format, + is_training=True): """Returns the gradients for the 3 inputs of BatchNorm. Args: grad_y: A `Tensor` of 4 dimensions for gradient for y. x: A `Tensor` of 4 dimensions for x. scale: A `Tensor` of 1 dimension for scaling. - pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when is_training=False. - pop_var: A `Tensor` of 1 dimension for the population variance. Only used when is_training=False. + pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when + is_training=False. + pop_var: A `Tensor` of 1 dimension for the population variance. Only used + when is_training=False. epsilon: A small float number added to the variance of x. data_format: The data format for input. Either b"NHWC" or b"NCHW". - is_training: A bool value to indicate the operation is for training (default) + is_training: A bool value to indicate the operation is for training + (default) or inference. Returns: @@ -830,27 +863,27 @@ def _BatchNormGrad(grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is grad_y = math_ops.cast(grad_y, dtypes.float32) if is_training: if data_format == b"NHWC": - keep_dims = False + keepdims = False reduce_axis = [0, 1, 2] else: - keep_dims = True + keepdims = True reduce_axis = [0, 2, 3] shape = [1, array_ops.size(scale), 1, 1] scale = array_ops.reshape(scale, shape) - mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keep_dims=keep_dims) - mean_x = math_ops.reduce_mean(x, reduce_axis, keep_dims=keep_dims) + mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims) + mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims) var_x = math_ops.reduce_mean( math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)), reduce_axis, - keep_dims=keep_dims) + keepdims=keepdims) grad_y_offset = grad_y - mean_grad_y x_offset = x - mean_x mean = math_ops.reduce_mean( - grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims) + grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) grad_x = scale * math_ops.rsqrt(var_x + epsilon) * ( grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( - grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims) + grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) if data_format == b"NCHW": grad_scale = array_ops.squeeze(grad_scale) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) @@ -900,7 +933,7 @@ def _FusedBatchNormGradGrad(op, *grad): grad_grad_scale = grad[1] grad_grad_offset = grad[2] grad_x, grad_scale, grad_offset = _BatchNormGrad( - grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training) + grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training) grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset] grad_grad_y, grad_x, grad_scale = gradients_impl.gradients( [grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial) @@ -954,14 +987,15 @@ def _TopKGrad(op, grad, _): # Substitute grad to appropriate locations and fill the rest with zeros, # finally reshaping it to the original input shape. - return [array_ops.reshape( - sparse_ops.sparse_to_dense(ind, - array_ops.reshape( - math_ops.reduce_prod(in_shape), [1]), - array_ops.reshape(grad, [-1]), - validate_indices=False), - in_shape), array_ops.zeros( - [], dtype=dtypes.int32)] + return [ + array_ops.reshape( + sparse_ops.sparse_to_dense( + ind, + array_ops.reshape(math_ops.reduce_prod(in_shape), [1]), + array_ops.reshape(grad, [-1]), + validate_indices=False), in_shape), + array_ops.zeros([], dtype=dtypes.int32) + ] @ops.RegisterGradient("NthElement") @@ -976,18 +1010,16 @@ def _NthElementGrad(op, grad): A list of two tensors, the first being the gradient w.r.t. the input, the second being the gradient w.r.t. the N (None). """ - input = op.inputs[0] + input = op.inputs[0] # pylint: disable=redefined-builtin output = op.outputs[0] # Compute the number of elements which equal to output in each reduction # dimension. If there are multiple elements then the gradient will be # divided between them. indicators = math_ops.cast( - math_ops.equal(array_ops.expand_dims(output, -1), input), - grad.dtype) + math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype) grad = array_ops.expand_dims(grad, -1) - num_selected = array_ops.expand_dims( - math_ops.reduce_sum(indicators, -1), -1) + num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1) return [math_ops.div(indicators, num_selected) * grad, None] diff --git a/tensorflow/python/ops/nn_grad_test.py b/tensorflow/python/ops/nn_grad_test.py index f7541c0e892819beaf27ad97d7d41b8f963a4ab9..49d54beb20073162279576e1e1011e10392378e0 100644 --- a/tensorflow/python/ops/nn_grad_test.py +++ b/tensorflow/python/ops/nn_grad_test.py @@ -24,23 +24,26 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import nn_grad +from tensorflow.python.ops import nn_grad # pylint: disable=unused-import from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test class Relu6OpTest(test.TestCase): + def testRelu6GradGrad(self): - inputs = constant_op.constant([[-2, -1, 1, 3], [5, 7, 8, 9]], - dtype=dtypes.float32) + inputs = constant_op.constant( + [[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32) x_init_value = np.array([[-3.5, -1.5, 2, 4], [4.5, 7.5, 8.5, 11]]) r = nn_ops.relu6(inputs) r_g = gradients_impl.gradients(r, inputs)[0] with self.test_session(): error = gradient_checker.compute_gradient_error( - inputs, inputs.get_shape().as_list(), - r_g, r_g.get_shape().as_list(), - x_init_value=x_init_value) + inputs, + inputs.get_shape().as_list(), + r_g, + r_g.get_shape().as_list(), + x_init_value=x_init_value) self.assertLess(error, 1e-4) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 55fcd176d62009b9c29afb763dc20daf78cdb5d9..5fa5708114fd5cda6afbca78fa0debf68f0252cc 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import candidate_sampling_ops from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_array_ops # pylint: disable=unused-import from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 32b14f86b567ce26334c1594e9ac6f00afd5b9d1..6ab839a503adb228b4c79cce4dc34f58ae017fad 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -41,15 +41,19 @@ from tensorflow.python.ops.gen_nn_ops import * from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export - # Aliases for some automatically-generated names. local_response_normalization = gen_nn_ops.lrn # pylint: disable=protected-access -def _non_atrous_convolution(input, filter, padding, data_format=None, # pylint: disable=redefined-builtin - strides=None, name=None): +def _non_atrous_convolution( + input, # pylint: disable=redefined-builtin + filter, # pylint: disable=redefined-builtin + padding, + data_format=None, # pylint: disable=redefined-builtin + strides=None, + name=None): """Computes sums of N-D convolutions (actually cross correlation). It is required that 1 <= N <= 3. @@ -90,16 +94,17 @@ def _non_atrous_convolution(input, filter, padding, data_format=None, # pylint: """ with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope: - input = ops.convert_to_tensor(input, name="input") + input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin input_shape = input.get_shape() - filter = ops.convert_to_tensor(filter, name="filter") + filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin filter_shape = filter.get_shape() - op = _NonAtrousConvolution(input_shape, - filter_shape=filter_shape, - padding=padding, - data_format=data_format, - strides=strides, - name=scope) + op = _NonAtrousConvolution( + input_shape, + filter_shape=filter_shape, + padding=padding, + data_format=data_format, + strides=strides, + name=scope) return op(input, filter) @@ -119,11 +124,14 @@ class _NonAtrousConvolution(object): name: see _non_atrous_convolution. """ - def __init__(self, - input_shape, - filter_shape, # pylint: disable=redefined-builtin - padding, data_format=None, - strides=None, name=None): + def __init__( + self, + input_shape, + filter_shape, # pylint: disable=redefined-builtin + padding, + data_format=None, + strides=None, + name=None): filter_shape = filter_shape.with_rank(input_shape.ndims) self.padding = padding self.name = name @@ -137,8 +145,8 @@ class _NonAtrousConvolution(object): if strides is None: strides = [1] * conv_dims elif len(strides) != conv_dims: - raise ValueError("len(strides)=%d, but should be %d" % - (len(strides), conv_dims)) + raise ValueError("len(strides)=%d, but should be %d" % (len(strides), + conv_dims)) if conv_dims == 1: # conv1d uses the 2-d data format names if data_format is None or data_format == "NWC": @@ -177,8 +185,14 @@ class _NonAtrousConvolution(object): # those for gen_nn_ops.conv2d and gen_nn_ops.conv3d. # pylint: disable=redefined-builtin def _conv1d(self, input, filter, strides, padding, data_format, name): - return conv1d(value=input, filters=filter, stride=strides, padding=padding, - data_format=data_format, name=name) + return conv1d( + value=input, + filters=filter, + stride=strides, + padding=padding, + data_format=data_format, + name=name) + # pylint: enable=redefined-builtin def __call__(self, inp, filter): # pylint: disable=redefined-builtin @@ -334,19 +348,20 @@ def with_space_to_batch( ValueError: if `spatial_dims` are invalid. """ - input = ops.convert_to_tensor(input, name="input") + input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin input_shape = input.get_shape() def build_op(num_spatial_dims, padding): return lambda inp, _: op(inp, num_spatial_dims, padding) - new_op = _WithSpaceToBatch(input_shape, - dilation_rate, - padding, - build_op, - filter_shape=filter_shape, - spatial_dims=spatial_dims, - data_format=data_format) + new_op = _WithSpaceToBatch( + input_shape, + dilation_rate, + padding, + build_op, + filter_shape=filter_shape, + spatial_dims=spatial_dims, + data_format=data_format) return new_op(input, None) @@ -377,9 +392,8 @@ class _WithSpaceToBatch(object): spatial_dims=None, data_format=None): """Helper class for _with_space_to_batch.""" - dilation_rate = ops.convert_to_tensor(dilation_rate, - dtypes.int32, - name="dilation_rate") + dilation_rate = ops.convert_to_tensor( + dilation_rate, dtypes.int32, name="dilation_rate") try: rate_shape = dilation_rate.get_shape().with_rank(1) except ValueError: @@ -439,9 +453,7 @@ class _WithSpaceToBatch(object): if const_filter_shape is not None: filter_shape = const_filter_shape self.base_paddings = _with_space_to_batch_base_paddings( - const_filter_shape, - num_spatial_dims, - rate_or_const_rate) + const_filter_shape, num_spatial_dims, rate_or_const_rate) else: self.num_spatial_dims = num_spatial_dims self.rate_or_const_rate = rate_or_const_rate @@ -478,9 +490,7 @@ class _WithSpaceToBatch(object): # shape was not fully defined. filter_shape = array_ops.shape(filter) base_paddings = _with_space_to_batch_base_paddings( - filter_shape, - self.num_spatial_dims, - self.rate_or_const_rate) + filter_shape, self.num_spatial_dims, self.rate_or_const_rate) paddings, crops = array_ops.required_space_to_batch_paddings( input_shape=input_spatial_shape, base_paddings=base_paddings, @@ -491,9 +501,7 @@ class _WithSpaceToBatch(object): paddings = _with_space_to_batch_adjust(paddings, 0, spatial_dims) crops = _with_space_to_batch_adjust(crops, 0, spatial_dims) input_converted = array_ops.space_to_batch_nd( - input=inp, - block_shape=dilation_rate, - paddings=paddings) + input=inp, block_shape=dilation_rate, paddings=paddings) result = self.op(input_converted, filter) @@ -519,17 +527,17 @@ def _with_space_to_batch_base_paddings(filter_shape, num_spatial_dims, # Spatial dimensions of the filters and the upsampled filters in which we # introduce (rate - 1) zeros between consecutive filter values. filter_spatial_shape = filter_shape[:num_spatial_dims] - dilated_filter_spatial_shape = (filter_spatial_shape + - (filter_spatial_shape - 1) * - (rate_or_const_rate - 1)) + dilated_filter_spatial_shape = ( + filter_spatial_shape + (filter_spatial_shape - 1) * + (rate_or_const_rate - 1)) pad_extra_shape = dilated_filter_spatial_shape - 1 # When full_padding_shape is odd, we pad more at end, following the same # convention as conv2d. pad_extra_start = pad_extra_shape // 2 pad_extra_end = pad_extra_shape - pad_extra_start - base_paddings = array_ops.stack([[pad_extra_start[i], pad_extra_end[i]] - for i in range(num_spatial_dims)]) + base_paddings = array_ops.stack( + [[pad_extra_start[i], pad_extra_end[i]] for i in range(num_spatial_dims)]) return base_paddings @@ -623,8 +631,8 @@ def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate): if strides is None: strides = [1] * num_spatial_dims elif len(strides) != num_spatial_dims: - raise ValueError("len(strides)=%d but should be %d" % - (len(strides), num_spatial_dims)) + raise ValueError("len(strides)=%d but should be %d" % (len(strides), + num_spatial_dims)) strides = np.array(strides, dtype=np.int32) if np.any(strides < 1): raise ValueError("all values of strides must be positive") @@ -636,9 +644,14 @@ def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate): @tf_export("nn.convolution") -def convolution(input, filter, # pylint: disable=redefined-builtin - padding, strides=None, dilation_rate=None, - name=None, data_format=None): +def convolution( + input, # pylint: disable=redefined-builtin + filter, # pylint: disable=redefined-builtin + padding, + strides=None, + dilation_rate=None, + name=None, + data_format=None): # pylint: disable=line-too-long """Computes sums of N-D convolutions (actually cross-correlation). @@ -753,16 +766,18 @@ def convolution(input, filter, # pylint: disable=redefined-builtin """ # pylint: enable=line-too-long with ops.name_scope(name, "convolution", [input, filter]) as name: - input = ops.convert_to_tensor(input, name="input") + input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin input_shape = input.get_shape() - filter = ops.convert_to_tensor(filter, name="filter") + filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin filter_shape = filter.get_shape() - op = Convolution(input_shape, - filter_shape, - padding, - strides=strides, - dilation_rate=dilation_rate, - name=name, data_format=data_format) + op = Convolution( + input_shape, + filter_shape, + padding, + strides=strides, + dilation_rate=dilation_rate, + name=name, + data_format=data_format) return op(input, filter) @@ -786,8 +801,11 @@ class Convolution(object): def __init__(self, input_shape, filter_shape, - padding, strides=None, dilation_rate=None, - name=None, data_format=None): + padding, + strides=None, + dilation_rate=None, + name=None, + data_format=None): """Helper function for convolution.""" num_total_dims = filter_shape.ndims if num_total_dims is None: @@ -809,17 +827,17 @@ class Convolution(object): if data_format is None or not data_format.startswith("NC"): input_channels_dim = input_shape[num_spatial_dims + 1] - spatial_dims = range(1, num_spatial_dims+1) + spatial_dims = range(1, num_spatial_dims + 1) else: input_channels_dim = input_shape[1] - spatial_dims = range(2, num_spatial_dims+2) + spatial_dims = range(2, num_spatial_dims + 2) - if not input_channels_dim.is_compatible_with(filter_shape[ - num_spatial_dims]): + if not input_channels_dim.is_compatible_with( + filter_shape[num_spatial_dims]): raise ValueError( "number of input channels does not match corresponding dimension of " - "filter, {} != {}".format(input_channels_dim, filter_shape[ - num_spatial_dims])) + "filter, {} != {}".format(input_channels_dim, + filter_shape[num_spatial_dims])) strides, dilation_rate = _get_strides_and_dilation_rate( num_spatial_dims, strides, dilation_rate) @@ -852,14 +870,15 @@ class Convolution(object): @tf_export("nn.pool") -def pool(input, # pylint: disable=redefined-builtin - window_shape, - pooling_type, - padding, - dilation_rate=None, - strides=None, - name=None, - data_format=None): +def pool( + input, # pylint: disable=redefined-builtin + window_shape, + pooling_type, + padding, + dilation_rate=None, + strides=None, + name=None, + data_format=None): # pylint: disable=line-too-long """Performs an N-D pooling operation. @@ -941,9 +960,9 @@ def pool(input, # pylint: disable=redefined-builtin """ # pylint: enable=line-too-long - with ops.name_scope(name, "%s_pool" % - (pooling_type.lower()), [input]) as scope: - input = ops.convert_to_tensor(input, name="input") + with ops.name_scope(name, "%s_pool" % (pooling_type.lower()), + [input]) as scope: + input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin num_spatial_dims = len(window_shape) if num_spatial_dims < 1 or num_spatial_dims > 3: @@ -963,17 +982,18 @@ def pool(input, # pylint: disable=redefined-builtin "strides > window_shape not supported due to inconsistency between " "CPU and GPU implementations") - pooling_ops = {("MAX", 1): max_pool, - ("MAX", 2): max_pool, - ("MAX", 3): max_pool3d, # pylint: disable=undefined-variable - ("AVG", 1): avg_pool, - ("AVG", 2): avg_pool, - ("AVG", 3): avg_pool3d, # pylint: disable=undefined-variable - } + pooling_ops = { + ("MAX", 1): max_pool, + ("MAX", 2): max_pool, + ("MAX", 3): max_pool3d, # pylint: disable=undefined-variable + ("AVG", 1): avg_pool, + ("AVG", 2): avg_pool, + ("AVG", 3): avg_pool3d, # pylint: disable=undefined-variable + } op_key = (pooling_type, num_spatial_dims) if op_key not in pooling_ops: - raise ValueError("%d-D %s pooling is not supported." % - (op_key[1], op_key[0])) + raise ValueError("%d-D %s pooling is not supported." % (op_key[1], + op_key[0])) if data_format is None or not data_format.startswith("NC"): adjusted_window_shape = [1] + list(window_shape) + [1] @@ -1000,12 +1020,13 @@ def pool(input, # pylint: disable=redefined-builtin if num_spatial_dims == 1: converted_input = array_ops.expand_dims(converted_input, spatial_dims[0]) - result = pooling_ops[op_key](converted_input, - adjusted_window_shape, - adjusted_strides, - converted_padding, - name=scope, - **data_format_kwargs) + result = pooling_ops[op_key]( + converted_input, + adjusted_window_shape, + adjusted_strides, + converted_padding, + name=scope, + **data_format_kwargs) if num_spatial_dims == 1: result = array_ops.squeeze(result, [spatial_dims[0]]) return result @@ -1021,7 +1042,9 @@ def pool(input, # pylint: disable=redefined-builtin @tf_export("nn.atrous_conv2d") def atrous_conv2d(value, filters, rate, padding, name=None): - """Atrous convolution (a.k.a. convolution with holes or dilated convolution). + """Atrous convolution (a.k.a. + + convolution with holes or dilated convolution). This function is a simpler wrapper around the more general @{tf.nn.convolution}, and exists only for backwards compatibility. You can @@ -1065,7 +1088,8 @@ def atrous_conv2d(value, filters, rate, padding, name=None): that effectively use atrous convolution in different ways are, among others, [OverFeat: Integrated Recognition, Localization and Detection using Convolutional Networks](http://arxiv.org/abs/1312.6229) and [Fast Image - Scanning with Deep Max-Pooling Convolutional Neural Networks](http://arxiv.org/abs/1302.1700). + Scanning with Deep Max-Pooling Convolutional Neural + Networks](http://arxiv.org/abs/1302.1700). Atrous convolution is also closely related to the so-called noble identities in multi-rate signal processing. @@ -1156,13 +1180,14 @@ def atrous_conv2d(value, filters, rate, padding, name=None): @tf_export("nn.conv2d_transpose") -def conv2d_transpose(value, - filter, # pylint: disable=redefined-builtin - output_shape, - strides, - padding="SAME", - data_format="NHWC", - name=None): +def conv2d_transpose( + value, + filter, # pylint: disable=redefined-builtin + output_shape, + strides, + padding="SAME", + data_format="NHWC", + name=None): """The transpose of `conv2d`. This operation is sometimes called "deconvolution" after [Deconvolutional @@ -1198,7 +1223,7 @@ def conv2d_transpose(value, if data_format not in ("NCHW", "NHWC"): raise ValueError("data_format has to be either NCHW or NHWC.") value = ops.convert_to_tensor(value, name="value") - filter = ops.convert_to_tensor(filter, name="filter") + filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin axis = 3 if data_format == "NHWC" else 1 if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[3]): raise ValueError("input channels does not match filter's input channels, " @@ -1207,15 +1232,16 @@ def conv2d_transpose(value, output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)): - raise ValueError("output_shape must have shape (4,), got {}" - .format(output_shape_.get_shape())) + raise ValueError("output_shape must have shape (4,), got {}".format( + output_shape_.get_shape())) if isinstance(output_shape, (list, np.ndarray)): # output_shape's shape should be == [4] if reached this point. if not filter.get_shape()[2].is_compatible_with(output_shape[axis]): raise ValueError( "output_shape does not match filter's output channels, " - "{} != {}".format(output_shape[axis], filter.get_shape()[2])) + "{} != {}".format(output_shape[axis], + filter.get_shape()[2])) if padding != "VALID" and padding != "SAME": raise ValueError("padding must be either VALID or SAME:" @@ -1281,29 +1307,32 @@ def atrous_conv2d_transpose(value, if not value.get_shape()[3].is_compatible_with(filters.get_shape()[3]): raise ValueError( "value's input channels does not match filters' input channels, " - "{} != {}".format(value.get_shape()[3], filters.get_shape()[3])) + "{} != {}".format(value.get_shape()[3], + filters.get_shape()[3])) if rate < 1: raise ValueError("rate {} cannot be less than one".format(rate)) if rate == 1: - return conv2d_transpose(value, - filters, - output_shape, - strides=[1, 1, 1, 1], - padding=padding, - data_format="NHWC") + return conv2d_transpose( + value, + filters, + output_shape, + strides=[1, 1, 1, 1], + padding=padding, + data_format="NHWC") output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)): - raise ValueError("output_shape must have shape (4,), got {}" - .format(output_shape_.get_shape())) + raise ValueError("output_shape must have shape (4,), got {}".format( + output_shape_.get_shape())) if isinstance(output_shape, (list, np.ndarray)): # output_shape's shape should be == [4] if reached this point. if not filters.get_shape()[2].is_compatible_with(output_shape[3]): raise ValueError( "output_shape does not match filter's output channels, " - "{} != {}".format(output_shape[3], filters.get_shape()[2])) + "{} != {}".format(output_shape[3], + filters.get_shape()[2])) # We have two padding contributions. The first is used for converting "SAME" # to "VALID". The second is required so that the height and width of the @@ -1352,14 +1381,13 @@ def atrous_conv2d_transpose(value, # component. space_to_batch_pad = [[0, pad_bottom_extra], [0, pad_right_extra]] - value = array_ops.space_to_batch(input=value, - paddings=space_to_batch_pad, - block_size=rate) + value = array_ops.space_to_batch( + input=value, paddings=space_to_batch_pad, block_size=rate) - input_sizes = [rate * rate * output_shape[0], - (in_height + pad_bottom_extra) // rate, - (in_width + pad_right_extra) // rate, - output_shape[3]] + input_sizes = [ + rate * rate * output_shape[0], (in_height + pad_bottom_extra) // rate, + (in_width + pad_right_extra) // rate, output_shape[3] + ] value = gen_nn_ops.conv2d_backprop_input( input_sizes=input_sizes, @@ -1373,19 +1401,19 @@ def atrous_conv2d_transpose(value, batch_to_space_crop = [[pad_top, pad_bottom + pad_bottom_extra], [pad_left, pad_right + pad_right_extra]] - return array_ops.batch_to_space(input=value, - crops=batch_to_space_crop, - block_size=rate) + return array_ops.batch_to_space( + input=value, crops=batch_to_space_crop, block_size=rate) @tf_export("nn.conv3d_transpose") -def conv3d_transpose(value, - filter, # pylint: disable=redefined-builtin - output_shape, - strides, - padding="SAME", - data_format="NDHWC", - name=None): +def conv3d_transpose( + value, + filter, # pylint: disable=redefined-builtin + output_shape, + strides, + padding="SAME", + data_format="NDHWC", + name=None): """The transpose of `conv3d`. This operation is sometimes called "deconvolution" after [Deconvolutional @@ -1419,7 +1447,7 @@ def conv3d_transpose(value, with ops.name_scope(name, "conv3d_transpose", [value, filter, output_shape]) as name: value = ops.convert_to_tensor(value, name="value") - filter = ops.convert_to_tensor(filter, name="filter") + filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin axis = 1 if data_format == "NCDHW" else 4 if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]): raise ValueError("input channels does not match filter's input channels, " @@ -1428,27 +1456,29 @@ def conv3d_transpose(value, output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(5)): - raise ValueError("output_shape must have shape (5,), got {}" - .format(output_shape_.get_shape())) + raise ValueError("output_shape must have shape (5,), got {}".format( + output_shape_.get_shape())) if isinstance(output_shape, (list, np.ndarray)): # output_shape's shape should be == [5] if reached this point. if not filter.get_shape()[3].is_compatible_with(output_shape[4]): raise ValueError( "output_shape does not match filter's output channels, " - "{} != {}".format(output_shape[4], filter.get_shape()[3])) + "{} != {}".format(output_shape[4], + filter.get_shape()[3])) if padding != "VALID" and padding != "SAME": raise ValueError("padding must be either VALID or SAME:" " {}".format(padding)) - return gen_nn_ops.conv3d_backprop_input_v2(input_sizes=output_shape_, - filter=filter, - out_backprop=value, - strides=strides, - padding=padding, - data_format=data_format, - name=name) + return gen_nn_ops.conv3d_backprop_input_v2( + input_sizes=output_shape_, + filter=filter, + out_backprop=value, + strides=strides, + padding=padding, + data_format=data_format, + name=name) # pylint: disable=protected-access @@ -1514,7 +1544,9 @@ def crelu(features, name=None, axis=-1): Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the *negative* part of the activation. Note that as a result this non-linearity doubles the depth of the activations. - Source: [Understanding and Improving Convolutional Neural Networks via Concatenated Rectified Linear Units. W. Shang, et al.](https://arxiv.org/abs/1603.05201) + Source: [Understanding and Improving Convolutional Neural Networks via + Concatenated Rectified Linear Units. W. Shang, et + al.](https://arxiv.org/abs/1603.05201) Args: features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`, @@ -1534,7 +1566,9 @@ def crelu(features, name=None, axis=-1): @tf_export("nn.relu6") def relu6(features, name=None): """Computes Rectified Linear 6: `min(max(features, 0), 6)`. - Source: [Convolutional Deep Belief Networks on CIFAR-10. A. Krizhevsky](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf) + + Source: [Convolutional Deep Belief Networks on CIFAR-10. A. + Krizhevsky](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf) Args: features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`, @@ -1622,14 +1656,16 @@ def _softmax(logits, compute_op, dim=-1, name=None): InvalidArgumentError: if `logits` is empty or `dim` is beyond the last dimension of `logits`. """ + def _swap_axis(logits, dim_index, last_index, name=None): """Swaps logits's dim_index and last_index.""" - return array_ops.transpose(logits, - array_ops.concat([ - math_ops.range(dim_index), [last_index], - math_ops.range(dim_index + 1, last_index), - [dim_index] - ], 0), name=name) + return array_ops.transpose( + logits, + array_ops.concat([ + math_ops.range(dim_index), [last_index], + math_ops.range(dim_index + 1, last_index), [dim_index] + ], 0), + name=name) logits = ops.convert_to_tensor(logits) @@ -1746,9 +1782,12 @@ def _ensure_xent_args(name, sentinel, labels, logits): @tf_export("nn.softmax_cross_entropy_with_logits_v2") -def softmax_cross_entropy_with_logits_v2(_sentinel=None, # pylint: disable=invalid-name - labels=None, logits=None, - dim=-1, name=None): +def softmax_cross_entropy_with_logits_v2( + _sentinel=None, # pylint: disable=invalid-name + labels=None, + logits=None, + dim=-1, + name=None): """Computes softmax cross entropy between `logits` and `labels`. Measures the probability error in discrete classification tasks in which the @@ -1790,19 +1829,19 @@ def softmax_cross_entropy_with_logits_v2(_sentinel=None, # pylint: disable=inva A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the softmax cross entropy loss. """ - _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, - labels, logits) + _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, labels, + logits) # TODO(pcmurray) Raise an error when the labels do not sum to 1. Note: This # could break users who call this with bad labels, but disregard the bad # results. - with ops.name_scope( - name, "softmax_cross_entropy_with_logits", [logits, labels]) as name: + with ops.name_scope(name, "softmax_cross_entropy_with_logits", + [logits, labels]) as name: logits = ops.convert_to_tensor(logits, name="logits") labels = ops.convert_to_tensor(labels, name="labels") - precise_logits = math_ops.cast(logits, dtypes.float32) if ( - logits.dtype == dtypes.float16) else logits + precise_logits = math_ops.cast( + logits, dtypes.float32) if (logits.dtype == dtypes.float16) else logits # labels and logits must be of the same type labels = math_ops.cast(labels, precise_logits.dtype) input_rank = array_ops.rank(precise_logits) @@ -1811,13 +1850,14 @@ def softmax_cross_entropy_with_logits_v2(_sentinel=None, # pylint: disable=inva # Move the dim to the end if dim is not the last dimension. if dim is not -1: + def _move_dim_to_end(tensor, dim_index, rank): - return array_ops.transpose(tensor, - array_ops.concat([ - math_ops.range(dim_index), - math_ops.range(dim_index + 1, rank), - [dim_index] - ], 0)) + return array_ops.transpose( + tensor, + array_ops.concat([ + math_ops.range(dim_index), + math_ops.range(dim_index + 1, rank), [dim_index] + ], 0)) precise_logits = _move_dim_to_end(precise_logits, dim, input_rank) labels = _move_dim_to_end(labels, dim, input_rank) @@ -1862,9 +1902,12 @@ See tf.nn.softmax_cross_entropy_with_logits_v2. @tf_export("nn.softmax_cross_entropy_with_logits") @deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION) -def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name - labels=None, logits=None, - dim=-1, name=None): +def softmax_cross_entropy_with_logits( + _sentinel=None, # pylint: disable=invalid-name + labels=None, + logits=None, + dim=-1, + name=None): """Computes softmax cross entropy between `logits` and `labels`. Measures the probability error in discrete classification tasks in which the @@ -1906,11 +1949,11 @@ def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the softmax cross entropy loss. """ - _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, - labels, logits) + _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, labels, + logits) - with ops.name_scope( - name, "softmax_cross_entropy_with_logits_sg", [logits, labels]) as name: + with ops.name_scope(name, "softmax_cross_entropy_with_logits_sg", + [logits, labels]) as name: labels = array_ops.stop_gradient(labels, name="labels_stop_gradient") return softmax_cross_entropy_with_logits_v2( @@ -1918,9 +1961,11 @@ def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid @tf_export("nn.sparse_softmax_cross_entropy_with_logits") -def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name - labels=None, logits=None, - name=None): +def sparse_softmax_cross_entropy_with_logits( + _sentinel=None, # pylint: disable=invalid-name + labels=None, + logits=None, + name=None): """Computes sparse softmax cross entropy between `logits` and `labels`. Measures the probability error in discrete classification tasks in which the @@ -1976,15 +2021,15 @@ def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable= [labels, logits]): labels = ops.convert_to_tensor(labels) logits = ops.convert_to_tensor(logits) - precise_logits = math_ops.cast(logits, dtypes.float32) if ( - dtypes.as_dtype(logits.dtype) == dtypes.float16) else logits + precise_logits = math_ops.cast(logits, dtypes.float32) if (dtypes.as_dtype( + logits.dtype) == dtypes.float16) else logits # Store label shape for result later. labels_static_shape = labels.get_shape() labels_shape = array_ops.shape(labels) if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0: - raise ValueError("Logits cannot be scalars - received shape %s." % - logits.get_shape()) + raise ValueError( + "Logits cannot be scalars - received shape %s." % logits.get_shape()) if logits.get_shape().ndims is not None and ( labels_static_shape.ndims is not None and labels_static_shape.ndims != logits.get_shape().ndims - 1): @@ -2041,12 +2086,13 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None): """ with ops.name_scope(name, "AvgPool", [value]) as name: value = ops.convert_to_tensor(value, name="input") - return gen_nn_ops._avg_pool(value, - ksize=ksize, - strides=strides, - padding=padding, - data_format=data_format, - name=name) + return gen_nn_ops._avg_pool( + value, + ksize=ksize, + strides=strides, + padding=padding, + data_format=data_format, + name=name) @tf_export("nn.max_pool") @@ -2083,8 +2129,8 @@ def _calc_conv_flops(graph, node): """Calculates the compute resources needed for Conv2D.""" input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) input_shape.assert_is_fully_defined() - filter_shape = graph_util.tensor_shape_from_node_def_name(graph, - node.input[1]) + filter_shape = graph_util.tensor_shape_from_node_def_name( + graph, node.input[1]) filter_shape.assert_is_fully_defined() output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) output_shape.assert_is_fully_defined() @@ -2092,8 +2138,9 @@ def _calc_conv_flops(graph, node): filter_width = int(filter_shape[1]) filter_in_depth = int(filter_shape[2]) output_count = np.prod(output_shape.as_list()) - return ops.OpStats("flops", (output_count * filter_in_depth * filter_height * - filter_width * 2)) + return ops.OpStats( + "flops", + (output_count * filter_in_depth * filter_height * filter_width * 2)) @ops.RegisterStatistics("DepthwiseConv2dNative", "flops") @@ -2101,8 +2148,8 @@ def _calc_depthwise_conv_flops(graph, node): """Calculates the compute resources needed for DepthwiseConv2dNative.""" input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) input_shape.assert_is_fully_defined() - filter_shape = graph_util.tensor_shape_from_node_def_name(graph, - node.input[1]) + filter_shape = graph_util.tensor_shape_from_node_def_name( + graph, node.input[1]) filter_shape.assert_is_fully_defined() output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) output_shape.assert_is_fully_defined() @@ -2167,6 +2214,30 @@ def xw_plus_b_v1(x, weights, biases, name=None): # pylint: disable=invalid-name mm = math_ops.matmul(x, weights) return bias_add_v1(mm, biases, name=name) +def _get_noise_shape(x, noise_shape): + # If noise_shape is none return immediately. + if noise_shape is None: + return array_ops.shape(x) + + try: + # Best effort to figure out the intended shape. + # If not possible, let the op to handle it. + # In eager mode exception will show up. + noise_shape_ = tensor_shape.as_shape(noise_shape) + except (TypeError, ValueError): + return noise_shape + + if (x.shape.dims is not None and + len(x.shape.dims) == len(noise_shape_.dims)): + new_dims = [] + for i, dim in enumerate(x.shape.dims): + if noise_shape_.dims[i].value is None and dim.value is not None: + new_dims.append(dim.value) + else: + new_dims.append(noise_shape_.dims[i].value) + return tensor_shape.TensorShape(new_dims) + + return noise_shape @tf_export("nn.dropout") def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name @@ -2210,21 +2281,20 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: raise ValueError("keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) - keep_prob = ops.convert_to_tensor(keep_prob, - dtype=x.dtype, - name="keep_prob") + keep_prob = ops.convert_to_tensor( + keep_prob, dtype=x.dtype, name="keep_prob") keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) # Do nothing if we know keep_prob == 1 if tensor_util.constant_value(keep_prob) == 1: return x - noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x) + noise_shape = _get_noise_shape(x, noise_shape) + # uniform [keep_prob, 1.0 + keep_prob) random_tensor = keep_prob - random_tensor += random_ops.random_uniform(noise_shape, - seed=seed, - dtype=x.dtype) + random_tensor += random_ops.random_uniform( + noise_shape, seed=seed, dtype=x.dtype) # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) binary_tensor = math_ops.floor(random_tensor) ret = math_ops.div(x, keep_prob) * binary_tensor @@ -2234,7 +2304,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di @tf_export("nn.top_k") -def top_k(input, k=1, sorted=True, name=None): +def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-builtin """Finds values and indices of the `k` largest entries for the last dimension. If the input is a vector (rank=1), finds the `k` largest entries in the vector @@ -2263,7 +2333,7 @@ def top_k(input, k=1, sorted=True, name=None): return gen_nn_ops._top_kv2(input, k=k, sorted=sorted, name=name) -def nth_element(input, n, reverse=False, name=None): +def nth_element(input, n, reverse=False, name=None): # pylint: disable=redefined-builtin r"""Finds values of the `n`-th order statistic for the last dmension. If the input is a vector (rank-1), finds the entries which is the nth-smallest @@ -2293,13 +2363,21 @@ def nth_element(input, n, reverse=False, name=None): @tf_export("nn.conv1d") @deprecation.deprecated_arg_values( - None, "`NCHW` for data_format is deprecated, use `NCW` instead", - warn_once=True, data_format="NCHW") + None, + "`NCHW` for data_format is deprecated, use `NCW` instead", + warn_once=True, + data_format="NCHW") @deprecation.deprecated_arg_values( - None, "`NHWC` for data_format is deprecated, use `NWC` instead", - warn_once=True, data_format="NHWC") -def conv1d(value, filters, stride, padding, - use_cudnn_on_gpu=None, data_format=None, + None, + "`NHWC` for data_format is deprecated, use `NWC` instead", + warn_once=True, + data_format="NHWC") +def conv1d(value, + filters, + stride, + padding, + use_cudnn_on_gpu=None, + data_format=None, name=None): r"""Computes a 1-D convolution given 3-D input and filter tensors. @@ -2327,7 +2405,7 @@ def conv1d(value, filters, stride, padding, Args: value: A 3D `Tensor`. Must be of type `float16` or `float32`. - filters: A 3D `Tensor`. Must have the same type as `input`. + filters: A 3D `Tensor`. Must have the same type as `value`. stride: An `integer`. The number of entries by which the filter is moved right at each step. padding: 'SAME' or 'VALID' @@ -2358,9 +2436,13 @@ def conv1d(value, filters, stride, padding, raise ValueError("data_format must be \"NWC\" or \"NCW\".") value = array_ops.expand_dims(value, spatial_start_dim) filters = array_ops.expand_dims(filters, 0) - result = gen_nn_ops.conv2d(value, filters, strides, padding, - use_cudnn_on_gpu=use_cudnn_on_gpu, - data_format=data_format) + result = gen_nn_ops.conv2d( + value, + filters, + strides, + padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format) return array_ops.squeeze(result, [spatial_start_dim]) @@ -2448,7 +2530,7 @@ def conv1d_transpose( spatial_start_dim = 2 strides = [1, 1, 1, stride] value = array_ops.expand_dims(value, spatial_start_dim) - filter = array_ops.expand_dims(filter, 0) + filter = array_ops.expand_dims(filter, 0) # pylint: disable=redefined-builtin result = gen_nn_ops.conv2d_backprop_input( input_sizes=output_shape_, @@ -2466,8 +2548,8 @@ def _calc_dilation2d_flops(graph, node): """Calculates the compute resources needed for Dilation2D.""" input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) input_shape.assert_is_fully_defined() - filter_shape = graph_util.tensor_shape_from_node_def_name(graph, - node.input[1]) + filter_shape = graph_util.tensor_shape_from_node_def_name( + graph, node.input[1]) filter_shape.assert_is_fully_defined() output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) output_shape.assert_is_fully_defined() @@ -2527,12 +2609,13 @@ def erosion2d(value, kernel, strides, rates, padding, name=None): with ops.name_scope(name, "erosion2d", [value, kernel]) as name: # Reduce erosion to dilation by duality. return math_ops.negative( - gen_nn_ops.dilation2d(input=math_ops.negative(value), - filter=array_ops.reverse_v2(kernel, [0, 1]), - strides=strides, - rates=rates, - padding=padding, - name=name)) + gen_nn_ops.dilation2d( + input=math_ops.negative(value), + filter=array_ops.reverse_v2(kernel, [0, 1]), + strides=strides, + rates=rates, + padding=padding, + name=name)) @tf_export("nn.in_top_k") @@ -2565,5 +2648,5 @@ def in_top_k(predictions, targets, k, name=None): Returns: A `Tensor` of type `bool`. Computed Precision at `k` as a `bool Tensor`. """ - with ops.name_scope(name, 'in_top_k'): + with ops.name_scope(name, "in_top_k"): return gen_nn_ops._in_top_kv2(predictions, targets, k, name=name) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 5a45bdc1e5e1d38a34176ed9443fcd1713f38e1e..21eea3db25af0d1bcfbc7496665f5535c3f660ea 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -383,6 +383,31 @@ class DropoutTest(test_lib.TestCase): x, keep_prob, noise_shape=array_ops.placeholder(dtypes.int32)) self.assertEqual(x.get_shape(), dropout_x.get_shape()) + def testPartialShapedDropout(self): + x_dim = 40 * 30 + y_dim = 3 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + with self.test_session(): + t = constant_op.constant( + 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + # Set noise_shape=[None, 1] which means [x_dim, 1]. + dropout = nn_ops.dropout(t, keep_prob, noise_shape=[None, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + final_count = 0 + for _ in xrange(0, num_iter): + value = dropout.eval() + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + def testInvalidKeepProb(self): x_dim = 40 y_dim = 30 diff --git a/tensorflow/python/ops/quantized_conv_ops_test.py b/tensorflow/python/ops/quantized_conv_ops_test.py index 5e9e71002705293403de83276fb70099d8864907..4ac2a8f634bb201c9aaecb74432f2e6e78ee840f 100644 --- a/tensorflow/python/ops/quantized_conv_ops_test.py +++ b/tensorflow/python/ops/quantized_conv_ops_test.py @@ -93,7 +93,8 @@ class Conv2DTest(test.TestCase): quantized_range = ((quantized_max - quantized_min) * range_adjust) range_scale = (quantized_range / number_of_steps) lowest_quantized = -(1 << (number_of_bits - 1)) - result = np.array([(quantized_min + ((float(x) - lowest_quantized) * range_scale)) + result = np.array([(quantized_min + + ((float(x) - lowest_quantized) * range_scale)) for x in quantized.flatten()]) return result diff --git a/tensorflow/python/ops/quantized_ops_test.py b/tensorflow/python/ops/quantized_ops_test.py index 4bf3b35e13879069e40162fc50180520a5f855f6..d590bc4be6d520cbaa000d9802b84cbfbf8e90b9 100644 --- a/tensorflow/python/ops/quantized_ops_test.py +++ b/tensorflow/python/ops/quantized_ops_test.py @@ -34,7 +34,10 @@ class QuantizedOpsTest(test.TestCase): def testQuantizeOp(self): expected_output = [1, 1, 2, 127, 255, 255] with self.test_session(use_gpu=False) as sess: - x = constant_op.constant([1.0, 1.25, 1.75, 127.0, 255.0, 500.0], shape=[6], dtype=dtypes.float32) + x = constant_op.constant( + [1.0, 1.25, 1.75, 127.0, 255.0, 500.0], + shape=[6], + dtype=dtypes.float32) x_min = 0.0 x_max = 255.0 op = array_ops.quantize(x, x_min, x_max, dtypes.quint8, mode="MIN_FIRST") diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 879c206313e476088b388f39a9a112f5cc449152..09d349fc2db61a09649a801a5d4784522b969d38 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import variables # pylint: disable=wildcard-import from tensorflow.python.ops.gen_resource_variable_ops import * # pylint: enable=wildcard-import +from tensorflow.python.training import checkpointable from tensorflow.python.util import compat @@ -107,6 +108,10 @@ class EagerResourceDeleter(object): """ def __init__(self, handle, handle_device): + if not isinstance(handle, ops.Tensor): + raise ValueError( + ("Passed handle=%s to EagerResourceDeleter. Was expecting a handle " + "Tensor." % (handle,))) self._handle = handle self._handle_device = handle_device @@ -344,15 +349,20 @@ class ResourceVariable(variables.Variable): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") + if isinstance(initial_value, checkpointable.CheckpointInitialValue): + self._maybe_initialize_checkpointable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + self._trainable = trainable if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] self._save_slice_info = None - self._in_graph_mode = context.in_graph_mode() - # Save the graph's container prefix for error checking. Reading the value of - # the ResourceVariable from another Graph in Eager mode is an error. - self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access - with ops.control_dependencies(None): + # Store the graph key so optimizers know how to only retrieve variables from + # this graph. + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + with ops.init_scope(): + self._in_graph_mode = context.in_graph_mode() with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access @@ -503,6 +513,14 @@ class ResourceVariable(variables.Variable): self._initializer_op = g.as_graph_element( ops.prepend_name_scope( variable_def.initializer_name, import_scope=import_scope)) + # Check whether initial_value_name exists for backwards compatibility. + if (hasattr(variable_def, "initial_value_name") and + variable_def.initial_value_name): + self._initial_value = g.as_graph_element( + ops.prepend_name_scope(variable_def.initial_value_name, + import_scope=import_scope)) + else: + self._initial_value = None if variable_def.snapshot_name: self._cached_value = g.as_graph_element( ops.prepend_name_scope( @@ -662,15 +680,7 @@ class ResourceVariable(variables.Variable): Returns: the read operation. - Raises: - ValueError: if the ResourceVariable was created in another isolation - environment or graph. """ - if (not self._in_graph_mode and - self._container_prefix != ops.get_default_graph()._container_prefix): # pylint: disable=protected-access - raise ValueError( - "Attempted to read a variable from another isolation environment" - " or Graph") with ops.name_scope("Read"): # Ensure we read the variable in the same device as the handle. with ops.device(self._handle_device): @@ -707,6 +717,12 @@ class ResourceVariable(variables.Variable): var_def = variable_pb2.VariableDef() var_def.variable_name = ops.strip_name_scope(self.handle.name, export_scope) + if self._initial_value is not None: + # This is inside an if-statement for backwards compatibility, since + # self._initial_value might be None for variables constructed from old + # protos. + var_def.initial_value_name = ops.strip_name_scope( + self._initial_value.name, export_scope) var_def.initializer_name = ops.strip_name_scope(self.initializer.name, export_scope) if self._cached_value is not None: @@ -777,38 +793,38 @@ class ResourceVariable(variables.Variable): # TODO(apassos): this here and below is not atomic. Consider making it # atomic if there's a way to do so without a performance cost for those who # don't need it. - with ops.control_dependencies([ - gen_resource_variable_ops.assign_sub_variable_op( - self.handle, - ops.convert_to_tensor(delta, dtype=self.dtype), - name=name) - ]): - return self.read_value() + return self._lazy_read(gen_resource_variable_ops.assign_sub_variable_op( + self.handle, + ops.convert_to_tensor(delta, dtype=self.dtype), + name=name)) def assign_add(self, delta, use_locking=None, name=None): - with ops.control_dependencies([ - gen_resource_variable_ops.assign_add_variable_op( - self.handle, - ops.convert_to_tensor(delta, dtype=self.dtype), - name=name) - ]): - return self.read_value() + return self._lazy_read(gen_resource_variable_ops.assign_add_variable_op( + self.handle, + ops.convert_to_tensor(delta, dtype=self.dtype), + name=name)) + + def _lazy_read(self, op): + if hasattr(self, "_trainable") and self._trainable: + tape.watch_variable(self) + return _UnreadVariable( + self._handle, self.dtype, self._handle_device, self._shape, + self._in_graph_mode, + self._handle_deleter if not self._in_graph_mode else None, op) def assign(self, value, use_locking=None, name=None): value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) self._shape.assert_is_compatible_with(value_tensor.shape) - with ops.control_dependencies([ + return self._lazy_read( gen_resource_variable_ops.assign_variable_op( self.handle, value_tensor, - name=name) - ]): - return self.read_value() + name=name)) def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask): - with ops.control_dependencies([ + return self._lazy_read( gen_array_ops.resource_strided_slice_assign( ref=self.handle, begin=begin, @@ -820,9 +836,12 @@ class ResourceVariable(variables.Variable): end_mask=end_mask, ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask, - shrink_axis_mask=shrink_axis_mask) - ]): - return self.value() + shrink_axis_mask=shrink_axis_mask)) + + def __int__(self): + if self.dtype != dtypes.int32 and self.dtype != dtypes.int64: + raise TypeError("Non-integer variable can't be converted to integer.") + return int(self.value().numpy()) def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): del name @@ -835,31 +854,106 @@ class ResourceVariable(variables.Variable): return self.value() def __iadd__(self, unused_other): - raise RuntimeError("Variable += value not supported.") + raise RuntimeError("Variable += value not supported. Use " + "variable.assign_add(value) to modify the variable " + "value and variable = variable + value to get a new " + "Tensor object.") def __isub__(self, unused_other): - raise RuntimeError("Variable -= value not supported.") + raise RuntimeError("Variable -= value not supported. Use " + "variable.assign_sub(value) to modify the variable " + "value and variable = variable - value to get a new " + "Tensor object.") def __imul__(self, unused_other): - raise RuntimeError("Variable *= value not supported.") + raise RuntimeError("Variable *= value not supported. Use " + "variable.assign_mul(value) to modify the variable " + "value and variable = variable * value to get a new " + "Tensor object.") def __idiv__(self, unused_other): - raise RuntimeError("Variable /= value not supported.") + raise RuntimeError("Variable /= value not supported. Use " + "variable.assign_div(value) to modify the variable " + "value and variable = variable / value to get a new " + "Tensor object.") def __itruediv__(self, unused_other): - raise RuntimeError("Variable /= value not supported.") + raise RuntimeError("Variable /= value not supported. Use " + "variable.assign_div(value) to modify the variable " + "value and variable = variable / value to get a new " + "Tensor object.") def __irealdiv__(self, unused_other): - raise RuntimeError("Variable /= value not supported.") + raise RuntimeError("Variable /= value not supported. Use " + "variable.assign_div(value) to modify the variable " + "value and variable = variable / value to get a new " + "Tensor object.") def __ipow__(self, unused_other): - raise RuntimeError("Variable **= value not supported.") + raise RuntimeError("Variable **= value not supported. Use " + "value and variable = variable ** value to get a new " + "Tensor object.") def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access +class _UnreadVariable(ResourceVariable): + """Represents a future for a read of a variable. + + Pretends to be the tensor if anyone looks. + """ + + def __init__(self, handle, dtype, handle_device, # pylint: disable=super-init-not-called + shape, in_graph_mode, deleter, parent_op): + # We do not call super init on purpose. + self._trainable = False + self._save_slice_info = None + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + self._in_graph_mode = in_graph_mode + self._handle = handle + self._handle_device = handle_device + self._shape = shape + self._initial_value = None + if isinstance(self._handle, ops.EagerTensor): + self._handle_name = "" + else: + self._handle_name = self._handle.name + self._dtype = dtype + self._constraint = None + self._cached_value = None + self._is_initialized_op = None + self._initializer_op = None + self._parent_op = parent_op + if context.in_graph_mode(): + self._graph_element = self.read_value() + else: + self._graph_element = None + self._handle_deleter = deleter + + def value(self): + return self._read_variable_op() + + def read_value(self): + return self._read_variable_op() + + def _read_variable_op(self): + with ops.control_dependencies([self._parent_op]): + return gen_resource_variable_ops.read_variable_op(self._handle, + self._dtype) + + def set_shape(self, shape): + self._shape = shape + + @property + def op(self): + """The op for this variable.""" + return self._parent_op + +ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor) +ops.register_dense_tensor_like_type(_UnreadVariable) + # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. @@ -931,3 +1025,9 @@ ops.register_proto_function( proto_type=variable_pb2.VariableDef, to_proto=_to_proto_fn, from_proto=_from_proto_fn) + + +def is_resource_variable(var): + """"Returns True if `var` is to be considered a ResourceVariable.""" + return isinstance(var, ResourceVariable) or hasattr( + var, "_should_act_as_resource_variable") diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index a1008f1c834f7c01af0ff8b3a0a648f499ce1f8a..aa8d4327d2f0e93768728744d5cce3fed385393f 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -83,8 +83,9 @@ def _best_effort_input_batch_size(flat_input): """Get static input batch size if available, with fallback to the dynamic one. Args: - flat_input: An iterable of time major input Tensors of shape [max_time, - batch_size, ...]. All inputs should have compatible batch sizes. + flat_input: An iterable of time major input Tensors of shape + `[max_time, batch_size, ...]`. + All inputs should have compatible batch sizes. Returns: The batch size in Python integer if available, or a scalar Tensor otherwise. @@ -171,11 +172,11 @@ def _rnn_step( return (final_output, final_state) Args: - time: Python int, the current time step - sequence_length: int32 `Tensor` vector of size [batch_size] - min_sequence_length: int32 `Tensor` scalar, min of sequence_length - max_sequence_length: int32 `Tensor` scalar, max of sequence_length - zero_output: `Tensor` vector of shape [output_size] + time: int32 `Tensor` scalar. + sequence_length: int32 `Tensor` vector of size [batch_size]. + min_sequence_length: int32 `Tensor` scalar, min of sequence_length. + max_sequence_length: int32 `Tensor` scalar, max of sequence_length. + zero_output: `Tensor` vector of shape [output_size]. state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, or a list/tuple of such tensors. call_cell: lambda returning tuple of (new_output, new_state) where @@ -202,6 +203,9 @@ def _rnn_step( flat_state = nest.flatten(state) flat_zero_output = nest.flatten(zero_output) + # Vector describing which batch entries are finished. + copy_cond = time >= sequence_length + def _copy_one_through(output, new_output): # TensorArray and scalar get passed through. if isinstance(output, tensor_array_ops.TensorArray): @@ -209,7 +213,6 @@ def _rnn_step( if output.shape.ndims == 0: return new_output # Otherwise propagate the old or the new value. - copy_cond = (time >= sequence_length) with ops.colocate_with(new_output): return array_ops.where(copy_cond, output, new_output) @@ -812,7 +815,10 @@ def _dynamic_rnn_loop(cell, return (time + 1, output_ta_t, new_state) if in_graph_mode: - loop_bound = max_sequence_length + # Make sure that we run at least 1 step, if necessary, to ensure + # the TensorArrays pick up the dynamic shape. + loop_bound = math_ops.minimum( + time_steps, math_ops.maximum(1, max_sequence_length)) else: # Using max_sequence_length isn't currently supported in the Eager branch. loop_bound = time_steps @@ -1122,6 +1128,12 @@ def raw_rnn(cell, loop_fn, def _copy_some_through(current, candidate): """Copy some tensors through via array_ops.where.""" def copy_fn(cur_i, cand_i): + # TensorArray and scalar get passed through. + if isinstance(cur_i, tensor_array_ops.TensorArray): + return cand_i + if cur_i.shape.ndims == 0: + return cand_i + # Otherwise propagate the old or the new value. with ops.colocate_with(cand_i): return array_ops.where(elements_finished, cur_i, cand_i) return nest.map_structure(copy_fn, current, candidate) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index f1ac3e9bafa09e4647b4a4263e74fad29b643fd5..923348ea44e18a87e09fe1c0424f0323eb967e3d 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -255,7 +255,7 @@ class RNNCell(base_layer.Layer): return output -class _LayerRNNCell(RNNCell): +class LayerRNNCell(RNNCell): """Subclass of RNNCells that act like proper `tf.Layer` objects. For backwards compatibility purposes, most `RNNCell` instances allow their @@ -297,7 +297,7 @@ class _LayerRNNCell(RNNCell): @tf_export("nn.rnn_cell.BasicRNNCell") -class BasicRNNCell(_LayerRNNCell): +class BasicRNNCell(LayerRNNCell): """The most basic RNN cell. Args: @@ -355,7 +355,7 @@ class BasicRNNCell(_LayerRNNCell): @tf_export("nn.rnn_cell.GRUCell") -class GRUCell(_LayerRNNCell): +class GRUCell(LayerRNNCell): """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). Args: @@ -473,7 +473,7 @@ class LSTMStateTuple(_LSTMStateTuple): @tf_export("nn.rnn_cell.BasicLSTMCell") -class BasicLSTMCell(_LayerRNNCell): +class BasicLSTMCell(LayerRNNCell): """Basic LSTM recurrent network cell. The implementation is based on: http://arxiv.org/abs/1409.2329. @@ -598,7 +598,7 @@ class BasicLSTMCell(_LayerRNNCell): @tf_export("nn.rnn_cell.LSTMCell") -class LSTMCell(_LayerRNNCell): +class LSTMCell(LayerRNNCell): """Long short-term memory unit (LSTM) recurrent network cell. The default non-peephole implementation is based on: diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 4b5072fd6799ae289d3c1a1b2a40878e36604bf4..6fe2f61016775b410045fefcc8764907b8ea39f3 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -33,6 +33,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_script_ops +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -50,19 +51,21 @@ class EagerFunc(object): self._func = func self._out_dtypes = Tout - def __call__(self, *args, **kwargs): - """Passes args, kwargs to `self._func`, which is executed eagerly.""" + def __call__(self, on_gpu, args): + """Passes `args` to `self._func`, which is executed eagerly.""" with context.eager_mode(): - ret = self._func(*args, **kwargs) + ret = self._func(*args) + maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu() if isinstance(ret, (tuple, list)): return [ - ops.convert_to_tensor(x, dtype=dtype) + maybe_copy_to_gpu(ops.convert_to_tensor(x, dtype=dtype)) for (x, dtype) in zip(ret, self._out_dtypes) ] elif ret is None: return ret else: - return ops.convert_to_tensor(ret, dtype=self._out_dtypes[0]) + return maybe_copy_to_gpu( + ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])) class FuncRegistry(object): @@ -95,7 +98,7 @@ class FuncRegistry(object): components of a tensor have different lengths. This is bad: ignoring the padding is wrong for text data, and removing the padding is wrong for binary data. To avoid this bug, we redo the conversion using an object dtype. - Additionally, we convert unicode strings to (byte-)strings for Python3 + Additionally, we convert unicode strings to (byte-)strings for compatibility. Args: @@ -109,23 +112,36 @@ class FuncRegistry(object): if result.dtype.char == "S" and result is not value: return np.asarray(value, order="C", dtype=object) elif result.dtype.char == "U" and result is not value: - value = np.vectorize(lambda x: x.encode())(value) + value = np.vectorize(lambda x: x.encode("utf8"))(value) return np.asarray(value, order="C", dtype=object) elif result.dtype.char == "U": return result.astype(np.bytes_) else: return result - def __call__(self, token, args): - """Calls the registered function for `token` with args.""" + def __call__(self, token, on_gpu, args): + """Calls the registered function for `token` with args. + + Args: + token: A key into this `FuncRegistry` identifying which function to call. + on_gpu: A boolean indicating whether or not `token`'s corresponding + operation was placed on GPU; only used if the function registered for + `token` is an `EagerPyFunc`. + args: The arguments to pass to the function registered for `token`. + + Returns: + The output of the function registered for `token`. + + Raises: + ValueError: if no function is registered for `token`. + """ func = self._funcs[token] if func is None: raise ValueError("callback %s is not found" % token) - ret = func(*args) - if isinstance(func, EagerFunc): - return ret + return func(on_gpu, args) else: + ret = func(*args) # Strings seem to lead to a memory leak here if they're not wrapped in a # list. if isinstance(ret, six.binary_type): @@ -161,7 +177,10 @@ class CleanupFunc(object): self._token = token def __del__(self): - _py_funcs.remove(self._token) + if _py_funcs is not None: + # If _py_funcs is None, the program is most likely in shutdown, and the + # _py_funcs object has been destroyed already. + _py_funcs.remove(self._token) def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): @@ -249,7 +268,7 @@ def py_func(func, inp, Tout, stateful=True, name=None): """Wraps a python function and uses it as a TensorFlow op. Given a python function `func`, which takes numpy arrays as its - inputs and returns numpy arrays as its outputs, wrap this function as an + arguments and returns numpy arrays as its outputs, wrap this function as an operation in a TensorFlow graph. The following snippet constructs a simple TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation in the graph: @@ -258,8 +277,8 @@ def py_func(func, inp, Tout, stateful=True, name=None): def my_func(x): # x will be a numpy array with the contents of the placeholder below return np.sinh(x) - inp = tf.placeholder(tf.float32) - y = tf.py_func(my_func, [inp], tf.float32) + input = tf.placeholder(tf.float32) + y = tf.py_func(my_func, [input], tf.float32) ``` **N.B.** The `tf.py_func()` operation has the following known limitations: @@ -275,10 +294,12 @@ def py_func(func, inp, Tout, stateful=True, name=None): server (e.g. using `with tf.device():`). Args: - func: A Python function, which accepts a list of NumPy `ndarray` objects - having element types that match the corresponding `tf.Tensor` objects - in `inp`, and returns a list of `ndarray` objects (or a single `ndarray`) - having element types that match the corresponding values in `Tout`. + func: A Python function, which accepts `ndarray` objects as arguments and + returns a list of `ndarray` objects (or a single `ndarray`). This function + must accept as many arguments as there are tensors in `inp`, and these + argument types will match the corresponding `tf.Tensor` objects + in `inp`. The returns `ndarray`s must match the number and types defined + `Tout`. Important Note: Input and output numpy `ndarray`s of `func` are not guaranteed to be copies. In some cases their underlying memory will be shared with the corresponding TensorFlow tensors. @@ -298,12 +319,15 @@ def py_func(func, inp, Tout, stateful=True, name=None): Returns: A list of `Tensor` or a single `Tensor` which `func` computes. """ + if context.in_eager_mode(): + result = func(*[x.numpy() for x in inp]) + result = nest.flatten(result) + + return [x if x is None else ops.convert_to_tensor(x) for x in result] + return _internal_py_func( func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name) -# TODO(akshayka): PyFuncs where the 'eager' attribute is set to True should be -# differentiable, i.e., the gradient of PyFunc should propagate Nones if the -# eager attribute is not set, and otherwise, it should return the gradient. ops.NotDifferentiable("PyFunc") ops.NotDifferentiable("PyFuncStateless") diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 3224856d7be0674a2cc064a226bf1a38abb6bc2b..0fbbf5a805f1439d85ad53f02bdb665c04248606 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -227,13 +227,14 @@ def sparse_concat(axis, [array_ops.reshape(shape, [1, -1]) for shape in shapes], 0), 0) shapes = [ array_ops.concat([ - max_shape[:axis], shape[-1:] if axis == -1 else - shape[axis:axis + 1], [] if axis == -1 else max_shape[axis + 1:] + max_shape[:axis], shape[-1:] + if axis == -1 else shape[axis:axis + 1], [] + if axis == -1 else max_shape[axis + 1:] ], 0) for shape in shapes ] - output_ind, output_val, output_shape = (gen_sparse_ops._sparse_concat( - inds, vals, shapes, axis, name=name)) + output_ind, output_val, output_shape = ( + gen_sparse_ops._sparse_concat(inds, vals, shapes, axis, name=name)) return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) @@ -300,15 +301,14 @@ def sparse_add(a, b, thresh=0): b = _convert_to_sparse_tensor(b) thresh = ops.convert_to_tensor( thresh, dtype=a.values.dtype.real_dtype.base_dtype, name="thresh") - output_ind, output_val, output_shape = (gen_sparse_ops._sparse_add( - a.indices, a.values, a.dense_shape, - b.indices, b.values, b.dense_shape, - thresh)) + output_ind, output_val, output_shape = ( + gen_sparse_ops._sparse_add(a.indices, a.values, a.dense_shape, + b.indices, b.values, b.dense_shape, thresh)) # Attempt to get output_shape statically. a.get_shape().assert_is_compatible_with(b.get_shape()) - static_shape = array_ops.broadcast_static_shape( - a.get_shape(), b.get_shape()) + static_shape = array_ops.broadcast_static_shape(a.get_shape(), + b.get_shape()) if static_shape.is_fully_defined(): output_shape = static_shape.as_list() @@ -317,8 +317,8 @@ def sparse_add(a, b, thresh=0): # swap to make `a` the SparseTensor. if isinstance(b, sparse_classes): a, b = b, a - return gen_sparse_ops._sparse_tensor_dense_add( - a.indices, a.values, a.dense_shape, b) + return gen_sparse_ops._sparse_tensor_dense_add(a.indices, a.values, + a.dense_shape, b) def _sparse_cross(inputs, name=None): @@ -397,19 +397,25 @@ def _sparse_cross_hashed(inputs, num_buckets=0, hash_key=None, name=None): _DEFAULT_HASH_KEY = 0xDECAFCAFFE -def _sparse_cross_internal( - inputs, hashed_output=False, num_buckets=0, hash_key=None, name=None): +def _sparse_cross_internal(inputs, + hashed_output=False, + num_buckets=0, + hash_key=None, + name=None): """See gen_sparse_ops._sparse_cross.""" if not isinstance(inputs, list): raise TypeError("Inputs must be a list") - if not all(isinstance(i, sparse_tensor.SparseTensor) or - isinstance(i, ops.Tensor) for i in inputs): + if not all( + isinstance(i, sparse_tensor.SparseTensor) or isinstance(i, ops.Tensor) + for i in inputs): raise TypeError("All inputs must be SparseTensors") - sparse_inputs = [i for i in inputs - if isinstance(i, sparse_tensor.SparseTensor)] - dense_inputs = [i for i in inputs - if not isinstance(i, sparse_tensor.SparseTensor)] + sparse_inputs = [ + i for i in inputs if isinstance(i, sparse_tensor.SparseTensor) + ] + dense_inputs = [ + i for i in inputs if not isinstance(i, sparse_tensor.SparseTensor) + ] indices = [sp_input.indices for sp_input in sparse_inputs] values = [sp_input.values for sp_input in sparse_inputs] @@ -504,8 +510,9 @@ def sparse_reorder(sp_input, name=None): """ sp_input = _convert_to_sparse_tensor(sp_input) - reordered_ind, reordered_val = (gen_sparse_ops._sparse_reorder( - sp_input.indices, sp_input.values, sp_input.dense_shape, name=name)) + reordered_ind, reordered_val = ( + gen_sparse_ops._sparse_reorder( + sp_input.indices, sp_input.values, sp_input.dense_shape, name=name)) if sp_input.get_shape().is_fully_defined(): dense_shape = sp_input.get_shape().as_list() @@ -572,8 +579,8 @@ def sparse_reshape(sp_input, shape, name=None): sp_input.indices, sp_input.dense_shape, shape, name=name) reshaped_shape_const = tensor_util.constant_value(shape) - if (reshaped_shape_const is not None - and sp_input.get_shape().is_fully_defined()): + if (reshaped_shape_const is not None and + sp_input.get_shape().is_fully_defined()): num_implied = sum((dim == -1) for dim in reshaped_shape_const) if num_implied > 1: raise ValueError("At most one dimension can be inferred (-1). Found: %s" @@ -589,15 +596,15 @@ def sparse_reshape(sp_input, shape, name=None): in_shape_size // np.prod(non_implied_idx)) reshaped_size = np.prod(reshaped_shape_const) if reshaped_size != in_shape_size: - raise ValueError( - "Cannot reshape a tensor with %d elements to shape %s " - "(%d elements)." - % (in_shape_size, original_reshaped_shape, reshaped_size)) + raise ValueError("Cannot reshape a tensor with %d elements to shape %s " + "(%d elements)." % + (in_shape_size, original_reshaped_shape, + reshaped_size)) reshaped_shape = reshaped_shape_const - return sparse_tensor.SparseTensor( - reshaped_ind, array_ops.identity(sp_input.values), - reshaped_shape) + return sparse_tensor.SparseTensor(reshaped_ind, + array_ops.identity(sp_input.values), + reshaped_shape) # TODO(aselle): Remove keyword required once for 1.0 final @@ -610,8 +617,11 @@ class KeywordRequired(object): @tf_export("sparse_split") def sparse_split(keyword_required=KeywordRequired(), - sp_input=None, num_split=None, axis=None, - name=None, split_dim=None): + sp_input=None, + num_split=None, + axis=None, + name=None, + split_dim=None): """Split a `SparseTensor` into `num_split` tensors along `axis`. If the `sp_input.dense_shape[axis]` is not an integer multiple of `num_split` @@ -660,18 +670,19 @@ def sparse_split(keyword_required=KeywordRequired(), split_dim) sp_input = _convert_to_sparse_tensor(sp_input) - output_inds, output_vals, output_shapes = (gen_sparse_ops._sparse_split( - axis, - sp_input.indices, - sp_input.values, - sp_input.dense_shape, - num_split, - name=name)) + output_inds, output_vals, output_shapes = ( + gen_sparse_ops._sparse_split( + axis, + sp_input.indices, + sp_input.values, + sp_input.dense_shape, + num_split, + name=name)) sparse_tensors = [] for i in range(0, num_split): sparse_tensors.append( - sparse_tensor.SparseTensor( - output_inds[i], output_vals[i], output_shapes[i])) + sparse_tensor.SparseTensor(output_inds[i], output_vals[i], + output_shapes[i])) return sparse_tensors @@ -713,12 +724,15 @@ def sparse_slice(sp_input, start, size, name=None): with ops.name_scope(name, "SparseSlice", [sp_input]) as name: output_indices, output_values, output_shape = gen_sparse_ops.sparse_slice( - sp_input.indices, sp_input.values, sp_input.dense_shape, start, size, name=name) + sp_input.indices, + sp_input.values, + sp_input.dense_shape, + start, + size, + name=name) - return sparse_tensor.SparseTensor( - output_indices, - output_values, - output_shape) + return sparse_tensor.SparseTensor(output_indices, output_values, + output_shape) @tf_export("sparse_to_dense") @@ -819,14 +833,14 @@ def sparse_reduce_max(sp_input, axis=None, keep_dims=False, The reduced Tensor. """ return gen_sparse_ops.sparse_reduce_max( - sp_input.indices, sp_input.values, - sp_input.dense_shape, - math_ops._ReductionDims(sp_input, axis, reduction_axes), - keep_dims) + sp_input.indices, sp_input.values, sp_input.dense_shape, + math_ops._ReductionDims(sp_input, axis, reduction_axes), keep_dims) @tf_export("sparse_reduce_max_sparse") -def sparse_reduce_max_sparse(sp_input, axis=None, keep_dims=False, +def sparse_reduce_max_sparse(sp_input, + axis=None, + keep_dims=False, reduction_axes=None): """Computes the max of elements across dimensions of a SparseTensor. @@ -855,10 +869,8 @@ def sparse_reduce_max_sparse(sp_input, axis=None, keep_dims=False, """ output_ind, output_val, output_shape = ( gen_sparse_ops.sparse_reduce_max_sparse( - sp_input.indices, sp_input.values, - sp_input.dense_shape, math_ops._ReductionDims(sp_input, axis, - reduction_axes), - keep_dims)) + sp_input.indices, sp_input.values, sp_input.dense_shape, + math_ops._ReductionDims(sp_input, axis, reduction_axes), keep_dims)) return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) @@ -905,14 +917,14 @@ def sparse_reduce_sum(sp_input, axis=None, keep_dims=False, The reduced Tensor. """ return gen_sparse_ops.sparse_reduce_sum( - sp_input.indices, sp_input.values, - sp_input.dense_shape, - math_ops._ReductionDims(sp_input, axis, reduction_axes), - keep_dims) + sp_input.indices, sp_input.values, sp_input.dense_shape, + math_ops._ReductionDims(sp_input, axis, reduction_axes), keep_dims) @tf_export("sparse_reduce_sum_sparse") -def sparse_reduce_sum_sparse(sp_input, axis=None, keep_dims=False, +def sparse_reduce_sum_sparse(sp_input, + axis=None, + keep_dims=False, reduction_axes=None): """Computes the sum of elements across dimensions of a SparseTensor. @@ -941,10 +953,8 @@ def sparse_reduce_sum_sparse(sp_input, axis=None, keep_dims=False, """ output_ind, output_val, output_shape = ( gen_sparse_ops.sparse_reduce_sum_sparse( - sp_input.indices, sp_input.values, - sp_input.dense_shape, math_ops._ReductionDims(sp_input, axis, - reduction_axes), - keep_dims)) + sp_input.indices, sp_input.values, sp_input.dense_shape, + math_ops._ReductionDims(sp_input, axis, reduction_axes), keep_dims)) return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) @@ -1053,8 +1063,8 @@ def sparse_to_indicator(sp_input, vocab_size, name=None): with ops.name_scope(name, "SparseToIndicator", [sp_input]) as name: num_entries = array_ops.shape(sp_input.indices)[0] new_values = array_ops.fill(array_ops.expand_dims(num_entries, 0), True) - sp_values = sparse_tensor.SparseTensor( - sp_input.indices, new_values, sp_input.dense_shape) + sp_values = sparse_tensor.SparseTensor(sp_input.indices, new_values, + sp_input.dense_shape) sp_new = sparse_merge(sp_input, sp_values, vocab_size, name) @@ -1174,8 +1184,7 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None, raise TypeError("vocab_size has to be a list of Tensors or Python ints. " "Found %s" % type(vocab_size)) for dim in vocab_size: - if not (isinstance(dim, ops.Tensor) or - isinstance(dim, numbers.Integral)): + if not (isinstance(dim, ops.Tensor) or isinstance(dim, numbers.Integral)): raise TypeError( "vocab_size has to be a list of Tensors or Python ints. Found %s" % type(dim)) @@ -1326,24 +1335,23 @@ def sparse_reset_shape(sp_input, new_shape=None): # error before the sparse_tensor.SparseTensor catches it. output_shape_tensor.get_shape()[0].merge_with(in_shape.get_shape()[0]) - output_shape_tensor_const = tensor_util.constant_value( - output_shape_tensor) + output_shape_tensor_const = tensor_util.constant_value(output_shape_tensor) # For cases where all shapes are known during graph construction - if (output_shape_tensor_const is not None - and sp_input.get_shape().is_fully_defined()): + if (output_shape_tensor_const is not None and + sp_input.get_shape().is_fully_defined()): in_shape_const = np.array(sp_input.get_shape().as_list()) if not np.all(in_shape_const <= output_shape_tensor_const): raise ValueError( "Requested new_shape should have dimension sizes >= sp_input.shape." - " Found new_shape (%s), sp_input.shape (%s)." - % (in_shape_const, output_shape_tensor_const)) + " Found new_shape (%s), sp_input.shape (%s)." % + (in_shape_const, output_shape_tensor_const)) output_shape_tensor = output_shape_tensor_const else: # For cases where shape is not known during graph construction. - output_shape_tensor = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - array_ops.shape(in_shape), array_ops.shape(output_shape_tensor))], - output_shape_tensor) + output_shape_tensor = control_flow_ops.with_dependencies([ + check_ops.assert_equal( + array_ops.shape(in_shape), array_ops.shape(output_shape_tensor)) + ], output_shape_tensor) output_shape_tensor = control_flow_ops.with_dependencies( [check_ops.assert_less_equal(in_shape, output_shape_tensor)], output_shape_tensor) @@ -1409,10 +1417,10 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None): values=sp_input.values, dense_shape=sp_input.dense_shape, default_value=default_value) - return (sparse_tensor.SparseTensor(indices=output_indices, - values=output_values, - dense_shape=sp_input.dense_shape), - empty_row_indicator) + return (sparse_tensor.SparseTensor( + indices=output_indices, + values=output_values, + dense_shape=sp_input.dense_shape), empty_row_indicator) @tf_export("serialize_sparse") @@ -1880,8 +1888,8 @@ def sparse_softmax(sp_input, name=None): [sp_input.indices, sp_input.values]) as name: out_vals = gen_sparse_ops.sparse_softmax(sp_input.indices, sp_input.values, sp_input.dense_shape) - return sparse_tensor.SparseTensor( - sp_input.indices, out_vals, sp_input.dense_shape) + return sparse_tensor.SparseTensor(sp_input.indices, out_vals, + sp_input.dense_shape) @tf_export("sparse_maximum") @@ -1907,9 +1915,9 @@ def sparse_maximum(sp_a, sp_b, name=None): Returns: output: the output SparseTensor. """ - with ops.name_scope(name, "SparseSparseMaximum", [sp_a.indices, sp_a.values, - sp_b.indices, - sp_b.values]) as name: + with ops.name_scope( + name, "SparseSparseMaximum", + [sp_a.indices, sp_a.values, sp_b.indices, sp_b.values]) as name: out_indices, out_values = gen_sparse_ops.sparse_sparse_maximum( sp_a.indices, sp_a.values, @@ -1944,9 +1952,9 @@ def sparse_minimum(sp_a, sp_b, name=None): Returns: output: the output SparseTensor. """ - with ops.name_scope(name, "SparseSparseMinimum", [sp_a.indices, sp_a.values, - sp_b.indices, - sp_b.values]) as name: + with ops.name_scope( + name, "SparseSparseMinimum", + [sp_a.indices, sp_a.values, sp_b.indices, sp_b.values]) as name: out_indices, out_values = gen_sparse_ops.sparse_sparse_minimum( sp_a.indices, sp_a.values, @@ -2010,14 +2018,15 @@ def sparse_transpose(sp_input, perm=None, name=None): dense_shape = sp_input.dense_shape transposed_dense_shape = array_ops.gather(dense_shape, perm) transposed_st = sparse_tensor.SparseTensor( - transposed_indices, sp_input.values, - transposed_dense_shape) + transposed_indices, sp_input.values, transposed_dense_shape) transposed_st = sparse_reorder(transposed_st) return transposed_st -def _add_sparse_to_tensors_map(sp_input, container=None, - shared_name=None, name=None): +def _add_sparse_to_tensors_map(sp_input, + container=None, + shared_name=None, + name=None): """Add a `SparseTensor` to a `SparseTensorsMap` and return its handle. Args: @@ -2038,12 +2047,18 @@ def _add_sparse_to_tensors_map(sp_input, container=None, sp_input = _convert_to_sparse_tensor(sp_input) return gen_sparse_ops._add_sparse_to_tensors_map( - sp_input.indices, sp_input.values, sp_input.dense_shape, - container=container, shared_name=shared_name, name=name) + sp_input.indices, + sp_input.values, + sp_input.dense_shape, + container=container, + shared_name=shared_name, + name=name) -def _add_many_sparse_to_tensors_map(sp_input, container=None, - shared_name=None, name=None): +def _add_many_sparse_to_tensors_map(sp_input, + container=None, + shared_name=None, + name=None): """Add a minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles. The `SparseTensor` must have rank `R` greater than 1, and the first dimension @@ -2072,12 +2087,18 @@ def _add_many_sparse_to_tensors_map(sp_input, container=None, sp_input = _convert_to_sparse_tensor(sp_input) return gen_sparse_ops._add_many_sparse_to_tensors_map( - sp_input.indices, sp_input.values, sp_input.dense_shape, - container=container, shared_name=shared_name, name=name) + sp_input.indices, + sp_input.values, + sp_input.dense_shape, + container=container, + shared_name=shared_name, + name=name) -def _take_many_sparse_from_tensors_map( - sparse_map_op, sparse_handles, rank=None, name=None): +def _take_many_sparse_from_tensors_map(sparse_map_op, + sparse_handles, + rank=None, + name=None): """Read `SparseTensors` from a `SparseTensorsMap` and concatenate them. The input `sparse_handles` must be a string matrix of shape `[N, 1]` where @@ -2140,16 +2161,18 @@ def _take_many_sparse_from_tensors_map( raise TypeError("sparse_map_op be an Operation") if sparse_map_op.type not in ("AddSparseToTensorsMap", "AddManySparseToTensorsMap"): - raise TypeError("sparse_map_op must be one of AddSparseToTensorsMap or " - "AddSparseToTensorsMap. Instead, found `%s`." % - sparse_map_op.type) + raise TypeError( + "sparse_map_op must be one of AddSparseToTensorsMap or " + "AddSparseToTensorsMap. Instead, found `%s`." % sparse_map_op.type) with ops.colocate_with(sparse_map_op): shared_name = sparse_map_op.get_attr("shared_name") or sparse_map_op.name output_indices, output_values, output_shape = ( gen_sparse_ops._take_many_sparse_from_tensors_map( - sparse_handles, dtype=sparse_map_op.get_attr("T"), + sparse_handles, + dtype=sparse_map_op.get_attr("T"), container=sparse_map_op.get_attr("container"), - shared_name=shared_name, name=name)) + shared_name=shared_name, + name=name)) # Feed rank data back in, if available output_indices.set_shape([None, rank]) diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 19900870725f5f01c4ba12979265a5533297d4c3..6d7eaababcd94d687ff20dddc35c68a98320a19b 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -155,27 +155,24 @@ def einsum(equation, *inputs, **kwargs): indices in its subscript, or - the input shapes are inconsistent along a particular axis. """ - name = kwargs.pop("name", None) + name = kwargs.pop('name', None) if kwargs: - raise TypeError("invalid keyword arguments for this function: " + - ", ".join([format(key) - for key in sorted(list(kwargs.keys()))])) - with ops.name_scope(name, "einsum", [equation, inputs]) as name: + raise TypeError('invalid keyword arguments for this function: ' + ', '.join( + [format(key) for key in sorted(list(kwargs.keys()))])) + with ops.name_scope(name, 'einsum', [equation, inputs]) as name: if '...' in equation: raise ValueError('Subscripts with ellipses are not yet supported.') match = re.match('([a-z,]+)(->[a-z]*)?', equation) if not match: - raise ValueError( - 'Indices have incorrect format: %s' % equation - ) + raise ValueError('Indices have incorrect format: %s' % equation) inputs = list(inputs) input_axis_labels = match.group(1).split(',') if len(inputs) != len(input_axis_labels): - raise ValueError('Got %d arguments for equation "%s", expecting %d' % ( - len(inputs), equation, len(input_axis_labels))) + raise ValueError('Got %d arguments for equation "%s", expecting %d' % + (len(inputs), equation, len(input_axis_labels))) axis_labels = set(''.join(input_axis_labels)) if match.group(2): @@ -188,37 +185,36 @@ def einsum(equation, *inputs, **kwargs): for ax in axes_: counts[ax] += 1 - output_axis_labels = ''.join(sorted( - ax for ax in indices - if counts[ax] == 1 - )) + output_axis_labels = ''.join( + sorted(ax for ax in indices if counts[ax] == 1)) for a in axis_labels: input_count = sum(1 for s in input_axis_labels if a in s) if input_count > 2 and a not in output_axis_labels: logging.warn( - 'Falling back to exponential-space implementation of einsum() because' - ' index "%s" is summed over more than two inputs.', a) + 'Falling back to exponential-space implementation of einsum()' + ' because index "%s" is summed over more than two inputs.', a) return _exponential_space_einsum(equation, *inputs) temp = inputs[0] temp_axis_labels = input_axis_labels[0] - for i in xrange(len(inputs)-1): - axes_to_sum = (set(temp_axis_labels) & set(input_axis_labels[i+1]) - - set(output_axis_labels)) - temp, temp_axis_labels = _einsum_reduction(temp, - temp_axis_labels, - inputs[i+1], - input_axis_labels[i+1], - axes_to_sum) + for i in xrange(len(inputs) - 1): + axes_to_sum = ( + set(temp_axis_labels) & + set(input_axis_labels[i + 1]) - set(output_axis_labels)) + temp, temp_axis_labels = _einsum_reduction( + temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1], + axes_to_sum) missing_indices = set(temp_axis_labels) - set(output_axis_labels) if missing_indices: - reduction_indices = [i for i, a in enumerate(temp_axis_labels) - if a not in output_axis_labels] + reduction_indices = [ + i for i, a in enumerate(temp_axis_labels) + if a not in output_axis_labels + ] temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices) - temp_axis_labels = ''.join(a for a in temp_axis_labels - if a in output_axis_labels) + temp_axis_labels = ''.join( + a for a in temp_axis_labels if a in output_axis_labels) if sorted(temp_axis_labels) != sorted(output_axis_labels): raise ValueError('Invalid equation: %s' % equation) @@ -296,8 +292,10 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum): return (1, a) axis_labels = [t0_axis_labels, t1_axis_labels] - sorted_axes = [sorted(sym_list, key=lambda a: sort_key(i, a)) - for i, sym_list in enumerate(axis_labels)] + sorted_axes = [ + sorted(sym_list, key=lambda a: sort_key(i, a)) + for i, sym_list in enumerate(axis_labels) + ] inputs = [t0, t1] for i, axes_str in enumerate(axis_labels): perm = [axes_str.find(a) for a in sorted_axes[i]] @@ -325,30 +323,30 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum): num_broadcast_elements_t0 = _total_size( t0_shape[len(preserved_axes):-len(axes_to_sum)]) num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):]) - new_shape = (t0_shape[:len(preserved_axes)] - + [num_broadcast_elements_t0, num_summed_elements]) + new_shape = ( + t0_shape[:len(preserved_axes)] + + [num_broadcast_elements_t0, num_summed_elements]) t0 = _reshape_if_necessary(t0, new_shape) t1_shape = _get_shape(t1) num_broadcast_elements_t1 = _total_size( - t1_shape[len(preserved_axes)+len(axes_to_sum):]) - new_shape = (t1_shape[:len(preserved_axes)] - + [num_summed_elements, num_broadcast_elements_t1]) + t1_shape[len(preserved_axes) + len(axes_to_sum):]) + new_shape = ( + t1_shape[:len(preserved_axes)] + + [num_summed_elements, num_broadcast_elements_t1]) t1 = _reshape_if_necessary(t1, new_shape) product = math_ops.matmul(t0, t1) # Undo compaction of broadcast axes uncompacted_shape = ( - t0_shape[:len(preserved_axes)+len(broadcast_axes[0])] - + t1_shape[len(t1_shape)-len(broadcast_axes[1]):] - ) + t0_shape[:len(preserved_axes) + len(broadcast_axes[0])] + + t1_shape[len(t1_shape) - len(broadcast_axes[1]):]) product = _reshape_if_necessary(product, uncompacted_shape) product_axes = ( - sorted_axes[0][:len(preserved_axes)+len(broadcast_axes[0])] + - sorted_axes[1][len(sorted_axes[1])-len(broadcast_axes[1]):] - ) + sorted_axes[0][:len(preserved_axes) + len(broadcast_axes[0])] + + sorted_axes[1][len(sorted_axes[1]) - len(broadcast_axes[1]):]) return product, ''.join(product_axes) @@ -402,13 +400,11 @@ def _total_size(shape_values): def _exponential_space_einsum(equation, *inputs): """Fallback implementation that supports summing an index over > 2 inputs.""" if '...' in equation: - raise ValueError("Subscripts with ellipses are not yet supported.") + raise ValueError('Subscripts with ellipses are not yet supported.') match = re.match('([a-z,]+)(->[a-z]*)?', equation) if not match: - raise ValueError( - 'Indices have incorrect format: %s' % equation - ) + raise ValueError('Indices have incorrect format: %s' % equation) inputs = list(inputs) idx_in = match.group(1).split(',') @@ -425,21 +421,15 @@ def _exponential_space_einsum(equation, *inputs): for ax in axes_: counts[ax] += 1 - idx_out = ''.join(sorted( - ax for ax in indices - if counts[ax] == 1 - )) + idx_out = ''.join(sorted(ax for ax in indices if counts[ax] == 1)) if len(idx_in) != len(inputs): - raise ValueError( - 'Expected %d inputs but got %d' % (len(idx_in), len(inputs)) - ) + raise ValueError('Expected %d inputs but got %d' % (len(idx_in), + len(inputs))) missing_idx = set(idx_out).difference(idx_all) if missing_idx: - raise ValueError( - 'Unknown output axes: %s' % missing_idx - ) + raise ValueError('Unknown output axes: %s' % missing_idx) axis_order = {} for ax in indices: @@ -452,18 +442,17 @@ def _exponential_space_einsum(equation, *inputs): for i, (input_, axes_) in enumerate(zip(inputs, idx_in)): if input_.get_shape().ndims != len(axes_): raise ValueError( - 'Input %d with axes %s has incorrect' \ - ' number of dimensions (expected %d, got %d)' % ( - i, axes_, len(axes_), input_.get_shape().ndims - ) + 'Input %d with axes %s has incorrect' \ + ' number of dimensions (expected %d, got %d)' % ( + i, axes_, len(axes_), input_.get_shape().ndims + ) ) sorted_idx = sorted(axes_, key=axis_order.get) if len(set(axes_)) != len(axes_): raise ValueError( - 'Subscript not supported: an axis appears more than once: %s' % axes_ - ) + 'Subscript not supported: an axis appears more than once: %s' % axes_) if list(axes_) != sorted_idx: permuted = [axes_.find(ax) for ax in sorted_idx] @@ -487,16 +476,15 @@ def _exponential_space_einsum(equation, *inputs): dims.append(dim) if len(set(dims)) > 1: - raise ValueError( - 'Dimension mismatch on axis: %s' % ax - ) + raise ValueError('Dimension mismatch on axis: %s' % ax) if ax not in idx_out: reduction_idx.append(j) # reshape, multiply - expanded_inputs = [array_ops.reshape(input_, shape) - for input_, shape in zip(inputs, shapes)] + expanded_inputs = [ + array_ops.reshape(input_, shape) for input_, shape in zip(inputs, shapes) + ] expanded_output = 1 for input_ in expanded_inputs: expanded_output *= input_ diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index c1a66717d86dd8278dbe676f1714d226351c245f..2c212f45483eacfd3fd27eecb8d7b2c846b5fe96 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -39,8 +39,9 @@ class LBetaTest(test.TestCase): x_one_half = [2, 1.] with self.test_session(use_gpu=True): self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_one)).eval()) - self.assertAllClose( - 0.5, math_ops.exp(special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose(0.5, + math_ops.exp( + special_math_ops.lbeta(x_one_half)).eval()) self.assertEqual([], special_math_ops.lbeta(x_one).get_shape()) def test_one_dimensional_arg_dynamic(self): @@ -70,8 +71,9 @@ class LBetaTest(test.TestCase): # Should evaluate to 1/2. x_one_half = [[2, 1.], [2, 1.]] with self.test_session(use_gpu=True): - self.assertAllClose( - [0.5, 0.5], math_ops.exp(special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose([0.5, 0.5], + math_ops.exp( + special_math_ops.lbeta(x_one_half)).eval()) self.assertEqual((2,), special_math_ops.lbeta(x_one_half).get_shape()) def test_two_dimensional_arg_dynamic(self): @@ -86,10 +88,12 @@ class LBetaTest(test.TestCase): # Should evaluate to 1/2. x_one_half = [[2, 1.], [2, 1.]] with self.test_session(use_gpu=True): - self.assertAllClose( - [0.5, 0.5], math_ops.exp(special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose([0.5, 0.5], + math_ops.exp( + special_math_ops.lbeta(x_one_half)).eval()) self.assertEqual( - (2,), array_ops.shape(special_math_ops.lbeta(x_one_half)).eval()) + (2,), + array_ops.shape(special_math_ops.lbeta(x_one_half)).eval()) self.assertEqual( tensor_shape.TensorShape([2]), special_math_ops.lbeta(x_one_half).get_shape()) @@ -97,8 +101,8 @@ class LBetaTest(test.TestCase): def test_complicated_shape(self): with self.test_session(use_gpu=True): x = ops.convert_to_tensor(np.random.rand(3, 2, 2)) - self.assertAllEqual( - (3, 2), array_ops.shape(special_math_ops.lbeta(x)).eval()) + self.assertAllEqual((3, 2), + array_ops.shape(special_math_ops.lbeta(x)).eval()) self.assertEqual( tensor_shape.TensorShape([3, 2]), special_math_ops.lbeta(x).get_shape()) @@ -155,7 +159,6 @@ class EinsumTest(test.TestCase): 'ijk->i', 'ijk->kji', 'ji,kj->ik', - 'ikl,kji->kl', 'klj,lki->ij', 'ijk,ilj->kli', @@ -164,7 +167,6 @@ class EinsumTest(test.TestCase): 'i,ijk,j->k', 'ij,ij,jk,kl->il', 'ij,kj,il,jm->ml', - 'a,ab,abc->abc', 'a,b,ab->ab', 'ab,ab,c->', @@ -173,25 +175,21 @@ class EinsumTest(test.TestCase): 'ab,ab,cd,cd->ac', 'ab,ab,cd,cd->cd', 'ab,ab,cd,cd,ef,ef->', - 'ab,cd,ef->abcdef', 'ab,cd,ef->acdf', 'ab,cd,de->abcde', 'ab,cd,de->be', 'ab,bcd,cd->abcd', 'ab,bcd,cd->abd', - 'eb,cb,fb->cef', 'abcd,ad', 'bd,db,eac->ace', 'ba,ac,da->bcd', - 'ab,ab', 'ab,ba', 'abc,abc', 'abc,bac', 'abc,cba', - 'dba,ead,cad->bce', 'aef,fbc,dca->bde', ] @@ -234,10 +232,8 @@ class EinsumTest(test.TestCase): def test_invalid(self): for axes in self.invalid_cases: inputs = [ - array_ops.placeholder( - dtypes.float32, shape=(3, 4)), - array_ops.placeholder( - dtypes.float32, shape=(3, 4)), + array_ops.placeholder(dtypes.float32, shape=(3, 4)), + array_ops.placeholder(dtypes.float32, shape=(3, 4)), ] with self.assertRaises(ValueError): _ = special_math_ops.einsum(axes, *inputs) @@ -245,16 +241,22 @@ class EinsumTest(test.TestCase): def test_invalid_keyword_arguments(self): m0 = array_ops.placeholder(dtypes.int32, shape=(1, None)) m1 = array_ops.placeholder(dtypes.int32, shape=(None, 1)) - with self.assertRaisesRegexp(TypeError, + with self.assertRaisesRegexp( + TypeError, 'invalid keyword arguments for this function: invalid1, invalid2'): - _ = special_math_ops.einsum('ij,jk->ik', m0, m1, name="name", - invalid1="value1", invalid2="value2") + _ = special_math_ops.einsum( + 'ij,jk->ik', + m0, + m1, + name='name', + invalid1='value1', + invalid2='value2') def test_dim_mismatch(self): for axes, input_shapes in self.dim_mismatch_cases: inputs = [ - array_ops.placeholder( - dtypes.float32, shape=shape) for shape in input_shapes + array_ops.placeholder(dtypes.float32, shape=shape) + for shape in input_shapes ] with self.assertRaises(ValueError): _ = special_math_ops.einsum(axes, *inputs) @@ -291,8 +293,8 @@ class EinsumTest(test.TestCase): m0: [[1, 2, 3]], m1: [[2], [1], [1]], } - np.testing.assert_almost_equal( - [[7]], sess.run(out, feed_dict=feed_dict)) + np.testing.assert_almost_equal([[7]], sess.run( + out, feed_dict=feed_dict)) with ops.Graph().as_default(): m0 = array_ops.placeholder(dtypes.int32, shape=(None, 3)) @@ -312,11 +314,11 @@ class EinsumTest(test.TestCase): out = special_math_ops.einsum('ijk,kl->ijl', m0, m1) with session.Session() as sess: feed_dict = { - m0: [[[1,2]]], + m0: [[[1, 2]]], m1: [[3], [2]], } - np.testing.assert_almost_equal( - [[[7]]], sess.run(out, feed_dict=feed_dict)) + np.testing.assert_almost_equal([[[7]]], + sess.run(out, feed_dict=feed_dict)) with ops.Graph().as_default(): m0 = array_ops.placeholder(dtypes.int32, shape=(2, 1)) @@ -325,10 +327,10 @@ class EinsumTest(test.TestCase): with session.Session() as sess: feed_dict = { m0: [[3], [2]], - m1: [[[1,2]]], + m1: [[[1, 2]]], } - np.testing.assert_almost_equal( - [[[7]]], sess.run(out, feed_dict=feed_dict)) + np.testing.assert_almost_equal([[[7]]], + sess.run(out, feed_dict=feed_dict)) with ops.Graph().as_default(): m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2)) @@ -339,8 +341,8 @@ class EinsumTest(test.TestCase): m0: [[[1, 2]]], m1: [3, 2], } - np.testing.assert_almost_equal( - [[7]], sess.run(out, feed_dict=feed_dict)) + np.testing.assert_almost_equal([[7]], sess.run( + out, feed_dict=feed_dict)) with ops.Graph().as_default(): m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2)) @@ -351,8 +353,8 @@ class EinsumTest(test.TestCase): m0: [[[[1, 2]], [[2, 1]]]], m1: [[3, 2]], } - np.testing.assert_almost_equal( - [[[7, 8]]], sess.run(out, feed_dict=feed_dict)) + np.testing.assert_almost_equal([[[7, 8]]], + sess.run(out, feed_dict=feed_dict)) if __name__ == '__main__': diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 30bf4e4ef1b96ea68e9020621f37551ac619a3c2..f6d9111009dc4f6a58ac81e7071ed7fe406600fa 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -25,7 +25,9 @@ import sys as _sys # Imports the following modules so that @RegisterGradient get executed. from tensorflow.python.ops import array_grad from tensorflow.python.ops import data_flow_grad +from tensorflow.python.ops import manip_grad from tensorflow.python.ops import math_grad +from tensorflow.python.ops import manip_grad from tensorflow.python.ops import sparse_grad from tensorflow.python.ops import spectral_grad from tensorflow.python.ops import state_grad @@ -42,11 +44,12 @@ from tensorflow.python.ops.special_math_ops import * # TODO(vrv): Switch to import * once we're okay with exposing the module. from tensorflow.python.ops.confusion_matrix import confusion_matrix from tensorflow.python.ops.control_flow_ops import Assert +from tensorflow.python.ops.control_flow_ops import case +from tensorflow.python.ops.control_flow_ops import cond from tensorflow.python.ops.control_flow_ops import group from tensorflow.python.ops.control_flow_ops import no_op -from tensorflow.python.ops.control_flow_ops import tuple -from tensorflow.python.ops.control_flow_ops import cond -from tensorflow.python.ops.control_flow_ops import case +from tensorflow.python.ops.control_flow_ops import tuple # pylint: disable=redefined-builtin +# pylint: enable=redefined-builtin from tensorflow.python.ops.control_flow_ops import while_loop from tensorflow.python.ops.data_flow_ops import * from tensorflow.python.ops.functional_ops import * @@ -59,6 +62,7 @@ from tensorflow.python.ops.logging_ops import Print from tensorflow.python.ops.logging_ops import get_summary_op from tensorflow.python.ops.lookup_ops import initialize_all_tables from tensorflow.python.ops.lookup_ops import tables_initializer +from tensorflow.python.ops.manip_ops import * from tensorflow.python.ops.math_ops import * from tensorflow.python.ops.numerics import * from tensorflow.python.ops.parsing_ops import * @@ -105,6 +109,7 @@ from tensorflow.python.ops import init_ops as _init_ops from tensorflow.python.ops import io_ops as _io_ops from tensorflow.python.ops import linalg_ops as _linalg_ops from tensorflow.python.ops import logging_ops as _logging_ops +from tensorflow.python.ops import manip_ops as _manip_ops from tensorflow.python.ops import math_ops as _math_ops from tensorflow.python.ops import numerics as _numerics from tensorflow.python.ops import parsing_ops as _parsing_ops @@ -264,34 +269,36 @@ _allowed_symbols = (_allowed_symbols_array_ops + _allowed_symbols_misc + _allowed_symbols_partitioned_variables) -remove_undocumented(__name__, _allowed_symbols, - [_sys.modules[__name__], - _array_ops, - _check_ops, - _clip_ops, - _confusion_matrix, - _control_flow_ops, - _constant_op, - _data_flow_ops, - _functional_ops, - _gradients, - _histogram_ops, - _init_ops, - _io_ops, - _linalg_ops, - _logging_ops, - _math_ops, - _numerics, - _parsing_ops, - _partitioned_variables, - _random_ops, - _script_ops, - _session_ops, - _sparse_ops, - _special_math_ops, - _state_ops, - _string_ops, - _template, - _tensor_array_ops, - _variable_scope, - _variables,]) +remove_undocumented(__name__, _allowed_symbols, [ + _sys.modules[__name__], + _array_ops, + _check_ops, + _clip_ops, + _confusion_matrix, + _control_flow_ops, + _constant_op, + _data_flow_ops, + _functional_ops, + _gradients, + _histogram_ops, + _init_ops, + _io_ops, + _linalg_ops, + _logging_ops, + _manip_ops, + _math_ops, + _numerics, + _parsing_ops, + _partitioned_variables, + _random_ops, + _script_ops, + _session_ops, + _sparse_ops, + _special_math_ops, + _state_ops, + _string_ops, + _template, + _tensor_array_ops, + _variable_scope, + _variables, +]) diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 3cc76fdbf34ff6de47d98400cd826d671c9178eb..6c0a090d16bb328de40f02edf9865a0e0a62d385 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -278,7 +278,7 @@ def assign(ref, value, validate_shape=None, use_locking=None, name=None): return gen_state_ops.assign( ref, value, use_locking=use_locking, name=name, validate_shape=validate_shape) - return ref.assign(value) + return ref.assign(value, name=name) @tf_export("count_up_to") @@ -353,11 +353,9 @@ def scatter_update(ref, indices, updates, use_locking=True, name=None): if ref.dtype._is_ref_dtype: return gen_state_ops.scatter_update(ref, indices, updates, use_locking=use_locking, name=name) - with ops.control_dependencies( - [gen_resource_variable_ops.resource_scatter_update( - ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), - name=name)]): - return ref.read_value() + return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update( # pylint: disable=protected-access + ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), + name=name)) @tf_export("scatter_nd_update") diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 84449e00beb4d2901f57c7cd41a4e755fe343c8c..424582b348d87d8a5b043ec9b771d8f2768a5994 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -140,7 +140,7 @@ def make_template(name_, func_, create_scope_now_=False, unique_name_=None, re-enter the scope and reuse those variables. Raises: - ValueError: if the name is None. + ValueError: if `name_` is None. """ return make_template_internal( name_, @@ -176,16 +176,14 @@ def make_template_internal(name_, custom_getter_: Optional custom getter for variables used in `func_`. See the @{tf.get_variable} `custom_getter` documentation for more information. - create_graph_function_: When True, the first invocation of the template will - execute `func_` as is, to allow for variable creation; however, the second - invocation and every invocation thereafter will execute func as a graph - function. In particular, this implies that `func_` must satisfy the - properties that `function.defun` requires of functions: See the - documentation of `function.defun` for details. When executing eagerly, - setting this flag to True can improve performance. Regardless of whether - eager execution is enabled, enabling this flag gives the caller access to - graph-function semantics, i.e., accesses to variables are totally ordered - and side-effecting ops are not pruned. + create_graph_function_: When True, `func_` will be executed as a graph + function. This implies that `func_` must satisfy the properties that + `function.defun` requires of functions: See the documentation of + `function.defun` for details. When executing eagerly, setting this flag to + True can improve performance. Regardless of whether eager execution is + enabled, enabling this flag gives the caller access to graph-function + semantics, i.e., accesses to variables are totally ordered and + side-effecting ops are not pruned. **kwargs: Keyword arguments to apply to `func_`. Returns: @@ -198,8 +196,8 @@ def make_template_internal(name_, re-enter the scope and reuse those variables. Raises: - ValueError: if the name is None. - ValueError: if unique_name_ is not None and eager execution is enabled. + ValueError: if `name_` is None. + ValueError: if `unique_name_` is not None and eager execution is enabled. """ if kwargs: @@ -266,18 +264,18 @@ class Template(object): template of the same scope/unique_name already exists and reuse is false, an error is raised. Defaults to None. custom_getter: optional custom getter to pass to `variable_scope()` - create_graph_function: When True, the first invocation of the template - will execute `func` as is, to allow for variable creation; however, the - second invocation and every invocation thereafter will execute `func` as - a graph function. Enabling this flag gives the caller access to - graph-function semantics, i.e., accesses to variables are totally - ordered and side-effecting ops are not pruned. - + create_graph_function: When True, `func` will be executed as a graph + function. Enabling this flag gives the caller access to graph-function + semantics, i.e., accesses to variables are totally ordered and + side-effecting ops are not pruned. Raises: - ValueError: if the name is None. + ValueError: if `name` is None. """ - self._func = func + if create_graph_function: + self._func = function.defun(func) + else: + self._func = func self._stacktrace = traceback.format_stack()[:-2] self._name = name self._unique_name = unique_name @@ -295,19 +293,13 @@ class Template(object): # This variable keeps track of whether the template has been called yet, # which is not the same as whether the scope has been created. self._variables_created = False - self._create_graph_function = create_graph_function def _call_func(self, args, kwargs): try: vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) trainable_at_start = len( ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) - result = self._func(*args, **kwargs) - if self._create_graph_function and not self._variables_created: - # Only execute self._func as a graph function once variables are - # created. - self._func = function.defun(self._func) if self._variables_created: # Variables were previously created, implying this is not the first @@ -542,14 +534,11 @@ class EagerTemplate(Template): names of all created Tensors. If set to False, the scope will be created at the first call location. custom_getter: optional custom getter to pass to `variable_scope()` - create_graph_function: When True, the first invocation of the template - will execute `func` as is, to allow for variable creation; however, the - second invocation and every invocation thereafter will execute `func` as - a graph function. Enabling this flag allows the caller to reap the - performance benefits associated with executing graphs, at the cost of - sacrificing debuggability; however, not all functions can be compiled - into graph functions. See the documentation for `function.defun` for - details. + create_graph_function: When True, `func` will be executed as a graph + function. Enabling this flag allows the caller to reap the performance + benefits associated with executing graphs, at the cost of sacrificing + debuggability; however, not all Python functions can be compiled into + graph functions. See the documentation for `function.defun` for details. Raises: RuntimeError: if eager execution is not enabled. @@ -568,17 +557,13 @@ class EagerTemplate(Template): # is created in __call__. variable_scope_name = None self._template_store = _EagerTemplateVariableStore(variable_scope_name) + self._variable_scope_context_manager = None def _call_func(self, args, kwargs): try: vars_at_start = self._template_store.variables() trainable_at_start = self._template_store.trainable_variables() - result = self._func(*args, **kwargs) - if self._create_graph_function and not self._variables_created: - # Only execute self._func as a graph function once variables are - # created. - self._func = function.defun(self._func) if self._variables_created: # Variables were previously created, implying this is not the first @@ -627,8 +612,12 @@ class EagerTemplate(Template): # the variable scope is opened in order to ensure that templates nested at # the same level correctly uniquify lower variable scope names. if self._variable_scope: - with variable_scope.variable_scope( - self._variable_scope, reuse=variable_scope.AUTO_REUSE): + # Create a cache for the variable scope context manager the first time + # around so that we don't have to keep recreating it. + if not self._variable_scope_context_manager: + self._variable_scope_context_manager = variable_scope.variable_scope( + self._variable_scope, reuse=variable_scope.AUTO_REUSE) + with self._variable_scope_context_manager: with self._template_store.as_default(): result = self._call_func(args, kwargs) return result diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 5cdf03509e3c427deec7e26345059211001e2131..3c08870146e447d84d4a5f620cbead633d94751f 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -653,7 +653,7 @@ class _EagerTensorArray(object): if len(tensors) > len(self._tensor_array) and not self._dynamic_size: raise ValueError( "Cannot unstack %d tensors into a TensorArray of static size %d" % - (len(tensors), len(self._tensors))) + (len(tensors), len(self._tensor_array))) ta = self._identity_without_array() ta._implementation._tensor_array = tensors # pylint: disable=protected-access return ta diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index db594ac6a0bd3c5380ec4dc368a091dbc48980eb..81565a63774da49628d100ef071b02f6311f6af2 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -771,8 +771,8 @@ class _VariableStore(object): if initializer is None: initializer, initializing_from_value = self._get_default_initializer( name=name, shape=shape, dtype=dtype) - # Clear control dependencies while creating the initializer. - with ops.control_dependencies(None): + # Enter an init scope when creating the initializer. + with ops.init_scope(): if initializing_from_value: init_val = initializer variable_dtype = None diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 7d7fa646c08523c5f572f8f4593c1d8fe8615c67..b785d0ede7433ab3105a3c865b998a8d19b6ea78 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -28,6 +28,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpointable from tensorflow.python.util import compat from tensorflow.python.util import tf_should_use from tensorflow.python.util.deprecation import deprecated @@ -35,7 +37,7 @@ from tensorflow.python.util.tf_export import tf_export @tf_export("Variable") -class Variable(object): +class Variable(checkpointable.CheckpointableBase): """See the @{$variables$Variables How To} for a high level overview. A variable maintains state in the graph across calls to `run()`. You add a @@ -211,6 +213,7 @@ class Variable(object): if not context.in_graph_mode(): raise RuntimeError("tf.Variable not supported in Eager mode. " "Please use tfe.Variable instead") + self._in_graph_mode = context.in_graph_mode() if variable_def: # If variable_def is provided, recreates the variable from its fields. if initial_value: @@ -304,9 +307,14 @@ class Variable(object): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") + if isinstance(initial_value, checkpointable.CheckpointInitialValue): + self._maybe_initialize_checkpointable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] - with ops.control_dependencies(None): + with ops.init_scope(): with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: @@ -377,8 +385,8 @@ class Variable(object): else: with ops.colocate_with(self._variable.op): self._snapshot = array_ops.identity(self._variable, name="read") + ops.add_to_collections(collections, self) - ops.add_to_collections(collections, self) self._caching_device = caching_device self._save_slice_info = None self._constraint = constraint @@ -552,7 +560,7 @@ class Variable(object): A `Tensor` holding the value of this variable after its initializer has run. """ - with ops.control_dependencies(None): + with ops.init_scope(): return control_flow_ops.cond(is_variable_initialized(self), self.read_value, lambda: self.initial_value) @@ -784,6 +792,20 @@ class Variable(object): setattr(Variable, operator, _run_op) + def _scatter_tensors_from_checkpoint(self, attributes): + """For implementing `Checkpointable`. Return an assignment op to run.""" + if (len(attributes) != 1 + or checkpointable.VARIABLE_VALUE_KEY not in attributes): + raise ValueError( + ("The variable %s was restored with unexpected values (expected one " + "with key %s, got %s)") % ( + self, checkpointable.VARIABLE_VALUE_KEY, attributes)) + return self.assign(attributes[checkpointable.VARIABLE_VALUE_KEY]) + + def _gather_tensors_for_checkpoint(self): + """For implementing `Checkpointable`. This object is saveable on its own.""" + return {checkpointable.VARIABLE_VALUE_KEY: self} + def _try_guard_against_uninitialized_dependencies(self, initial_value): """Attempt to guard against dependencies on uninitialized variables. @@ -1021,6 +1043,61 @@ class Variable(object): return Variable(variable_def=variable_def, import_scope=import_scope) + def __iadd__(self, other): + logging.log_first_n( + logging.WARN, + "Variable += will be deprecated. Use variable.assign_add" + " if you want assignment to the variable value or 'x = x + y'" + " if you want a new python Tensor object.", 1) + return self + other + + def __isub__(self, other): + logging.log_first_n( + logging.WARN, + "Variable -= will be deprecated. Use variable.assign_sub" + " if you want assignment to the variable value or 'x = x - y'" + " if you want a new python Tensor object.", 1) + return self - other + + def __imul__(self, other): + logging.log_first_n( + logging.WARN, + "Variable *= will be deprecated. Use variable.assign_mul" + " if you want assignment to the variable value or 'x = x * y'" + " if you want a new python Tensor object.", 1) + return self * other + + def __idiv__(self, other): + logging.log_first_n( + logging.WARN, + "Variable /= will be deprecated. Use variable.assign_div" + " if you want assignment to the variable value or 'x = x / y'" + " if you want a new python Tensor object.", 1) + return self / other + + def __itruediv__(self, other): + logging.log_first_n( + logging.WARN, + "Variable /= will be deprecated. Use variable.assign_div" + " if you want assignment to the variable value or 'x = x / y'" + " if you want a new python Tensor object.", 1) + return self / other + + def __irealdiv__(self, other): + logging.log_first_n( + logging.WARN, + "Variable /= will be deprecated. Use variable.assign_div" + " if you want assignment to the variable value or 'x = x / y'" + " if you want a new python Tensor object.", 1) + return self / other + + def __ipow__(self, other): + logging.log_first_n( + logging.WARN, + "Variable **= will be deprecated. Use 'x = x ** y'" + " if you want a new python Tensor object.", 1) + return self ** other + class SaveSliceInfo(object): """Information on how to save this Variable as a slice. diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py index 9b92d9a18005ca5e6be3820427e3a3ba60a8ec2d..cce64c0ccafc29a9d0d0b51b4c97c5673264657b 100644 --- a/tensorflow/python/platform/app.py +++ b/tensorflow/python/platform/app.py @@ -23,6 +23,7 @@ import sys as _sys from tensorflow.python.platform import flags from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export def _usage(shorthelp): @@ -108,6 +109,7 @@ def _define_help_flags(): _define_help_flags_called = True +@tf_export('app.run') def run(main=None, argv=None): """Runs the program with an optional 'main' function and 'argv' list.""" diff --git a/tensorflow/python/platform/resource_loader.py b/tensorflow/python/platform/resource_loader.py index 2455acb4c0c469acbb928c4ec44571e50e06de1f..8f7b12e2b2b92d9b2bfe397d0e7cba59e11bc1f6 100644 --- a/tensorflow/python/platform/resource_loader.py +++ b/tensorflow/python/platform/resource_loader.py @@ -29,8 +29,10 @@ import sys as _sys from tensorflow.python.util import tf_inspect as _inspect from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export +@tf_export('resource_loader.load_resource') def load_resource(path): """Load the resource at given path, where path is relative to tensorflow/. @@ -52,6 +54,7 @@ def load_resource(path): # pylint: disable=protected-access +@tf_export('resource_loader.get_data_files_path') def get_data_files_path(): """Get a direct path to the data files colocated with the script. @@ -62,6 +65,7 @@ def get_data_files_path(): return _os.path.dirname(_inspect.getfile(_sys._getframe(1))) +@tf_export('resource_loader.get_root_dir_with_all_resources') def get_root_dir_with_all_resources(): """Get a root directory containing all the data attributes in the build rule. @@ -101,6 +105,7 @@ def get_root_dir_with_all_resources(): return data_files_dir or script_dir +@tf_export('resource_loader.get_path_to_datafile') def get_path_to_datafile(path): """Get the path to the specified file in the data dependencies. @@ -120,6 +125,7 @@ def get_path_to_datafile(path): return _os.path.join(data_files_path, path) +@tf_export('resource_loader.readahead_file_path') def readahead_file_path(path, readahead='128M'): # pylint: disable=unused-argument """Readahead files not implemented; simply returns given path.""" return path diff --git a/tensorflow/python/platform/stacktrace_handler_test.py b/tensorflow/python/platform/stacktrace_handler_test.py index 3f0e534f4cbd97ecbd7db1fae3b48af72310c24f..f2071f9d54ceb99831999ec08ab71d63862f1c36 100644 --- a/tensorflow/python/platform/stacktrace_handler_test.py +++ b/tensorflow/python/platform/stacktrace_handler_test.py @@ -57,7 +57,8 @@ class StacktraceHandlerTest(test.TestCase): # Capture its output. capture both stdout and stderr and append them. # We are not worried about timing or order of messages in this test. - child_output = child_process.stdout.read() + child_process.stderr.read() + child_stdout, child_stderr = child_process.communicate() + child_output = child_stdout + child_stderr # Make sure the child process is dead before we proceed. child_process.wait() diff --git a/tensorflow/python/platform/tf_logging.py b/tensorflow/python/platform/tf_logging.py index 85ed4f071c7022801f20db75d538e5917b8eea66..22aabfd7121ac9b2eebeae2693f174e044d504ef 100644 --- a/tensorflow/python/platform/tf_logging.py +++ b/tensorflow/python/platform/tf_logging.py @@ -35,6 +35,7 @@ import threading import six from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export # Don't use this directly. Use _get_logger() instead. @@ -90,30 +91,37 @@ def _get_logger(): _logger_lock.release() +@tf_export('logging.log') def log(level, msg, *args, **kwargs): _get_logger().log(level, msg, *args, **kwargs) +@tf_export('logging.debug') def debug(msg, *args, **kwargs): _get_logger().debug(msg, *args, **kwargs) +@tf_export('logging.error') def error(msg, *args, **kwargs): _get_logger().error(msg, *args, **kwargs) +@tf_export('logging.fatal') def fatal(msg, *args, **kwargs): _get_logger().fatal(msg, *args, **kwargs) +@tf_export('logging.info') def info(msg, *args, **kwargs): _get_logger().info(msg, *args, **kwargs) +@tf_export('logging.warn') def warn(msg, *args, **kwargs): _get_logger().warn(msg, *args, **kwargs) +@tf_export('logging.warning') def warning(msg, *args, **kwargs): _get_logger().warning(msg, *args, **kwargs) @@ -136,15 +144,18 @@ _log_prefix = None # later set to google2_log_prefix _log_counter_per_token = {} +@tf_export('logging.TaskLevelStatusMessage') def TaskLevelStatusMessage(msg): error(msg) +@tf_export('logging.flush') def flush(): raise NotImplementedError() # Code below is taken from pyglib/logging +@tf_export('logging.vlog') def vlog(level, msg, *args, **kwargs): _get_logger().log(level, msg, *args, **kwargs) @@ -164,6 +175,7 @@ def _GetNextLogCountPerToken(token): return _log_counter_per_token[token] +@tf_export('logging.log_every_n') def log_every_n(level, msg, n, *args): """Log 'msg % args' at level 'level' once per 'n' times. @@ -180,6 +192,7 @@ def log_every_n(level, msg, n, *args): log_if(level, msg, not (count % n), *args) +@tf_export('logging.log_first_n') def log_first_n(level, msg, n, *args): # pylint: disable=g-bad-name """Log 'msg % args' at level 'level' only first 'n' times. @@ -195,6 +208,7 @@ def log_first_n(level, msg, n, *args): # pylint: disable=g-bad-name log_if(level, msg, count < n, *args) +@tf_export('logging.log_if') def log_if(level, msg, condition, *args): """Log 'msg % args' at level 'level' only if condition is fulfilled.""" if condition: @@ -251,11 +265,13 @@ def google2_log_prefix(level, timestamp=None, file_and_line=None): return s +@tf_export('logging.get_verbosity') def get_verbosity(): """Return how much logging output will be produced.""" return _get_logger().getEffectiveLevel() +@tf_export('logging.set_verbosity') def set_verbosity(v): """Sets the threshold for what messages will be logged.""" _get_logger().setLevel(v) @@ -296,4 +312,10 @@ _allowed_symbols = [ 'warning', ] +tf_export('logging.DEBUG').export_constant(__name__, 'DEBUG') +tf_export('logging.ERROR').export_constant(__name__, 'ERROR') +tf_export('logging.FATAL').export_constant(__name__, 'FATAL') +tf_export('logging.INFO').export_constant(__name__, 'INFO') +tf_export('logging.WARN').export_constant(__name__, 'WARN') + remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py index 8f780545607f7ba2337c83ad2c3740f542b802f6..0e20ca35bba606079ed5b0f225dd3029772b5af3 100644 --- a/tensorflow/python/profiler/model_analyzer.py +++ b/tensorflow/python/profiler/model_analyzer.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.profiler import option_builder from tensorflow.python.profiler import tfprof_logger +from tensorflow.python.util.tf_export import tf_export _DEFAULT_PROFILE_OPTIONS = 0 _DEFAULT_ADVISE_OPTIONS = 0 @@ -121,6 +122,7 @@ def _build_advisor_options(options): return opts +@tf_export('profiler.Profiler') class Profiler(object): """TensorFlow multi-step profiler. @@ -304,6 +306,7 @@ class Profiler(object): print_mdl.WriteProfile(filename) +@tf_export('profiler.profile') def profile(graph=None, run_meta=None, op_log=None, @@ -378,6 +381,7 @@ def profile(graph=None, return tfprof_node +@tf_export('profiler.advise') def advise(graph=None, run_meta=None, options=_DEFAULT_ADVISE_OPTIONS): """Auto profile and advise. diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py index 915385558889e64277611bd71251f8f937a18159..04ba28c219e276e1ca79bd4e20e7d1b6ee700db5 100644 --- a/tensorflow/python/profiler/model_analyzer_test.py +++ b/tensorflow/python/profiler/model_analyzer_test.py @@ -224,15 +224,15 @@ class PrintModelAnalysisTest(test.TestCase): # pylint: disable=line-too-long with gfile.Open(outfile, 'r') as f: lines = f.read().split('\n') + self.assertGreater(len(lines), 5) result = '\n'.join([l[:min(len(l), 80)] for l in lines]) - self.assertEqual( - compat.as_bytes( - 'node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/168.86k flops)\n model_analyzer_testlib.py:63:BuildFullModel (0/1.80k params, 0/45.37k flops)\n model_analyzer_testlib.py:40:BuildSmallModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (0/4 params, 0/8 flops)\n model_analyzer_testlib.py:48:BuildSmallModel (0/648 params, 0/1.30k flops)\n model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:53:BuildSmallModel (0/1.15k params, 0/2.30k flops)\n model_analyzer_testlib.py:54:BuildSmallModel (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:63:BuildFullModel (gradient) (0/0 params, 0/67.39k f\n model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/46.66\n model_analyzer_testlib.py:54:BuildSmallModel (gradient) (0/0 params, 0/20.74\n model_analyzer_testlib.py:67:BuildFullModel (0/1.04k params, 0/18.58k flops)\n model_analyzer_testlib.py:67:BuildFullModel (gradient) (0/0 params, 0/37.00k f\n model_analyzer_testlib.py:69:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (0/0 params, 0/258 flops)\n model_analyzer_testlib.py:70:BuildFullModel (gradient) (0/0 params, 0/129 flop\n model_analyzer_testlib.py:72:BuildFullModel (0/0 params, 0/141 flops)\n' - ), compat.as_bytes(lib.CheckAndRemoveDoc(result))) + self.assertTrue( + compat.as_text(lib.CheckAndRemoveDoc(result)) + .startswith('node name | # parameters | # float_ops')) self.assertLess(0, tfprof_node.total_exec_micros) self.assertEqual(2844, tfprof_node.total_parameters) - self.assertEqual(168863, tfprof_node.total_float_ops) + self.assertLess(168800, tfprof_node.total_float_ops) self.assertEqual(8, len(tfprof_node.children)) self.assertEqual('_TFProfRoot', tfprof_node.name) self.assertEqual( diff --git a/tensorflow/python/profiler/option_builder.py b/tensorflow/python/profiler/option_builder.py index 13942ad6a2adc1f1d1cad778ebd280d358f64a59..2ad7adf76933df65ca795dca361397f436adb995 100644 --- a/tensorflow/python/profiler/option_builder.py +++ b/tensorflow/python/profiler/option_builder.py @@ -20,8 +20,10 @@ from __future__ import print_function import copy from tensorflow.python.profiler import tfprof_logger +from tensorflow.python.util.tf_export import tf_export +@tf_export('profiler.ProfileOptionBuilder') class ProfileOptionBuilder(object): # pylint: disable=line-too-long """Option Builder for Profiling API. @@ -298,7 +300,7 @@ class ProfileOptionBuilder(object): # pylint: disable=line-too-long """Only show profiler nodes consuming no less than 'min_float_ops'. - Please see https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/profilerg3doc/profile_model_architecture.md + Please see https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/profiler/g3doc/profile_model_architecture.md on the caveats of calculating float operations. Args: diff --git a/tensorflow/python/profiler/profiler.py b/tensorflow/python/profiler/profiler.py index 130dcb5134d6f7e6eb43aebea803b366a5ce27d8..fa7f30b236997cecd6d5df98c334aa7f5cc571e4 100644 --- a/tensorflow/python/profiler/profiler.py +++ b/tensorflow/python/profiler/profiler.py @@ -31,6 +31,7 @@ from tensorflow.python.profiler.option_builder import ProfileOptionBuilder from tensorflow.python.profiler.tfprof_logger import write_op_log from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export _allowed_symbols = [ @@ -48,6 +49,12 @@ _allowed_symbols.extend([ 'OpLogProto', ]) +# Export protos +tf_export('profiler.GraphNodeProto')(GraphNodeProto) +tf_export('profiler.MultiGraphNodeProto')(MultiGraphNodeProto) +tf_export('profiler.AdviceProto')(AdviceProto) +tf_export('profiler.OpLogProto')(OpLogProto) + remove_undocumented(__name__, _allowed_symbols, [ Profiler, profile, diff --git a/tensorflow/python/profiler/tfprof_logger.py b/tensorflow/python/profiler/tfprof_logger.py index ffda7ddad759ce68bf718bcfa6e568cfadd59b53..8d121064967f2f87cd0aefaa361bfd6f387a3e6e 100644 --- a/tensorflow/python/profiler/tfprof_logger.py +++ b/tensorflow/python/profiler/tfprof_logger.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import gfile from tensorflow.python.profiler.internal import flops_registry # pylint: disable=unused-import +from tensorflow.python.util.tf_export import tf_export TRAINABLE_VARIABLES = '_trainable_variables' REGISTERED_FLOP_STATS = 'flops' @@ -187,6 +188,7 @@ def merge_default_with_oplog(graph, op_log=None, run_meta=None, return tmp_op_log +@tf_export('profiler.write_op_log') def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True): """Log provided 'op_log', and add additional model information below. diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 3f25311a8361d11fbc583413708e148648d95906..50f481d29e9d39bd12741b5f9e02b7201336134d 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -29,6 +29,7 @@ limitations under the License. %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; %rename("%s") TFE_Py_RegisterExceptionClass; +%rename("%s") TFE_Py_RegisterFallbackExceptionClass; %rename("%s") TFE_Py_Execute; %rename("%s") TFE_Py_FastPathExecute; %rename("%s") TFE_Py_UID; diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index e34aa7cc2ca41ecdd7c9ff52ab8f3d552f26fe69..30e0a099d8b2e30cff36b69164ba9f1789dd8916 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -148,6 +148,7 @@ py_test( "//tensorflow/python:math_ops", "//tensorflow/python:saver_test_utils", "//tensorflow/python:state_ops", + "//tensorflow/python:test_ops", "//tensorflow/python:util", "//tensorflow/python:variables", ], diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 62ee53b816c2a38327fa116d2924446e6bf24a1e..7347da75364818b95d3f2ad7dfa74a8c3614b161 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -34,8 +34,10 @@ from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants from tensorflow.python.training import saver as tf_saver from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export +@tf_export("saved_model.builder.SavedModelBuilder") class SavedModelBuilder(object): """Builds the `SavedModel` protocol buffer and saves variables and assets. diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py index 7e3e8df47fb0e024eae8add6a788d632709740af..ec49a0539ff52f6cc69bb24483ede657b698ab8d 100644 --- a/tensorflow/python/saved_model/constants.py +++ b/tensorflow/python/saved_model/constants.py @@ -20,33 +20,52 @@ from __future__ import division from __future__ import print_function from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export # Subdirectory name containing the asset files. ASSETS_DIRECTORY = "assets" +tf_export("saved_model.constants.ASSETS_DIRECTORY").export_constant( + __name__, "ASSETS_DIRECTORY") # CollectionDef key containing SavedModel assets. ASSETS_KEY = "saved_model_assets" +tf_export("saved_model.constants.ASSETS_KEY").export_constant( + __name__, "ASSETS_KEY") # CollectionDef key for the legacy init op. LEGACY_INIT_OP_KEY = "legacy_init_op" +tf_export("saved_model.constants.LEGACY_INIT_OP_KEY").export_constant( + __name__, "LEGACY_INIT_OP_KEY") # CollectionDef key for the SavedModel main op. MAIN_OP_KEY = "saved_model_main_op" +tf_export("saved_model.constants.MAIN_OP_KEY").export_constant( + __name__, "MAIN_OP_KEY") # Schema version for SavedModel. SAVED_MODEL_SCHEMA_VERSION = 1 +tf_export("saved_model.constants.SAVED_MODEL_SCHEMA_VERSION").export_constant( + __name__, "SAVED_MODEL_SCHEMA_VERSION") # File name for SavedModel protocol buffer. SAVED_MODEL_FILENAME_PB = "saved_model.pb" +tf_export("saved_model.constants.SAVED_MODEL_FILENAME_PB").export_constant( + __name__, "SAVED_MODEL_FILENAME_PB") # File name for text version of SavedModel protocol buffer. SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt" +tf_export("saved_model.constants.SAVED_MODEL_FILENAME_PBTXT").export_constant( + __name__, "SAVED_MODEL_FILENAME_PBTXT") # Subdirectory name containing the variables/checkpoint files. VARIABLES_DIRECTORY = "variables" +tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant( + __name__, "VARIABLES_DIRECTORY") # File name used for variables. VARIABLES_FILENAME = "variables" +tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant( + __name__, "VARIABLES_FILENAME") _allowed_symbols = [ diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 5ff954fd9f83989565e007cad3f0f66913e0a4dd..bebf1d5e0d3cc6ac0e431230577704365d37a437 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -32,6 +32,7 @@ from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants from tensorflow.python.training import saver as tf_saver from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export def _parse_saved_model(export_dir): @@ -156,6 +157,7 @@ def _get_legacy_init_op_tensor(meta_graph_def_to_load): return legacy_init_op_tensor +@tf_export("saved_model.loader.maybe_saved_model_directory") def maybe_saved_model_directory(export_dir): """Checks whether the provided export directory could contain a SavedModel. @@ -176,6 +178,7 @@ def maybe_saved_model_directory(export_dir): return file_io.file_exists(txt_path) or file_io.file_exists(pb_path) +@tf_export("saved_model.loader.load") def load(sess, tags, export_dir, **saver_kwargs): """Loads the model from a SavedModel as specified by tags. @@ -232,13 +235,10 @@ def load(sess, tags, export_dir, **saver_kwargs): asset_tensors_dictionary = _get_asset_tensors(export_dir, meta_graph_def_to_load) - main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load) + main_op_tensor = ( + _get_main_op_tensor(meta_graph_def_to_load) or + (_get_legacy_init_op_tensor(meta_graph_def_to_load))) if main_op_tensor is not None: sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) - else: - legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load) - if legacy_init_op_tensor is not None: - sess.run( - fetches=[legacy_init_op_tensor], feed_dict=asset_tensors_dictionary) return meta_graph_def_to_load diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py index 355fd57bf1d2166f58a5fdc95d04695ea05b56b3..631ee63729513d24c2ddae71b771f7cf1695358f 100644 --- a/tensorflow/python/saved_model/main_op_impl.py +++ b/tensorflow/python/saved_model/main_op_impl.py @@ -22,8 +22,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables +from tensorflow.python.util.tf_export import tf_export +@tf_export('saved_model.main_op.main_op') def main_op(): """Returns a main op to init variables and tables. @@ -40,6 +42,7 @@ def main_op(): # TODO(sukritiramesh): Integrate with Saver for complete restore functionality. +@tf_export('saved_model.main_op.main_op_with_restore') def main_op_with_restore(restore_op_name): """Returns a main op to init variables, tables and restore the graph. diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 1ea619ff55dea00f8ee09024ab45dcd324a2ddce..d9d316882584470769c14cf0c5f265b58e37ab43 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import os -from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 @@ -28,8 +27,8 @@ from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops +from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import control_flow_ops @@ -54,8 +53,14 @@ def tearDownModule(): file_io.delete_recursively(test.get_temp_dir()) +@test_util.with_c_api class SavedModelTest(test.TestCase): + def _get_export_dir(self, label): + if ops._USE_C_API: + label += "_c_api" + return os.path.join(test.get_temp_dir(), label) + def _init_and_validate_variable(self, sess, variable_name, variable_value): v = variables.Variable(variable_value, name=variable_name) sess.run(variables.global_variables_initializer()) @@ -123,8 +128,7 @@ class SavedModelTest(test.TestCase): self.assertFalse(loader.maybe_saved_model_directory(base_path)) def testBadSavedModelFileFormat(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_bad_saved_model_file_format") + export_dir = self._get_export_dir("test_bad_saved_model_file_format") # Attempt to load a SavedModel from an export directory that does not exist. with self.test_session(graph=ops.Graph()) as sess: with self.assertRaisesRegexp(IOError, @@ -157,8 +161,7 @@ class SavedModelTest(test.TestCase): loader.load(sess, ["foo"], export_dir) def testVerifySessionGraphUsage(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_verify_session_graph_usage") + export_dir = self._get_export_dir("test_verify_session_graph_usage") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -178,7 +181,7 @@ class SavedModelTest(test.TestCase): 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) def testSequence(self): - export_dir = os.path.join(test.get_temp_dir(), "test_sequence") + export_dir = self._get_export_dir("test_sequence") builder = saved_model_builder.SavedModelBuilder(export_dir) # Expect an assertion error since add_meta_graph_and_variables() should be @@ -195,7 +198,7 @@ class SavedModelTest(test.TestCase): sess, ["baz"]) def testTags(self): - export_dir = os.path.join(test.get_temp_dir(), "test_tags") + export_dir = self._get_export_dir("test_tags") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with a single variable. SavedModel invoked to: @@ -284,7 +287,7 @@ class SavedModelTest(test.TestCase): export_dir) def testVariables(self): - export_dir = os.path.join(test.get_temp_dir(), "test_variables") + export_dir = self._get_export_dir("test_variables") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with two variables. SavedModel invoked to: @@ -336,7 +339,7 @@ class SavedModelTest(test.TestCase): export_dir) def testGraphWithoutVariables(self): - export_dir = os.path.join(test.get_temp_dir(), "test_graph_has_variables") + export_dir = self._get_export_dir("test_graph_has_variables") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with no variables. @@ -371,7 +374,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(30.0, sess.run(c)) def testNoOverwrite(self): - export_dir = os.path.join(test.get_temp_dir(), "test_no_overwrite") + export_dir = self._get_export_dir("test_no_overwrite") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with a single variable. SavedModel invoked to: @@ -395,7 +398,7 @@ class SavedModelTest(test.TestCase): export_dir) def testSaveAsText(self): - export_dir = os.path.join(test.get_temp_dir(), "test_astext") + export_dir = self._get_export_dir("test_astext") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with a single variable. SavedModel invoked to: @@ -426,7 +429,7 @@ class SavedModelTest(test.TestCase): 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) def testCollections(self): - export_dir = os.path.join(test.get_temp_dir(), "test_collections") + export_dir = self._get_export_dir("test_collections") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with a single variable added to a collection. SavedModel invoked to: @@ -476,7 +479,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(len(ops.get_collection("foo_vars")), 0) def testSignatureDefs(self): - export_dir = os.path.join(test.get_temp_dir(), "test_signature_defs") + export_dir = self._get_export_dir("test_signature_defs") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with a single variable and a single entry in the signature def map. @@ -536,8 +539,7 @@ class SavedModelTest(test.TestCase): self.assertEqual("foo_new", bar_signature["foo_key"].method_name) def testSignatureDefValidation(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_signature_def_validation") + export_dir = self._get_export_dir("test_signature_def_validation") builder = saved_model_builder.SavedModelBuilder(export_dir) tensor_without_name = meta_graph_pb2.TensorInfo() @@ -555,7 +557,7 @@ class SavedModelTest(test.TestCase): self._validate_outputs_tensor_info(builder, tensor_empty) def testAssets(self): - export_dir = os.path.join(test.get_temp_dir(), "test_assets") + export_dir = self._get_export_dir("test_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -588,7 +590,7 @@ class SavedModelTest(test.TestCase): self.assertFalse(file_io.file_exists(ignored_asset_path)) def testCustomMainOp(self): - export_dir = os.path.join(test.get_temp_dir(), "test_main_op") + export_dir = self._get_export_dir("test_main_op") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -623,7 +625,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(3, ops.get_collection("v")[2].eval()) def testLegacyInitOp(self): - export_dir = os.path.join(test.get_temp_dir(), "test_legacy_init_op") + export_dir = self._get_export_dir("test_legacy_init_op") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -657,8 +659,8 @@ class SavedModelTest(test.TestCase): self.assertEqual(3, ops.get_collection("v")[2].eval()) def testLegacyInitOpWithNonEmptyCollection(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_legacy_init_op_with_non_empty_collection") + export_dir = self._get_export_dir( + "test_legacy_init_op_with_non_empty_collection") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -685,7 +687,7 @@ class SavedModelTest(test.TestCase): sess, ["foo"], legacy_init_op=legacy_init_op) def testMultipleAssets(self): - export_dir = os.path.join(test.get_temp_dir(), "test_multiple_assets") + export_dir = self._get_export_dir("test_multiple_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -727,7 +729,7 @@ class SavedModelTest(test.TestCase): "asset_file_tensor:0") def testDuplicateAssets(self): - export_dir = os.path.join(test.get_temp_dir(), "test_duplicate_assets") + export_dir = self._get_export_dir("test_duplicate_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) with self.test_session(graph=ops.Graph()) as sess: @@ -775,7 +777,7 @@ class SavedModelTest(test.TestCase): "asset_file_tensor:0") def testOp(self): - export_dir = os.path.join(test.get_temp_dir(), "test_op") + export_dir = self._get_export_dir("test_op") builder = saved_model_builder.SavedModelBuilder(export_dir) with session.Session( @@ -818,7 +820,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(3, ops.get_collection("v")[2].eval()) def testCustomSaveable(self): - export_dir = os.path.join(test.get_temp_dir(), "custom_saveable") + export_dir = self._get_export_dir("custom_saveable") builder = saved_model_builder.SavedModelBuilder(export_dir) with session.Session( @@ -847,7 +849,7 @@ class SavedModelTest(test.TestCase): self.assertEqual(3.0, v1.values().eval()) def testClearDevices(self): - export_dir = os.path.join(test.get_temp_dir(), "test_clear_devices") + export_dir = self._get_export_dir("test_clear_devices") builder = saved_model_builder.SavedModelBuilder(export_dir) # Specify a device and save a variable. @@ -871,7 +873,7 @@ class SavedModelTest(test.TestCase): 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) def testStripDefaultAttrs(self): - export_dir = os.path.join(test.get_temp_dir(), "test_strip_default_attrs") + export_dir = self._get_export_dir("test_strip_default_attrs") builder = saved_model_builder.SavedModelBuilder(export_dir) # Add a graph with two float32 variables and a Complex Op composing them @@ -940,59 +942,75 @@ class SavedModelTest(test.TestCase): self.assertIn("T", node_def.attr) self.assertIn("Tout", node_def.attr) - def testStripDefaultAttrsInconsistentConsumerDefaults(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_strip_default_attrs_no_consumer_defaults") + # Tests the behavior of loading SavedModels that having missing attrs or attrs + # with incorrect types. + def testInconsistentConsumerDefaultAttrs(self): + export_dir = self._get_export_dir( + "test_strip_default_attrs_no_consumer_defaults") builder = saved_model_builder.SavedModelBuilder(export_dir) - # Add a graph with two float32 variables and a Complex Op composing them - # with strip_default_attrs enabled. This must remove the following - # defaults for the "Complex" Op: - # o "T" : float32. (input type) - # o "Tout" : complex64. (output type) + # Add a graph with a single variable and a test op with a defaultless + # float32 attr, "test_attr". with session.Session(graph=ops.Graph()) as sess: - real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") - imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") - math_ops.complex(real_num, imag_num, name="complex") + variables.Variable(1.0, dtype=dtypes.float64, name="var") + test_ops.test_attr(T=dtypes.float32, name="test_attr") sess.run(variables.global_variables_initializer()) - builder.add_meta_graph_and_variables( - sess, ["foo"], strip_default_attrs=True) + builder.add_meta_graph_and_variables(sess, ["foo"]) # Save the SavedModel to disk in text format. builder.save(as_text=True) - # Update the Op registry to remove defaults for all attrs("T", "Tout") from - # the "Complex" OpDef. - complex_op_def = op_def_registry.get_registered_ops()["Complex"] - original_complex_op_def = op_def_pb2.OpDef() - original_complex_op_def.CopyFrom(complex_op_def) - for attr_def in complex_op_def.attr: - attr_def.ClearField("default_value") + # Rewrite the SavedModel to remove the T attr from "test_attr". + saved_model_file = os.path.join( + export_dir, constants.SAVED_MODEL_FILENAME_PBTXT) + with open(saved_model_file) as f: + original_saved_model = f.read() + + no_attr_saved_model = original_saved_model.replace(""" + attr { + key: "T" + value { + type: DT_FLOAT + } + }""", "") + with open(saved_model_file, "w") as f: + f.write(no_attr_saved_model) # Loading the SavedModel via the loader must fail because the SavedModel - # does not have any attr values for the "Complex" node and the current - # op registry does not have have any default values for the "Complex" op. + # does not have any attr values for the "TestAttr" node, and there is no + # default specified in the TestAttr OpDef. sess = session.Session(graph=ops.Graph()) - with self.assertRaisesRegexp( - ValueError, - "Expected one attr with name .*T(out)?.* in name: \"complex\".*"): + if ops._USE_C_API: + error_message = "NodeDef missing attr 'T' from Op complex128). - complex_op_def.CopyFrom(original_complex_op_def) - for attr_def in complex_op_def.attr: - if attr_def.name == "Tout": - attr_def.default_value.type = types_pb2.DT_COMPLEX128 - - # Loading the SavedModel via the loader must set "Tout" attr_value for the - # "Complex" node according to the latest defaults (complex128). This is - # expected to fail the model import as there is no OpKernel registered to - # handle attrs "T" (float32) and "Tout" (complex128). + # Rewrite the SavedModel to change the type of the T attr in "test_attr" + bad_type_saved_model = original_saved_model.replace(""" + attr { + key: "T" + value { + type: DT_FLOAT + } + }""", """ + attr { + key: "T" + value { + type: DT_DOUBLE + } + }""") + with open(saved_model_file, "w") as f: + f.write(bad_type_saved_model) + + # Loading the SavedModel via the loader must fail because there is no + # OpKernel registered to handle T = double. sess = session.Session(graph=ops.Graph()) with self.assertRaisesRegexp( errors.InvalidArgumentError, - ".*No OpKernel was registered to support Op \'Complex\' with these " + ".*No OpKernel was registered to support Op \'TestAttr\' with these " "attrs..*"): loader.load(sess, ["foo"], export_dir) diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py index 935a124645bde509a1b5a7751a285a85acbe8cab..6461fe8a7e7bef1a2fc787879da9e3324e2655c8 100644 --- a/tensorflow/python/saved_model/signature_constants.py +++ b/tensorflow/python/saved_model/signature_constants.py @@ -20,51 +20,79 @@ from __future__ import division from __future__ import print_function from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export # Key in the signature def map for `default` serving signatures. The default # signature is used in inference requests where a specific signature was not # specified. DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default" +tf_export("saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY" + ).export_constant(__name__, "DEFAULT_SERVING_SIGNATURE_DEF_KEY") ################################################################################ # Classification API constants. # Classification inputs. CLASSIFY_INPUTS = "inputs" +tf_export("saved_model.signature_constants.CLASSIFY_INPUTS").export_constant( + __name__, "CLASSIFY_INPUTS") # Classification method name used in a SignatureDef. CLASSIFY_METHOD_NAME = "tensorflow/serving/classify" +tf_export( + "saved_model.signature_constants.CLASSIFY_METHOD_NAME").export_constant( + __name__, "CLASSIFY_METHOD_NAME") # Classification classes output. CLASSIFY_OUTPUT_CLASSES = "classes" +tf_export( + "saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES").export_constant( + __name__, "CLASSIFY_OUTPUT_CLASSES") # Classification scores output. CLASSIFY_OUTPUT_SCORES = "scores" +tf_export( + "saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES").export_constant( + __name__, "CLASSIFY_OUTPUT_SCORES") ################################################################################ # Prediction API constants. # Predict inputs. PREDICT_INPUTS = "inputs" +tf_export("saved_model.signature_constants.PREDICT_INPUTS").export_constant( + __name__, "PREDICT_INPUTS") # Prediction method name used in a SignatureDef. PREDICT_METHOD_NAME = "tensorflow/serving/predict" +tf_export( + "saved_model.signature_constants.PREDICT_METHOD_NAME").export_constant( + __name__, "PREDICT_METHOD_NAME") # Predict outputs. PREDICT_OUTPUTS = "outputs" +tf_export("saved_model.signature_constants.PREDICT_OUTPUTS").export_constant( + __name__, "PREDICT_OUTPUTS") ################################################################################ # Regression API constants. # Regression inputs. REGRESS_INPUTS = "inputs" +tf_export("saved_model.signature_constants.REGRESS_INPUTS").export_constant( + __name__, "REGRESS_INPUTS") # Regression method name used in a SignatureDef. REGRESS_METHOD_NAME = "tensorflow/serving/regress" +tf_export( + "saved_model.signature_constants.REGRESS_METHOD_NAME").export_constant( + __name__, "REGRESS_METHOD_NAME") # Regression outputs. REGRESS_OUTPUTS = "outputs" +tf_export("saved_model.signature_constants.REGRESS_OUTPUTS").export_constant( + __name__, "REGRESS_OUTPUTS") ################################################################################ diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index 240ea61aa5f8553852044f84b61d010bfbca69d1..d0331591889110df86bdb2ac69c037bc3b968f91 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -26,8 +26,10 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import utils +from tensorflow.python.util.tf_export import tf_export +@tf_export('saved_model.signature_def_utils.build_signature_def') def build_signature_def(inputs=None, outputs=None, method_name=None): """Utility function to build a SignatureDef protocol buffer. @@ -53,6 +55,7 @@ def build_signature_def(inputs=None, outputs=None, method_name=None): return signature_def +@tf_export('saved_model.signature_def_utils.regression_signature_def') def regression_signature_def(examples, predictions): """Creates regression signature from given examples and predictions. @@ -94,6 +97,7 @@ def regression_signature_def(examples, predictions): return signature_def +@tf_export('saved_model.signature_def_utils.classification_signature_def') def classification_signature_def(examples, classes, scores): """Creates classification signature from given examples and predictions. @@ -146,6 +150,7 @@ def classification_signature_def(examples, classes, scores): return signature_def +@tf_export('saved_model.signature_def_utils.predict_signature_def') def predict_signature_def(inputs, outputs): """Creates prediction signature from given inputs and outputs. @@ -180,6 +185,7 @@ def predict_signature_def(inputs, outputs): return signature_def +@tf_export('saved_model.signature_def_utils.is_valid_signature') def is_valid_signature(signature_def): """Determine whether a SignatureDef can be served by TensorFlow Serving.""" if signature_def is None: diff --git a/tensorflow/python/saved_model/simple_save.py b/tensorflow/python/saved_model/simple_save.py index 9a81e5cd80705482865e05b040d712418a993da1..042b8fa8e22703d8ffb5e12de3f844d22fb1b1ce 100644 --- a/tensorflow/python/saved_model/simple_save.py +++ b/tensorflow/python/saved_model/simple_save.py @@ -23,8 +23,10 @@ from tensorflow.python.saved_model import builder from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants +from tensorflow.python.util.tf_export import tf_export +@tf_export('saved_model.simple_save') def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None): """Convenience function to build a SavedModel suitable for serving. @@ -40,17 +42,20 @@ def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None): - It will be treated as a graph for inference / serving (i.e. uses the tag `tag_constants.SERVING`) - The SavedModel will load in TensorFlow Serving and supports the - [Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto). + [Predict + API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto). To use the Classify, Regress, or MultiInference APIs, please use either [tf.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) or the lower level - [SavedModel APIs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). + [SavedModel + APIs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). - Some TensorFlow ops depend on information on disk or other information called "assets". These are generally handled automatically by adding the assets to the `GraphKeys.ASSET_FILEPATHS` collection. Only assets in that collection are exported; if you need more custom behavior, you'll need to - use the [SavedModelBuilder](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/builder.py). + use the + [SavedModelBuilder](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/builder.py). More information about SavedModel and signatures can be found here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md. diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py index e2facafda51919d3f1e0ccbe646db522ed0bc49b..d164e2c23f24469d7536f87cb431afe618ddcc06 100644 --- a/tensorflow/python/saved_model/tag_constants.py +++ b/tensorflow/python/saved_model/tag_constants.py @@ -20,19 +20,26 @@ from __future__ import division from __future__ import print_function from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export # Tag for the `serving` graph. SERVING = "serve" +tf_export("saved_model.tag_constants.SERVING").export_constant( + __name__, "SERVING") # Tag for the `training` graph. TRAINING = "train" +tf_export("saved_model.tag_constants.TRAINING").export_constant( + __name__, "TRAINING") # Tag for the `gpu` graph. GPU = "gpu" +tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU") # Tag for the `tpu` graph. TPU = "tpu" +tf_export("saved_model.tag_constants.TPU").export_constant(__name__, "TPU") _allowed_symbols = [ "SERVING", diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py index 73ca8c9c1c6d8fddc8a9c7dbee56682999281c28..cddce29a08a6c4c79a4c7c5dbfb48a86131530b2 100644 --- a/tensorflow/python/saved_model/utils_impl.py +++ b/tensorflow/python/saved_model/utils_impl.py @@ -22,11 +22,13 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.util.tf_export import tf_export # TensorInfo helpers. +@tf_export("saved_model.utils.build_tensor_info") def build_tensor_info(tensor): """Utility function to build TensorInfo proto. @@ -50,6 +52,7 @@ def build_tensor_info(tensor): return tensor_info +@tf_export("saved_model.utils.get_tensor_from_tensor_info") def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): """Returns the Tensor or SparseTensor described by a TensorInfo proto. diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index 92c1fcadd29c7858da1d31375c209bf1b21f3103..b80ad79074e85bdeae70148b2822c319c29468bc 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -72,8 +72,10 @@ from tensorflow.python.summary.writer.writer_cache import FileWriterCache from tensorflow.python.util import compat as _compat from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export +@tf_export('summary.scalar') def scalar(name, tensor, collections=None, family=None): """Outputs a `Summary` protocol buffer containing a single scalar value. @@ -102,6 +104,7 @@ def scalar(name, tensor, collections=None, family=None): return val +@tf_export('summary.image') def image(name, tensor, max_outputs=3, collections=None, family=None): """Outputs a `Summary` protocol buffer with images. @@ -156,6 +159,7 @@ def image(name, tensor, max_outputs=3, collections=None, family=None): return val +@tf_export('summary.histogram') def histogram(name, values, collections=None, family=None): # pylint: disable=line-too-long """Outputs a `Summary` protocol buffer with a histogram. @@ -195,6 +199,7 @@ def histogram(name, values, collections=None, family=None): return val +@tf_export('summary.audio') def audio(name, tensor, sample_rate, max_outputs=3, collections=None, family=None): # pylint: disable=line-too-long @@ -242,6 +247,7 @@ def audio(name, tensor, sample_rate, max_outputs=3, collections=None, return val +@tf_export('summary.merge') def merge(inputs, collections=None, name=None): # pylint: disable=line-too-long """Merges summaries. @@ -286,6 +292,7 @@ def merge(inputs, collections=None, name=None): return val +@tf_export('summary.merge_all') def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None): """Merges all summaries collected in the default graph. @@ -318,6 +325,7 @@ def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None): return merge(summary_ops) +@tf_export('summary.get_summary_description') def get_summary_description(node_def): """Given a TensorSummary node_def, retrieve its SummaryDescription. diff --git a/tensorflow/python/summary/summary_iterator.py b/tensorflow/python/summary/summary_iterator.py index 6969c4cf1500bf4b1fda900336158e5af4395ea6..321b11ffb73487405428340df94010ed8ddbfcd4 100644 --- a/tensorflow/python/summary/summary_iterator.py +++ b/tensorflow/python/summary/summary_iterator.py @@ -21,8 +21,10 @@ from __future__ import print_function from tensorflow.core.util import event_pb2 from tensorflow.python.lib.io import tf_record +from tensorflow.python.util.tf_export import tf_export +@tf_export('train.summary_iterator') def summary_iterator(path): # pylint: disable=line-too-long """An iterator for reading `Event` protocol buffers from an event file. diff --git a/tensorflow/python/summary/text_summary.py b/tensorflow/python/summary/text_summary.py index 94a85d73e2f77388f9a29b1c135fc6046a8362d0..6418c847f3c819cf2491bb449921d15c39eae288 100644 --- a/tensorflow/python/summary/text_summary.py +++ b/tensorflow/python/summary/text_summary.py @@ -26,10 +26,12 @@ from __future__ import print_function from tensorflow.core.framework import summary_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.ops.summary_ops import tensor_summary +from tensorflow.python.util.tf_export import tf_export PLUGIN_NAME = "text" +@tf_export("summary.text") def text_summary(name, tensor, collections=None): """Summarizes textual data. diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index 12f120116f4439059f42c7212469ee835cc13ef4..1f3f2287043c021d636113b5a8807c9f4adf77aa 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -32,6 +32,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import plugin_asset from tensorflow.python.summary.writer.event_file_writer import EventFileWriter +from tensorflow.python.util.tf_export import tf_export _PLUGINS_DIR = "plugins" @@ -276,6 +277,7 @@ class SummaryToEventTransformer(object): self.event_writer.add_event(event) +@tf_export("summary.FileWriter") class FileWriter(SummaryToEventTransformer): """Writes `Summary` protocol buffers to event files. diff --git a/tensorflow/python/summary/writer/writer_cache.py b/tensorflow/python/summary/writer/writer_cache.py index bad289303c0fd0de7836b03a6762d04505521a89..645fa28a37fb125b6b1224961251bc8879d5fe6d 100644 --- a/tensorflow/python/summary/writer/writer_cache.py +++ b/tensorflow/python/summary/writer/writer_cache.py @@ -22,8 +22,10 @@ import threading from tensorflow.python.framework import ops from tensorflow.python.summary.writer.writer import FileWriter +from tensorflow.python.util.tf_export import tf_export +@tf_export('summary.FileWriterCache') class FileWriterCache(object): """Cache for file writers. diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index a2e86a1c43a9c27041d963b2b8d7af582e1054c7..f7a578e52b00b6d97b3b314c7a2a08d9071c8f73 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -101,15 +101,15 @@ def freeze_graph_with_def_protos(input_graph_def, _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: - saver = saver_lib.Saver(saver_def=input_saver_def, - write_version=checkpoint_version) + saver = saver_lib.Saver( + saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: - sess.run(initializer_nodes.split(",")) + sess.run(initializer_nodes.replace(' ', '').split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] @@ -126,29 +126,31 @@ def freeze_graph_with_def_protos(input_graph_def, # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor - saver = saver_lib.Saver(var_list=var_list, - write_version=checkpoint_version) + saver = saver_lib.Saver( + var_list=var_list, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: - sess.run(initializer_nodes.split(",")) + sess.run(initializer_nodes.replace(' ', '').split(",")) - variable_names_whitelist = (variable_names_whitelist.split(",") - if variable_names_whitelist else None) - variable_names_blacklist = (variable_names_blacklist.split(",") - if variable_names_blacklist else None) + variable_names_whitelist = ( + variable_names_whitelist.replace(' ', '').split(",") + if variable_names_whitelist else None) + variable_names_blacklist = ( + variable_names_blacklist.replace(' ', '').split(",") + if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, - output_node_names.split(","), + output_node_names.replace(' ', '').split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, - output_node_names.split(","), + output_node_names.replace(' ', '').split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) @@ -237,21 +239,39 @@ def freeze_graph(input_graph, if input_saver: input_saver_def = _parse_input_saver_proto(input_saver, input_binary) freeze_graph_with_def_protos( - input_graph_def, input_saver_def, input_checkpoint, output_node_names, - restore_op_name, filename_tensor_name, output_graph, clear_devices, - initializer_nodes, variable_names_whitelist, variable_names_blacklist, - input_meta_graph_def, input_saved_model_dir, - saved_model_tags.split(","), checkpoint_version=checkpoint_version) + input_graph_def, + input_saver_def, + input_checkpoint, + output_node_names, + restore_op_name, + filename_tensor_name, + output_graph, + clear_devices, + initializer_nodes, + variable_names_whitelist, + variable_names_blacklist, + input_meta_graph_def, + input_saved_model_dir, + saved_model_tags.replace(' ', '').split(","), + checkpoint_version=checkpoint_version) def main(unused_args): + if FLAGS.checkpoint_version == 1: + checkpoint_version = saver_pb2.SaverDef.V1 + elif FLAGS.checkpoint_version == 2: + checkpoint_version = saver_pb2.SaverDef.V2 + else: + print("Invalid checkpoint version (must be '1' or '2'): %d" % + FLAGS.checkpoint_version) + return -1 freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, FLAGS.input_checkpoint, FLAGS.output_node_names, FLAGS.restore_op_name, FLAGS.filename_tensor_name, FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes, FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist, FLAGS.input_meta_graph, FLAGS.input_saved_model_dir, - FLAGS.saved_model_tags, checkpoint_version=checkpoint_version) + FLAGS.saved_model_tags, checkpoint_version) if __name__ == "__main__": @@ -275,7 +295,7 @@ if __name__ == "__main__": parser.add_argument( "--checkpoint_version", type=int, - default=saver_pb2.SaverDef.V2, + default=2, help="Tensorflow variable file format") parser.add_argument( "--output_graph", diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py index 342732465d48f40a4ffeac97146fb1b6d564c568..91f0061ebccaebbdbb09f283d9d52d813459f493 100644 --- a/tensorflow/python/tools/freeze_graph_test.py +++ b/tensorflow/python/tools/freeze_graph_test.py @@ -84,9 +84,18 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): input_meta_graph = checkpoint_meta_graph_file freeze_graph.freeze_graph( - input_graph_path, input_saver_def_path, input_binary, checkpoint_path, - output_node_names, restore_op_name, filename_tensor_name, - output_graph_path, clear_devices, "", "", input_meta_graph, + input_graph_path, + input_saver_def_path, + input_binary, + checkpoint_path, + output_node_names, + restore_op_name, + filename_tensor_name, + output_graph_path, + clear_devices, + "", + "", + input_meta_graph, checkpoint_version=saver_write_version) # Now we make sure the variable is now a constant, and that the graph still diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py index 8716058e619d8e970834ec4d57e4d8ff21559d5c..dd876cbe7fcd64a8de70eb28f67996df9de1dd7d 100644 --- a/tensorflow/python/tools/inspect_checkpoint.py +++ b/tensorflow/python/tools/inspect_checkpoint.py @@ -97,8 +97,9 @@ def parse_numpy_printoption(kv_str): raise argparse.ArgumentTypeError( "Setting '%s' from the command line is not supported." % k) try: - v = (v_type(v_str) if v_type is not bool - else flags.BooleanParser().parse(v_str)) + v = ( + v_type(v_str) + if v_type is not bool else flags.BooleanParser().parse(v_str)) except ValueError as e: raise argparse.ArgumentTypeError(e.message) np.set_printoptions(**{k: v}) @@ -121,9 +122,12 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument( - "--file_name", type=str, default="", help="Checkpoint filename. " - "Note, if using Checkpoint V2 format, file_name is the " - "shared prefix between all files in the checkpoint.") + "--file_name", + type=str, + default="", + help="Checkpoint filename. " + "Note, if using Checkpoint V2 format, file_name is the " + "shared prefix between all files in the checkpoint.") parser.add_argument( "--tensor_name", type=str, diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py index c2687bf557b03ff588fd369771077c92ba012a15..9c1927122252f45ddfa8092045c7589fa0f45532 100644 --- a/tensorflow/python/tools/optimize_for_inference_lib.py +++ b/tensorflow/python/tools/optimize_for_inference_lib.py @@ -349,6 +349,7 @@ def fold_batch_norms(input_graph_def): bias_add_op.op = "BiasAdd" bias_add_op.name = node.name bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"]) + bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"]) bias_add_op.input.extend([new_conv_op.name, offset_op.name]) new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op]) diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py index 6dd24c0dca1d326592e4f33eba4e6233248dac5f..084a4500f8e1eb7f75e1e01668fae655b5e06763 100644 --- a/tensorflow/python/tools/optimize_for_inference_test.py +++ b/tensorflow/python/tools/optimize_for_inference_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import image_ops @@ -38,6 +39,7 @@ from tensorflow.python.platform import test from tensorflow.python.tools import optimize_for_inference_lib +@test_util.with_c_api class OptimizeForInferenceTest(test.TestCase): def create_node_def(self, op, name, inputs): @@ -145,7 +147,7 @@ class OptimizeForInferenceTest(test.TestCase): np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32) gamma_op = constant_op.constant( np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32) - ops.get_default_graph().graph_def_versions.producer = 8 + test_util.set_producer_version(ops.get_default_graph(), 8) gen_nn_ops._batch_norm_with_global_normalization( conv_op, mean_op, @@ -171,48 +173,56 @@ class OptimizeForInferenceTest(test.TestCase): self.assertNotEqual("BatchNormWithGlobalNormalization", node.op) def testFoldFusedBatchNorms(self): - with self.test_session() as sess: - inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6] - input_op = constant_op.constant( - np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32) - weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4] - weights_op = constant_op.constant( - np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32) - conv_op = nn_ops.conv2d( - input_op, weights_op, [1, 1, 1, 1], padding="SAME", name="conv_op") - mean_op = constant_op.constant( - np.array([10, 20]), shape=[2], dtype=dtypes.float32) - variance_op = constant_op.constant( - np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32) - beta_op = constant_op.constant( - np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32) - gamma_op = constant_op.constant( - np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32) - ops.get_default_graph().graph_def_versions.producer = 9 - gen_nn_ops._fused_batch_norm( - conv_op, - gamma_op, - beta_op, - mean_op, - variance_op, - 0.00001, - is_training=False, - name="output") - original_graph_def = sess.graph_def - original_result = sess.run(["output:0"]) - optimized_graph_def = optimize_for_inference_lib.fold_batch_norms( - original_graph_def) - - with self.test_session() as sess: - _ = importer.import_graph_def( - optimized_graph_def, input_map={}, name="optimized") - optimized_result = sess.run(["optimized/output:0"]) - - self.assertAllClose( - original_result, optimized_result, rtol=1e-04, atol=1e-06) - - for node in optimized_graph_def.node: - self.assertNotEqual("FusedBatchNorm", node.op) + for data_format, use_gpu in [("NHWC", False), ("NCHW", True)]: + with self.test_session(use_gpu=use_gpu) as sess: + inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6] + input_op = constant_op.constant( + np.array(inputs), + shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6], + dtype=dtypes.float32) + weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4] + weights_op = constant_op.constant( + np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32) + conv_op = nn_ops.conv2d( + input_op, + weights_op, [1, 1, 1, 1], + padding="SAME", + data_format=data_format, + name="conv_op") + mean_op = constant_op.constant( + np.array([10, 20]), shape=[2], dtype=dtypes.float32) + variance_op = constant_op.constant( + np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32) + beta_op = constant_op.constant( + np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32) + gamma_op = constant_op.constant( + np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32) + ops.get_default_graph().graph_def_versions.producer = 9 + gen_nn_ops._fused_batch_norm( + conv_op, + gamma_op, + beta_op, + mean_op, + variance_op, + 0.00001, + is_training=False, + data_format=data_format, + name="output") + original_graph_def = sess.graph_def + original_result = sess.run(["output:0"]) + optimized_graph_def = optimize_for_inference_lib.fold_batch_norms( + original_graph_def) + + with self.test_session(use_gpu=use_gpu) as sess: + _ = importer.import_graph_def( + optimized_graph_def, input_map={}, name="optimized") + optimized_result = sess.run(["optimized/output:0"]) + + self.assertAllClose( + original_result, optimized_result, rtol=1e-04, atol=1e-06) + + for node in optimized_graph_def.node: + self.assertNotEqual("FusedBatchNorm", node.op) def testFuseResizePadAndConv(self): with self.test_session() as sess: diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 21e8e803fcb3d12a2e41b5f9e2810742ec220be8..33f6debbcbecb652774c776be54323bbaa824822 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -31,6 +31,7 @@ import warnings import numpy as np +from six import integer_types from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.example import example_pb2 @@ -38,7 +39,7 @@ from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.framework import ops as ops_lib -from tensorflow.python.platform import app +from tensorflow.python.platform import app # pylint: disable=unused-import from tensorflow.python.saved_model import loader from tensorflow.python.tools import saved_model_utils @@ -440,7 +441,7 @@ def _create_example_string(example_dict): elif isinstance(feature_list[0], str): example.features.feature[feature_name].bytes_list.value.extend( feature_list) - elif isinstance(feature_list[0], (int, long)): + elif isinstance(feature_list[0], integer_types): example.features.feature[feature_name].int64_list.value.extend( feature_list) else: diff --git a/tensorflow/python/training/adadelta.py b/tensorflow/python/training/adadelta.py index 13c07cfd7bf4333fee3edc3c3ad9d2fb7bcbaad2..c08e3cca007dc17f1112d53bf729c1accf61b5df 100644 --- a/tensorflow/python/training/adadelta.py +++ b/tensorflow/python/training/adadelta.py @@ -22,8 +22,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.AdadeltaOptimizer") class AdadeltaOptimizer(optimizer.Optimizer): """Optimizer that implements the Adadelta algorithm. diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py index afa192f7cc6e0ecd629fd94252d26961f1407183..deb4e6f546379eff330235dbc302a30c44193830 100644 --- a/tensorflow/python/training/adagrad.py +++ b/tensorflow/python/training/adagrad.py @@ -25,8 +25,10 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.AdagradOptimizer") class AdagradOptimizer(optimizer.Optimizer): """Optimizer that implements the Adagrad algorithm. diff --git a/tensorflow/python/training/adagrad_da.py b/tensorflow/python/training/adagrad_da.py index b3f9ea323c2bb4fd9ecee93863fbc7955b47a947..5ba403554f570d9df33a5d525a40de2eb0d11138 100644 --- a/tensorflow/python/training/adagrad_da.py +++ b/tensorflow/python/training/adagrad_da.py @@ -23,8 +23,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.AdagradDAOptimizer") class AdagradDAOptimizer(optimizer.Optimizer): """Adagrad Dual Averaging algorithm for sparse linear models. diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py index 0c69f8bf3997452f0eeb71c93f4fcf98eb27d8f9..c92f6fc3015960a2b821651231bb94713e0d53dd 100644 --- a/tensorflow/python/training/adam.py +++ b/tensorflow/python/training/adam.py @@ -26,8 +26,10 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.AdamOptimizer") class AdamOptimizer(optimizer.Optimizer): """Optimizer that implements the Adam algorithm. diff --git a/tensorflow/python/training/basic_loops.py b/tensorflow/python/training/basic_loops.py index 52b0f4210612bad4a2e838153ac9cbdb1023bf66..7af821c81928e67e0f258bc064d582a4186995c1 100644 --- a/tensorflow/python/training/basic_loops.py +++ b/tensorflow/python/training/basic_loops.py @@ -18,8 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import errors +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.basic_train_loop") def basic_train_loop(supervisor, train_step_fn, args=None, kwargs=None, master=""): """Basic loop to train a model. diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index 752d585cd17e1b1a89abbae7c9e61fa966ad7f93..aae757b99aa9abb2fca112dcc781fc31e367649d 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -47,6 +47,7 @@ from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.training.session_run_hook import SessionRunArgs from tensorflow.python.training.summary_io import SummaryWriterCache +from tensorflow.python.util.tf_export import tf_export class _HookTimer(object): @@ -85,6 +86,7 @@ class _HookTimer(object): raise NotImplementedError +@tf_export("train.SecondOrStepTimer") class SecondOrStepTimer(_HookTimer): """Timer that triggers at most once every N seconds or once every N steps. """ @@ -164,6 +166,7 @@ class NeverTriggerTimer(_HookTimer): return None +@tf_export("train.LoggingTensorHook") class LoggingTensorHook(session_run_hook.SessionRunHook): """Prints the given tensors every N local steps, every N seconds, or at end. @@ -262,6 +265,7 @@ class LoggingTensorHook(session_run_hook.SessionRunHook): self._log_tensors(values) +@tf_export("train.StopAtStepHook") class StopAtStepHook(session_run_hook.SessionRunHook): """Hook that requests stop at a specified step.""" @@ -317,6 +321,7 @@ class StopAtStepHook(session_run_hook.SessionRunHook): run_context.request_stop() +@tf_export("train.CheckpointSaverListener") class CheckpointSaverListener(object): """Interface for listeners that take action before or after checkpoint save. @@ -331,7 +336,7 @@ class CheckpointSaverListener(object): `CheckpointSaverHook`, as in this example: ```python - class ExampleCheckpointSaverListerner(CheckpointSaverListener): + class ExampleCheckpointSaverListener(CheckpointSaverListener): def begin(self): # You can add ops to the graph here. print('Starting the session.') @@ -347,7 +352,7 @@ class CheckpointSaverListener(object): print('Done with the session.') ... - listener = ExampleCheckpointSaverListerner() + listener = ExampleCheckpointSaverListener() saver_hook = tf.train.CheckpointSaverHook( checkpoint_dir, listeners=[listener]) with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): @@ -375,6 +380,7 @@ class CheckpointSaverListener(object): pass +@tf_export("train.CheckpointSaverHook") class CheckpointSaverHook(session_run_hook.SessionRunHook): """Saves checkpoints every N steps or seconds.""" @@ -497,6 +503,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): return savers[0] +@tf_export("train.StepCounterHook") class StepCounterHook(session_run_hook.SessionRunHook): """Hook that counts steps per second.""" @@ -575,12 +582,14 @@ class StepCounterHook(session_run_hook.SessionRunHook): self._last_global_step = stale_global_step +@tf_export("train.NanLossDuringTrainingError") class NanLossDuringTrainingError(RuntimeError): def __str__(self): return "NaN loss during training." +@tf_export("train.NanTensorHook") class NanTensorHook(session_run_hook.SessionRunHook): """Monitors the loss tensor and stops training if loss is NaN. @@ -612,6 +621,7 @@ class NanTensorHook(session_run_hook.SessionRunHook): run_context.request_stop() +@tf_export("train.SummarySaverHook") class SummarySaverHook(session_run_hook.SessionRunHook): """Saves summaries every N steps.""" @@ -720,6 +730,7 @@ class SummarySaverHook(session_run_hook.SessionRunHook): return summary_op +@tf_export("train.GlobalStepWaiterHook") class GlobalStepWaiterHook(session_run_hook.SessionRunHook): """Delays execution until global step reaches `wait_until_step`. @@ -767,6 +778,7 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook): time.sleep(0.5) +@tf_export("train.FinalOpsHook") class FinalOpsHook(session_run_hook.SessionRunHook): """A hook which evaluates `Tensors` at the end of a session.""" @@ -793,6 +805,7 @@ class FinalOpsHook(session_run_hook.SessionRunHook): feed_dict=self._final_ops_feed_dict) +@tf_export("train.FeedFnHook") class FeedFnHook(session_run_hook.SessionRunHook): """Runs `feed_fn` and sets the `feed_dict` accordingly.""" @@ -810,6 +823,7 @@ class FeedFnHook(session_run_hook.SessionRunHook): fetches=None, feed_dict=self.feed_fn()) +@tf_export("train.ProfilerHook") class ProfilerHook(session_run_hook.SessionRunHook): """Captures CPU/GPU profiling information every N steps or seconds. diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index 5054873bc1c7751e6164a868b91b8ef7be0a5c79..fa3de6fad27b6cc773f9f2e86e9f95395eb7c285 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver +from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -36,6 +37,7 @@ __all__ = [ ] +@tf_export("train.load_checkpoint") def load_checkpoint(ckpt_dir_or_file): """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`. @@ -60,6 +62,7 @@ def load_checkpoint(ckpt_dir_or_file): return pywrap_tensorflow.NewCheckpointReader(filename) +@tf_export("train.load_variable") def load_variable(ckpt_dir_or_file, name): """Returns the tensor value of the given variable in the checkpoint. @@ -77,6 +80,7 @@ def load_variable(ckpt_dir_or_file, name): return reader.get_tensor(name) +@tf_export("train.list_variables") def list_variables(ckpt_dir_or_file): """Returns list of all variables in the checkpoint. @@ -95,6 +99,7 @@ def list_variables(ckpt_dir_or_file): return result +@tf_export("train.init_from_checkpoint") def init_from_checkpoint(ckpt_dir_or_file, assignment_map): """Initializes current variables with tensors loaded from given checkpoint. @@ -176,7 +181,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) reader = load_checkpoint(ckpt_dir_or_file) variable_map = reader.get_variable_to_shape_map() - for tensor_name_in_ckpt, current_var_or_name in six.iteritems(assignment_map): + for tensor_name_in_ckpt, current_var_or_name in sorted( + six.iteritems(assignment_map)): var = None # Check if this is Variable object or list of Variable objects (in case of # partitioned variables). @@ -233,7 +239,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): if "/part_" in var_name: var_name = var_name[:var_name.index("/part_")] scope_variables.add(var_name) - for var_name in scope_variables: + for var_name in sorted(scope_variables): # Lookup name with specified prefix and suffix from current variable. # If tensor_name given is '/' (root), don't use it for full name. full_tensor_name = var_name[len(scopes):] @@ -241,6 +247,9 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): full_tensor_name = full_tensor_name[1:] if tensor_name_in_ckpt != "/": full_tensor_name = tensor_name_in_ckpt + full_tensor_name + # Remove trailing '/', if any, in the full_tensor_name + if full_tensor_name.endswith("/"): + full_tensor_name = full_tensor_name[:-1] if full_tensor_name not in variable_map: raise ValueError( "Tensor %s (%s in %s) is not found in %s checkpoint" % ( diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py new file mode 100644 index 0000000000000000000000000000000000000000..9d62c5ff91449189eeed8d017c21134a1230c3c8 --- /dev/null +++ b/tensorflow/python/training/checkpointable.py @@ -0,0 +1,582 @@ +"""An object-local variable management scheme.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import weakref + +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_io_ops as io_ops +from tensorflow.python.util import nest + +# A key indicating a variable's value in an object's checkpointed Tensors +# (Checkpointable._gather_tensors_for_checkpoint). If this is the only key and +# the object has no dependencies, then its value may be restored on object +# creation (avoiding double assignment when executing eagerly). +VARIABLE_VALUE_KEY = "VARIABLE_VALUE" + +_CheckpointableReference = collections.namedtuple( + "_CheckpointableReference", + [ + # The local name for this dependency. + "name", + # The Checkpointable object being referenced. + "ref" + ]) + + +class CheckpointInitialValue(ops.Tensor): + """Tensor wrapper for managing update UIDs in `Variables`. + + When supplied as an initial value, objects of this type let a `Variable` + (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial + value came from. This allows deferred restorations to be sequenced in the + order the user specified them, and lets us fall back on assignment if an + initial value is not set (e.g. due to a custom getter interfering). + + See comments in _add_variable_with_custom_getter for more information about + how `CheckpointInitialValue` is used. + """ + + def __init__(self, checkpoint_position, shape=None): + self.wrapped_value = checkpoint_position.restore_ops()[ + VARIABLE_VALUE_KEY] + if shape: + # We need to set the static shape information on the initializer if + # possible so we don't get a variable with an unknown shape. + self.wrapped_value.set_shape(shape) + self._checkpoint_position = checkpoint_position + + @property + def __class__(self): + return (self.wrapped_value.__class__, CheckpointInitialValue) + + def __getattr__(self, attr): + try: + return getattr(self.wrapped_value, attr) + except AttributeError: + return self.__getattribute__(attr) + + @property + def checkpoint_position(self): + return self._checkpoint_position + + +class _CheckpointPosition(object): + """Indicates a position within a `_Checkpoint`.""" + + def __init__(self, checkpoint, proto_id): + """Specify an object within a checkpoint. + + Args: + checkpoint: A _Checkpoint object. + proto_id: The index of this object in CheckpointableObjectGraph.nodes. + """ + self._checkpoint = checkpoint + self._proto_id = proto_id + + def restore(self, checkpointable): + """Restore this value into `checkpointable`.""" + if self.bind_object(checkpointable): + # This object's correspondence with a checkpointed object is new, so + # process deferred restorations for it and its dependencies. + restore_ops = checkpointable._restore_from_checkpoint_position(self) # pylint: disable=protected-access + if restore_ops: + self._checkpoint.restore_ops.extend(restore_ops) + + def bind_object(self, checkpointable): + """Set a checkpoint<->object correspondence and process slot variables. + + Args: + checkpointable: The object to record a correspondence for. + Returns: + True if this is a new assignment, False if this object has already been + mapped to a checkpointed `Object` proto. + Raises: + AssertionError: If another object is already bound to the `Object` proto. + """ + checkpoint = self.checkpoint + current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None) + if current_assignment is None: + checkpoint.object_by_proto_id[self._proto_id] = checkpointable + for deferred_slot_restoration in ( + checkpoint.deferred_slot_restorations.pop(self._proto_id, ())): + checkpointable._create_or_restore_slot_variable( # pylint: disable=protected-access + slot_variable_position=_CheckpointPosition( + checkpoint=checkpoint, + proto_id=deferred_slot_restoration.slot_variable_id), + variable=deferred_slot_restoration.original_variable, + slot_name=deferred_slot_restoration.slot_name) + for slot_restoration in checkpoint.slot_restorations.pop( + self._proto_id, ()): + optimizer_object = checkpoint.object_by_proto_id.get( + slot_restoration.optimizer_id, None) + if optimizer_object is None: + # The optimizer has not yet been created or tracked. Record in the + # checkpoint that the slot variables need to be restored when it is. + checkpoint.deferred_slot_restorations.setdefault( + slot_restoration.optimizer_id, []).append( + _DeferredSlotVariableRestoration( + original_variable=checkpointable, + slot_variable_id=slot_restoration.slot_variable_id, + slot_name=slot_restoration.slot_name)) + else: + optimizer_object._create_or_restore_slot_variable( # pylint: disable=protected-access + slot_variable_position=_CheckpointPosition( + checkpoint=checkpoint, + proto_id=slot_restoration.slot_variable_id), + variable=checkpointable, + slot_name=slot_restoration.slot_name) + return True # New assignment + else: + # The object was already mapped for this checkpoint load, which means + # we don't need to do anything besides check that the mapping is + # consistent (if the dependency DAG is not a tree then there are + # multiple paths to the same object). + if current_assignment is not checkpointable: + raise AssertionError( + ("Unable to load the checkpoint into this object graph. Either " + "the Checkpointable object references in the Python program " + "have changed in an incompatible way, or the checkpoint was " + "generated in an incompatible program.\n\nTwo checkpoint " + "references resolved to different objects (%s and %s).") + % (current_assignment, checkpointable)) + return False # Not a new assignment + + def is_simple_variable(self): + """Determine whether this value is restorable with a Tensor initializer.""" + attributes = self.object_proto.attributes + return (len(attributes) == 1 + and attributes[0].name == VARIABLE_VALUE_KEY + and not self.object_proto.children) + + def restore_ops(self): + """Create restore ops for this object's attributes.""" + restore_tensors = {} + for serialized_tensor in self.object_proto.attributes: + checkpoint_key = serialized_tensor.checkpoint_key + dtype = self._checkpoint.dtype_map[checkpoint_key] + base_type = dtype.base_dtype + with ops.init_scope(): + restore, = io_ops.restore_v2( + prefix=self._checkpoint.save_path, + tensor_names=[checkpoint_key], + shape_and_slices=[""], + dtypes=[base_type], + name="%s_checkpoint_read" % (serialized_tensor.name,)) + restore_tensors[serialized_tensor.name] = restore + return restore_tensors + + @property + def checkpoint(self): + return self._checkpoint + + @property + def checkpointable(self): + return self._checkpoint.object_by_proto_id[self._proto_id] + + @property + def object_proto(self): + return self._checkpoint.object_graph_proto.nodes[self._proto_id] + + @property + def restore_uid(self): + return self._checkpoint.restore_uid + + def __repr__(self): + return repr(self.object_proto) + + +_DeferredSlotVariableRestoration = collections.namedtuple( + "_DeferredSlotVariableRestoration", + [ + "original_variable", + "slot_variable_id", + "slot_name", + ] +) + +_SlotVariableRestoration = collections.namedtuple( + "_SlotVariableRestoration", + [ + # The checkpoint proto id of the optimizer object. + "optimizer_id", + # The checkpoint proto id of the slot variable. + "slot_variable_id", + "slot_name", + ]) + + +class _Checkpoint(object): + """Holds the status of an object-based checkpoint load.""" + + def __init__(self, object_graph_proto, save_path): + """Specify the checkpoint being loaded. + + Args: + object_graph_proto: The CheckpointableObjectGraph protocol buffer + associated with this checkpoint. + save_path: The path to the checkpoint, as returned by + `tf.train.latest_checkpoint`. + """ + self.object_graph_proto = object_graph_proto + self.restore_uid = ops.uid() + # Dictionary mapping from an id in the protocol buffer flat array to + # Checkpointable Python objects. This mapping may be deferred if a + # checkpoint is restored before all dependencies have been tracked. Uses + # weak references so that partial restorations don't create reference cycles + # (as objects with deferred dependencies will generally have references to + # this object). + self.object_by_proto_id = weakref.WeakValueDictionary() + self.save_path = save_path + reader = pywrap_tensorflow.NewCheckpointReader(save_path) + self.dtype_map = reader.get_variable_to_dtype_map() + # When graph building, contains a list of ops to run to restore objects from + # this checkpoint. + self.restore_ops = [] + # A mapping from optimizer proto ids to lists of slot variables to be + # restored when the optimizer is tracked. Only includes slot variables whose + # regular variables have already been created, and only for optimizer + # objects which have not yet been created/tracked. + self.deferred_slot_restorations = {} + # A mapping from variable proto ids to lists of slot variables to be + # restored when the variable is created/tracked. These get shifted over to + # deferred_slot_restorations if the optimizer hasn't been created when that + # happens. + self.slot_restorations = {} + for node_index, node in enumerate(self.object_graph_proto.nodes): + for slot_reference in node.slot_variables: + # `node` refers to an `Optimizer`, since only these have slot variables. + self.slot_restorations.setdefault( + slot_reference.original_variable_node_id, []).append( + _SlotVariableRestoration( + optimizer_id=node_index, + slot_variable_id=slot_reference.slot_variable_node_id, + slot_name=slot_reference.slot_name)) + + +class CheckpointableBase(object): + """Base class for `Checkpointable` objects without automatic dependencies. + + This class has no __setattr__ override for performance reasons. Dependencies + must be added explicitly. Unless attribute assignment is performance-critical, + use `Checkpointable` instead. Use `CheckpointableBase` for `isinstance` + checks. + """ + + def _maybe_initialize_checkpointable(self): + """Initialize dependency management. + + Not __init__, since most objects will forget to call it. + """ + if hasattr(self, "_checkpoint_dependencies"): + # __init__ already called. This check means that we don't need + # Checkpointable.__init__() in the constructor of every TensorFlow object. + return + # A list of _CheckpointableReference objects. + self._checkpoint_dependencies = [] + # Maps names -> Checkpointable objects + self._dependency_names = {} + # Restorations for other Checkpointable objects on which this object may + # eventually depend. + self._deferred_dependencies = {} # local name -> _CheckpointPosition list + # The UID of the highest assignment to this object. Used to ensure that the + # last requested assignment determines the final value of an object. + if hasattr(self, "_update_uid"): + raise AssertionError( + "Internal error: the object had an update UID set before its " + "initialization code was run.") + self._update_uid = -1 + + def _add_variable_with_custom_getter( + self, name, shape=None, dtype=dtypes.float32, + initializer=None, getter=None, **kwargs_for_getter): + """Restore-on-create for a variable be saved with this `Checkpointable`. + + If the user has requested that this object or another `Checkpointable` which + depends on this object be restored from a checkpoint (deferred loading + before variable object creation), `initializer` may be ignored and the value + from the checkpoint used instead. + + Args: + name: A name for the variable. Must be unique within this object. + shape: The shape of the variable. + dtype: The data type of the variable. + + initializer: The initializer to use. Ignored if there is a deferred + restoration left over from a call to + `_restore_from_checkpoint_position`. + + getter: The getter to wrap which actually fetches the variable. + **kwargs_for_getter: Passed to the getter. + + Returns: + The new variable object. + + Raises: + ValueError: If the variable name is not unique. + """ + self._maybe_initialize_checkpointable() + if name in self._dependency_names: + raise ValueError( + ("A variable named '%s' already exists in this Checkpointable, but " + "Checkpointable._add_variable called to create another with " + "that name. Variable names must be unique within a Checkpointable " + "object.") % (name,)) + if context.in_eager_mode(): + # If this is a variable with a single Tensor stored in the checkpoint, we + # can set that value as an initializer rather than initializing and then + # assigning (when executing eagerly). This call returns None if there is + # nothing to restore. + checkpoint_initializer = self._preload_simple_restoration( + name=name, shape=shape) + else: + checkpoint_initializer = None + if (checkpoint_initializer is not None + and not ( + isinstance(initializer, CheckpointInitialValue) + and initializer.restore_uid > checkpoint_initializer.restore_uid)): + # If multiple Checkpointable objects are "creating" the same variable via + # the magic of custom getters, the one with the highest restore UID (the + # one called last) has to make the final initializer. If another custom + # getter interrupts this process by overwriting the initializer, then + # we'll catch that when we call _track_checkpointable. So this is "best + # effort" to set the initializer with the highest restore UID. + initializer = checkpoint_initializer + shape = None + + new_variable = getter( + name=name, shape=shape, dtype=dtype, initializer=initializer, + **kwargs_for_getter) + + # If we set an initializer and the variable processed it, tracking will not + # assign again. It will add this variable to our dependencies, and if there + # is a non-trivial restoration queued, it will handle that. This also + # handles slot variables. + return self._track_checkpointable(new_variable, name=name) + + def _preload_simple_restoration(self, name, shape): + """Return a dependency's value for restore-on-create. + + Note the restoration is not deleted; if for some reason preload is called + and then not assigned to the variable (for example because a custom getter + overrides the initializer), the assignment will still happen once the + variable is tracked (determined based on checkpoint.restore_uid). + + Args: + name: The object-local name of the dependency holding the variable's + value. + shape: The shape of the variable being loaded into. + Returns: + An callable for use as a variable's initializer/initial_value, or None if + one should not be set (either because there was no variable with this name + in the checkpoint or because it needs more complex deserialization). Any + non-trivial deserialization will happen when the variable object is + tracked. + """ + deferred_dependencies_list = self._deferred_dependencies.get(name, ()) + if not deferred_dependencies_list: + # Nothing to do; we don't have a restore for this dependency queued up. + return + for checkpoint_position in deferred_dependencies_list: + if not checkpoint_position.is_simple_variable(): + # If _any_ pending restoration is too complicated to fit in an + # initializer (because it has dependencies, or because there are + # multiple Tensors to restore), bail and let the general tracking code + # handle it. + return None + checkpoint_position = max( + deferred_dependencies_list, + key=lambda restore: restore.checkpoint.restore_uid) + return CheckpointInitialValue( + checkpoint_position=checkpoint_position, shape=shape) + + def _track_checkpointable(self, checkpointable, name, overwrite=False): + """Declare a dependency on another `Checkpointable` object. + + Indicates that checkpoints for this object should include variables from + `checkpointable`. + + Variables in a checkpoint are mapped to `Checkpointable`s based on names if + provided when the checkpoint was written, but otherwise use the order those + `Checkpointable`s were declared as dependencies. + + To avoid breaking existing checkpoints when modifying a class, neither + variable names nor dependency names (the names passed to + `track_checkpointable`) may change. + + Args: + checkpointable: A `Checkpointable` which this object depends on. + name: A local name for `checkpointable`, used for loading checkpoints into + the correct objects. + overwrite: Boolean, whether silently replacing dependencies is OK. Used + for __setattr__, where throwing an error on attribute reassignment would + be inappropriate. + + Returns: + `checkpointable`, for convenience when declaring a dependency and + assigning to a member variable in one statement. + + Raises: + TypeError: If `checkpointable` does not inherit from `Checkpointable`. + ValueError: If another object is already tracked by this name. + """ + self._maybe_initialize_checkpointable() + if not isinstance(checkpointable, CheckpointableBase): + raise TypeError( + ("Checkpointable._track_checkpointable() passed type %s, not a " + "Checkpointable.") % (type(checkpointable),)) + new_reference = _CheckpointableReference(name=name, ref=checkpointable) + if (name in self._dependency_names + and self._dependency_names[name] is not checkpointable): + if not overwrite: + raise ValueError( + ("Called Checkpointable._track_checkpointable() with name='%s', " + "but a Checkpointable with this name is already declared as a " + "dependency. Names must be unique (or overwrite=True).") % (name,)) + # This is a weird thing to do, but we're not going to stop people from + # using __setattr__. + for index, (old_name, _) in enumerate(self._checkpoint_dependencies): + if name == old_name: + self._checkpoint_dependencies[index] = new_reference + else: + self._checkpoint_dependencies.append(new_reference) + + self._dependency_names[name] = checkpointable + deferred_dependency_list = self._deferred_dependencies.pop(name, None) + if deferred_dependency_list is not None: + for checkpoint_position in deferred_dependency_list: + checkpoint_position.restore(checkpointable=checkpointable) + return checkpointable + + def _restore_from_checkpoint_position(self, checkpoint_position): + """Restore this object and its dependencies (may be deferred).""" + # Attempt a breadth-first traversal, since presumably the user has more + # control over shorter paths. If we don't have all of the dependencies at + # this point, the end result is not breadth-first (since other deferred + # traversals will happen later). + visit_queue = collections.deque([checkpoint_position]) + restore_ops = [] + while visit_queue: + current_position = visit_queue.popleft() + restore_ops.extend(nest.flatten( + current_position.checkpointable # pylint: disable=protected-access + ._single_restoration_from_checkpoint_position( + checkpoint_position=current_position, + visit_queue=visit_queue))) + return restore_ops + + def _single_restoration_from_checkpoint_position( + self, checkpoint_position, visit_queue): + """Restore this object, and either queue its dependencies or defer them.""" + self._maybe_initialize_checkpointable() + checkpoint = checkpoint_position.checkpoint + # If the UID of this restore is lower than our current update UID, we don't + # need to actually restore the object. However, we should pass the + # restoration on to our dependencies. + if checkpoint.restore_uid > self._update_uid: + restore_op = self._scatter_tensors_from_checkpoint( + checkpoint_position.restore_ops()) + self._update_uid = checkpoint.restore_uid + else: + restore_op = () + for child in checkpoint_position.object_proto.children: + child_position = _CheckpointPosition( + checkpoint=checkpoint, + proto_id=child.node_id) + local_object = self._dependency_names.get(child.local_name, None) + if local_object is None: + # We don't yet have a dependency registered with this name. Save it + # in case we do. + self._deferred_dependencies.setdefault(child.local_name, []).append( + child_position) + else: + if child_position.bind_object(checkpointable=local_object): + # This object's correspondence is new, so dependencies need to be + # visited. Delay doing it so that we get a breadth-first dependency + # resolution order (shallowest paths first). The caller is responsible + # for emptying visit_queue. + visit_queue.append(child_position) + return restore_op + + def _scatter_tensors_from_checkpoint(self, attributes): + """Restores this object from a checkpoint. + + Args: + attributes: A dictionary of Tensors, with key corresponding to those + returned from _gather_tensors_for_checkpoint. + Returns: + A restore op to run (if graph building). + """ + if attributes: + raise AssertionError( + ("A Checkpointable object which was not expecting any data received " + "some from a checkpoint. (Got %s)") % (attributes,)) + return () # No restore ops + + def _gather_tensors_for_checkpoint(self): + """Returns a dictionary of Tensors to save with this object.""" + return {} + + +class Checkpointable(CheckpointableBase): + """Manages dependencies on other objects. + + `Checkpointable` objects may have dependencies: other `Checkpointable` objects + which should be saved if the object declaring the dependency is saved. A + correctly saveable program has a dependency graph such that if changing a + global variable affects an object (e.g. changes the behavior of any of its + methods) then there is a chain of dependencies from the influenced object to + the variable. + + Dependency edges have names, and are created implicitly when a + `Checkpointable` object is assigned to an attribute of another + `Checkpointable` object. For example: + + ``` + obj = Checkpointable() + obj.v = ResourceVariable(0.) + ``` + + The `Checkpointable` object `obj` now has a dependency named "v" on a + variable. + + `Checkpointable` objects may specify `Tensor`s to be saved and restored + directly (e.g. a `Variable` indicating how to save itself) rather than through + dependencies on other objects. See + `Checkpointable._scatter_tensors_from_checkpoint` and + `Checkpointable._gather_tensors_for_checkpoint` for details. + """ + + def __setattr__(self, name, value): + """Support self.foo = checkpointable syntax.""" + # Perform the attribute assignment, and potentially call other __setattr__ + # overrides such as that for tf.keras.Model. + super(Checkpointable, self).__setattr__(name, value) + if isinstance(value, CheckpointableBase): + self._track_checkpointable( + value, name=name, + # Allow the user to switch the Checkpointable which is tracked by this + # name, since assigning a new variable to an attribute has + # historically been fine (e.g. Adam did this). + # TODO(allenl): Should this be a warning once Checkpointable save/load + # is usable? + overwrite=True) diff --git a/tensorflow/python/training/checkpointable_test.py b/tensorflow/python/training/checkpointable_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e79acb49758b6a7d69dd084692d434bea808db64 --- /dev/null +++ b/tensorflow/python/training/checkpointable_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import test +from tensorflow.python.training import checkpointable + + +class InterfaceTests(test.TestCase): + + def testMultipleAssignment(self): + root = checkpointable.Checkpointable() + root.leaf = checkpointable.Checkpointable() + root.leaf = root.leaf + duplicate_name_dep = checkpointable.Checkpointable() + with self.assertRaises(ValueError): + root._track_checkpointable(duplicate_name_dep, name="leaf") + # No error; we're overriding __setattr__, so we can't really stop people + # from doing this while maintaining backward compatibility. + root.leaf = duplicate_name_dep + root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py index 0e31255b74f64657cffc4a2f58798835513f0444..0ff97d85e37e6167f1200ba56940f4a663c259a2 100644 --- a/tensorflow/python/training/coordinator.py +++ b/tensorflow/python/training/coordinator.py @@ -27,8 +27,10 @@ import six from tensorflow.python.framework import errors from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.Coordinator") class Coordinator(object): """A coordinator for threads. @@ -406,6 +408,7 @@ class Coordinator(object): # Threads for the standard services. +@tf_export("train.LooperThread") class LooperThread(threading.Thread): """A thread that runs code repeatedly, optionally on a timer. diff --git a/tensorflow/python/training/coordinator_test.py b/tensorflow/python/training/coordinator_test.py index 149d3eed414d53f46dcab403b7b4822ffa66e644..3e4ac1dfff9708fd1a5cd8bdf23f99d8f963bd16 100644 --- a/tensorflow/python/training/coordinator_test.py +++ b/tensorflow/python/training/coordinator_test.py @@ -85,8 +85,8 @@ class CoordinatorTest(test.TestCase): self.assertFalse(coord.wait_for_stop(0.1)) wait_for_stop_ev = threading.Event() has_stopped_ev = threading.Event() - t = threading.Thread(target=StopOnEvent, - args=(coord, wait_for_stop_ev, has_stopped_ev)) + t = threading.Thread( + target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev)) t.start() self.assertFalse(coord.should_stop()) self.assertFalse(coord.wait_for_stop(0.01)) @@ -100,7 +100,8 @@ class CoordinatorTest(test.TestCase): threads = [ threading.Thread(target=SleepABit, args=(0.01,)), threading.Thread(target=SleepABit, args=(0.02,)), - threading.Thread(target=SleepABit, args=(0.01,))] + threading.Thread(target=SleepABit, args=(0.01,)) + ] for t in threads: t.start() coord.join(threads) @@ -112,7 +113,8 @@ class CoordinatorTest(test.TestCase): threads = [ threading.Thread(target=SleepABit, args=(0.01, coord)), threading.Thread(target=SleepABit, args=(0.02, coord)), - threading.Thread(target=SleepABit, args=(0.01, coord))] + threading.Thread(target=SleepABit, args=(0.01, coord)) + ] for t in threads: t.start() WaitForThreadsToRegister(coord, 3) @@ -125,7 +127,8 @@ class CoordinatorTest(test.TestCase): threads = [ threading.Thread(target=SleepABit, args=(0.01, coord)), threading.Thread(target=SleepABit, args=(0.02,)), - threading.Thread(target=SleepABit, args=(0.01, coord))] + threading.Thread(target=SleepABit, args=(0.01, coord)) + ] for t in threads: t.start() WaitForThreadsToRegister(coord, 2) @@ -135,14 +138,17 @@ class CoordinatorTest(test.TestCase): self.assertFalse(t.is_alive()) def testJoinGraceExpires(self): + def TestWithGracePeriod(stop_grace_period): coord = coordinator.Coordinator() wait_for_stop_ev = threading.Event() has_stopped_ev = threading.Event() threads = [ - threading.Thread(target=StopOnEvent, - args=(coord, wait_for_stop_ev, has_stopped_ev)), - threading.Thread(target=SleepABit, args=(10.0,))] + threading.Thread( + target=StopOnEvent, + args=(coord, wait_for_stop_ev, has_stopped_ev)), + threading.Thread(target=SleepABit, args=(10.0,)) + ] for t in threads: t.daemon = True t.start() @@ -150,6 +156,7 @@ class CoordinatorTest(test.TestCase): has_stopped_ev.wait() with self.assertRaisesRegexp(RuntimeError, "threads still running"): coord.join(threads, stop_grace_period_secs=stop_grace_period) + TestWithGracePeriod(1e-10) TestWithGracePeriod(0.002) TestWithGracePeriod(1.0) @@ -159,16 +166,16 @@ class CoordinatorTest(test.TestCase): wait_for_stop_ev = threading.Event() has_stopped_ev = threading.Event() threads = [ - threading.Thread(target=StopOnEvent, - args=(coord, wait_for_stop_ev, has_stopped_ev)), - threading.Thread(target=SleepABit, args=(10.0,))] + threading.Thread( + target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev)), + threading.Thread(target=SleepABit, args=(10.0,)) + ] for t in threads: t.daemon = True t.start() wait_for_stop_ev.set() has_stopped_ev.wait() - coord.join( - threads, stop_grace_period_secs=1., ignore_live_threads=True) + coord.join(threads, stop_grace_period_secs=1., ignore_live_threads=True) def testJoinRaiseReportExcInfo(self): coord = coordinator.Coordinator() @@ -180,7 +187,8 @@ class CoordinatorTest(test.TestCase): args=(coord, ev_1, ev_2, RuntimeError("First"), False)), threading.Thread( target=RaiseOnEvent, - args=(coord, ev_2, None, RuntimeError("Too late"), False))] + args=(coord, ev_2, None, RuntimeError("Too late"), False)) + ] for t in threads: t.start() @@ -199,7 +207,8 @@ class CoordinatorTest(test.TestCase): args=(coord, ev_1, ev_2, RuntimeError("First"), True)), threading.Thread( target=RaiseOnEvent, - args=(coord, ev_2, None, RuntimeError("Too late"), True))] + args=(coord, ev_2, None, RuntimeError("Too late"), True)) + ] for t in threads: t.start() @@ -214,9 +223,8 @@ class CoordinatorTest(test.TestCase): threading.Thread( target=RaiseOnEvent, args=(coord, ev_1, None, - errors_impl.OutOfRangeError(None, None, "First"), - True)) - ] + errors_impl.OutOfRangeError(None, None, "First"), True)) + ] for t in threads: t.start() @@ -230,7 +238,7 @@ class CoordinatorTest(test.TestCase): threading.Thread( target=RaiseOnEvent, args=(coord, ev_1, None, ValueError("Clean stop"), True)) - ] + ] for t in threads: t.start() @@ -247,7 +255,8 @@ class CoordinatorTest(test.TestCase): args=(coord, ev_1, ev_2, RuntimeError("First"))), threading.Thread( target=RaiseOnEventUsingContextHandler, - args=(coord, ev_2, None, RuntimeError("Too late")))] + args=(coord, ev_2, None, RuntimeError("Too late"))) + ] for t in threads: t.start() @@ -262,7 +271,7 @@ class CoordinatorTest(test.TestCase): threading.Thread( target=RaiseOnEvent, args=(coord, ev_1, None, RuntimeError("First"), True)), - ] + ] for t in threads: t.start() @@ -274,7 +283,7 @@ class CoordinatorTest(test.TestCase): threading.Thread( target=RaiseOnEvent, args=(coord, ev_1, None, RuntimeError("Second"), True)), - ] + ] for t in threads: t.start() with self.assertRaisesRegexp(RuntimeError, "Second"): @@ -337,24 +346,29 @@ class LooperTest(test.TestCase): def testTargetArgs(self): n = [3] coord = coordinator.Coordinator() - thread = coordinator.LooperThread.loop(coord, 0, target=_StopAt0, - args=(coord, n)) + thread = coordinator.LooperThread.loop( + coord, 0, target=_StopAt0, args=(coord, n)) coord.join([thread]) self.assertEqual(0, n[0]) def testTargetKwargs(self): n = [3] coord = coordinator.Coordinator() - thread = coordinator.LooperThread.loop(coord, 0, target=_StopAt0, - kwargs={"coord": coord, "n": n}) + thread = coordinator.LooperThread.loop( + coord, 0, target=_StopAt0, kwargs={ + "coord": coord, + "n": n + }) coord.join([thread]) self.assertEqual(0, n[0]) def testTargetMixedArgs(self): n = [3] coord = coordinator.Coordinator() - thread = coordinator.LooperThread.loop(coord, 0, target=_StopAt0, - args=(coord,), kwargs={"n": n}) + thread = coordinator.LooperThread.loop( + coord, 0, target=_StopAt0, args=(coord,), kwargs={ + "n": n + }) coord.join([thread]) self.assertEqual(0, n[0]) diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py index 37ab625779f788b1b8e270a15db3244ea6f1bef3..689088bb41edfd94a1d483ed2b5f7447e9e060e7 100644 --- a/tensorflow/python/training/device_setter.py +++ b/tensorflow/python/training/device_setter.py @@ -23,6 +23,7 @@ from tensorflow.core.framework import node_def_pb2 from tensorflow.python.framework import device as pydev from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib +from tensorflow.python.util.tf_export import tf_export class _RoundRobinStrategy(object): @@ -121,6 +122,7 @@ class _ReplicaDeviceChooser(object): return worker_device.to_string() +@tf_export("train.replica_device_setter") def replica_device_setter(ps_tasks=0, ps_device="/job:ps", worker_device="/job:worker", merge_devices=True, cluster=None, ps_ops=None, ps_strategy=None): diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py index c64a1b3f799e776c7bbbbcfb691bdd97e4a34466..9d02e694db15637126f37ee5575638908b351def 100644 --- a/tensorflow/python/training/ftrl.py +++ b/tensorflow/python/training/ftrl.py @@ -22,8 +22,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.FtrlOptimizer") class FtrlOptimizer(optimizer.Optimizer): """Optimizer that implements the FTRL algorithm. @@ -265,4 +267,3 @@ class FtrlOptimizer(optimizer.Optimizer): grad.dtype), math_ops.cast(self._learning_rate_power_tensor, grad.dtype), use_locking=self._use_locking) - diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py index 5a536e27297f054671e7e44a9e5d20a8b36580b7..380e14e02497fbe3681d6bae03fe9c636c5d13aa 100644 --- a/tensorflow/python/training/gradient_descent.py +++ b/tensorflow/python/training/gradient_descent.py @@ -23,8 +23,10 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.GradientDescentOptimizer") class GradientDescentOptimizer(optimizer.Optimizer): """Optimizer that implements the gradient descent algorithm. """ diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 331a51e8bc848917967fed06632fe0d1c5bcad9c..bd9985a7c5c181c0431e0c0a91186bc36b11c787 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.summary import summary from tensorflow.python.training import queue_runner +from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access @@ -53,9 +54,12 @@ _restore_sparse = sparse_ops._take_many_sparse_from_tensors_map # pylint: enable=protected-access +@tf_export("train.match_filenames_once") def match_filenames_once(pattern, name=None): """Save the list of files matching pattern, so it is only computed once. + NOTE: The order of the files returned can be non-deterministic. + Args: pattern: A file pattern (glob), or 1D tensor of file patterns. name: A name for the operations (optional). @@ -70,6 +74,7 @@ def match_filenames_once(pattern, name=None): collections=[ops.GraphKeys.LOCAL_VARIABLES]) +@tf_export("train.limit_epochs") def limit_epochs(tensor, num_epochs=None, name=None): """Returns tensor `num_epochs` times and then raises an `OutOfRange` error. @@ -102,6 +107,7 @@ def limit_epochs(tensor, num_epochs=None, name=None): return array_ops.identity(tensor, name=name) +@tf_export("train.input_producer") def input_producer(input_tensor, element_shape=None, num_epochs=None, @@ -184,6 +190,7 @@ def input_producer(input_tensor, return q +@tf_export("train.string_input_producer") def string_input_producer(string_tensor, num_epochs=None, shuffle=True, @@ -253,6 +260,7 @@ def string_input_producer(string_tensor, cancel_op=cancel_op) +@tf_export("train.range_input_producer") def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None): """Produces the integers from 0 to limit-1 in a queue. @@ -290,6 +298,7 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None, shared_name, "fraction_of_%d_full" % capacity, name) +@tf_export("train.slice_input_producer") def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None): """Produces a slice of each `Tensor` in `tensor_list`. @@ -885,6 +894,7 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity, # Batching functions ---------------------------------------------------------- +@tf_export("train.batch") def batch(tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -979,6 +989,7 @@ def batch(tensors, batch_size, num_threads=1, capacity=32, name=name) +@tf_export("train.maybe_batch") def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -1031,6 +1042,7 @@ def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32, name=name) +@tf_export("train.batch_join") def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -1136,6 +1148,7 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False, name=name) +@tf_export("train.maybe_batch_join") def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, @@ -1188,6 +1201,7 @@ def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32, name=name) +@tf_export("train.shuffle_batch") def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -1287,6 +1301,7 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, name=name) +@tf_export("train.maybe_shuffle_batch") def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, keep_input, num_threads=1, seed=None, enqueue_many=False, shapes=None, @@ -1346,6 +1361,7 @@ def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, name=name) +@tf_export("train.shuffle_batch_join") def shuffle_batch_join(tensors_list, batch_size, capacity, min_after_dequeue, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, @@ -1439,6 +1455,7 @@ def shuffle_batch_join(tensors_list, batch_size, capacity, name=name) +@tf_export("train.maybe_shuffle_batch_join") def maybe_shuffle_batch_join(tensors_list, batch_size, capacity, min_after_dequeue, keep_input, seed=None, enqueue_many=False, shapes=None, diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index 3ee49650e01bd31d7d34fe1e109599531626058c..10ab4c1137ff226d88902143d4f2281ad77de531 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Various learning rate decay functions.""" from __future__ import absolute_import from __future__ import division @@ -26,10 +25,16 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.util.tf_export import tf_export -def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, - staircase=False, name=None): +@tf_export("train.exponential_decay") +def exponential_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): """Applies exponential decay to the learning rate. When training a model, it is often recommended to lower the learning rate as @@ -85,9 +90,9 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, """ if global_step is None: raise ValueError("global_step is required for exponential_decay.") - with ops.name_scope(name, "ExponentialDecay", - [learning_rate, global_step, - decay_steps, decay_rate]) as name: + with ops.name_scope( + name, "ExponentialDecay", + [learning_rate, global_step, decay_steps, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype global_step = math_ops.cast(global_step, dtype) @@ -96,10 +101,11 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, p = global_step / decay_steps if staircase: p = math_ops.floor(p) - return math_ops.multiply(learning_rate, math_ops.pow(decay_rate, p), - name=name) + return math_ops.multiply( + learning_rate, math_ops.pow(decay_rate, p), name=name) +@tf_export("train.piecewise_constant") def piecewise_constant(x, boundaries, values, name=None): """Piecewise constant from boundaries and interval values. @@ -156,15 +162,15 @@ def piecewise_constant(x, boundaries, values, name=None): boundaries[i] = b else: raise ValueError( - "Boundaries (%s) must have the same dtype as x (%s)." % ( - b.dtype.base_dtype, x.dtype.base_dtype)) + "Boundaries (%s) must have the same dtype as x (%s)." % + (b.dtype.base_dtype, x.dtype.base_dtype)) # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing. values = ops.convert_n_to_tensor(values) for v in values[1:]: if v.dtype.base_dtype != values[0].dtype.base_dtype: raise ValueError( - "Values must have elements all with the same dtype (%s vs %s)." % ( - values[0].dtype.base_dtype, v.dtype.base_dtype)) + "Values must have elements all with the same dtype (%s vs %s)." % + (values[0].dtype.base_dtype, v.dtype.base_dtype)) pred_fn_pairs = [] pred_fn_pairs.append((x <= boundaries[0], lambda: values[0])) pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1])) @@ -179,9 +185,14 @@ def piecewise_constant(x, boundaries, values, name=None): return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) -def polynomial_decay(learning_rate, global_step, decay_steps, - end_learning_rate=0.0001, power=1.0, - cycle=False, name=None): +@tf_export("train.polynomial_decay") +def polynomial_decay(learning_rate, + global_step, + decay_steps, + end_learning_rate=0.0001, + power=1.0, + cycle=False, + name=None): """Applies a polynomial decay to the learning rate. It is commonly observed that a monotonically decreasing learning rate, whose @@ -255,9 +266,10 @@ def polynomial_decay(learning_rate, global_step, decay_steps, """ if global_step is None: raise ValueError("global_step is required for polynomial_decay.") - with ops.name_scope(name, "PolynomialDecay", - [learning_rate, global_step, - decay_steps, end_learning_rate, power]) as name: + with ops.name_scope( + name, "PolynomialDecay", + [learning_rate, global_step, decay_steps, end_learning_rate, power + ]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype global_step = math_ops.cast(global_step, dtype) @@ -267,23 +279,29 @@ def polynomial_decay(learning_rate, global_step, decay_steps, if cycle: # Find the first multiple of decay_steps that is bigger than global_step. # If global_step is zero set the multiplier to 1 - multiplier = control_flow_ops.cond(math_ops.equal(global_step, 0), - lambda: 1.0, - lambda: math_ops.ceil( - global_step / decay_steps)) + multiplier = control_flow_ops.cond( + math_ops.equal(global_step, 0), lambda: 1.0, + lambda: math_ops.ceil(global_step / decay_steps)) decay_steps = math_ops.multiply(decay_steps, multiplier) else: # Make sure that the global_step used is not bigger than decay_steps. global_step = math_ops.minimum(global_step, decay_steps) p = math_ops.div(global_step, decay_steps) - return math_ops.add(math_ops.multiply(learning_rate - end_learning_rate, - math_ops.pow(1 - p, power)), - end_learning_rate, name=name) - - -def natural_exp_decay(learning_rate, global_step, decay_steps, decay_rate, - staircase=False, name=None): + return math_ops.add( + math_ops.multiply(learning_rate - end_learning_rate, + math_ops.pow(1 - p, power)), + end_learning_rate, + name=name) + + +@tf_export("train.natural_exp_decay") +def natural_exp_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): """Applies natural exponential decay to the initial learning rate. When training a model, it is often recommended to lower the learning rate as @@ -349,8 +367,13 @@ def natural_exp_decay(learning_rate, global_step, decay_steps, decay_rate, return math_ops.multiply(learning_rate, exponent, name=name) -def inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate, - staircase=False, name=None): +@tf_export("train.inverse_time_decay") +def inverse_time_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): """Applies inverse time decay to the initial learning rate. When training a model, it is often recommended to lower the learning rate as @@ -362,13 +385,15 @@ def inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate, The function returns the decayed learning rate. It is computed as: ```python - decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step) + decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / + decay_step) ``` or, if `staircase` is `True`, as: ```python - decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step)) + decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / + decay_step)) ``` Example: decay 1/t with a rate of 0.5: @@ -379,7 +404,8 @@ def inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate, learning_rate = 0.1 decay_steps = 1.0 decay_rate = 0.5 - learning_rate = tf.train.inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate) + learning_rate = tf.train.inverse_time_decay(learning_rate, global_step, + decay_steps, decay_rate) # Passing global_step to minimize() will increment it at each step. learning_step = ( @@ -424,8 +450,8 @@ def inverse_time_decay(learning_rate, global_step, decay_steps, decay_rate, return math_ops.div(learning_rate, denom, name=name) -def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, - name=None): +@tf_export("train.cosine_decay") +def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): """Applies cosine decay to the learning rate. See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent @@ -484,8 +510,14 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, return math_ops.multiply(learning_rate, decayed) -def cosine_decay_restarts(learning_rate, global_step, first_decay_steps, - t_mul=2.0, m_mul=1.0, alpha=0.0, name=None): +@tf_export("train.cosine_decay_restarts") +def cosine_decay_restarts(learning_rate, + global_step, + first_decay_steps, + t_mul=2.0, + m_mul=1.0, + alpha=0.0, + name=None): """Applies cosine decay with restarts to the learning rate. See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent @@ -532,10 +564,9 @@ def cosine_decay_restarts(learning_rate, global_step, first_decay_steps, """ if global_step is None: raise ValueError("cosine decay restarts requires global_step") - with ops.name_scope(name, "SGDRDecay", - [learning_rate, global_step]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, - name="initial_learning_rate") + with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step]) as name: + learning_rate = ops.convert_to_tensor( + learning_rate, name="initial_learning_rate") dtype = learning_rate.dtype global_step = math_ops.cast(global_step, dtype) first_decay_steps = math_ops.cast(first_decay_steps, dtype) @@ -547,11 +578,12 @@ def cosine_decay_restarts(learning_rate, global_step, first_decay_steps, def compute_step(completed_fraction, geometric=False): if geometric: - i_restart = math_ops.floor(math_ops.log(1.0 - completed_fraction * ( - 1.0 - t_mul)) / math_ops.log(t_mul)) + i_restart = math_ops.floor( + math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / + math_ops.log(t_mul)) - sum_r = (1.0 - t_mul ** i_restart) / (1.0 - t_mul) - completed_fraction = (completed_fraction - sum_r) / t_mul ** i_restart + sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) + completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart else: i_restart = math_ops.floor(completed_fraction) @@ -564,16 +596,21 @@ def cosine_decay_restarts(learning_rate, global_step, first_decay_steps, lambda: compute_step(completed_fraction, geometric=False), lambda: compute_step(completed_fraction, geometric=True)) - m_fac = m_mul ** i_restart - cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos( - constant_op.constant(math.pi) * completed_fraction)) + m_fac = m_mul**i_restart + cosine_decayed = 0.5 * m_fac * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction)) decayed = (1 - alpha) * cosine_decayed + alpha return math_ops.multiply(learning_rate, decayed, name=name) -def linear_cosine_decay(learning_rate, global_step, decay_steps, - num_periods=0.5, alpha=0.0, beta=0.001, +@tf_export("train.linear_cosine_decay") +def linear_cosine_decay(learning_rate, + global_step, + decay_steps, + num_periods=0.5, + alpha=0.0, + beta=0.001, name=None): """Applies linear cosine decay to the learning rate. @@ -651,9 +688,15 @@ def linear_cosine_decay(learning_rate, global_step, decay_steps, return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) -def noisy_linear_cosine_decay(learning_rate, global_step, decay_steps, - initial_variance=1.0, variance_decay=0.55, - num_periods=0.5, alpha=0.0, beta=0.001, +@tf_export("train.noisy_linear_cosine_decay") +def noisy_linear_cosine_decay(learning_rate, + global_step, + decay_steps, + initial_variance=1.0, + variance_decay=0.55, + num_periods=0.5, + alpha=0.0, + beta=0.001, name=None): """Applies noisy linear cosine decay to the learning rate. @@ -734,8 +777,8 @@ def noisy_linear_cosine_decay(learning_rate, global_step, decay_steps, math_ops.pow(1.0 + global_step, variance_decay)) std = math_ops.sqrt(variance) noisy_linear_decayed = ( - linear_decayed + random_ops.random_normal( - linear_decayed.shape, stddev=std)) + linear_decayed + + random_ops.random_normal(linear_decayed.shape, stddev=std)) completed_fraction = global_step / decay_steps fraction = 2.0 * num_periods * completed_fraction diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py index cf9530d87c46783b517884610b644b076bef6807..bd9fa79d8feac68c149f787ee8501bdddb173d33 100644 --- a/tensorflow/python/training/momentum.py +++ b/tensorflow/python/training/momentum.py @@ -22,8 +22,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.MomentumOptimizer") class MomentumOptimizer(optimizer.Optimizer): """Optimizer that implements the Momentum algorithm. diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index 6865513b0e4aad18d77887770a11243642958e7a..cda421cef837fa6ab25898208a8dc94d70561048 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -247,7 +247,7 @@ class MomentumOptimizerTest(test.TestCase): # pylint: enable=cell-var-from-loop opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0) - sgd_op = opt.minimize(loss if context.in_eager_mode() else loss()) + sgd_op = opt.minimize(loss) self.evaluate(variables.global_variables_initializer()) # Run 1 step of sgd self.evaluate(sgd_op) @@ -262,7 +262,7 @@ class MomentumOptimizerTest(test.TestCase): return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]])) opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0) - sgd_op = opt.minimize(loss if context.in_eager_mode() else loss()) + sgd_op = opt.minimize(loss) self.evaluate(variables.global_variables_initializer()) self.evaluate(sgd_op) self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0)) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index fa3517db27be4581deb85f77f022406b8b30ec56..6c5c9e01a76d539b550420134b09090b89beed46 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -41,6 +41,7 @@ from tensorflow.python.training import queue_runner from tensorflow.python.training import saver as training_saver from tensorflow.python.training import session_manager as sm from tensorflow.python.training import session_run_hook +from tensorflow.python.util.tf_export import tf_export # The list of exceptions that we should recover from. Exceptions not in this @@ -52,6 +53,7 @@ _PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError) USE_DEFAULT = object() +@tf_export('train.Scaffold') class Scaffold(object): """Structure to create or gather pieces commonly needed to train a model. @@ -272,6 +274,7 @@ class Scaffold(object): resources.initialize_resources(resources.local_resources())) +@tf_export('train.MonitoredTrainingSession') def MonitoredTrainingSession(master='', # pylint: disable=invalid-name is_chief=True, checkpoint_dir=None, @@ -381,6 +384,7 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name stop_grace_period_secs=stop_grace_period_secs) +@tf_export('train.SessionCreator') class SessionCreator(object): """A factory for tf.Session.""" @@ -390,6 +394,7 @@ class SessionCreator(object): 'create_session is not implemented for {}.'.format(self)) +@tf_export('train.ChiefSessionCreator') class ChiefSessionCreator(SessionCreator): """Creates a tf.Session for a chief.""" @@ -441,6 +446,7 @@ class ChiefSessionCreator(SessionCreator): init_fn=self._scaffold.init_fn) +@tf_export('train.WorkerSessionCreator') class WorkerSessionCreator(SessionCreator): """Creates a tf.Session for a worker.""" @@ -706,6 +712,7 @@ class _MonitoredSession(object): return self._coordinated_creator.tf_sess +@tf_export('train.MonitoredSession') class MonitoredSession(_MonitoredSession): """Session-like object that handles initialization, recovery and hooks. @@ -788,6 +795,7 @@ class MonitoredSession(_MonitoredSession): stop_grace_period_secs=stop_grace_period_secs) +@tf_export('train.SingularMonitoredSession') class SingularMonitoredSession(_MonitoredSession): """Session-like object that handles initialization, restoring, and hooks. diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index e34c759e894c86a103f0228163f7bae2ffc7fb61..b9ecb27df19d051c28ec1c3fe3cd9fd86717a5ed 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import slot_creator +from tensorflow.python.util.tf_export import tf_export # TODO(touts): switch to variables.Variable. @@ -187,7 +188,7 @@ def _zero_debias(unbiased_var, value, decay): with variable_scope.variable_scope( unbiased_var.op.name, values=[unbiased_var, value, decay]) as scope: with ops.colocate_with(unbiased_var): - with ops.control_dependencies(None): + with ops.init_scope(): biased_initializer = init_ops.zeros_initializer( dtype=unbiased_var.dtype)(unbiased_var.get_shape()) local_step_initializer = init_ops.zeros_initializer() @@ -230,6 +231,7 @@ def _zero_debias(unbiased_var, value, decay): return unbiased_ema_delta +@tf_export("train.ExponentialMovingAverage") class ExponentialMovingAverage(object): """Maintains moving averages of variables by employing an exponential decay. @@ -385,7 +387,7 @@ class ExponentialMovingAverage(object): # For variables: to lower communication bandwidth across devices we keep # the moving averages on the same device as the variables. For other # tensors, we rely on the existing device allocation mechanism. - with ops.control_dependencies(None): + with ops.init_scope(): if isinstance(var, variables.Variable): avg = slot_creator.create_slot(var, var.initialized_value(), @@ -398,7 +400,9 @@ class ExponentialMovingAverage(object): avg = slot_creator.create_zeros_slot( var, self._name, - colocate_with_primary=(var.op.type in ["Variable", "VariableV2"])) + colocate_with_primary=(var.op.type in ["Variable", + "VariableV2", + "VarHandleOp"])) if self._zero_debias: zero_debias_true.add(avg) self._averages[var] = avg diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 038469b1bac9d2fabce788340278ea165f2f9249..678d6322aa5ecea0a603b6a9858f7619638eae30 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -34,8 +34,10 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.training import checkpointable from tensorflow.python.training import slot_creator from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export def _get_variable_for(v): @@ -174,20 +176,44 @@ class _StreamingModelPortProcessor(_OptimizableVariable): return g +class _TensorProcessor(_OptimizableVariable): + """Processor for ordinary Tensors. + + Even though a Tensor can't really be updated, sometimes it is useful to + compute the gradients with respect to a Tensor using the optimizer. Updating + the Tensor is, of course, unsupported. + """ + + def __init__(self, v): + self._v = v + + def target(self): + return self._v + + def update_op(self, optimizer, g): + raise NotImplementedError("Trying to update a Tensor ", self._v) + + def _get_processor(v): """The processor of v.""" if context.in_eager_mode(): - return _DenseResourceVariableProcessor(v) + if isinstance(v, ops.Tensor): + return _TensorProcessor(v) + else: + return _DenseResourceVariableProcessor(v) if v.op.type == "VarHandleOp": return _DenseResourceVariableProcessor(v) if isinstance(v, variables.Variable): return _RefVariableProcessor(v) if v.op.type == "SubmodelPort": return _StreamingModelPortProcessor(v) + if isinstance(v, ops.Tensor): + return _TensorProcessor(v) raise NotImplementedError("Trying to optimize unsupported type ", v) -class Optimizer(object): +@tf_export("train.Optimizer") +class Optimizer(checkpointable.Checkpointable): """Base class for optimizers. This class defines the API to add Ops to train a model. You never use this @@ -298,9 +324,18 @@ class Optimizer(object): self._use_locking = use_locking self._name = name # Dictionary of slots. - # {slot_name : { variable_to_train: slot_for_the_variable, ...}, ... } + # {slot_name : + # {_var_key(variable_to_train): slot_for_the_variable, ... }, + # ... } self._slots = {} self._non_slot_dict = {} + # For implementing Checkpointable. Stores information about how to restore + # slot variables which have not yet been created + # (checkpointable._CheckpointPosition objects). + # {slot_name : + # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, + # ... } + self._deferred_slot_restorations = {} def get_name(self): return self._name @@ -380,7 +415,9 @@ class Optimizer(object): given variable. Args: - loss: A Tensor containing the value to minimize. + loss: A Tensor containing the value to minimize or a callable taking + no arguments which returns the value to minimize. When eager execution + is enabled it must be a callable. var_list: Optional list or tuple of `tf.Variable` to update to minimize `loss`. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. @@ -399,37 +436,27 @@ class Optimizer(object): Raises: TypeError: If `var_list` contains anything else than `Variable` objects. ValueError: If some arguments are invalid. - RuntimeError: If called with eager execution enabled and if `grad_loss` - is not `None` or `loss` is not callable. + RuntimeError: If called with eager execution enabled and `loss` is + not callable. @compatibility(eager) - When eager execution is enabled, `loss` should be a Python function that - takes elements of `var_list` as arguments and computes the value to be - minimized. If `var_list` is None, `loss` should take no arguments. - Gradient computation is done with respect to the elements of `var_list` if - not None, else with respect to any trainable variables created during the - execution of the `loss` function. - `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and - `grad_loss` are ignored when eager execution is enabled. + When eager execution is enabled, `gate_gradients`, `aggregation_method`, + and `colocate_gradients_with_ops` are ignored. @end_compatibility """ - if context.in_eager_mode(): - if grad_loss is not None: - raise RuntimeError( - "`grad_loss` argument to Optimizer.compute_gradients " - "not supported when eager execution is enabled.") - if not callable(loss): - raise RuntimeError( - "`loss` passed to Optimizer.compute_gradients should " - "be a function when eager execution is enabled.") - # TODO(agarwal): consider passing parameters to the `loss` function. + if callable(loss): + with backprop.GradientTape() as tape: + if var_list is not None: + tape.watch(var_list) + loss_value = loss() if var_list is None: - return backprop.implicit_grad(loss)() - else: - var_list = nest.flatten(var_list) - grads = backprop.gradients_function(loss)(*var_list) - grads_and_vars = list(zip(grads, var_list)) - return grads_and_vars + var_list = tape.watched_variables() + grads = tape.gradient(loss_value, var_list, grad_loss) + return list(zip(grads, var_list)) + if context.in_eager_mode(): + raise RuntimeError( + "`loss` passed to Optimizer.compute_gradients should " + "be a function when eager execution is enabled.") if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH]: raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " @@ -514,7 +541,7 @@ class Optimizer(object): if not var_list: raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, _, v in converted_grads_and_vars],)) - with ops.control_dependencies(None): + with ops.init_scope(): self._create_slots([_get_variable_for(v) for v in var_list]) update_ops = [] with ops.name_scope(name, self._name) as name: @@ -533,7 +560,15 @@ class Optimizer(object): else: with ops.control_dependencies([self._finish(update_ops, "update")]): with ops.colocate_with(global_step): - apply_updates = state_ops.assign_add(global_step, 1, name=name) + if isinstance(global_step, resource_variable_ops.ResourceVariable): + # TODO(apassos): the implicit read in assign_add is slow; consider + # making it less so. + apply_updates = resource_variable_ops.assign_add_variable_op( + global_step.handle, + ops.convert_to_tensor(1, dtype=global_step.dtype), + name=name) + else: + apply_updates = state_ops.assign_add(global_step, 1, name=name) if context.in_graph_mode(): if isinstance(apply_updates, ops.Tensor): @@ -592,7 +627,7 @@ class Optimizer(object): if executing_eagerly: # No variable.op in eager mode. We don't expect lots of eager graphs, # but behavior should be consistent with graph mode. - return variable._container_prefix == current_graph._container_prefix # pylint: disable=protected-access + return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access else: return variable.op.graph is current_graph @@ -858,7 +893,11 @@ class Optimizer(object): """ named_slots = self._slot_dict(slot_name) if _var_key(var) not in named_slots: - named_slots[_var_key(var)] = slot_creator.create_slot(var, val, op_name) + new_slot_variable = slot_creator.create_slot(var, val, op_name) + self._restore_slot_variable( + slot_name=slot_name, variable=var, + slot_variable=new_slot_variable) + named_slots[_var_key(var)] = new_slot_variable return named_slots[_var_key(var)] def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype, @@ -879,8 +918,12 @@ class Optimizer(object): """ named_slots = self._slot_dict(slot_name) if _var_key(var) not in named_slots: - named_slots[_var_key(var)] = slot_creator.create_slot_with_initializer( + new_slot_variable = slot_creator.create_slot_with_initializer( var, initializer, shape, dtype, op_name) + self._restore_slot_variable( + slot_name=slot_name, variable=var, + slot_variable=new_slot_variable) + named_slots[_var_key(var)] = new_slot_variable return named_slots[_var_key(var)] def _zeros_slot(self, var, slot_name, op_name): @@ -897,5 +940,79 @@ class Optimizer(object): """ named_slots = self._slot_dict(slot_name) if _var_key(var) not in named_slots: - named_slots[_var_key(var)] = slot_creator.create_zeros_slot(var, op_name) + new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + self._restore_slot_variable( + slot_name=slot_name, variable=var, + slot_variable=new_slot_variable) + named_slots[_var_key(var)] = new_slot_variable return named_slots[_var_key(var)] + + # -------------- + # For implementing the Checkpointable interface. + # -------------- + + def _restore_slot_variable(self, slot_name, variable, slot_variable): + """Restore a newly created slot variable's value.""" + variable_key = _var_key(variable) + deferred_restorations = self._deferred_slot_restorations.get( + slot_name, {}).pop(variable_key, []) + # Iterate over restores, highest restore UID first to minimize the number + # of assignments. + deferred_restorations.sort(key=lambda position: position.restore_uid, + reverse=True) + for checkpoint_position in deferred_restorations: + checkpoint_position.restore(slot_variable) + + def _create_or_restore_slot_variable( + self, slot_variable_position, slot_name, variable): + """Restore a slot variable's value, possibly creating it. + + Called when a variable which has an associated slot variable is created or + restored. When executing eagerly, we create the slot variable with a + restoring initializer. + + No new variables are created when graph building. Instead, + _restore_slot_variable catches these after normal creation and adds restore + ops to the graph. This method is nonetheless important when graph building + for the case when a slot variable has already been created but `variable` + has just been added to a dependency graph (causing us to realize that the + slot variable needs to be restored). + + Args: + slot_variable_position: A `checkpointable._CheckpointPosition` object + indicating the slot variable `Checkpointable` object to be restored. + slot_name: The name of this `Optimizer`'s slot to restore into. + variable: The variable object this slot is being created for. + """ + named_slots = self._slot_dict(slot_name) + variable_key = _var_key(variable) + slot_variable = named_slots.get(variable_key, None) + if (slot_variable is None + and context.in_eager_mode() + and slot_variable_position.is_simple_variable()): + initializer = checkpointable.CheckpointInitialValue( + checkpoint_position=slot_variable_position) + slot_variable = self._get_or_make_slot( + var=variable, + val=initializer, + slot_name=slot_name, + op_name=self._name) + # Slot variables are not owned by any one object (because we don't want to + # save the slot variable if the optimizer is saved without the non-slot + # variable, or if the non-slot variable is saved without the optimizer; + # it's a dependency hypergraph with edges of the form (optimizer, non-slot + # variable, variable)). So we don't _track_ slot variables anywhere, and + # instead special-case this dependency and otherwise pretend it's a normal + # graph. + if slot_variable is not None: + # If we've either made this slot variable, or if we've pulled out an + # existing slot variable, we should restore it. + slot_variable_position.restore(slot_variable) + else: + # We didn't make the slot variable. Defer restoring until it gets created + # normally. We keep a list rather than the one with the highest restore + # UID in case slot variables have their own dependencies, in which case + # those could differ between restores. + self._deferred_slot_restorations.setdefault( + slot_name, {}).setdefault(variable_key, []).append( + slot_variable_position) diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 6bdae39073d48e0bd8b757a2d5145480e92d185f..0cab6410e83ca1880a0a4a80d2cfa5c17517af95 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -44,11 +43,10 @@ class OptimizerTest(test.TestCase): name='a_%d' % i) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, name='b_%d' % i) - def loss(v0, v1): - return 5 * v0 + 3 * v1 + def loss(): + return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop # Note that for eager execution, minimize expects a function instead of a # Tensor. - cost = loss if context.in_eager_mode() else loss(var0, var1) global_step = resource_variable_ops.ResourceVariable( array_ops.zeros([], dtypes.int64), name='global_step_%d' % i) sgd_op = gradient_descent.GradientDescentOptimizer(3.0) @@ -58,7 +56,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([1.0, 2.0], self.evaluate(var0)) self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 1 step of sgd through optimizer - opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) + opt_op = sgd_op.minimize(loss, global_step, [var0, var1]) self.evaluate(opt_op) # Validate updated params self.assertAllClose([-14., -13.], self.evaluate(var0)) @@ -125,10 +123,9 @@ class OptimizerTest(test.TestCase): [3.0, 4.0], dtype=dtype, trainable=False, name='b') return 5 * var0 + var1 # pylint: enable=cell-var-from-loop - cost = loss if context.in_eager_mode() else loss() sgd_op = gradient_descent.GradientDescentOptimizer(3.0) with self.assertRaisesRegexp(ValueError, 'No.*variables'): - sgd_op.minimize(cost) + sgd_op.minimize(loss) @test_util.run_in_graph_and_eager_modes() def testNoGradients(self): @@ -140,14 +137,13 @@ class OptimizerTest(test.TestCase): var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, name='b%d' % i) # pylint: disable=cell-var-from-loop - def loss(_): + def loss(): return 5 * var0 # pylint: enable=cell-var-from-loop - cost = loss if context.in_eager_mode() else loss(var1) sgd_op = gradient_descent.GradientDescentOptimizer(3.0) with self.assertRaisesRegexp(ValueError, 'No gradients'): # var1 has no gradient - sgd_op.minimize(cost, var_list=[var1]) + sgd_op.minimize(loss, var_list=[var1]) @test_util.run_in_graph_and_eager_modes() def testNoGradientsForAnyVariables_Minimize(self): @@ -158,13 +154,12 @@ class OptimizerTest(test.TestCase): name='a_%d' % i) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, name='b_%d' % i) - def loss(unused_v1, unused_v2): + def loss(): return constant_op.constant(5.0) - cost = loss if context.in_eager_mode() else loss(var0, var1) sgd_op = gradient_descent.GradientDescentOptimizer(3.0) with self.assertRaisesRegexp(ValueError, 'No gradients provided for any variable'): - sgd_op.minimize(cost, var_list=[var0, var1]) + sgd_op.minimize(loss, var_list=[var0, var1]) @test_util.run_in_graph_and_eager_modes() def testNoGradientsForAnyVariables_ApplyGradients(self): @@ -189,11 +184,10 @@ class OptimizerTest(test.TestCase): name='a%d' % i) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, name='b%d' % i) - def loss(v0, v1): - return 5 * v0 + 3 * v1 - cost = loss if context.in_eager_mode() else loss(var0, var1) + def loss(): + return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop sgd_op = gradient_descent.GradientDescentOptimizer(3.0) - grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1]) + grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1]) # Convert gradients to tf.Variables converted_grads = [ resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype), @@ -221,6 +215,21 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-14., -13.], self.evaluate(var0)) self.assertAllClose([-6., -5.], self.evaluate(var1)) + @test_util.run_in_graph_and_eager_modes() + def testComputeGradientsWithTensors(self): + x = ops.convert_to_tensor(1.0) + def f(): + return x * x + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + grads_and_vars = sgd_op.compute_gradients(f, [x]) + self.assertEqual(1, len(grads_and_vars)) + grad, x_as_var = grads_and_vars[0] + self.assertIs(x, x_as_var) + self.assertEqual(2.0, self.evaluate(grad)) + + with self.assertRaises(NotImplementedError): + sgd_op.apply_gradients(grads_and_vars) + def testTrainOp(self): with self.test_session(): var0 = variables.Variable([1.0, 2.0]) diff --git a/tensorflow/python/training/proximal_adagrad.py b/tensorflow/python/training/proximal_adagrad.py index da31ab325d5e45e1943f554c45717cceb4dc638f..9bd677b8efcd447f74ec2a3cbe94d63eeb9a4dd1 100644 --- a/tensorflow/python/training/proximal_adagrad.py +++ b/tensorflow/python/training/proximal_adagrad.py @@ -23,8 +23,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.ProximalAdagradOptimizer") class ProximalAdagradOptimizer(optimizer.Optimizer): # pylint: disable=line-too-long """Optimizer that implements the Proximal Adagrad algorithm. diff --git a/tensorflow/python/training/proximal_gradient_descent.py b/tensorflow/python/training/proximal_gradient_descent.py index 53e9dc2ef2c86a20070fdbdc690b39d2c0e9df06..369b6cbb50e5c621737c095a24eeb473f3870534 100644 --- a/tensorflow/python/training/proximal_gradient_descent.py +++ b/tensorflow/python/training/proximal_gradient_descent.py @@ -24,8 +24,10 @@ from tensorflow.python.ops import math_ops # pylint: enable=unused-import from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.ProximalGradientDescentOptimizer") class ProximalGradientDescentOptimizer(optimizer.Optimizer): # pylint: disable=line-too-long """Optimizer that implements the proximal gradient descent algorithm. diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py index 4e7c81d7b2913d71a23dcaa3751db2aaffdc67cf..07afba79abf4d636c9ec2d53bcf2641594a35733 100644 --- a/tensorflow/python/training/queue_runner_impl.py +++ b/tensorflow/python/training/queue_runner_impl.py @@ -27,8 +27,10 @@ from tensorflow.python.eager import context from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.queue_runner.QueueRunner", "train.QueueRunner") class QueueRunner(object): """Holds a list of enqueue operations for a queue, each to be run in a thread. @@ -384,6 +386,7 @@ class QueueRunner(object): import_scope=import_scope) +@tf_export("train.queue_runner.add_queue_runner", "train.add_queue_runner") def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS): """Adds a `QueueRunner` to a collection in the graph. @@ -402,6 +405,8 @@ def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS): ops.add_to_collection(collection, qr) +@tf_export("train.queue_runner.start_queue_runners", + "train.start_queue_runners") def start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection=ops.GraphKeys.QUEUE_RUNNERS): """Starts all queue runners collected in the graph. diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py index ebec725b7b98e9a078f5558af85355988e8aca67..341b970c92e42b4fe392d91f57219d713d2513e5 100644 --- a/tensorflow/python/training/rmsprop.py +++ b/tensorflow/python/training/rmsprop.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """One-line documentation for rmsprop module. rmsprop algorithm [tieleman2012rmsprop] @@ -43,16 +42,20 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.RMSPropOptimizer") class RMSPropOptimizer(optimizer.Optimizer): """Optimizer that implements the RMSProp algorithm. - See the [paper](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). + See the + [paper](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). """ def __init__(self, @@ -105,21 +108,24 @@ class RMSPropOptimizer(optimizer.Optimizer): def _create_slots(self, var_list): for v in var_list: - init_rms = init_ops.ones_initializer(dtype=v.dtype) + if v.get_shape().is_fully_defined(): + init_rms = init_ops.ones_initializer(dtype=v.dtype.base_dtype) + else: + init_rms = array_ops.ones_like(v) self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(), - v.dtype, "rms", self._name) + v.dtype.base_dtype, "rms", + self._name) if self._centered: self._zeros_slot(v, "mg", self._name) self._zeros_slot(v, "momentum", self._name) def _prepare(self): - self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate, - name="learning_rate") + self._learning_rate_tensor = ops.convert_to_tensor( + self._learning_rate, name="learning_rate") self._decay_tensor = ops.convert_to_tensor(self._decay, name="decay") - self._momentum_tensor = ops.convert_to_tensor(self._momentum, - name="momentum") - self._epsilon_tensor = ops.convert_to_tensor(self._epsilon, - name="epsilon") + self._momentum_tensor = ops.convert_to_tensor( + self._momentum, name="momentum") + self._epsilon_tensor = ops.convert_to_tensor(self._epsilon, name="epsilon") def _apply_dense(self, grad, var): rms = self.get_slot(var, "rms") diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 4f3773c0fc71e1f1abd8197dea94ce2a63881389..3888e9bba42dc89055638ad0abe2b7e1a9f5b548 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -53,6 +53,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export # Op names which identify variable reads which should be saved. @@ -889,6 +890,7 @@ def _GetCheckpointFilename(save_dir, latest_filename): return os.path.join(save_dir, latest_filename) +@tf_export("train.generate_checkpoint_state_proto") def generate_checkpoint_state_proto(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None): @@ -933,6 +935,7 @@ def generate_checkpoint_state_proto(save_dir, return coord_checkpoint_proto +@tf_export("train.update_checkpoint_state") def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, @@ -1025,6 +1028,7 @@ def _update_checkpoint_state(save_dir, text_format.MessageToString(ckpt)) +@tf_export("train.get_checkpoint_state") def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. @@ -1082,6 +1086,7 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None): return ckpt +@tf_export("train.Saver") class Saver(object): """Saves and restores variables. @@ -1229,7 +1234,7 @@ class Saver(object): The `saver_def` proto should be the one returned by the `as_saver_def()` call of the `Saver` that was created for that `Graph`. builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. - Defaults to `BaseSaverBuilder()`. + Defaults to `BulkSaverBuilder()`. defer_build: If `True`, defer adding the save and restore ops to the `build()` call. In that case `build()` should be called before finalizing the graph or using the saver. @@ -1309,7 +1314,7 @@ class Saver(object): if not self.saver_def or context.in_eager_mode(): if self._builder is None: - self._builder = BaseSaverBuilder(self._write_version) + self._builder = BulkSaverBuilder(self._write_version) if self._var_list is None: # pylint: disable=protected-access @@ -1592,9 +1597,9 @@ class Saver(object): [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: - A string: path prefix used for the checkpoint files. If checkpoint - format is V1 and the saver is sharded, this string ends with: - '-?????-of-nnnnn' where 'nnnnn' is the number of shards created. + A string: path prefix used for the checkpoint files. If the saver is + sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' + is the number of shards created. If the saver is empty, returns None. Raises: @@ -1744,11 +1749,6 @@ class Saver(object): return if save_path is None: raise ValueError("Can't load save_path when it is None.") - if (os.path.isfile(save_path) and - self._write_version != saver_pb2.SaverDef.V1): - raise ValueError("The specified path: %s is a file." - " Please specify only the path prefix" - " to the checkpoint files." % save_path) logging.info("Restoring parameters from %s", save_path) if context.in_graph_mode(): sess.run(self.saver_def.restore_op_name, @@ -1788,6 +1788,7 @@ def _prefix_to_checkpoint_path(prefix, format_version): return prefix # Just the data file. +@tf_export("train.latest_checkpoint") def latest_checkpoint(checkpoint_dir, latest_filename=None): """Finds the filename of latest saved checkpoint file. @@ -1817,6 +1818,7 @@ def latest_checkpoint(checkpoint_dir, latest_filename=None): return None +@tf_export("train.import_meta_graph") def import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs): """Recreates a Graph saved in a `MetaGraphDef` proto. @@ -1918,6 +1920,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False, return None +@tf_export("train.export_meta_graph") def export_meta_graph(filename=None, meta_info_def=None, graph_def=None, @@ -1994,6 +1997,7 @@ def export_meta_graph(filename=None, return meta_graph_def +@tf_export("train.checkpoint_exists") def checkpoint_exists(checkpoint_prefix): """Checks whether a V1 or V2 checkpoint exists with the specified prefix. @@ -2018,6 +2022,7 @@ def checkpoint_exists(checkpoint_prefix): return False +@tf_export("train.get_checkpoint_mtimes") def get_checkpoint_mtimes(checkpoint_prefixes): """Returns the mtimes (modification timestamps) of the checkpoints. diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py index 29da67a30a58c1b8b8e172b2ccede340880fef58..2f421d1cc0a0190670082fabf4e25470c6a1723b 100644 --- a/tensorflow/python/training/server_lib.py +++ b/tensorflow/python/training/server_lib.py @@ -23,6 +23,7 @@ from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import errors from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, @@ -92,6 +93,7 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, return server_def +@tf_export("train.Server") class Server(object): """An in-process TensorFlow server, for use in distributed training. @@ -221,6 +223,7 @@ class Server(object): start=start) +@tf_export("train.ClusterSpec") class ClusterSpec(object): """Represents a cluster as a set of "tasks", organized into "jobs". diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py index b396a1e7d0a06ec7b952ba2980e081e01e681d4d..360e02fb44c1062f71bb50449b9ef381510a9c69 100644 --- a/tensorflow/python/training/session_manager.py +++ b/tensorflow/python/training/session_manager.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as saver_mod +from tensorflow.python.util.tf_export import tf_export def _maybe_name(obj): @@ -44,6 +45,7 @@ def _maybe_name(obj): return "" % type(obj) +@tf_export("train.SessionManager") class SessionManager(object): """Training helper that restores from checkpoint and creates session. diff --git a/tensorflow/python/training/session_run_hook.py b/tensorflow/python/training/session_run_hook.py index 5b023d8a2672af5d1fab1c2566b19fca738fd1f7..89f40300650f3b6cd1ae15d946640c9df91771e2 100644 --- a/tensorflow/python/training/session_run_hook.py +++ b/tensorflow/python/training/session_run_hook.py @@ -96,8 +96,10 @@ from __future__ import division from __future__ import print_function import collections +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.SessionRunHook") class SessionRunHook(object): """Hook to extend calls to MonitoredSession.run().""" @@ -189,6 +191,7 @@ class SessionRunHook(object): pass +@tf_export("train.SessionRunArgs") class SessionRunArgs( collections.namedtuple("SessionRunArgs", ["fetches", "feed_dict", "options"])): @@ -213,6 +216,7 @@ class SessionRunArgs( return super(SessionRunArgs, cls).__new__(cls, fetches, feed_dict, options) +@tf_export("train.SessionRunContext") class SessionRunContext(object): """Provides information about the `session.run()` call being made. @@ -264,6 +268,7 @@ class SessionRunContext(object): self._stop_requested = True +@tf_export("train.SessionRunValues") class SessionRunValues( collections.namedtuple("SessionRunValues", ["results", "options", "run_metadata"])): diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index ea28b5ddfc2dbbf65ec60e86d29ff2a9988d2b97..75ef3d5976aba9f0cbe849d9f6984646d71a29ef 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -48,11 +48,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -def _is_resource(v): - """Returns true if v is something you get from a resource variable.""" - return isinstance(v, resource_variable_ops.ResourceVariable) - - def _create_slot_var(primary, val, scope, validate_shape, shape, dtype): """Helper function for creating a slot variable.""" @@ -60,9 +55,12 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype): # scope. current_partitioner = variable_scope.get_variable_scope().partitioner variable_scope.get_variable_scope().set_partitioner(None) + # When init from val instead of callable initializer, the shape is expected to + # be None, not or any fully defined shape. + shape = shape if callable(val) else None slot = variable_scope.get_variable( scope, initializer=val, trainable=False, - use_resource=_is_resource(primary), + use_resource=resource_variable_ops.is_resource_variable(primary), shape=shape, dtype=dtype, validate_shape=validate_shape) variable_scope.get_variable_scope().set_partitioner(current_partitioner) @@ -108,7 +106,8 @@ def create_slot(primary, val, name, colocate_with_primary=True): # and the same name has been previously used, the scope name will add '_N' # as suffix for unique identifications. validate_shape = val.get_shape().is_fully_defined() - with variable_scope.variable_scope(None, primary.op.name + "/" + name): + prefix = primary.op.name if context.in_graph_mode() else primary._shared_name # pylint: disable=protected-access + with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: with ops.colocate_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index e4514aaea223b6b254a7a72e11e6b70b576fd54b..d2ad34773e0615256c340826dcc312cc8a00dc23 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -37,8 +37,10 @@ from tensorflow.python.training import saver as saver_mod from tensorflow.python.training import session_manager as session_manager_mod from tensorflow.python.training import training_util from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export +@tf_export("train.Supervisor") class Supervisor(object): """A training helper that checkpoints models and computes summaries. diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py index 47702fdad05d13015e0cbf7768129b0c53b6c14c..0c6cf910d1a01dc20b15fb1cd5dbb249fbb60ef5 100644 --- a/tensorflow/python/training/sync_replicas_optimizer.py +++ b/tensorflow/python/training/sync_replicas_optimizer.py @@ -31,6 +31,7 @@ from tensorflow.python.training import optimizer from tensorflow.python.training import queue_runner from tensorflow.python.training import session_manager from tensorflow.python.training import session_run_hook +from tensorflow.python.util.tf_export import tf_export # Please note that the gradients from replicas are averaged instead of summed @@ -38,6 +39,7 @@ from tensorflow.python.training import session_run_hook # rate according to the number of replicas. This change is introduced to be # consistent with how gradients are aggregated (averaged) within a batch in a # replica. +@tf_export("train.SyncReplicasOptimizer") class SyncReplicasOptimizer(optimizer.Optimizer): """Class to synchronize, aggregate gradients and pass them to the optimizer. diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 03811fa38dd021fd5ff222bfbe32234606d6c681..78c8ce9208efc2f2fa8b5c671d3379e7ca8c70f5 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -135,7 +135,7 @@ from tensorflow.python.training.queue_runner import * # For the module level doc. from tensorflow.python.training import input as _input -from tensorflow.python.training.input import * +from tensorflow.python.training.input import * # pylint: disable=redefined-builtin # pylint: enable=wildcard-import from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer @@ -189,6 +189,7 @@ from tensorflow.python.training.training_util import create_global_step from tensorflow.python.training.training_util import get_or_create_global_step from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef from tensorflow.python.pywrap_tensorflow import NewCheckpointReader +from tensorflow.python.util.tf_export import tf_export # pylint: disable=wildcard-import # Training data protos. @@ -239,6 +240,23 @@ _allowed_symbols = [ "SequenceExample", # from example_pb2. "ServerDef", ] + +# pylint: disable=undefined-variable +tf_export("train.BytesList")(BytesList) +tf_export("train.ClusterDef")(ClusterDef) +tf_export("train.Example")(Example) +tf_export("train.Feature")(Feature) +tf_export("train.Features")(Features) +tf_export("train.FeatureList")(FeatureList) +tf_export("train.FeatureLists")(FeatureLists) +tf_export("train.FloatList")(FloatList) +tf_export("train.Int64List")(Int64List) +tf_export("train.JobDef")(JobDef) +tf_export("train.SaverDef")(SaverDef) +tf_export("train.SequenceExample")(SequenceExample) +tf_export("train.ServerDef")(ServerDef) +# pylint: enable=undefined-variable + # Include extra modules for docstrings because: # * Input methods in tf.train are documented in io_ops. # * Saver methods in tf.train are documented in state_ops. diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py index e98c32b614418224b1bc14081bc35f175d769965..d7133cfb500ef11e5b94c7c36905e039f9c0bf46 100644 --- a/tensorflow/python/training/training_ops.py +++ b/tensorflow/python/training/training_ops.py @@ -19,7 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.training import gen_training_ops +from tensorflow.python.training import gen_training_ops # pylint: disable=unused-import # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.training.gen_training_ops import * diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 89a9e129328fe38da2ce497a7f26dc11446ea032..499f1feb2dbf8aee26314a43b0a000fb91a1c686 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export # Picked a long key value to minimize the chance of collision with user defined @@ -40,6 +41,7 @@ GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache' write_graph = graph_io.write_graph +@tf_export('train.global_step') def global_step(sess, global_step_tensor): """Small helper to get the global step. @@ -67,6 +69,7 @@ def global_step(sess, global_step_tensor): return int(sess.run(global_step_tensor)) +@tf_export('train.get_global_step') def get_global_step(graph=None): """Get the global step tensor. @@ -101,6 +104,7 @@ def get_global_step(graph=None): return global_step_tensor +@tf_export('train.create_global_step') def create_global_step(graph=None): """Create global step tensor in graph. @@ -139,6 +143,7 @@ def create_global_step(graph=None): ops.GraphKeys.GLOBAL_STEP]) +@tf_export('train.get_or_create_global_step') def get_or_create_global_step(graph=None): """Returns and create (if necessary) the global step tensor. @@ -156,6 +161,7 @@ def get_or_create_global_step(graph=None): return global_step_tensor +@tf_export('train.assert_global_step') def assert_global_step(global_step_tensor): """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`. diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py index 270d96a3c7c831d8c06dd86199cf2dc5dfc43421..4163fcac79e3d237c4c4c4303e1db2c39e5fe7c6 100644 --- a/tensorflow/python/util/compat.py +++ b/tensorflow/python/util/compat.py @@ -41,8 +41,11 @@ import numpy as _np import six as _six from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util.tf_export import tf_export +@tf_export('compat.as_bytes', 'compat.as_str') def as_bytes(bytes_or_text, encoding='utf-8'): """Converts either bytes or unicode to `bytes`, using utf-8 encoding for text. @@ -65,6 +68,7 @@ def as_bytes(bytes_or_text, encoding='utf-8'): (bytes_or_text,)) +@tf_export('compat.as_text') def as_text(bytes_or_text, encoding='utf-8'): """Returns the given argument as a unicode string. @@ -93,6 +97,7 @@ else: as_str = as_text +@tf_export('compat.as_str_any') def as_str_any(value): """Converts to `str` as `str(value)`, but use `as_str` for `bytes`. @@ -108,6 +113,7 @@ def as_str_any(value): return str(value) +@tf_export('compat.path_to_str') def path_to_str(path): """Returns the file system path representation of a `PathLike` object, else as it is. @@ -125,11 +131,16 @@ def path_to_str(path): # Numpy 1.8 scalars don't inherit from numbers.Integral in Python 3, so we # need to check them specifically. The same goes from Real and Complex. integral_types = (_numbers.Integral, _np.integer) +tf_export('compat.integral_types').export_constant(__name__, 'integral_types') real_types = (_numbers.Real, _np.integer, _np.floating) +tf_export('compat.real_types').export_constant(__name__, 'real_types') complex_types = (_numbers.Complex, _np.number) +tf_export('compat.complex_types').export_constant(__name__, 'complex_types') # Either bytes or text. bytes_or_text_types = (bytes, _six.text_type) +tf_export('compat.bytes_or_text_types').export_constant(__name__, + 'bytes_or_text_types') _allowed_symbols = [ 'as_str', diff --git a/tensorflow/python/util/compat_internal.py b/tensorflow/python/util/compat_internal.py index a299b2fc3c302705d9493904e8ac0f81e4b8d371..1905c3e3832550906c601bd4545e72b5bd135e2c 100644 --- a/tensorflow/python/util/compat_internal.py +++ b/tensorflow/python/util/compat_internal.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Functions for Python 2 vs. 3 compatibility that are private to TensorFlow.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.util.compat import as_str_any + def path_to_str(path): - """Returns the file system path representation of a `PathLike` object, else as it is. + """Returns the file system path representation of a `PathLike` object, + else as it is. Args: path: An object that can be converted to path representation. diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py index 8a66f0435a8cb3d689a6613e2fca5bab1c0a37e3..376be39978fb11463ae8a870492a359c89a9f2ce 100644 --- a/tensorflow/python/util/deprecation.py +++ b/tensorflow/python/util/deprecation.py @@ -24,6 +24,7 @@ import re from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import decorator_utils +from tensorflow.python.util import is_in_graph_mode from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -38,13 +39,14 @@ _PRINTED_WARNING = {} def _add_deprecated_function_notice_to_docstring(doc, date, instructions): """Adds a deprecation notice to a docstring for deprecated functions.""" + main_text = ['THIS FUNCTION IS DEPRECATED. It will be removed %s.' % + ('in a future version' if date is None else ('after %s' % date))] + if instructions: + main_text.append('Instructions for updating:') return decorator_utils.add_notice_to_docstring( doc, instructions, 'DEPRECATED FUNCTION', - '(deprecated)', [ - 'THIS FUNCTION IS DEPRECATED. It will be removed %s.' % ( - 'in a future version' if date is None else ('after %s' % date)), - 'Instructions for updating:']) + '(deprecated)', main_text) def _add_deprecated_arg_notice_to_docstring(doc, date, instructions): @@ -66,23 +68,135 @@ def _validate_deprecation_args(date, instructions): raise ValueError('Don\'t deprecate things without conversion instructions!') -def _call_location(): +def _call_location(outer=False): """Returns call location given level up from current call.""" frame = tf_inspect.currentframe() if frame: # CPython internals are available, use them for performance. # walk back two frames to get to deprecated function caller. - first_frame = frame.f_back - second_frame = first_frame.f_back - frame = second_frame if second_frame else first_frame + frame = frame.f_back + if frame.f_back: + frame = frame.f_back + if outer and frame.f_back: + frame = frame.f_back return '%s:%d' % (frame.f_code.co_filename, frame.f_lineno) else: # Slow fallback path stack = tf_inspect.stack(0) # 0 avoids generating unused context - entry = stack[2] + entry = stack[3 if outer else 2] return '%s:%d' % (entry[1], entry[2]) +def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True): + """Deprecate a symbol in favor of a new name with identical semantics. + + This function is meant to be used when defining a backwards-compatibility + alias for a symbol which has been moved. For example: + + module1.py: + ```python + class NewNameForClass: pass + ``` + + module2.py: + ```python + import module1 + + DeprecatedNameForClass = deprecated_alias( + deprecated_name='module2.DeprecatedNameForClass', + name='module1.NewNameForClass', + module1.NewNameForClass) + ``` + + This function works for classes and functions. + + For classes, it creates a new class which is functionally identical (it + inherits from the original, and overrides its constructor), but which prints + a deprecation warning when an instance is created. It also adds a deprecation + notice to the class' docstring. + + For functions, it returns a function wrapped by `tf_decorator.make_decorator`. + That function prints a warning when used, and has a deprecation notice in its + docstring. This is more or less equivalent (the deprecation warning has + slightly different text) to writing: + + ```python + @deprecated + def deprecated_alias(original_args): + real_function(original_args) + ``` + + Args: + deprecated_name: The name of the symbol that is being deprecated, to be used + in the warning message. This should be its fully qualified name to avoid + confusion. + name: The name of the symbol that is to be used instead of the deprecated + name. This should be a fully qualified name to avoid confusion. + func_or_class: The (non-deprecated) class or function for which a deprecated + alias should be created. + warn_once: If True (the default), only print a deprecation warning the first + time this function is used, or the class is instantiated. + + Returns: + A wrapped version of `func_or_class` which prints a deprecation warning on + use and has a modified docstring. + """ + if tf_inspect.isclass(func_or_class): + + # Make a new class with __init__ wrapped in a warning. + class NewClass(func_or_class): # pylint: disable=missing-docstring + __doc__ = decorator_utils.add_notice_to_docstring( + func_or_class.__doc__, 'Please use %s instead.' % name, + 'DEPRECATED CLASS', + '(deprecated)', ['THIS CLASS IS DEPRECATED. ' + 'It will be removed in a future version. ']) + __name__ = func_or_class.__name__ + __module__ = _call_location(outer=True) + + def __init__(self, *args, **kwargs): + if hasattr(NewClass.__init__, '__func__'): + # Python 2 + NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__ + else: + # Python 3 + NewClass.__init__.__doc__ = func_or_class.__init__.__doc__ + + if _PRINT_DEPRECATION_WARNINGS: + # We're making the alias as we speak. The original may have other + # aliases, so we cannot use it to check for whether it's already been + # warned about. + if NewClass.__init__ not in _PRINTED_WARNING: + if warn_once: + _PRINTED_WARNING[NewClass.__init__] = True + logging.warning( + 'From %s: The name %s is deprecated. Please use %s instead.\n', + _call_location(), deprecated_name, name) + super(NewClass, self).__init__(*args, **kwargs) + + return NewClass + else: + decorator_utils.validate_callable(func_or_class, 'deprecated') + + # Make a wrapper for the original + @functools.wraps(func_or_class) + def new_func(*args, **kwargs): # pylint: disable=missing-docstring + if _PRINT_DEPRECATION_WARNINGS: + # We're making the alias as we speak. The original may have other + # aliases, so we cannot use it to check for whether it's already been + # warned about. + if new_func not in _PRINTED_WARNING: + if warn_once: + _PRINTED_WARNING[new_func] = True + logging.warning( + 'From %s: The name %s is deprecated. Please use %s instead.\n', + _call_location(), deprecated_name, name) + return func_or_class(*args, **kwargs) + return tf_decorator.make_decorator( + func_or_class, new_func, 'deprecated', + _add_deprecated_function_notice_to_docstring( + func_or_class.__doc__, None, 'Please use %s instead.' % name)) + + def deprecated(date, instructions, warn_once=True): """Decorator for marking functions or methods deprecated. @@ -284,7 +398,9 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples, @functools.wraps(func) def new_func(*args, **kwargs): """Deprecation wrapper.""" - if _PRINT_DEPRECATION_WARNINGS: + # TODO(apassos) figure out a way to have reasonable performance with + # deprecation warnings and eager mode. + if is_in_graph_mode.IS_IN_GRAPH_MODE() and _PRINT_DEPRECATION_WARNINGS: invalid_args = [] named_args = tf_inspect.getcallargs(func, *args, **kwargs) for arg_name, spec in iter(deprecated_positions.items()): diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py index e61edb5cfa3f8f7676b8a77d787781abdd80f310..bdd0bc48d29319914e184ea4331a5e9d4a1c3328 100644 --- a/tensorflow/python/util/deprecation_test.py +++ b/tensorflow/python/util/deprecation_test.py @@ -24,6 +24,56 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation +class DeprecatedAliasTest(test.TestCase): + + @test.mock.patch.object(logging, "warning", autospec=True) + def test_function_alias(self, mock_warning): + deprecated_func = deprecation.deprecated_alias("deprecated.func", + "real.func", + logging.error) + + logging.error("fake error logged") + self.assertEqual(0, mock_warning.call_count) + deprecated_func("FAKE ERROR!") + self.assertEqual(1, mock_warning.call_count) + # Make sure the error points to the right file. + self.assertRegexpMatches(mock_warning.call_args[0][1], + r"deprecation_test\.py:") + deprecated_func("ANOTHER FAKE ERROR!") + self.assertEqual(1, mock_warning.call_count) + + @test.mock.patch.object(logging, "warning", autospec=True) + def test_class_alias(self, mock_warning): + class MyClass(object): + """My docstring.""" + + init_args = [] + + def __init__(self, arg): + MyClass.init_args.append(arg) + + deprecated_cls = deprecation.deprecated_alias("deprecated.cls", + "real.cls", + MyClass) + + print(deprecated_cls.__name__) + print(deprecated_cls.__module__) + print(deprecated_cls.__doc__) + + MyClass("test") + self.assertEqual(0, mock_warning.call_count) + deprecated_cls("deprecated") + self.assertEqual(1, mock_warning.call_count) + # Make sure the error points to the right file. + self.assertRegexpMatches(mock_warning.call_args[0][1], + r"deprecation_test\.py:") + deprecated_cls("deprecated again") + self.assertEqual(1, mock_warning.call_count) + + self.assertEqual(["test", "deprecated", "deprecated again"], + MyClass.init_args) + + class DeprecationTest(test.TestCase): @test.mock.patch.object(logging, "warning", autospec=True) diff --git a/tensorflow/python/util/is_in_graph_mode.py b/tensorflow/python/util/is_in_graph_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae89ecb714c25787732f0d6c671d78144bec395 --- /dev/null +++ b/tensorflow/python/util/is_in_graph_mode.py @@ -0,0 +1,22 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A function that tells you if the program is running in graph mode.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Call IS_IN_GRAPH_MODE() when you want to know whether the thread is in +# graph mode. By default, we always are. +IS_IN_GRAPH_MODE = lambda: True diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 874df3d1087e157f8bfcec12ba3495e341c14b7b..23c2c48f4b5a165bd6e356a6243b234619af1c4c 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -497,7 +497,9 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): shallow_tree: an arbitrarily nested structure. input_tree: an arbitrarily nested structure. check_types: if `True` (default) the sequence types of `shallow_tree` and - `input_tree` have to be the same. + `input_tree` have to be the same. Note that even with check_types==True, + this function will consider two different namedtuple classes with the same + name and _fields attribute to be the same class. Raises: TypeError: If `shallow_tree` is a sequence but `input_tree` is not. @@ -513,10 +515,21 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): "Input has type: %s." % type(input_tree)) if check_types and not isinstance(input_tree, type(shallow_tree)): - raise TypeError( - "The two structures don't have the same sequence type. Input " - "structure has type %s, while shallow structure has type %s." - % (type(input_tree), type(shallow_tree))) + # Duck-typing means that nest should be fine with two different + # namedtuples with identical name and fields. + shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) + input_is_namedtuple = _is_namedtuple(input_tree, False) + if shallow_is_namedtuple and input_is_namedtuple: + if not _same_namedtuples(shallow_tree, input_tree): + raise TypeError( + "The two namedtuples don't have the same sequence type. Input " + "structure has type %s, while shallow structure has type %s." + % (type(input_tree), type(shallow_tree))) + else: + raise TypeError( + "The two structures don't have the same sequence type. Input " + "structure has type %s, while shallow structure has type %s." + % (type(input_tree), type(shallow_tree))) if len(input_tree) != len(shallow_tree): raise ValueError( @@ -532,8 +545,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): (list(_six.iterkeys(input_tree)), list(_six.iterkeys(shallow_tree)))) - input_tree = list(_six.iteritems(input_tree)) - shallow_tree = list(_six.iteritems(shallow_tree)) + input_tree = list(sorted(_six.iteritems(input_tree))) + shallow_tree = list(sorted(_six.iteritems(shallow_tree))) for shallow_branch, input_branch in zip(shallow_tree, input_tree): assert_shallow_structure(shallow_branch, input_branch, diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 6bec397db577c5be5847a701ccc92367dc008fc9..4439d6241ea9607b194cbb17304dbb77dc9f57a8 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -425,6 +425,19 @@ class NestTest(test.TestCase): with self.assertRaisesRegexp(ValueError, expected_message): nest.assert_shallow_structure(inp_ab2, inp_ab1) + inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) + inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) + nest.assert_shallow_structure(inp_ab, inp_ba) + + # This assertion is expected to pass: two namedtuples with the same + # name and field names are considered to be identical. + same_name_type_0 = collections.namedtuple("same_name", ("a", "b")) + same_name_type_1 = collections.namedtuple("same_name", ("a", "b")) + inp_shallow = same_name_type_0(1, 2) + inp_deep = same_name_type_1(1, [1, 2, 3]) + nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) + nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) + def testFlattenUpTo(self): # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py index c4168f7b1ac80976a957e96c79c72fe3b288d622..c2fe6fc4494428693605a5a7463a9f590a2da39e 100644 --- a/tensorflow/python/util/tf_inspect.py +++ b/tensorflow/python/util/tf_inspect.py @@ -134,6 +134,11 @@ def getmembers(object, predicate=None): # pylint: disable=redefined-builtin return _inspect.getmembers(object, predicate) +def getmodule(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.getmodule.""" + return _inspect.getmodule(object) + + def getmro(cls): """TFDecorator-aware replacement for inspect.getmro.""" return _inspect.getmro(cls) diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py index a9e8ffb30c3392251c2bf7076e02aafd2338696b..8903e1156b27b3a28543eb5ecfcc6eeb1a04f6ae 100644 --- a/tensorflow/python/util/tf_inspect_test.py +++ b/tensorflow/python/util/tf_inspect_test.py @@ -124,6 +124,17 @@ class TfInspectTest(test.TestCase): inspect.getmembers(TestDecoratedClass), tf_inspect.getmembers(TestDecoratedClass)) + def testGetModule(self): + self.assertEqual( + inspect.getmodule(TestDecoratedClass), + tf_inspect.getmodule(TestDecoratedClass)) + self.assertEqual( + inspect.getmodule(test_decorated_function), + tf_inspect.getmodule(test_decorated_function)) + self.assertEqual( + inspect.getmodule(test_undecorated_function), + tf_inspect.getmodule(test_undecorated_function)) + def testGetSource(self): expected = '''@test_decorator('decorator') def test_decorated_function_with_defaults(a, b=2, c='Hello'): diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 384445e6c1629e5518459b5382aa9b92698fb6ff..58b47067662b2595f53ca648dcab7a2a194039ab 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -109,6 +109,24 @@ string ToString(cudnnStatus_t status) { } } +template +cudnnDataType_t GetCudnnDataType(); + +template <> +cudnnDataType_t GetCudnnDataType() { + return CUDNN_DATA_DOUBLE; +} + +template <> +cudnnDataType_t GetCudnnDataType() { + return CUDNN_DATA_FLOAT; +} + +template <> +cudnnDataType_t GetCudnnDataType() { + return CUDNN_DATA_HALF; +} + namespace wrap { static port::ThreadPool* InitCudnnThreadpool() { @@ -559,7 +577,7 @@ class ScopedFilterDescriptor { // A helper function to decide whether to enable the TENSOR_OP_MATH math type static bool TensorOpMathEnabled() { static bool is_enabled = [] { - bool is_disabled; + bool is_disabled = false; TF_CHECK_OK( tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH", /*default_val=*/false, &is_disabled)); @@ -568,6 +586,25 @@ static bool TensorOpMathEnabled() { return is_enabled; } +// A helper function to decide whether to use CUDNN_BATCHNORM_SPATIAL_PERSISTENT +// in batchnorm. This mode can be faster in some tasks because an optimized path +// may be selected for CUDNN_DATA_FLOAT and CUDNN_DATA_HALF data types, compute +// capability 6.0 or higher. The reason we set it to false by default is that +// this mode may use scaled atomic integer reduction that may cause a numerical +// overflow for certain input data range. +// TODO(yangzihao): Use autotune to choose between this mode and +// CUDNN_BATCHNORM_SPATIAL mode. +static bool BatchnormSpatialPersistentEnabled() { + static bool is_enabled = [] { + bool is_enabled = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar( + "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", + /*default_val=*/false, &is_enabled)); + return is_enabled; + }(); + return is_enabled; +} + // Turns a ConvolutionDescriptor structure into a cudnn convolution handle // within a scope. class ScopedConvolutionDescriptor { @@ -2106,7 +2143,6 @@ inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo( dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm( Stream* stream, CUDAExecutor* parent, void* dnn_handle, - int cudnn_type, // Actually cudnnDataType_t. const dnn::AlgorithmConfig& algorithm_config, bool is_profiling, const ScopedTensorDescriptor& input_nd, const ScopedFilterDescriptor& filter, @@ -2264,8 +2300,8 @@ cudnnDataType_t GetConvComputeType() { template bool CudnnSupport::DoConvolveImpl( - Stream* stream, int cudnn_type, // Actually cudnnDataType_t. - const BatchDescriptor& batch_descriptor, const DeviceMemory& input_data, + Stream* stream, const BatchDescriptor& batch_descriptor, + const DeviceMemory& input_data, const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const ConvolutionDescriptor& convolution_descriptor, @@ -2273,12 +2309,11 @@ bool CudnnSupport::DoConvolveImpl( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - ScopedTensorDescriptor input_nd{parent_, batch_descriptor, - static_cast(cudnn_type)}; - ScopedTensorDescriptor output_nd{parent_, output_descriptor, - static_cast(cudnn_type)}; + cudnnDataType_t cudnn_type = GetCudnnDataType(); + ScopedTensorDescriptor input_nd{parent_, batch_descriptor, cudnn_type}; + ScopedTensorDescriptor output_nd{parent_, output_descriptor, cudnn_type}; ScopedFilterDescriptor filter{parent_, filter_descriptor, batch_descriptor, - static_cast(cudnn_type)}; + cudnn_type}; ScopedConvolutionDescriptor conv{parent_, convolution_descriptor, GetConvComputeType()}; @@ -2505,9 +2540,8 @@ bool CudnnSupport::DoFusedConvolveImpl( const bool is_profiling = output_profile_result != nullptr; DeviceMemory scratch; dnn::AlgorithmDesc algotype = GetCudnnConvolutionForwardAlgorithm( - stream, parent_, dnn_handle_, cudnn_data_type, algorithm_config, - is_profiling, conv_input_nd, filter, conv, output_nd, scratch_allocator, - &scratch); + stream, parent_, dnn_handle_, algorithm_config, is_profiling, + conv_input_nd, filter, conv, output_nd, scratch_allocator, &scratch); if (algotype.is_default()) { if (!is_profiling) { LOG(ERROR) << "No suitable algorithm found"; @@ -2758,6 +2792,11 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( ScopedTensorDescriptor scale_offset_descriptor{ parent_, scale_offset_desc, ToCudnnDataType(scale_data_type)}; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; +#if CUDNN_VERSION >= 7000 + if (BatchnormSpatialPersistentEnabled()) { + mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + } +#endif float one = 1.0; float zero = 0.0; @@ -2859,6 +2898,11 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl( parent_, scale_offset_desc, static_cast(cudnn_scale_type)}; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; +#if CUDNN_VERSION >= 7000 + if (BatchnormSpatialPersistentEnabled()) { + mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + } +#endif float one = 1.0; float zero = 0.0; @@ -2888,9 +2932,9 @@ bool CudnnSupport::DoConvolve( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return DoConvolveImpl( - stream, CUDNN_DATA_FLOAT, batch_descriptor, input_data, filter_descriptor, - filter_data, convolution_descriptor, output_descriptor, output_data, - scratch_allocator, algorithm_config, output_profile_result); + stream, batch_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result); } bool CudnnSupport::DoConvolve( @@ -2916,9 +2960,9 @@ bool CudnnSupport::DoConvolve( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return DoConvolveImpl( - stream, CUDNN_DATA_HALF, batch_descriptor, input_data, filter_descriptor, - filter_data, convolution_descriptor, output_descriptor, output_data, - scratch_allocator, algorithm_config, output_profile_result); + stream, batch_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result); } bool CudnnSupport::DoFusedConvolve( @@ -3027,7 +3071,6 @@ bool CudnnSupport::DoFusedConvolve( template DeviceMemory CudnnSupport::MaybeTransformLayout( Stream* stream, - int cudnn_type, // Actually cudnnDataType_t. BatchDescriptor* output_descriptor, DeviceMemory backward_output_data, std::unique_ptr>* transform_scratch) { @@ -3041,11 +3084,11 @@ DeviceMemory CudnnSupport::MaybeTransformLayout( BatchDescriptor transformed_output_descriptor; transformed_output_descriptor.CloneFrom(*output_descriptor); transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX); - ScopedTensorDescriptor orig_out_back_nd{ - parent_, *output_descriptor, static_cast(cudnn_type)}; + cudnnDataType_t cudnn_type = GetCudnnDataType(); + ScopedTensorDescriptor orig_out_back_nd{parent_, *output_descriptor, + cudnn_type}; ScopedTensorDescriptor transformed_out_back_nd{ - parent_, transformed_output_descriptor, - static_cast(cudnn_type)}; + parent_, transformed_output_descriptor, cudnn_type}; float alpha = 1.0f; float beta = 0.0f; @@ -3092,7 +3135,6 @@ bool CudnnSupport::DoTransformTensor(Stream* stream, template bool CudnnSupport::DoConvolveBackwardDataImpl( Stream* stream, - int cudnn_type, // Actually cudnnDataType_t. const FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const BatchDescriptor& output_descriptor_in, @@ -3119,15 +3161,13 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( output_descriptor.CloneFrom(output_descriptor_in); std::unique_ptr> transform_scratch; backward_output_data = MaybeTransformLayout( - stream, cudnn_type, &output_descriptor, backward_output_data, - &transform_scratch); + stream, &output_descriptor, backward_output_data, &transform_scratch); - ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, - static_cast(cudnn_type)}; - ScopedTensorDescriptor in_back_nd{parent_, input_descriptor, - static_cast(cudnn_type)}; + cudnnDataType_t cudnn_type = GetCudnnDataType(); + ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type}; + ScopedTensorDescriptor in_back_nd{parent_, input_descriptor, cudnn_type}; ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor, - static_cast(cudnn_type)}; + cudnn_type}; ScopedConvolutionDescriptor conv{parent_, convolution_descriptor, GetConvComputeType()}; @@ -3315,11 +3355,11 @@ bool CudnnSupport::DoConvolveBackwardData( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardDataImpl( - stream, CUDNN_DATA_FLOAT, filter_descriptor, filter_data, - output_descriptor_in, backward_output_data, convolution_descriptor, - input_descriptor, backward_input_data, scratch_allocator, - algorithm_config, output_profile_result); + return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, + output_descriptor_in, backward_output_data, + convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, + algorithm_config, output_profile_result); } bool CudnnSupport::DoConvolveBackwardData( @@ -3333,17 +3373,16 @@ bool CudnnSupport::DoConvolveBackwardData( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardDataImpl( - stream, CUDNN_DATA_HALF, filter_descriptor, filter_data, - output_descriptor_in, backward_output_data, convolution_descriptor, - input_descriptor, backward_input_data, scratch_allocator, - algorithm_config, output_profile_result); + return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, + output_descriptor_in, backward_output_data, + convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, + algorithm_config, output_profile_result); } template bool CudnnSupport::DoConvolveBackwardFilterImpl( - Stream* stream, int cudnn_type, // Actually cudnnDataType_t. - const dnn::BatchDescriptor& input_descriptor, + Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_descriptor_in, DeviceMemory backward_output_data, @@ -3369,16 +3408,13 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( output_descriptor.CloneFrom(output_descriptor_in); std::unique_ptr> transform_scratch; backward_output_data = MaybeTransformLayout( - stream, static_cast(cudnn_type), - &output_descriptor, backward_output_data, - &transform_scratch); - - ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, - static_cast(cudnn_type)}; - ScopedTensorDescriptor input_nd{parent_, input_descriptor, - static_cast(cudnn_type)}; + stream, &output_descriptor, backward_output_data, &transform_scratch); + + cudnnDataType_t cudnn_type = GetCudnnDataType(); + ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type}; + ScopedTensorDescriptor input_nd{parent_, input_descriptor, cudnn_type}; ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor, - static_cast(cudnn_type)}; + cudnn_type}; ScopedConvolutionDescriptor conv{parent_, convolution_descriptor, GetConvComputeType()}; @@ -3568,10 +3604,10 @@ bool CudnnSupport::DoConvolveBackwardFilter( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return DoConvolveBackwardFilterImpl( - stream, CUDNN_DATA_FLOAT, input_descriptor, input_data, - output_descriptor_in, backward_output_data, convolution_descriptor, - filter_descriptor, backward_filter_data, scratch_allocator, - algorithm_config, output_profile_result); + stream, input_descriptor, input_data, output_descriptor_in, + backward_output_data, convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, algorithm_config, + output_profile_result); } bool CudnnSupport::DoConvolveBackwardFilter( @@ -3586,16 +3622,15 @@ bool CudnnSupport::DoConvolveBackwardFilter( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { return DoConvolveBackwardFilterImpl( - stream, CUDNN_DATA_HALF, input_descriptor, input_data, - output_descriptor_in, backward_output_data, convolution_descriptor, - filter_descriptor, backward_filter_data, scratch_allocator, - algorithm_config, output_profile_result); + stream, input_descriptor, input_data, output_descriptor_in, + backward_output_data, convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, algorithm_config, + output_profile_result); } template bool CudnnSupport::DoConvolveBackwardBiasImpl( - Stream* stream, int cudnn_type, // Actually cudnnDataType_t. - const dnn::BatchDescriptor& input_descriptor, + Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, const dnn::BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { @@ -3606,10 +3641,9 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl( LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); } - ScopedTensorDescriptor input_nd{parent_, input_descriptor, - static_cast(cudnn_type)}; - ScopedTensorDescriptor bias_nd{parent_, bias_descriptor, - static_cast(cudnn_type)}; + cudnnDataType_t cudnn_type = GetCudnnDataType(); + ScopedTensorDescriptor input_nd{parent_, input_descriptor, cudnn_type}; + ScopedTensorDescriptor bias_nd{parent_, bias_descriptor, cudnn_type}; // Alpha is the scaling factor for input. float alpha = 1.0; @@ -3633,9 +3667,8 @@ bool CudnnSupport::DoConvolveBackwardBias( const DeviceMemory& input_data, const BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { - return DoConvolveBackwardBiasImpl(stream, CUDNN_DATA_DOUBLE, input_descriptor, - input_data, bias_descriptor, - backward_bias_data); + return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, + bias_descriptor, backward_bias_data); } bool CudnnSupport::DoConvolveBackwardBias( @@ -3643,9 +3676,8 @@ bool CudnnSupport::DoConvolveBackwardBias( const DeviceMemory& input_data, const BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { - return DoConvolveBackwardBiasImpl(stream, CUDNN_DATA_FLOAT, input_descriptor, - input_data, bias_descriptor, - backward_bias_data); + return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, + bias_descriptor, backward_bias_data); } bool CudnnSupport::DoConvolveBackwardBias( @@ -3653,9 +3685,8 @@ bool CudnnSupport::DoConvolveBackwardBias( const DeviceMemory& input_data, const BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { - return DoConvolveBackwardBiasImpl(stream, CUDNN_DATA_HALF, input_descriptor, - input_data, bias_descriptor, - backward_bias_data); + return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, + bias_descriptor, backward_bias_data); } bool CudnnSupport::DoMatMul(Stream* stream, diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index ee28c0bf57a51a63be7ebbce5c8f80e09737bb16..40aa974dd967df50075da6f2bb34439cd238a113 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -611,7 +611,6 @@ class CudnnSupport : public dnn::DnnSupport { template DeviceMemory MaybeTransformLayout( Stream* stream, - int cudnn_type, // Actually cudnnDataType_t. dnn::BatchDescriptor* output_descriptor, DeviceMemory backward_output_data, std::unique_ptr>* transform_scratch) @@ -644,7 +643,6 @@ class CudnnSupport : public dnn::DnnSupport { template bool DoConvolveImpl(Stream* stream, - int cudnn_type, // Actually cudnnDataType_t. const dnn::BatchDescriptor& batch_descriptor, const DeviceMemory& input_data, const dnn::FilterDescriptor& filter_descriptor, @@ -675,7 +673,6 @@ class CudnnSupport : public dnn::DnnSupport { template bool DoConvolveBackwardDataImpl( Stream* stream, - int cudnn_type, // Actually cudnnDataType_t. const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const dnn::BatchDescriptor& output_descriptor, @@ -688,8 +685,7 @@ class CudnnSupport : public dnn::DnnSupport { template bool DoConvolveBackwardFilterImpl( - Stream* stream, int cudnn_type, // Actually cudnnDataType_t. - const dnn::BatchDescriptor& input_descriptor, + Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_descriptor_in, DeviceMemory backward_output_data, @@ -702,7 +698,6 @@ class CudnnSupport : public dnn::DnnSupport { template bool DoConvolveBackwardBiasImpl(Stream* stream, - int cudnn_type, // Actually cudnnDataType_t. const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, const dnn::BatchDescriptor& bias_descriptor, diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index f4162b096299ca9405e1f3045e370d0da1acf8da..aa88fe770f3596e5da5e12705c3b706365382134 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -896,7 +896,7 @@ class DnnSupport { // offset: offset parameters. // estimated_mean: population mean estimated during training. // Used for inference only; empty for training. - // estimated_variance: population variance estimated during traning, + // estimated_variance: population variance estimated during training, // used for inference only; empty for training. // x_desc: dimensions of the input data, which is the same as the dimensions // of the output. diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc index d71938634d6e6fe092d9a1e0861215bb101e824f..95168836278add5d6592ff0c3d0f7245e6f6bc5b 100644 --- a/tensorflow/stream_executor/dso_loader.cc +++ b/tensorflow/stream_executor/dso_loader.cc @@ -33,6 +33,10 @@ limitations under the License. #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" +#if !defined(PLATFORM_GOOGLE) +#include "cuda/cuda_config.h" +#endif + namespace perftools { namespace gputools { namespace internal { @@ -97,11 +101,12 @@ string GetCudnnVersion() { return TF_CUDNN_VERSION; } /* static */ port::Status DsoLoader::GetLibcuptiDsoHandle(void** dso_handle) { #if defined(ANDROID_TEGRA) - // On Android devices the CUDA version number is not added to the library name. - return GetDsoHandle(FindDsoPath(port::Env::Default()->FormatLibraryFileName( - "cupti", ""), - GetCudaCuptiLibraryPath()), - dso_handle); + // On Android devices the CUDA version number is not added to the library + // name. + return GetDsoHandle( + FindDsoPath(port::Env::Default()->FormatLibraryFileName("cupti", ""), + GetCudaCuptiLibraryPath()), + dso_handle); #else return GetDsoHandle(FindDsoPath(port::Env::Default()->FormatLibraryFileName( "cupti", GetCudaVersion()), diff --git a/tensorflow/stream_executor/dso_loader.h b/tensorflow/stream_executor/dso_loader.h index 9495f7253a1d475f0b5321b71419febd086832af..354c7b50b8209755991827b3c36afac790cb952b 100644 --- a/tensorflow/stream_executor/dso_loader.h +++ b/tensorflow/stream_executor/dso_loader.h @@ -28,10 +28,6 @@ limitations under the License. #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/mutex.h" -#if !defined(PLATFORM_GOOGLE) -#include "cuda/cuda_config.h" -#endif - namespace perftools { namespace gputools { namespace internal { diff --git a/tensorflow/stream_executor/executor_cache.cc b/tensorflow/stream_executor/executor_cache.cc index a23d6a70ba237efb2a83f8f56975173015ba9a39..d1a8aae167455a7dc728999fbbaf1a119cf6a101 100644 --- a/tensorflow/stream_executor/executor_cache.cc +++ b/tensorflow/stream_executor/executor_cache.cc @@ -23,6 +23,14 @@ namespace gputools { port::StatusOr ExecutorCache::GetOrCreate( const StreamExecutorConfig& config, const std::function& factory) { + // In the fast path case, the cache already has an entry and we can just + // return after Get() which only takes a shared lock and not a unique lock. + // If we need to create, we take a unique lock on cache_. + auto fast_result = Get(config); + if (fast_result.ok()) { + return fast_result; + } + Entry* entry = nullptr; { mutex_lock lock{mutex_}; @@ -59,12 +67,17 @@ port::StatusOr ExecutorCache::Get( const StreamExecutorConfig& config) { Entry* entry = nullptr; { - mutex_lock lock{mutex_}; - entry = &cache_[config.ordinal]; - // Release the map lock; the address of 'entry' is stable because - // std::map guarantees reference stability. + tf_shared_lock lock{mutex_}; + auto it = cache_.find(config.ordinal); + if (it != cache_.end()) { + entry = &it->second; + } else { + return port::Status(port::error::NOT_FOUND, + port::Printf("No executors registered for ordinal %d", + config.ordinal)); + } } - mutex_lock lock{entry->configurations_mutex}; + tf_shared_lock lock{entry->configurations_mutex}; if (entry->configurations.empty()) { return port::Status( port::error::NOT_FOUND, diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc index cc32a6beaa5f83d6883b02682c14327b735a1caa..f23224ae772b9c5915426feaef1155fc9711f075 100644 --- a/tensorflow/stream_executor/multi_platform_manager.cc +++ b/tensorflow/stream_executor/multi_platform_manager.cc @@ -45,7 +45,7 @@ namespace gputools { /* static */ port::StatusOr MultiPlatformManager::PlatformWithName( const string& target) { - mutex_lock lock(GetPlatformsMutex()); + tf_shared_lock lock(GetPlatformsMutex()); auto it = GetPlatformMap()->find(port::Lowercase(target)); if (it == GetPlatformMap()->end()) { @@ -59,7 +59,7 @@ namespace gputools { /* static */ port::StatusOr MultiPlatformManager::PlatformWithId( const Platform::Id& id) { - mutex_lock lock(GetPlatformsMutex()); + tf_shared_lock lock(GetPlatformsMutex()); auto it = GetPlatformByIdMap()->find(id); if (it == GetPlatformByIdMap()->end()) { return port::Status( diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index f32d4561550c0ff60511047c87821dffe736c935..818d67f7b5be1e8f2db66b24976a529b361a4990 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -10,6 +10,10 @@ load( "tf_additional_xla_deps_py", "if_static", ) +load( + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", +) load( "@local_config_cuda//cuda:build_defs.bzl", "if_cuda", @@ -197,6 +201,7 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False): "-fno-exceptions", "-ftemplate-depth=900"]) + if_cuda(["-DGOOGLE_CUDA=1"]) + + if_tensorrt(["-DGOOGLE_TENSORRT=1"]) + if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML", "-fopenmp",]) + if_android_arm(["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) @@ -214,6 +219,13 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False): "//conditions:default": ["-pthread"] })) + +def tfe_xla_copts(): + return select({ + "//tensorflow:with_xla_support": ["-DTENSORFLOW_EAGER_USE_XLA"], + "//conditions:default": [], + }) + def tf_opts_nortti_if_android(): return if_android([ "-fno-rtti", @@ -486,6 +498,9 @@ def tf_gen_op_wrappers_cc(name, # is invalid to specify both "hidden" and "op_whitelist". # cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the # specified ops. +# gen_locally: if True, the genrule to generate the Python library will be run +# without sandboxing. This would help when the genrule depends on symlinks +# which may not be supported in the sandbox. def tf_gen_op_wrapper_py(name, out=None, hidden=None, @@ -496,7 +511,8 @@ def tf_gen_op_wrapper_py(name, generated_target_name=None, op_whitelist=[], cc_linkopts=[], - api_def_srcs=[]): + api_def_srcs=[], + gen_locally=False): if (hidden or hidden_file) and op_whitelist: fail('Cannot pass specify both hidden and op_whitelist.') @@ -551,6 +567,7 @@ def tf_gen_op_wrapper_py(name, outs=[out], srcs=api_def_srcs + [hidden_file], tools=[tool_name] + tf_binary_additional_srcs(), + local = (1 if gen_locally else 0), cmd=("$(location " + tool_name + ") " + api_def_args_str + " @$(location " + hidden_file + ") " + ("1" if require_shape_functions else "0") + " > $@")) @@ -560,6 +577,7 @@ def tf_gen_op_wrapper_py(name, outs=[out], srcs=api_def_srcs, tools=[tool_name] + tf_binary_additional_srcs(), + local = (1 if gen_locally else 0), cmd=("$(location " + tool_name + ") " + api_def_args_str + " " + op_list_arg + " " + ("1" if require_shape_functions else "0") + " " + @@ -600,7 +618,7 @@ def tf_cc_test(name, srcs=srcs + tf_binary_additional_srcs(), copts=tf_copts() + extra_copts, linkopts=select({ - "//tensorflow:android": [ + clean_dep("//tensorflow:android"): [ "-pie", ], clean_dep("//tensorflow:windows"): [], @@ -666,6 +684,7 @@ def tf_cuda_cc_test(name, tags=[], data=[], size="medium", + extra_copts=[], linkstatic=0, args=[], linkopts=[]): @@ -676,6 +695,7 @@ def tf_cuda_cc_test(name, tags=tags + ["manual"], data=data, size=size, + extra_copts=extra_copts, linkstatic=linkstatic, linkopts=linkopts, args=args) @@ -696,6 +716,7 @@ def tf_cuda_cc_test(name, tags=tags + tf_cuda_tests_tags(), data=data, size=size, + extra_copts=extra_copts, linkopts=linkopts, args=args) @@ -866,9 +887,11 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs): When the library is built with --config=cuda: - - both deps and cuda_deps are used as dependencies - - the cuda runtime is added as a dependency (if necessary) - - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts + - Both deps and cuda_deps are used as dependencies. + - The cuda runtime is added as a dependency (if necessary). + - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts. + - In addition, when the library is also built with TensorRT enabled, it + additionally passes -DGOOGLE_TENSORRT=1 to the list of copts. Args: - cuda_deps: BUILD dependencies which will be linked if and only if: @@ -887,7 +910,8 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs): clean_dep("//tensorflow/core:cuda"), "@local_config_cuda//cuda:cuda_headers" ]), - copts=copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]), + copts=(copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), **kwargs) register_extension_info( @@ -1288,6 +1312,46 @@ def tf_extension_linkopts(): def tf_extension_copts(): return [] # No extension c opts +# In tf_py_wrap_cc generated libraries +# module init functions are not exported unless +# they contain one of the keywords in the version file +# this prevents custom python modules. +# This function attempts to append init_module_name to list of +# exported functions in version script +def _append_init_to_versionscript_impl(ctx): + mod_name = ctx.attr.module_name + if ctx.attr.is_version_script: + ctx.actions.expand_template( + template=ctx.file.template_file, + output=ctx.outputs.versionscript, + substitutions={ + "global:":"global:\n init_%s;\n PyInit_*;"%(mod_name), + }, + is_executable=False, + ) + else: + ctx.actions.expand_template( + template=ctx.file.template_file, + output=ctx.outputs.versionscript, + substitutions={ + "*tensorflow*":"*tensorflow*\ninit_%s\nPyInit_*\n"%(mod_name), + }, + is_executable=False, + ) + + +_append_init_to_versionscript= rule( + implementation=_append_init_to_versionscript_impl, + attrs={ + "module_name":attr.string(mandatory=True), + "template_file":attr.label(allow_files=True,single_file=True,mandatory=True), + "is_version_script":attr.bool(default=True, + doc='whether target is a ld version script or exported symbol list', + mandatory=False), + }, + outputs={"versionscript":"%{name}.lds"}, +) + def tf_py_wrap_cc(name, srcs, swig_includes=[], @@ -1309,26 +1373,39 @@ def tf_py_wrap_cc(name, toolchain_deps=["//tools/defaults:crosstool"], module_name=module_name, py_module_name=name) + vscriptname=name+"_versionscript" + _append_init_to_versionscript( + name=vscriptname, + module_name=module_name, + is_version_script=select({ + "@local_config_cuda//cuda:darwin":False, + "//conditions:default":True, + }), + template_file=select({ + "@local_config_cuda//cuda:darwin":clean_dep("//tensorflow:tf_exported_symbols.lds"), + "//conditions:default":clean_dep("//tensorflow:tf_version_script.lds") + }) + ) extra_linkopts = select({ "@local_config_cuda//cuda:darwin": [ "-Wl,-exported_symbols_list", - clean_dep("//tensorflow:tf_exported_symbols.lds") + "%s.lds"%vscriptname, ], clean_dep("//tensorflow:windows"): [], clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-Wl,--version-script", - clean_dep("//tensorflow:tf_version_script.lds") + "%s.lds"%vscriptname, ] }) extra_deps += select({ "@local_config_cuda//cuda:darwin": [ - clean_dep("//tensorflow:tf_exported_symbols.lds") + "%s.lds"%vscriptname, ], clean_dep("//tensorflow:windows"): [], clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ - clean_dep("//tensorflow:tf_version_script.lds") + "%s.lds"%vscriptname, ] }) diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index d11031639592aa1d3e6ce1c7f09c2f0679b29854..e731127a63d792825e15a4b95379517117edebb0 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -41,24 +41,41 @@ genrule( # every module exported using tf_export. For e.g. if an op is decorated with # @tf_export('module1.module2', 'module3'). Then, outs should include # api/module1/module2/__init__.py and api/module3/__init__.py. + # keep sorted outs = [ "api/__init__.py", + "api/app/__init__.py", "api/bitwise/__init__.py", + "api/compat/__init__.py", "api/contrib/__init__.py", "api/contrib/stat_summarizer/__init__.py", + "api/data/__init__.py", "api/distributions/__init__.py", "api/distributions/bijectors/__init__.py", "api/errors/__init__.py", - "api/image/__init__.py", - "api/linalg/__init__.py", - "api/nn/__init__.py", - "api/spectral/__init__.py", - "api/train/__init__.py", - "api/app/__init__.py", + "api/estimator/__init__.py", + "api/estimator/export/__init__.py", + "api/estimator/inputs/__init__.py", + "api/feature_column/__init__.py", "api/gfile/__init__.py", "api/graph_util/__init__.py", + "api/image/__init__.py", + "api/initializers/__init__.py", "api/keras/__init__.py", + "api/keras/activations/__init__.py", + "api/keras/applications/__init__.py", + "api/keras/applications/densenet/__init__.py", + "api/keras/applications/inception_resnet_v2/__init__.py", + "api/keras/applications/inception_v3/__init__.py", + "api/keras/applications/mobilenet/__init__.py", + "api/keras/applications/nasnet/__init__.py", + "api/keras/applications/resnet50/__init__.py", + "api/keras/applications/vgg16/__init__.py", + "api/keras/applications/vgg19/__init__.py", + "api/keras/applications/xception/__init__.py", "api/keras/backend/__init__.py", + "api/keras/callbacks/__init__.py", + "api/keras/constraints/__init__.py", "api/keras/datasets/__init__.py", "api/keras/datasets/boston_housing/__init__.py", "api/keras/datasets/cifar10/__init__.py", @@ -66,17 +83,47 @@ genrule( "api/keras/datasets/imdb/__init__.py", "api/keras/datasets/mnist/__init__.py", "api/keras/datasets/reuters/__init__.py", + "api/keras/estimator/__init__.py", + "api/keras/initializers/__init__.py", + "api/keras/layers/__init__.py", + "api/keras/losses/__init__.py", + "api/keras/metrics/__init__.py", + "api/keras/models/__init__.py", + "api/keras/optimizers/__init__.py", + "api/keras/preprocessing/__init__.py", + "api/keras/preprocessing/image/__init__.py", + "api/keras/preprocessing/sequence/__init__.py", + "api/keras/preprocessing/text/__init__.py", + "api/keras/regularizers/__init__.py", "api/keras/utils/__init__.py", + "api/keras/wrappers/__init__.py", + "api/keras/wrappers/scikit_learn/__init__.py", + "api/layers/__init__.py", + "api/linalg/__init__.py", "api/logging/__init__.py", - "api/resource_loader/__init__.py", - "api/sysconfig/__init__.py", - "api/test/__init__.py", - "api/initializers/__init__.py", - "api/keras/initializers/__init__.py", + "api/losses/__init__.py", "api/metrics/__init__.py", + "api/nn/__init__.py", "api/nn/rnn_cell/__init__.py", + "api/profiler/__init__.py", + "api/python_io/__init__.py", + "api/resource_loader/__init__.py", + "api/saved_model/__init__.py", + "api/saved_model/builder/__init__.py", + "api/saved_model/constants/__init__.py", + "api/saved_model/loader/__init__.py", + "api/saved_model/main_op/__init__.py", + "api/saved_model/signature_constants/__init__.py", + "api/saved_model/signature_def_utils/__init__.py", + "api/saved_model/tag_constants/__init__.py", + "api/saved_model/utils/__init__.py", "api/sets/__init__.py", + "api/spectral/__init__.py", "api/summary/__init__.py", + "api/sysconfig/__init__.py", + "api/test/__init__.py", + "api/train/__init__.py", + "api/train/queue_runner/__init__.py", ], cmd = "$(location create_python_api) $(OUTS)", tools = ["create_python_api"], diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt index bc7cf7267f7d23121402e63903f01ddc6caa2e04..5a02bb2175e2d6ad71722799143090f2735c1a37 100644 --- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.Variable" tf_class { is_instance: "" + is_instance: "" is_instance: "" member { name: "SaveSliceInfo" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt index ab697b1b95b15e3ac7974e7092f1d5934b088bb6..be9ba4ce85bd5b9905a39e3f45873c534594e15f 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'weighted_sum\'], " } member_method { name: "evaluate" @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt index b73f6433e226f6b570b68c6a419c53d5c808d9d6..91fca67b6b5b1187b61f398a152793362c0c6e30 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'weighted_sum\'], " } member_method { name: "evaluate" @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt index efc441ae2f2a00f663c11f84c1155bece0c8e08a..cd4f72fcf839fa89f25c7ed115ee6c61294283c3 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt index 20ce87987060d9013bd071d6fc9f1f4f33467121..303fd74a64d0c7f5a0292a4eaabec63455c29381 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt index 73211aaf8ba5f925982afe3d17c4b8f009250cb8..c97ea7969eff3e6952a604e72ce140b49d304461 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt index 27a159639d2098aace2e69718d9ac4e38a29fdc3..4b5b5bf0e3599a81e2e853ae8ba34ef12cc63097 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt index dbcc187f94509e3c9265d59cb76d0cdd01bd2333..aa6ac46613fbead7457b19e1aae5f2532afddef1 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "mode" mtype: "" } + member { + name: "prediction_hooks" + mtype: "" + } member { name: "predictions" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt index 76f527f796e95f342eb144ae3de87ff234338021..42a0d595216ad28363727b9d7c066fc37fddd02c 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt @@ -44,7 +44,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt index c45318b98a034255d32c326179813de14cf1d4c8..2de52d6c57cc70b562c3c10b7f23cd15b63e25f8 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt index 04a2aa080d0704a8b7ec98f8eafda4bd1944e567..e552f33720bb939b8a98d34ef3de78bda7db976c 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index baedf596e8fbce921ed7e0570542b8a11655dba4..bda1c2bf85977e69b0969bc8b6056710d88ca910 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -100,6 +100,10 @@ tf_module { name: "hsv_to_rgb" argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "is_jpeg" + argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "non_max_suppression" argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt index 7fe3e2db09c45f26283d0da01d313405a97d0e54..a13bfe0a920ce13dd9a91f106c9cbcbd185b0cc7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt @@ -38,6 +38,10 @@ tf_class { name: "input_spec" mtype: "" } + member { + name: "layers" + mtype: "" + } member { name: "losses" mtype: "" @@ -108,11 +112,11 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'inputs\', \'outputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_loss" - argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_update" @@ -120,7 +124,7 @@ tf_class { } member_method { name: "add_variable" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " } member_method { name: "add_weight" @@ -140,7 +144,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -160,15 +164,15 @@ tf_class { } member_method { name: "evaluate_generator" - argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " } member_method { name: "fit_generator" - argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " } member_method { name: "from_config" @@ -228,7 +232,7 @@ tf_class { } member_method { name: "predict_generator" - argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " } member_method { name: "predict_on_batch" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt index 0a6096813155d59eb1c7920f2bcd250ed9730982..fb6c8d70dd43eae60ea2fb86f0fc63c36d2b13ad 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt @@ -39,6 +39,10 @@ tf_class { name: "input_spec" mtype: "" } + member { + name: "layers" + mtype: "" + } member { name: "losses" mtype: "" @@ -125,7 +129,7 @@ tf_class { } member_method { name: "add_loss" - argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_update" @@ -133,7 +137,7 @@ tf_class { } member_method { name: "add_variable" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..42cb91445059873d9a4ed32d609129de203a764f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.keras.applications.densenet" +tf_module { + member_method { + name: "DenseNet121" + argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], " + } + member_method { + name: "DenseNet169" + argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], " + } + member_method { + name: "DenseNet201" + argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], " + } + member_method { + name: "decode_predictions" + argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], " + } + member_method { + name: "preprocess_input" + argspec: "args=[\'x\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cd75b87540533680d096853ae8645da132dd119a --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.keras.applications.nasnet" +tf_module { + member_method { + name: "NASNetLarge" + argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], " + } + member_method { + name: "NASNetMobile" + argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], " + } + member_method { + name: "decode_predictions" + argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], " + } + member_method { + name: "preprocess_input" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt index daeb5aad419156a19f929fdd455f6c208cd7390f..9fc086eb8e17ef368b38e8d51f0ac8bf0562ca4f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.keras.applications" tf_module { + member { + name: "densenet" + mtype: "" + } member { name: "inception_resnet_v2" mtype: "" @@ -12,6 +16,10 @@ tf_module { name: "mobilenet" mtype: "" } + member { + name: "nasnet" + mtype: "" + } member { name: "resnet50" mtype: "" @@ -28,6 +36,18 @@ tf_module { name: "xception" mtype: "" } + member_method { + name: "DenseNet121" + argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], " + } + member_method { + name: "DenseNet169" + argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], " + } + member_method { + name: "DenseNet201" + argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], " + } member_method { name: "InceptionResNetV2" argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], " @@ -40,6 +60,14 @@ tf_module { name: "MobileNet" argspec: "args=[\'input_shape\', \'alpha\', \'depth_multiplier\', \'dropout\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'1\', \'0.001\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], " } + member_method { + name: "NASNetLarge" + argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], " + } + member_method { + name: "NASNetMobile" + argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], " + } member_method { name: "ResNet50" argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt index 44fbe0f7a04e8573a5348d626854e3b5834381dd..ba2d083a755384d4ec2076ac0dea580a1a878f1d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt @@ -398,7 +398,7 @@ tf_module { } member_method { name: "rnn" - argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\'], " + argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "round" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt index ea4d5143540611f0585b67910cb319454b8560dc..454823fd23e72c6aa6bf6aa608707fa3b893b986 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'stateful_metrics\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "on_batch_begin" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt index 8719c07ca385d2794e5c7e77f75d6d2bc734b7cb..d4c85a4519eb922629f107ef7b61c3f11cb27163 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'schedule\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'schedule\', \'verbose\'], varargs=None, keywords=None, defaults=[\'0\'], " } member_method { name: "on_batch_begin" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt index 0e6901f28affdfc73092c2b9f3af07d17db61a9f..543de0ad48b86502fc83374e5e6d82822485f331 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'count_mode\'], varargs=None, keywords=None, defaults=[\'samples\'], " + argspec: "args=[\'self\', \'count_mode\', \'stateful_metrics\'], varargs=None, keywords=None, defaults=[\'samples\', \'None\'], " } member_method { name: "on_batch_begin" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt index ef08f9b20f4c95f3692a03be7f4220f20aae9a58..bda31751d429ca0d0544402e5c496a0597e1849e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing" tf_module { member_method { name: "load_data" - argspec: "args=[\'path\', \'seed\', \'test_split\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'113\', \'0.2\'], " + argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], " } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt index 8b1c17e9da13a76dcc2c09f3c01a0375bf0cb9fe..ff962876b66cae013de5d711dc7eac5d5c80d8c3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt @@ -6,6 +6,6 @@ tf_module { } member_method { name: "load_data" - argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=None, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], " + argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], " } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt index 6b3ed1e9af0ea7ab4fa83c07c520adf6727a93ac..2da4a13067f2b39eb06304864ea626002300a862 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt @@ -6,6 +6,6 @@ tf_module { } member_method { name: "load_data" - argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=None, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], " + argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], " } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt index a32151e22fab59e999c1e916e5c628d2e1b3f5ee..770a107b664d7ab0a8aedf292a34d4258a201859 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -127,7 +127,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt index 46b1713196fdd2470aefa6227dd19cdbf93185b9..0ce42b706ec20a8ea1cc83ec95cb64d9be2e5710 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt @@ -126,7 +126,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt index 9bfaf2756284c7d287895e8d0b22d96ff1fa1627..b371ad148cee16dd243869d929e0c1c002794682 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -127,7 +127,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt index 2b8ac4f1f4857eb437bc3d67cd68989d3c6842f7..699208a0b9b665b69f02edaa2b2d2aeed6a83b63 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt @@ -73,6 +73,10 @@ tf_class { name: "scope_name" mtype: "" } + member { + name: "trainable" + mtype: "" + } member { name: "trainable_variables" mtype: "" @@ -123,7 +127,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -131,7 +135,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" @@ -159,7 +163,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_output_at" @@ -175,7 +179,7 @@ tf_class { } member_method { name: "get_updates_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_weights" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt index c9a0b887258de2d6b5aa88280b1f7b0d3bf7f6e2..ff08def0a08e5201bc01d61be3f2d66d712c384b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -127,7 +127,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index b847e224d6baeb11135c51ee270f2daa2d52f8a4..6db22ca0320519fd9c101456c9c9c0e26a9a11e0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -116,7 +116,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -128,7 +128,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt index 86578d958e151b47b892b3ada0dbc745d32dbe59..07d3f023e54105c606b198c05750ffa78ee5d0c8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -127,7 +127,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt index 348012dcde3407dad74ea3f56842e3182098b632..92b9760d53e35d3e5066a730bb5cbda45492cc64 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt @@ -126,7 +126,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt index 0419251083f63cbd57244e76f35aee74db434eab..83c528b40117222ac2b3e85ad338459948d0aa8c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt @@ -114,7 +114,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -126,7 +126,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt index 337e85e812d8ef19e873dd49d39108ff3d452bbb..b329f1c46bb07ab7684dec6aaf45a20b98c27ed9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -114,7 +114,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt index 1357dc0f0d6455b18bef0dabe08639e0dee1ab49..c741d4d6e6cf8da9712e68f86abe64e2828823da 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt @@ -183,7 +183,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -195,7 +195,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" @@ -227,7 +227,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_output_at" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt index b71a08f6c3b5e62970ba90c1d27dde5a4067e3e6..57596badf1881950270fa6d3c074afb65daaa8eb 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt @@ -126,7 +126,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt index a01a6067efb3ce217b603da5ea9c2c17c51c8ef7..3829353cc3c195a750ad862707c5c8563e203fba 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt @@ -126,7 +126,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 0dbbdf283836e4121c925200749784abdeb0a5a8..3b171b137af699c9608494a17c5651b439fe4545 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -114,7 +114,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt index 964ef89c2e2abdf8b6f7dc3893751f56dd380e90..29d9cf78ab5ed3bdd1a488359b59cf7171e7e051 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -187,7 +187,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -199,7 +199,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" @@ -231,7 +231,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_output_at" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt index 6a7b23c5409914396f2ce10fcb593a1ca8d65c9e..8134fb738683b79764662d9ea7f721fe04751162 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt @@ -126,7 +126,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt index 324745e5a33de47ed91f1b5c037445ee01780ba3..c5d452300947d7f74e7458e2a04bfdfabb1c1da2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt @@ -114,7 +114,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -126,7 +126,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt index e12ae0505440c31068f0ac132adfd675b93e0593..bcbed9241b525a953c8b499197facaefebe8cc44 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt @@ -114,7 +114,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -126,7 +126,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt index 9e889ca8637f759d495092c9bc6862005e5e8f23..ff0db15f190675d533c50c277eb1cb60e0b95e55 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -127,7 +127,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt index 932680941d269660533e93077818c4884c6e28c4..1d3f33f04516345ee32f16befe0d7200d2cdad00 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -127,7 +127,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt index db644f958f5d781c2dcc5bbbca52e3b656230510..c86bc49b22a8cc3e004a77f4a21594aacb2c665a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt @@ -114,7 +114,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -126,7 +126,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt index 74fa1db02076f5a5cdc1feb412ea2ce5095e326d..ad539a7c4c5362500baef0a9c89d054762bbb47d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt @@ -94,7 +94,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'activity_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\'], " } member_method { name: "add_loss" @@ -118,7 +118,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -130,7 +130,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" @@ -162,7 +162,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_output_at" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..dd67b76523cc50409516e29f963f59d039455bfd --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.keras.layers.SeparableConv1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bf62c095e7cc3fbeac95919a0f9fdc545efd3d25 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.keras.layers.SeparableConvolution1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 3414810db44da6ff0e3f77b1a5db24329de7a88a..6e3cde3e3eaba4f9985411d66a220f7cdd4ee7ad 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -114,7 +114,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt index cf34034ef0abf36c0e7ff18ee8adcc8aeaeae5eb..6fafc77b947d0df11755e3136ed2e7a14c148081 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -175,7 +175,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" @@ -187,7 +187,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" @@ -219,7 +219,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_output_at" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ee4b2fa39ed34a544ee800e9370e4f34c4a17041 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt @@ -0,0 +1,183 @@ +path: "tensorflow.keras.layers.Softmax" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index b76499658da58c178728246b3199391ca064fa3e..3dde1e576918409b106649789443f910775e2f6c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -118,11 +118,11 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=kwargs, defaults=None" + argspec: "args=[\'self\', \'inputs\', \'states\', \'constants\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "compute_mask" @@ -158,7 +158,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_output_at" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt index 2376d815a6400034a51e3d17f98a030209356cf3..ef31c5443efa0c0e5a7a2e0a422d2a9c9c49baaf 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt @@ -126,7 +126,7 @@ tf_class { } member_method { name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "count_params" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt index 2a7059d9aa7ac12d8130c30622bc5f190562695c..1e176d8d4b4eb010049f267be3d0683228a7782b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt @@ -69,6 +69,10 @@ tf_class { name: "scope_name" mtype: "" } + member { + name: "trainable" + mtype: "" + } member { name: "trainable_variables" mtype: "" @@ -155,7 +159,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_output_at" @@ -171,7 +175,7 @@ tf_class { } member_method { name: "get_updates_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_weights" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt index 58bffa087521517fe7f0b5dcd6cae0a8b39a4e25..ea3bb2f8f567c648cd8b3dfa6f179a108690b0f0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt @@ -68,6 +68,10 @@ tf_class { name: "scope_name" mtype: "" } + member { + name: "trainable" + mtype: "" + } member { name: "trainable_variables" mtype: "" @@ -154,7 +158,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_output_at" @@ -170,7 +174,7 @@ tf_class { } member_method { name: "get_updates_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" } member_method { name: "get_weights" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt index fe336c4be5a84a3764b550ca5ad2fcd1d3b85b94..088c8e88e26f59f2753733252882f5e0e8287fb6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt @@ -292,10 +292,18 @@ tf_module { name: "Reshape" mtype: "" } + member { + name: "SeparableConv1D" + mtype: "" + } member { name: "SeparableConv2D" mtype: "" } + member { + name: "SeparableConvolution1D" + mtype: "" + } member { name: "SeparableConvolution2D" mtype: "" @@ -308,6 +316,10 @@ tf_module { name: "SimpleRNNCell" mtype: "" } + member { + name: "Softmax" + mtype: "" + } member { name: "SpatialDropout1D" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt index de285c1aab197ea5cae9c94048a5131f8463ebde..42729e4237685638d38301cece6e93383ddfffba 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt @@ -22,7 +22,7 @@ tf_module { } member_method { name: "deserialize" - argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "get" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt index d239098b0b2bf37fea924ed52074385acf48de96..f85b328e34e6645b0fe0ade18df86411ec0f4e1f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt @@ -38,6 +38,10 @@ tf_class { name: "input_spec" mtype: "" } + member { + name: "layers" + mtype: "" + } member { name: "losses" mtype: "" @@ -108,11 +112,11 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'inputs\', \'outputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_loss" - argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_update" @@ -120,7 +124,7 @@ tf_class { } member_method { name: "add_variable" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " } member_method { name: "add_weight" @@ -140,7 +144,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -160,15 +164,15 @@ tf_class { } member_method { name: "evaluate_generator" - argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\'], " + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], " } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " } member_method { name: "fit_generator" - argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " } member_method { name: "from_config" @@ -228,7 +232,7 @@ tf_class { } member_method { name: "predict_generator" - argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " } member_method { name: "predict_on_batch" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt index 7c1bfcb22558ec3a64c63ebbf0466f9114ef68ee..2e044d78bb2cd6c0ac817218480565c785d11ddc 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt @@ -39,6 +39,10 @@ tf_class { name: "input_spec" mtype: "" } + member { + name: "layers" + mtype: "" + } member { name: "losses" mtype: "" @@ -125,7 +129,7 @@ tf_class { } member_method { name: "add_loss" - argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_update" @@ -133,7 +137,7 @@ tf_class { } member_method { name: "add_variable" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " } member_method { name: "add_weight" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt index ed040c15864b4f4c386d2d9e1f664d35d651fa14..32667cf31e4aaacf3374ca4a434f32eec5b3e07e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'0.95\', \'1e-08\', \'0.0\'], " + argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'0.95\', \'None\', \'0.0\'], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt index a24651429a3db49a96b217259c5c6ef09efed2f2..efca59e8e427d28de36446a49ea4e1ca0bb385eb 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'lr\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'1e-08\', \'0.0\'], " + argspec: "args=[\'self\', \'lr\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'None\', \'0.0\'], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt index a0d978fded3825bafcd8d60e34677029495b1245..5546e2067ab65abce928d609b41b65bbc40246f6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-08\', \'0.0\'], " + argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\', \'amsgrad\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'None\', \'0.0\', \'False\'], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt index 1b70c93ad5f0a8fd52d65fb4b8132a87878c26dd..aaa54a106066266d0a7c19f4609e4cc7ed766d95 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'1e-08\', \'0.0\'], " + argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'None\', \'0.0\'], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt index b49dbe5cf82ea838076134a0feecc120bfb88f84..1fada7fd9c6eefbb16f1b5a042e6fea607a461a9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'schedule_decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'1e-08\', \'0.004\'], " + argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'schedule_decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'None\', \'0.004\'], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt index c8860d80d40353211df65f08fda5deb26af91d66..fd3f97f35dcb18c82188c51345c2e3276a88f23f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'1e-08\', \'0.0\'], " + argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'None\', \'0.0\'], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt index 5bc8c4012049b0414936fb56a853fc32430df3d9..ce91caa1afe081ccf05ecdd4884a3e29ea93d496 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'num_words\', \'filters\', \'lower\', \'split\', \'char_level\'], varargs=None, keywords=None, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \', \'False\'], " + argspec: "args=[\'self\', \'num_words\', \'filters\', \'lower\', \'split\', \'char_level\', \'oov_token\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \', \'False\', \'None\'], " } member_method { name: "fit_on_sequences" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt index 3adc6b6faa6f62330f9ac3d621f29adfc380a09d..16e1cbe650e1662f8694fd7137ad20a48a90675b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'target\', \'width\', \'verbose\', \'interval\'], varargs=None, keywords=None, defaults=[\'30\', \'1\', \'0.05\'], " + argspec: "args=[\'self\', \'target\', \'width\', \'verbose\', \'interval\', \'stateful_metrics\'], varargs=None, keywords=None, defaults=[\'30\', \'1\', \'0.05\', \'None\'], " } member_method { name: "add" @@ -12,6 +12,6 @@ tf_class { } member_method { name: "update" - argspec: "args=[\'self\', \'current\', \'values\', \'force\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + argspec: "args=[\'self\', \'current\', \'values\'], varargs=None, keywords=None, defaults=[\'None\'], " } } diff --git a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0b84165285102daf0a8e3dd6542bfc391e50f77b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.manip" +tf_module { + member_method { + name: "roll" + argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index a2e728f94b41341b1a7c2a06d2c92d490f6eeb87..44536787f09fc98bba8a4eb0bc562427cfe48b8b 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.rnn_cell.BasicLSTMCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index 4211faa1ec615da8938d9a858a19a9e9a76378cd..768565d3cacbd1313ee5a64c9b15f9ab70683772 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.rnn_cell.BasicRNNCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index 06fdc638c82b0d19b03857e33f083a94b7fd133b..6ecc134d4df866ab5d59e238a8157064421579bd 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.rnn_cell.GRUCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index ef48cff0c329a7af5009d31fda429cf649c24261..4b3ca1578ba52f30e3405ff198fb716496a462c6 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.rnn_cell.LSTMCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index db1ed4218514ad51f28703c27598eada9464511e..066c4513ff5185b50bdf193f579e71e505dbd3b6 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -396,6 +396,10 @@ tf_module { name: "losses" mtype: "" } + member { + name: "manip" + mtype: "" + } member { name: "metrics" mtype: "" @@ -2044,10 +2048,26 @@ tf_module { name: "unique_with_counts" argspec: "args=[\'x\', \'out_idx\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } + member_method { + name: "unravel_index" + argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "unsorted_segment_max" argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "unsorted_segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_sqrt_n" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "unsorted_segment_sum" argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt index 863beaea4cf05a67e572c97b556bc1eb598d9ced..c02e54adfbd9f33e661453767b517a5f0de90d57 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.AdadeltaOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt index 0a7aa9b6bc14c95e74ab05a3aeb71b770a918f60..2b619908fc6aea3f4b8e6a57d0dcf85a9854d466 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.AdagradDAOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt index 83724fea55d005e9476801feb1bf58cb004aa141..2005cf4677c06cf1f8b4207a444690fdd0c2306e 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.AdagradOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt index e285b27a0531e00d27941fe451570a5056995c17..0a2bae1d9021b20707e03ae5786e71f388266c14 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.AdamOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt index fc28577d6ed1328ae85970cf22cc458b7cf54344..847f9ad75998f1bdda8858650091c70fd0b4015b 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.FtrlOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt index bf3c1d81f877e3a8a7e24d5455e9c5bf6a41f764..13a58e0608ed269415ba78d84a03f1bae128e80c 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.GradientDescentOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt index a640c8d2c6366951cbba6a15d2000d9369cbbdbf..bfbc2357a346c7bfef0242a735ab14c5f4005b22 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.MomentumOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt index 6b33c236a35f09422a42a17b3ffddf5ba7b1595f..437efa0a2bd04c308db6186e714a5d8785541fa5 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt @@ -1,6 +1,8 @@ path: "tensorflow.train.Optimizer" tf_class { is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt index d23fcaed7b4cee397dcf9c51eb3b521e5461c9e5..72f224605f67e72dd78699b5f1a703cc3edd566b 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.ProximalAdagradOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt index b6c03e71d9ffb50bd6377b489fcc444453bd9752..316275b1fb1abd384e193994e35115a1c463f07d 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.ProximalGradientDescentOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt index 4a82db11cb8d85bd0c44135ecaf507c62fae41a1..af50a1986100d830f0809a3f4a0f01faa8821b3b 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.RMSPropOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt index e9131bf544f2e7f08928f46d2be06a00259690be..6edc516c9392fa14f23ffc2a6481ec21216f06cf 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.SyncReplicasOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 8fb6b1cdfd8981e427062e186f6ac26b24231b8b..608a34ab7b32bdc26cebbe43b383155406fb51b2 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -17,10 +17,6 @@ py_test( name = "api_compatibility_test", srcs = ["api_compatibility_test.py"], data = [ - ":convert_from_multiline", - "//tensorflow/core/api_def:base_api_def", - "//tensorflow/core/api_def:python_api_def", - "//tensorflow/python:hidden_ops", "//tensorflow/tools/api/golden:api_golden", "//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt", "//tensorflow/tools/api/tests:README.txt", @@ -29,7 +25,6 @@ py_test( deps = [ "//tensorflow:tensorflow_py", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/tools/api/lib:python_object_to_proto_visitor", diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index afcbf50944cc47b3ae3086b17279f2ce2fdc6ee7..c1e09cc531ed8e8995e3e73b87e96b72fba6c038 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -28,10 +28,8 @@ from __future__ import division from __future__ import print_function import argparse -from collections import defaultdict import os import re -import subprocess import sys import unittest @@ -39,7 +37,6 @@ import tensorflow as tf from google.protobuf import text_format -from tensorflow.core.framework import api_def_pb2 from tensorflow.python.lib.io import file_io from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -67,11 +64,6 @@ _API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden' _TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt' _UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt' -_CONVERT_FROM_MULTILINE_SCRIPT = 'tensorflow/tools/api/tests/convert_from_multiline' -_BASE_API_DIR = 'tensorflow/core/api_def/base_api' -_PYTHON_API_DIR = 'tensorflow/core/api_def/python_api' -_HIDDEN_OPS_FILE = 'tensorflow/python/ops/hidden_ops.txt' - def _KeyToFilePath(key): """From a given key, construct a filepath.""" @@ -96,55 +88,6 @@ def _FileNameToKey(filename): return api_object_key -def _GetSymbol(symbol_id): - """Get TensorFlow symbol based on the given identifier. - - Args: - symbol_id: Symbol identifier in the form module1.module2. ... .sym. - - Returns: - Symbol corresponding to the given id. - """ - # Ignore first module which should be tensorflow - symbol_id_split = symbol_id.split('.')[1:] - symbol = tf - for sym in symbol_id_split: - symbol = getattr(symbol, sym) - return symbol - - -def _IsGenModule(module_name): - if not module_name: - return False - module_name_split = module_name.split('.') - return module_name_split[-1].startswith('gen_') - - -def _GetHiddenOps(): - hidden_ops_file = file_io.FileIO(_HIDDEN_OPS_FILE, 'r') - hidden_ops = set() - for line in hidden_ops_file: - line = line.strip() - if not line: - continue - if line[0] == '#': # comment line - continue - # If line is of the form "op_name # comment", only keep the op_name. - line_split = line.split('#') - hidden_ops.add(line_split[0].strip()) - return hidden_ops - - -def _GetGoldenApiDefs(): - old_api_def_files = file_io.get_matching_files(_GetApiDefFilePath('*')) - return {file_path: file_io.read_file_to_string(file_path) - for file_path in old_api_def_files} - - -def _GetApiDefFilePath(graph_op_name): - return os.path.join(_PYTHON_API_DIR, 'api_def_%s.pbtxt' % graph_op_name) - - class ApiCompatibilityTest(test.TestCase): def __init__(self, *args, **kwargs): @@ -287,188 +230,6 @@ class ApiCompatibilityTest(test.TestCase): update_goldens=FLAGS.update_goldens) -class ApiDefTest(test.TestCase): - - def __init__(self, *args, **kwargs): - super(ApiDefTest, self).__init__(*args, **kwargs) - self._first_cap_pattern = re.compile('(.)([A-Z][a-z]+)') - self._all_cap_pattern = re.compile('([a-z0-9])([A-Z])') - - def _GenerateLowerCaseOpName(self, op_name): - lower_case_name = self._first_cap_pattern.sub(r'\1_\2', op_name) - return self._all_cap_pattern.sub(r'\1_\2', lower_case_name).lower() - - def _CreatePythonApiDef(self, base_api_def, endpoint_names): - """Creates Python ApiDef that overrides base_api_def if needed. - - Args: - base_api_def: (api_def_pb2.ApiDef) base ApiDef instance. - endpoint_names: List of Python endpoint names. - - Returns: - api_def_pb2.ApiDef instance with overrides for base_api_def - if module.name endpoint is different from any existing - endpoints in base_api_def. Otherwise, returns None. - """ - endpoint_names_set = set(endpoint_names) - - # If the only endpoint is equal to graph_op_name then - # it is equivalent to having no endpoints. - if (not base_api_def.endpoint and len(endpoint_names) == 1 - and endpoint_names[0] == - self._GenerateLowerCaseOpName(base_api_def.graph_op_name)): - return None - - base_endpoint_names_set = { - self._GenerateLowerCaseOpName(endpoint.name) - for endpoint in base_api_def.endpoint} - - if endpoint_names_set == base_endpoint_names_set: - return None # All endpoints are the same - - api_def = api_def_pb2.ApiDef() - api_def.graph_op_name = base_api_def.graph_op_name - - for endpoint_name in sorted(endpoint_names): - new_endpoint = api_def.endpoint.add() - new_endpoint.name = endpoint_name - - return api_def - - def _GetBaseApiMap(self): - """Get a map from graph op name to its base ApiDef. - - Returns: - Dictionary mapping graph op name to corresponding ApiDef. - """ - # Convert base ApiDef in Multiline format to Proto format. - converted_base_api_dir = os.path.join( - test.get_temp_dir(), 'temp_base_api_defs') - subprocess.check_call( - [os.path.join(resource_loader.get_root_dir_with_all_resources(), - _CONVERT_FROM_MULTILINE_SCRIPT), - _BASE_API_DIR, converted_base_api_dir]) - - name_to_base_api_def = {} - base_api_files = file_io.get_matching_files( - os.path.join(converted_base_api_dir, 'api_def_*.pbtxt')) - for base_api_file in base_api_files: - if file_io.file_exists(base_api_file): - api_defs = api_def_pb2.ApiDefs() - text_format.Merge( - file_io.read_file_to_string(base_api_file), api_defs) - for api_def in api_defs.op: - name_to_base_api_def[api_def.graph_op_name] = api_def - return name_to_base_api_def - - def _AddHiddenOpOverrides(self, name_to_base_api_def, api_def_map): - """Adds ApiDef overrides to api_def_map for hidden Python ops. - - Args: - name_to_base_api_def: Map from op name to base api_def_pb2.ApiDef. - api_def_map: Map from file path to api_def_pb2.ApiDefs for Python API - overrides. - """ - hidden_ops = _GetHiddenOps() - for hidden_op in hidden_ops: - if hidden_op not in name_to_base_api_def: - logging.warning('Unexpected hidden op name: %s' % hidden_op) - continue - - base_api_def = name_to_base_api_def[hidden_op] - if base_api_def.visibility != api_def_pb2.ApiDef.HIDDEN: - api_def = api_def_pb2.ApiDef() - api_def.graph_op_name = base_api_def.graph_op_name - api_def.visibility = api_def_pb2.ApiDef.HIDDEN - - file_path = _GetApiDefFilePath(base_api_def.graph_op_name) - api_def_map[file_path].op.extend([api_def]) - - @unittest.skipUnless( - sys.version_info.major == 2 and os.uname()[0] == 'Linux', - 'API compabitility test goldens are generated using python2 on Linux.') - def testAPIDefCompatibility(self): - # Get base ApiDef - name_to_base_api_def = self._GetBaseApiMap() - snake_to_camel_graph_op_names = { - self._GenerateLowerCaseOpName(name): name - for name in name_to_base_api_def.keys()} - # Extract Python API - visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() - public_api_visitor = public_api.PublicAPIVisitor(visitor) - public_api_visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf, public_api_visitor) - proto_dict = visitor.GetProtos() - - # Map from file path to Python ApiDefs. - new_api_defs_map = defaultdict(api_def_pb2.ApiDefs) - # We need to override all endpoints even if 1 endpoint differs from base - # ApiDef. So, we first create a map from an op to all its endpoints. - op_to_endpoint_name = defaultdict(list) - - # Generate map from generated python op to endpoint names. - for public_module, value in proto_dict.items(): - module_obj = _GetSymbol(public_module) - for sym in value.tf_module.member_method: - obj = getattr(module_obj, sym.name) - - # Check if object is defined in gen_* module. That is, - # the object has been generated from OpDef. - if hasattr(obj, '__module__') and _IsGenModule(obj.__module__): - if obj.__name__ not in snake_to_camel_graph_op_names: - # Symbol might be defined only in Python and not generated from - # C++ api. - continue - relative_public_module = public_module[len('tensorflow.'):] - full_name = (relative_public_module + '.' + sym.name - if relative_public_module else sym.name) - op_to_endpoint_name[obj].append(full_name) - - # Generate Python ApiDef overrides. - for op, endpoint_names in op_to_endpoint_name.items(): - graph_op_name = snake_to_camel_graph_op_names[op.__name__] - api_def = self._CreatePythonApiDef( - name_to_base_api_def[graph_op_name], endpoint_names) - - if api_def: - file_path = _GetApiDefFilePath(graph_op_name) - api_defs = new_api_defs_map[file_path] - api_defs.op.extend([api_def]) - - self._AddHiddenOpOverrides(name_to_base_api_def, new_api_defs_map) - - old_api_defs_map = _GetGoldenApiDefs() - for file_path, new_api_defs in new_api_defs_map.items(): - # Get new ApiDef string. - new_api_defs_str = str(new_api_defs) - - # Get current ApiDef for the given file. - old_api_defs_str = ( - old_api_defs_map[file_path] if file_path in old_api_defs_map else '') - - if old_api_defs_str == new_api_defs_str: - continue - - if FLAGS.update_goldens: - logging.info('Updating %s...' % file_path) - file_io.write_string_to_file(file_path, new_api_defs_str) - else: - self.assertMultiLineEqual( - old_api_defs_str, new_api_defs_str, - 'To update golden API files, run api_compatibility_test locally ' - 'with --update_goldens=True flag.') - - for file_path in set(old_api_defs_map) - set(new_api_defs_map): - if FLAGS.update_goldens: - logging.info('Deleting %s...' % file_path) - file_io.delete_file(file_path) - else: - self.fail( - '%s file is no longer needed and should be removed.' - 'To update golden API files, run api_compatibility_test locally ' - 'with --update_goldens=True flag.' % file_path) - - if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index aa341b144cb8ef3c9a13635c62a7ae1be90b0994..aeac085d30aef746366192361f249eb01f95e8da 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -177,7 +177,17 @@ do_pylint() { echo "pylint took $((PYLINT_END_TIME - PYLINT_START_TIME)) s" echo "" - grep -E '(\[E|\[W0311|\[W0312)' ${OUTPUT_FILE} > ${ERRORS_FILE} + # Report only what we care about + # Ref https://pylint.readthedocs.io/en/latest/technical_reference/features.html + # E: all errors + # W0311 bad-indentation + # W0312 mixed-indentation + # C0330 bad-continuation + # C0301 line-too-long + # C0326 bad-whitespace + # W0611 unused-import + # W0622 redefined-builtin + grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301|\[C0326|\[W0611|\[W0622)' ${OUTPUT_FILE} > ${ERRORS_FILE} N_ERRORS=0 while read -r LINE; do @@ -313,7 +323,7 @@ do_external_licenses_check(){ EXTRA_LICENSES_FILE="$(mktemp)_extra_licenses.log" echo "Getting external dependencies for ${BUILD_TARGET}" - bazel query "attr('licenses', 'notice', deps(${BUILD_TARGET}))" --no_implicit_deps --no_host_deps --keep_going \ + bazel query "attr('licenses', 'notice', deps(${BUILD_TARGET}))" --keep_going \ | grep -E -v "^//tensorflow" \ | sed -e 's|:.*||' \ | sort \ @@ -322,7 +332,7 @@ do_external_licenses_check(){ echo echo "Getting list of external licenses mentioned in ${LICENSES_TARGET}." - bazel query "deps(${LICENSES_TARGET})" --no_implicit_deps --no_host_deps --keep_going \ + bazel query "deps(${LICENSES_TARGET})" --keep_going \ | grep -E -v "^//tensorflow" \ | sed -e 's|:.*||' \ | sort \ @@ -336,6 +346,18 @@ do_external_licenses_check(){ EXTERNAL_LICENSES_CHECK_END_TIME=$(date +'%s') + # Blacklist + echo ${MISSING_LICENSES_FILE} + grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -v ${MISSING_LICENSES_FILE} > temp.txt + mv temp.txt ${MISSING_LICENSES_FILE} + + # Whitelist + echo ${EXTRA_LICENSE_FILE} + grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -v ${EXTRA_LICENSES_FILE} > temp.txt + mv temp.txt ${EXTRA_LICENSES_FILE} + + + echo echo "do_external_licenses_check took $((EXTERNAL_LICENSES_CHECK_END_TIME - EXTERNAL_LICENSES_CHECK_START_TIME)) s" echo @@ -509,9 +531,14 @@ do_check_futures_test() { python check_futures_test.py } +do_check_file_name_test() { + cd "$ROOT_DIR/tensorflow/tools/test" + python file_name_test.py +} + # Supply all sanity step commands and descriptions -SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_cmake_python_sanity") -SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency") +SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_cmake_python_sanity" "do_check_file_name_test") +SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency" "Check file names for cases") INCREMENTAL_FLAG="" DEFAULT_BAZEL_CONFIGS="--config=hdfs --config=gcp" diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh index cf8737c2d8c746b6ad6c436745193290e31326ea..1df6a84d7c6f86abfb965063625ac43a3f1a57fb 100755 --- a/tensorflow/tools/ci_build/install/install_bazel.sh +++ b/tensorflow/tools/ci_build/install/install_bazel.sh @@ -15,7 +15,7 @@ # ============================================================================== # Select bazel version. -BAZEL_VERSION="0.8.0" +BAZEL_VERSION="0.10.0" set +e local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index 71744c04f2f432bc76eadfac406233ad8241a52a..d406b83a6246d18c335fb52cea1256d7809fa61a 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -43,8 +43,8 @@ pip2 install --upgrade werkzeug==0.11.10 pip3 install --upgrade werkzeug==0.11.10 # Install bleach. html5lib will be picked up as a dependency. -pip2 install --upgrade bleach==1.5.0 -pip3 install --upgrade bleach==1.5.0 +pip2 install --upgrade bleach==2.0.0 +pip3 install --upgrade bleach==2.0.0 # Install markdown. pip2 install --upgrade markdown==2.6.8 diff --git a/tensorflow/tools/ci_build/pylintrc b/tensorflow/tools/ci_build/pylintrc index e71017e621ccc8b42cdf8d4e4bd27a81791bbe4c..68fdb617166f70d2bddf0c472d23102960777de0 100644 --- a/tensorflow/tools/ci_build/pylintrc +++ b/tensorflow/tools/ci_build/pylintrc @@ -180,7 +180,17 @@ docstring-min-length=10 max-line-length=80 # Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ +ignore-long-lines=(?x) + (^\s*(import|from)\s + |\$Id:\s\/\/depot\/.+#\d+\s\$ + |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') + |^\s*\#\ LINT\.ThenChange + |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ + |pylint + |""" + |\# + |lambda + |(https?|ftp):) # Allow the body of an if to be on the same line as the test if there is no # else. diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py index 347d0769a92cc767f2e263fce0e21d7d0bc8e586..52a0da9a14847e863d92fee9ef7e63e4af0cf068 100755 --- a/tensorflow/tools/ci_build/update_version.py +++ b/tensorflow/tools/ci_build/update_version.py @@ -261,14 +261,12 @@ def major_minor_change(old_version, new_version): def update_dockerfiles(old_version, new_version): """Update dockerfiles if there was a major change.""" if major_minor_change(old_version, new_version): - old_r_major_minor = r"r%s\.%s" % (old_version.major, old_version.minor) - old_r_major_minor_string = old_r_major_minor.replace("\\", "") - r_major_minor = r"r%s\.%s" % (new_version.major, new_version.minor) - r_major_minor_string = r_major_minor.replace("\\", "") + old_r_major_minor = "r%s.%s" % (old_version.major, old_version.minor) + r_major_minor = "r%s.%s" % (new_version.major, new_version.minor) print("Detected Major.Minor change.") print("Updating pattern %s to %s in additional files" - % (old_r_major_minor_string, r_major_minor_string)) + % (old_r_major_minor, r_major_minor)) # Update dockerfiles replace_string_in_line(old_r_major_minor, r_major_minor, DEVEL_DOCKERFILE) diff --git a/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat index 957729bb37db3ae49800c277f4090a52117c699d..c1bc71850754c5b4b42a6eb50be465ba8f98c218 100644 --- a/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat @@ -36,7 +36,7 @@ SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe" :: Run cmake to create Visual Studio Project files. -%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% +%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX :: Run msbuild in the resulting VS project files to build a pip package. %MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 tf_python_build_pip_package.vcxproj diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat index 5a362de3992156fea8a5fc6ab4c70ba67ab47f89..b87e4a9bec41264827d415a11dfa6f23aeda725d 100644 --- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat @@ -37,7 +37,7 @@ SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe" :: Run cmake to create Visual Studio Project files. -%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_ENABLE_GPU=ON -DCUDNN_HOME=%CUDNN_HOME% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% +%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_ENABLE_GPU=ON -DCUDNN_HOME=%CUDNN_HOME% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% -Dtensorflow_DISABLE_EIGEN_FORCEINLINE=%DISABLE_FORCEINLINE% -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX :: Run msbuild in the resulting VS project files to build a pip package. %MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 tf_python_build_pip_package.vcxproj diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh index fa28e3d79ca4ee5f429a41dd3e871248d5c047ca..583d1d5f09527861015458c636af2259b34d45f8 100755 --- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh @@ -41,7 +41,7 @@ run_configure_for_cpu_build # build_libtensorflow_tarball in ../builds/libtensorflow.sh # cannot be used on Windows since it relies on pkg_tar rules. # So we do something special here -bazel build -c opt \ +bazel build -c opt --copt=/arch:AVX \ tensorflow:libtensorflow.so \ tensorflow/tools/lib_package:clicenses_generate \ tensorflow/java:libtensorflow_jni.so \ diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh index 573c926203fc76b787ba08b10bd71c8effda29b6..94276c6c5c9ce897ca24f03efe3d93e1ea1e00c9 100644 --- a/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh +++ b/tensorflow/tools/ci_build/windows/libtensorflow_gpu.sh @@ -41,7 +41,7 @@ run_configure_for_gpu_build # build_libtensorflow_tarball in ../builds/libtensorflow.sh # cannot be used on Windows since it relies on pkg_tar rules. # So we do something special here -bazel build -c opt \ +bazel build -c opt --copt=/arch:AVX \ tensorflow:libtensorflow.so \ tensorflow/tools/lib_package:clicenses_generate \ tensorflow/java:libtensorflow_jni.so \ diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index fa1cc739056e7d50ace73e9ca6645b5dc04621e5..6e90b286c99f894ddd25268afc69043759571c36 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -46,8 +46,9 @@ class APIChangeSpec(object): """ -class _FileEditTuple(collections.namedtuple( - "_FileEditTuple", ["comment", "line", "start", "old", "new"])): +class _FileEditTuple( + collections.namedtuple("_FileEditTuple", + ["comment", "line", "start", "old", "new"])): """Each edit that is recorded by a _FileEditRecorder. Fields: @@ -179,8 +180,7 @@ class _ASTCallVisitor(ast.NodeVisitor): function_renames = self._api_change_spec.function_renames try: new_name = function_renames[full_name] - self._file_edit.add("Renamed function %r to %r" % (full_name, - new_name), + self._file_edit.add("Renamed function %r to %r" % (full_name, new_name), node.lineno, node.col_offset, full_name, new_name) except KeyError: pass @@ -227,7 +227,7 @@ class _ASTCallVisitor(ast.NodeVisitor): # loop over lines while 1: # Reverse the text to and regular expression search for whitespace - text = self._lines[line-1] + text = self._lines[line - 1] reversed_preceding_text = text[:col][::-1] # First find if a [ can be found with only whitespace between it and # col. @@ -236,8 +236,8 @@ class _ASTCallVisitor(ast.NodeVisitor): new_col_offset = col - m.start(1) - 1 return line, new_col_offset else: - if (reversed_preceding_text=="" or - reversed_preceding_text.isspace()): + if (reversed_preceding_text == "" or + reversed_preceding_text.isspace()): line = line - 1 prev_line = self._lines[line - 1] # TODO(aselle): @@ -248,8 +248,8 @@ class _ASTCallVisitor(ast.NodeVisitor): # node ranges to filter out spurious #'s that appear in string # literals. comment_start = prev_line.find("#") - if comment_start == -1: - col = len(prev_line) -1 + if comment_start == -1: + col = len(prev_line) - 1 elif find_string_chars.search(prev_line[comment_start:]) is None: col = comment_start else: @@ -260,7 +260,6 @@ class _ASTCallVisitor(ast.NodeVisitor): # it is not possible to use that in an argument. return node.lineno, node.col_offset - def visit_Call(self, node): # pylint: disable=invalid-name """Handle visiting a call node in the AST. @@ -268,7 +267,6 @@ class _ASTCallVisitor(ast.NodeVisitor): node: Current Node """ - # Find a simple attribute name path e.g. "tf.foo.bar" full_name = self._get_attribute_full_path(node.func) @@ -293,18 +291,21 @@ class _ASTCallVisitor(ast.NodeVisitor): lineno, col_offset = self._find_true_position(arg) if lineno is None or col_offset is None: self._file_edit.add( - "Failed to add keyword %r to reordered function %r" - % (reordered[idx], full_name), arg.lineno, arg.col_offset, - "", "", + "Failed to add keyword %r to reordered function %r" % + (reordered[idx], full_name), + arg.lineno, + arg.col_offset, + "", + "", error="A necessary keyword argument failed to be inserted.") else: keyword_arg = reordered[idx] if (full_name in function_keyword_renames and keyword_arg in function_keyword_renames[full_name]): keyword_arg = function_keyword_renames[full_name][keyword_arg] - self._file_edit.add("Added keyword %r to reordered function %r" - % (reordered[idx], full_name), lineno, - col_offset, "", keyword_arg + "=") + self._file_edit.add("Added keyword %r to reordered function %r" % + (reordered[idx], full_name), lineno, col_offset, + "", keyword_arg + "=") # Examine each keyword argument and convert it to the final renamed form renamed_keywords = ({} if full_name not in function_keyword_renames else @@ -322,11 +323,11 @@ class _ASTCallVisitor(ast.NodeVisitor): # value. key_start = argval_col_offset - len(argkey) - 1 key_end = key_start + len(argkey) + 1 - if (self._lines[argval_lineno - 1][key_start:key_end] == - argkey + "="): + if (self._lines[argval_lineno - 1][key_start:key_end] == argkey + + "="): self._file_edit.add("Renamed keyword argument from %r to %r" % - (argkey, renamed_keywords[argkey]), - argval_lineno, + (argkey, + renamed_keywords[argkey]), argval_lineno, argval_col_offset - len(argkey) - 1, argkey + "=", renamed_keywords[argkey] + "=") continue @@ -335,7 +336,8 @@ class _ASTCallVisitor(ast.NodeVisitor): (argkey, renamed_keywords[argkey]), argval.lineno, argval.col_offset - len(argkey) - 1, - "", "", + "", + "", error="Failed to find keyword lexographically. Fix manually.") ast.NodeVisitor.generic_visit(self, node) @@ -352,7 +354,7 @@ class _ASTCallVisitor(ast.NodeVisitor): if full_name in self._api_change_spec.change_to_function: if not hasattr(node, "is_function_for_call"): new_text = full_name + "()" - self._file_edit.add("Changed %r to %r"%(full_name, new_text), + self._file_edit.add("Changed %r to %r" % (full_name, new_text), node.lineno, node.col_offset, full_name, new_text) ast.NodeVisitor.generic_visit(self, node) @@ -380,8 +382,8 @@ class ASTCodeUpgrader(object): # Write to a temporary file, just in case we are doing an implace modify. with open(in_filename, "r") as in_file, \ tempfile.NamedTemporaryFile("w", delete=False) as temp_file: - ret = self.process_opened_file( - in_filename, in_file, out_filename, temp_file) + ret = self.process_opened_file(in_filename, in_file, out_filename, + temp_file) shutil.move(temp_file.name, out_filename) return ret @@ -424,6 +426,7 @@ class ASTCodeUpgrader(object): out_file.write(out_text) text += "\n" return 1, text, process_errors + # pylint: enable=broad-except def process_tree(self, root_directory, output_root_directory, @@ -444,16 +447,16 @@ class ASTCodeUpgrader(object): # make sure output directory doesn't exist if output_root_directory and os.path.exists(output_root_directory): - print("Output directory %r must not already exist." % ( - output_root_directory)) + print("Output directory %r must not already exist." % + (output_root_directory)) sys.exit(1) # make sure output directory does not overlap with root_directory norm_root = os.path.split(os.path.normpath(root_directory)) norm_output = os.path.split(os.path.normpath(output_root_directory)) if norm_root == norm_output: - print("Output directory %r same as input directory %r" % ( - root_directory, output_root_directory)) + print("Output directory %r same as input directory %r" % + (root_directory, output_root_directory)) sys.exit(1) # Collect list of files to process (we do this to correctly handle if the @@ -465,14 +468,16 @@ class ASTCodeUpgrader(object): copy_files = [f for f in file_list if not f.endswith(".py")] for filename in py_files: fullpath = os.path.join(dir_name, filename) - fullpath_output = os.path.join( - output_root_directory, os.path.relpath(fullpath, root_directory)) + fullpath_output = os.path.join(output_root_directory, + os.path.relpath(fullpath, + root_directory)) files_to_process.append((fullpath, fullpath_output)) if copy_other_files: for filename in copy_files: fullpath = os.path.join(dir_name, filename) - fullpath_output = os.path.join( - output_root_directory, os.path.relpath(fullpath, root_directory)) + fullpath_output = os.path.join(output_root_directory, + os.path.relpath( + fullpath, root_directory)) files_to_copy.append((fullpath, fullpath_output)) file_count = 0 @@ -641,18 +646,17 @@ class TFAPIChangeSpec(APIChangeSpec): "tf.concat": ["concat_dim", "values", "name"], "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"], "tf.nn.softmax_cross_entropy_with_logits": [ - "logits", "labels", "dim", "name"], + "logits", "labels", "dim", "name" + ], "tf.nn.sparse_softmax_cross_entropy_with_logits": [ - "logits", "labels", "name"], - "tf.nn.sigmoid_cross_entropy_with_logits": [ - "logits", "labels", "name"], + "logits", "labels", "name" + ], + "tf.nn.sigmoid_cross_entropy_with_logits": ["logits", "labels", "name"], "tf.op_scope": ["values", "name", "default_name"], } # Specially handled functions. - self.function_handle = { - "tf.reverse": self._reverse_handler - } + self.function_handle = {"tf.reverse": self._reverse_handler} @staticmethod def _reverse_handler(file_edit_recorder, node): @@ -661,12 +665,13 @@ class TFAPIChangeSpec(APIChangeSpec): comment = ("ERROR: tf.reverse has had its argument semantics changed\n" "significantly the converter cannot detect this reliably, so you" "need to inspect this usage manually.\n") - file_edit_recorder.add(comment, - node.lineno, - node.col_offset, - "tf.reverse", - "tf.reverse", - error="tf.reverse requires manual check.") + file_edit_recorder.add( + comment, + node.lineno, + node.col_offset, + "tf.reverse", + "tf.reverse", + error="tf.reverse requires manual check.") if __name__ == "__main__": diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py index a495f9883b284869d043441d1cfecca01296eda3..3d02eacba6e7a91e6d3c88e8297306de9782f4bf 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_test.py @@ -114,7 +114,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."]) def testListComprehension(self): - def _test(input, output): + def _test(input, output): # pylint: disable=redefined-builtin _, unused_report, errors, new_text = self._upgrade(input) self.assertEqual(new_text, output) _test("tf.concat(0, \t[x for x in y])\n", diff --git a/tensorflow/tools/dist_test/build_server.sh b/tensorflow/tools/dist_test/build_server.sh index 878fabd248f3c1dd5cb79983df5220ebf5893026..225c0347416ec8c8fef855946d18e838bd767690 100755 --- a/tensorflow/tools/dist_test/build_server.sh +++ b/tensorflow/tools/dist_test/build_server.sh @@ -16,14 +16,15 @@ # # Builds the test server for distributed (GRPC) TensorFlow # -# Usage: build_server.sh [--test] +# Usage: build_server.sh [--test] # # Arguments: # docker_image_name: Name of the docker image to build. # E.g.: tensorflow/tf_grpc_test_server:0.11.0rc1 # -# whl_url: URL from which the TensorFlow whl file will be downloaded. +# whl_file_location: URL from which the TensorFlow whl file will be downloaded. # E.g.: https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl +# E.g.: /path/to/folder/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl # # The optional flag --test lets the script to use the Dockerfile for the # testing GRPC server. Without the flag, the script will build the non-test @@ -41,11 +42,11 @@ die() { # Check arguments if [[ $# -lt 2 ]]; then - die "Usage: $0 [--test]" + die "Usage: $0 [--test]" fi DOCKER_IMG_NAME=$1 -WHL_URL=$2 +WHL_FILE_LOCATION=$2 shift 2 # Current script directory @@ -53,7 +54,7 @@ DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BUILD_DIR=$(mktemp -d) echo "" -echo "Using whl file URL: ${WHL_URL}" +echo "Using whl file URL: ${WHL_FILE_LOCATION}" echo "Building in temporary directory: ${BUILD_DIR}" cp -r ${DIR}/* "${BUILD_DIR}"/ || \ @@ -65,9 +66,15 @@ if [[ $1 == "--test" ]]; then fi echo "Using Docker file: ${DOCKER_FILE}" +if [[ $WHL_FILE_LOCATION =~ 'http://' || $WHL_FILE_LOCATION =~ 'https://' ]]; then + # Download whl file into the build context directory. + wget -P "${BUILD_DIR}" "${WHL_FILE_LOCATION}" || \ + die "Failed to download tensorflow whl file from URL: ${WHL_FILE_LOCATION}" +else + cp "${WHL_FILE_LOCATION}" "${BUILD_DIR}" +fi + # Download whl file into the build context directory. -wget -P "${BUILD_DIR}" ${WHL_URL} || \ - die "Failed to download tensorflow whl file from URL: ${WHL_URL}" if [[ ! -f "${DOCKER_FILE}" ]]; then die "ERROR: Unable to find dockerfile: ${DOCKER_FILE}" diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh index 7d7f92d246e1ca0b519ac3bf30fde673621ff755..435f9d0dc9c55a3dcfc45e7e46f279b4679a9086 100755 --- a/tensorflow/tools/dist_test/local_test.sh +++ b/tensorflow/tools/dist_test/local_test.sh @@ -24,19 +24,20 @@ # 3) Call a script to launch a k8s TensorFlow GRPC cluster inside the container # and run the distributed test suite. # -# Usage: local_test.sh +# Usage: local_test.sh # [--leave_container_running] # [--model_name ] # [--num_workers ] # [--num_parameter_servers ] # [--sync_replicas] # -# E.g., local_test.sh --model_name CENSUS_WIDENDEEP -# local_test.sh --num_workers 3 --num_parameter_servers 3 +# E.g., local_test.sh --model_name CENSUS_WIDENDEEP +# local_test.sh --num_workers 3 --num_parameter_servers 3 # # Arguments: -# -# Specify custom TensorFlow whl file URL to install in the test Docker image. +# whl_file_location: URL from which the TensorFlow whl file will be acquired. +# E.g.: https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl +# E.g.: /path/to/folder/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl # # --leave_container_running: Do not stop the docker-in-docker container after # the termination of the tests, e.g., for debugging @@ -81,9 +82,9 @@ NUM_WORKERS=2 NUM_PARAMETER_SERVERS=2 SYNC_REPLICAS_FLAG="" -WHL_URL=${1} -if [[ -z "${WHL_URL}" ]]; then - die "whl file URL is not specified" +WHL_FILE_LOCATION=${1} +if [[ -z "${WHL_FILE_LOCATION}" ]]; then + die "whl file location is not specified" fi while true; do @@ -98,8 +99,8 @@ while true; do NUM_PARAMETER_SERVERS=$2 elif [[ $1 == "--sync_replicas" ]]; then SYNC_REPLICAS_FLAG="--sync_replicas" - elif [[ $1 == "--whl_url" ]]; then - WHL_URL=$2 + elif [[ $1 == "--WHL_FILE_LOCATION" ]]; then + WHL_FILE_LOCATION=$2 fi shift @@ -130,15 +131,19 @@ fi # Create docker build context directory. BUILD_DIR=$(mktemp -d) echo "" -echo "Using whl file URL: ${WHL_URL}" +echo "Using whl file location: ${WHL_FILE_LOCATION}" echo "Building in temporary directory: ${BUILD_DIR}" cp -r ${DIR}/* "${BUILD_DIR}"/ || \ die "Failed to copy files to ${BUILD_DIR}" -# Download whl file into the build context directory. -wget -P "${BUILD_DIR}" ${WHL_URL} || \ - die "Failed to download tensorflow whl file from URL: ${WHL_URL}" +if [[ $WHL_FILE_LOCATION =~ 'http://' || $WHL_FILE_LOCATION =~ 'https://' ]]; then + # Download whl file into the build context directory. + wget -P "${BUILD_DIR}" "${WHL_FILE_LOCATION}" || \ + die "Failed to download tensorflow whl file from URL: ${WHL_FILE_LOCATION}" +else + cp "${WHL_FILE_LOCATION}" "${BUILD_DIR}" +fi # Build docker image for test. docker build ${NO_CACHE_FLAG} -t ${DOCKER_IMG_NAME} \ diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py index e40ecb43f9a00bee7309895969ff65e48b95b4e9..a2d12442c44553a287637029843021b7541fa3fa 100644 --- a/tensorflow/tools/dist_test/python/mnist_replica.py +++ b/tensorflow/tools/dist_test/python/mnist_replica.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Distributed MNIST training and validation, with model replicas. A simple softmax model with one hidden layer is defined. The parameters @@ -32,7 +31,6 @@ perform forward computation and gradient calculation in parallel, which should lead to increased training speed for the simple model. """ - from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -45,7 +43,6 @@ import time import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data - flags = tf.app.flags flags.DEFINE_string("data_dir", "/tmp/mnist-data", "Directory for storing mnist data") @@ -56,8 +53,7 @@ flags.DEFINE_integer("task_index", None, "Worker task index, should be >= 0. task_index=0 is " "the master worker task the performs the variable " "initialization ") -flags.DEFINE_integer("num_gpus", 1, - "Total number of gpus for each machine." +flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine." "If you don't use GPU, please set it to '0'") flags.DEFINE_integer("replicas_to_aggregate", None, "Number of replicas to aggregate before parameter update" @@ -69,24 +65,24 @@ flags.DEFINE_integer("train_steps", 200, "Number of (global) training steps to perform") flags.DEFINE_integer("batch_size", 100, "Training batch size") flags.DEFINE_float("learning_rate", 0.01, "Learning rate") -flags.DEFINE_boolean("sync_replicas", False, - "Use the sync_replicas (synchronized replicas) mode, " - "wherein the parameter updates from workers are aggregated " - "before applied to avoid stale gradients") +flags.DEFINE_boolean( + "sync_replicas", False, + "Use the sync_replicas (synchronized replicas) mode, " + "wherein the parameter updates from workers are aggregated " + "before applied to avoid stale gradients") flags.DEFINE_boolean( "existing_servers", False, "Whether servers already exists. If True, " "will use the worker hosts via their GRPC URLs (one client process " "per worker host). Otherwise, will create an in-process TensorFlow " "server.") -flags.DEFINE_string("ps_hosts","localhost:2222", +flags.DEFINE_string("ps_hosts", "localhost:2222", "Comma-separated list of hostname:port pairs") flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "Comma-separated list of hostname:port pairs") -flags.DEFINE_string("job_name", None,"job name: worker or ps") +flags.DEFINE_string("job_name", None, "job name: worker or ps") FLAGS = flags.FLAGS - IMAGE_PIXELS = 28 @@ -97,7 +93,7 @@ def main(unused_argv): if FLAGS.job_name is None or FLAGS.job_name == "": raise ValueError("Must specify an explicit `job_name`") - if FLAGS.task_index is None or FLAGS.task_index =="": + if FLAGS.task_index is None or FLAGS.task_index == "": raise ValueError("Must specify an explicit `task_index`") print("job name = %s" % FLAGS.job_name) @@ -110,9 +106,7 @@ def main(unused_argv): # Get the number of workers. num_workers = len(worker_spec) - cluster = tf.train.ClusterSpec({ - "ps": ps_spec, - "worker": worker_spec}) + cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec}) if not FLAGS.existing_servers: # Not using existing servers. Create an in-process server. @@ -217,7 +211,8 @@ def main(unused_argv): sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False, - device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index]) + device_filters=["/job:ps", + "/job:worker/task:%d" % FLAGS.task_index]) # The chief worker (task_index==0) session will prepare the session, # while the remaining workers will wait for the preparation to complete. @@ -231,8 +226,7 @@ def main(unused_argv): server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index] print("Using existing server at: %s" % server_grpc_url) - sess = sv.prepare_or_wait_for_session(server_grpc_url, - config=sess_config) + sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config) else: sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index 5dc4a053fd2cae7d83739507fea31e7afc92d77c..d16761c3675942838fd2be0ea6e0b7463a3bf249 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -70,7 +70,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.5 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.6 --depth=1 https://github.com/tensorflow/tensorflow.git . # TODO(craigcitro): Don't install the pip package, since it makes it # more difficult to experiment with local changes. Instead, just add diff --git a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl index 96b260ad3aeb78622dd1ad276f7d524dd598e3bf..3690e7dfe57a4682276a90b10cb84c9a329b3f5e 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl +++ b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl @@ -3,7 +3,7 @@ FROM tensorflow/tensorflow:latest-devel LABEL maintainer="Clayne Robison" # These arguments are parameterized. Use --build-args to override. -ARG TF_BRANCH=r1.5 +ARG TF_BRANCH=r1.6 ARG WHL_DIR=/whl RUN apt-get update && apt-get install -y --no-install-recommends \ diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index 07ffd3839a32ef194100322e54b9133412e4b664..4ef37881bc91aaa58bab031c69b4a96c2a9d8ec1 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -79,7 +79,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.5 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.6 --depth=1 https://github.com/tensorflow/tensorflow.git . # Configure the build for our CUDA configuration. ENV CI_BUILD_PYTHON python diff --git a/tensorflow/tools/docker/jupyter_notebook_config.py b/tensorflow/tools/docker/jupyter_notebook_config.py index 0acbf6fcee58b3eb14794c0f3bb8d2f6ae6e5910..05dcefb099a92683e2cd4700fff54c89c018baa6 100644 --- a/tensorflow/tools/docker/jupyter_notebook_config.py +++ b/tensorflow/tools/docker/jupyter_notebook_config.py @@ -15,6 +15,7 @@ import os from IPython.lib import passwd +c = c # pylint:disable=undefined-variable c.NotebookApp.ip = '*' c.NotebookApp.port = int(os.getenv('PORT', 8888)) c.NotebookApp.open_browser = False diff --git a/tensorflow/tools/docs/generate_1_0.py b/tensorflow/tools/docs/generate_1_0.py index cdc03fdcacf44f7be49e739962b63ba84cf94896..f4384e0ced77718c80d4d146a2d72072588a0541 100644 --- a/tensorflow/tools/docs/generate_1_0.py +++ b/tensorflow/tools/docs/generate_1_0.py @@ -53,7 +53,6 @@ if __name__ == '__main__': 'factorization', 'grid_rnn', 'labeled_tensor', - 'ndlstm', 'quantization', 'session_bundle', 'slim', diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 003f972070cb05aa6f34a3748d47f019744de058..34dd419f15676babfa9a36c2c0960b01248b6f69 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -215,7 +215,6 @@ def _get_default_do_not_descend_map(): # Block contrib.keras to de-clutter the docs 'keras', 'labeled_tensor', - 'ndlstm', 'quantization', 'session_bundle', 'slim', diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 3db164c2b5b78dbcb3c408ce89c067d33c2a2af4..e758229535e7b10994a39cbafb37e116fd2a465c 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -111,8 +111,8 @@ SYMBOL_REFERENCE_RE = re.compile( r""" # Start with a literal "@{". @\{ - # Group at least 1 symbol: not "}" or "\n". - ([^}\n]+) + # Group at least 1 symbol, not "}". + ([^}]+) # Followed by a closing "}" \} """, diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 8a0e9af5216c881326449b3e85b94c0be331fa37..fca5436ca5fadd1fb5da07d7523bb51c871164b5 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -76,8 +76,9 @@ class ParserTest(googletest.TestCase): pass string = ( - 'A @{tf.reference}, another @{tf.reference}, a member ' - '@{tf.reference.foo}, and a @{tf.third$link `text` with `code` in it}.') + 'A @{tf.reference}, another @{tf.reference$with\nnewline}, a member ' + '@{tf.reference.foo}, and a @{tf.third$link `text` with `code` in ' + 'it}.') duplicate_of = {'tf.third': 'tf.fourth'} index = {'tf.reference': HasOneMember, 'tf.reference.foo': HasOneMember.foo, @@ -93,7 +94,7 @@ class ParserTest(googletest.TestCase): self.assertEqual('A ' 'tf.reference, ' 'another ' - 'tf.reference, ' + 'with\nnewline, ' 'a member ' 'tf.reference.foo, ' 'and a link ' diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py index b5df633800ae5a3ce67cf03910d472b9908d6249..543b5fa6fefcd8e8dca99ad7eac7cca76781ccd3 100644 --- a/tensorflow/tools/docs/pretty_docs.py +++ b/tensorflow/tools/docs/pretty_docs.py @@ -162,7 +162,7 @@ def _build_class_page(page_info): parts.append(h3.format(**method_info.__dict__)) if method_info.signature is not None: - parts.append(_build_signature(method_info)) + parts.append(_build_signature(method_info, use_full_name=False)) parts.append(method_info.doc.docstring) parts.append(_build_function_details(method_info.doc.function_details)) @@ -259,14 +259,14 @@ def _build_module_page(page_info): return ''.join(parts) -def _build_signature(obj_info): +def _build_signature(obj_info, use_full_name=True): """Returns a md code block showing the function signature.""" # Special case tf.range, since it has an optional first argument if obj_info.full_name == 'tf.range': return ( '``` python\n' - "range(limit, delta=1, dtype=None, name='range')\n" - "range(start, limit, delta=1, dtype=None, name='range')\n" + "tf.range(limit, delta=1, dtype=None, name='range')\n" + "tf.range(start, limit, delta=1, dtype=None, name='range')\n" '```\n\n') parts = ['``` python'] @@ -281,7 +281,11 @@ def _build_signature(obj_info): sig = ',\n'.join(' %s' % sig_item for sig_item in obj_info.signature) sig = '\n'+sig+'\n' - parts.append(signature_template.format(name=obj_info.short_name, sig=sig)) + if use_full_name: + obj_name = obj_info.full_name + else: + obj_name = obj_info.short_name + parts.append(signature_template.format(name=obj_name, sig=sig)) parts.append('```\n\n') return '\n'.join(parts) diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index b5465b7fb32856833fc2a12c8dfea58c2e8e79dd..ad3668fa02e102607c9a03ac312451a147affdda 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -99,22 +99,22 @@ cc_library( "freeze_requantization_ranges.cc", "fuse_convolutions.cc", "insert_logging.cc", - "remove_ema.cc", "obfuscate_names.cc", + "quantize_nodes.cc", + "quantize_weights.cc", "remove_attribute.cc", + "remove_control_dependencies.cc", "remove_device.cc", + "remove_ema.cc", "remove_nodes.cc", "rename_attribute.cc", "rename_op.cc", + "round_weights.cc", "set_device.cc", "sort_by_execution_order.cc", "sparsify_gather.cc", "strip_unused_nodes.cc", - ] + if_not_windows([ - "quantize_nodes.cc", - "quantize_weights.cc", - "round_weights.cc", - ]), + ], hdrs = [ "fold_constants_lib.h", ], diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index 345d9eadb858cadebe03ecb3297aea52ba54bd37..67badb4869029b684cae05130d3c1e190dfb40d2 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -639,6 +639,13 @@ specified devices may not be available. In order to work with graphs like these, you can run this transform to wipe the slate clean and delete the device specifier from all ops. +### remove_control_dependencies + +Args: None \ +Prerequisites: None + +Removes all control dependencies from the graph. + ### remove_nodes Args: diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc index 5ccd88cfa1acfd55e90504d66417349e42fe3b50..a022f5792676c62c52fd1197b0d8c436f7161a47 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc @@ -183,22 +183,6 @@ Status ExtractRangeFromParams(const TransformFuncContext& context, return Status::OK(); } -bool AreAttrsEqual(const NodeDef* current_node, const NodeDef* other_node) { - if (current_node->attr_size() != other_node->attr_size()) { - return false; - } - string current_serialized; - string other_serialized; - for (const auto& attr : other_node->attr()) { - auto iter = current_node->attr().find(attr.first); - if (iter == current_node->attr().end()) return false; - iter->second.SerializeToString(¤t_serialized); - attr.second.SerializeToString(&other_serialized); - if (current_serialized != other_serialized) return false; - } - return true; -} - } // namespace // Analyzes all the nodes in the graph to figure out which ones are duplicates diff --git a/tensorflow/tools/graph_transforms/remove_control_dependencies.cc b/tensorflow/tools/graph_transforms/remove_control_dependencies.cc new file mode 100644 index 0000000000000000000000000000000000000000..cba6b78fc5c43ca17f4f30930eb74efdb12940a1 --- /dev/null +++ b/tensorflow/tools/graph_transforms/remove_control_dependencies.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Remove control depdencies in preparation for inference. +// In the tensorflow graph, control dependencies are represented as extra +// inputs which are referenced with "^tensor_name". +// See node_def.proto for more details. +Status RemoveControlDependencies(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + output_graph_def->Clear(); + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = output_graph_def->mutable_node()->Add(); + *new_node = node; + new_node->clear_input(); + for (const auto& input : node.input()) { + if (input[0] != '^') { + new_node->add_input(input); + } + } + } + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("remove_control_dependencies", RemoveControlDependencies); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_nodes.cc b/tensorflow/tools/graph_transforms/remove_nodes.cc index 119b44d6a4a4d066b734ae8a0e655c771087d0db..05f036a86a09b2a6a94e9c1a1220803eabc64da5 100644 --- a/tensorflow/tools/graph_transforms/remove_nodes.cc +++ b/tensorflow/tools/graph_transforms/remove_nodes.cc @@ -81,7 +81,17 @@ Status RemoveNodes(const GraphDef& input_graph_def, return Status::OK(); } const NodeDef& input_node = match.inputs[0].node; - inputs_to_rename[replace_node.name()] = input_node.name(); + string target_name = input_node.name(); + for (const string& input : replace_node.input()) { + if (!input.compare(0, target_name.size(), target_name)) { + if (input.size() == target_name.size() || + input[target_name.size()] == ':') { + target_name = input; + break; + } + } + } + inputs_to_rename[replace_node.name()] = target_name; inputs_to_rename["^" + replace_node.name()] = "^" + input_node.name(); new_nodes->push_back(input_node); diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc index 96324d0deab400078fdf388bff69001f8e2df9aa..701e350fc39d083665f5420e6b73510c182e12ce 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include "tensorflow/c/checkpoint_reader.h" #include "tensorflow/core/framework/tensor.h" @@ -28,9 +29,10 @@ limitations under the License. #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { -using strings::StrCat; using str_util::Join; using str_util::Split; +using str_util::StringReplace; +using strings::StrCat; namespace graph_transforms { @@ -84,48 +86,71 @@ void CreateConstNode(const Tensor& tensor, const string& name, SetNodeTensorAttr("value", tensor, node_def); } +string GetMonolithicTensorKey(const string& tensor_slice_name) { + std::vector names = Split(tensor_slice_name, "/"); + if (StringPiece(names[names.size() - 1]).starts_with("part_")) { + CHECK_GE(names.size(), 2); + names.pop_back(); + } + return Join(names, "/"); +} + Status ObtainTensorSlice(const GraphDef& input_graph_def, - const string& tensor_name, + const string& target_name, string* shape_slice_string) { string restore_node_name; for (const auto& node : input_graph_def.node()) { - std::vector node_name_parts = str_util::Split(node.name(), "/"); + std::vector node_name_parts = Split(node.name(), "/"); if (node_name_parts.size() == 2 && StringPiece(node_name_parts[0]).starts_with("save") && StringPiece(node_name_parts[1]).starts_with("Assign") && - node.input(0) == tensor_name) { + node.input(0) == target_name) { restore_node_name = node.input(1); break; } } + + std::vector restore_node_parts = Split(restore_node_name, ":"); + CHECK_LE(restore_node_parts.size(), 2); + string tensor_names_node; string shape_and_slices_node; for (const auto& node : input_graph_def.node()) { - if ((node.name() == restore_node_name) && (node.op() == "RestoreV2")) { + if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) { + tensor_names_node = node.input(1); shape_and_slices_node = node.input(2); break; } } + + int offset = -1; + for (const auto& node : input_graph_def.node()) { + if (node.name() == tensor_names_node) { + Tensor tensor_names_tensor; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor)); + const auto& tensor_names_value = tensor_names_tensor.flat(); + for (int i = 0; i < tensor_names_value.size(); i++) { + if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) { + offset = i; + break; + } + } + } + } + if (offset == -1) { + return errors::Internal("Unable to find RestoreV2 entry for variable: ", + target_name); + } for (const auto& node : input_graph_def.node()) { if (node.name() == shape_and_slices_node) { Tensor shape_and_slices_tensor; TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor)); const auto& shape_and_slices_value = shape_and_slices_tensor.flat(); - *shape_slice_string = shape_and_slices_value(0); + *shape_slice_string = shape_and_slices_value(offset); return Status::OK(); } } - return errors::Internal("Unable to find slice for variable: ", tensor_name); -} - -string GetMonolithicTensorKey(const string& tensor_slice_name) { - std::vector names = str_util::Split(tensor_slice_name, "/"); - CHECK_GE(names.size(), 2); - CHECK(StringPiece(names[names.size() - 1]).starts_with("part_")); - - // Remove the "part_x" suffix - names.pop_back(); - return str_util::Join(names, "/"); + return errors::Internal("Unable to find slice for variable: ", target_name); } Status ReadTensorFromCheckpoint( @@ -179,6 +204,22 @@ Status ObtainVariableInfo( return Status::OK(); } +Status RemoveInputAtIndex(NodeDef* n, int index) { + for (int i = index; i < n->input_size() - 1; i++) { + n->mutable_input()->SwapElements(i, i + 1); + } + n->mutable_input()->RemoveLast(); + return Status::OK(); +} + +Status RemoveNodeAtIndex(GraphDef* g, int index) { + for (int i = index; i < g->node_size() - 1; i++) { + g->mutable_node()->SwapElements(i, i + 1); + } + g->mutable_node()->RemoveLast(); + return Status::OK(); +} + Status SparsifyGatherInternal( const GraphDef& input_graph_def, const std::unique_ptr >& @@ -193,6 +234,15 @@ Status SparsifyGatherInternal( GraphDef current_graph_def = input_graph_def; bool any_match_found = false; + // Populate references. + std::unordered_map refs; + for (const auto& node : current_graph_def.node()) { + for (const auto& input : node.input()) { + auto parsed_input = StringReplace(input, "^", "", true); + refs[parsed_input] += 1; + } + } + // The subgraphs may have overlapping components, therefore GraphMatcher // doesn't return all subgraphs in one round -- this has to be multi-round // update. @@ -200,15 +250,15 @@ Status SparsifyGatherInternal( any_match_found = false; GraphDef replaced_graph_def = current_graph_def; std::vector init_table_node_names; - std::vector removed_variable_names; + std::vector removed_node_names; TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( current_graph_def, pattern, [&ckpt_reader, &any_match_found, &init_table_node_names, - &shapes_and_slices, &removed_variable_names]( - const NodeMatch& match, const std::set& input_nodes, - const std::set& output_nodes, - std::vector* new_nodes) { + &shapes_and_slices, &removed_node_names, + &refs](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { any_match_found = true; // The captured subgraph should be of the following pattern: @@ -290,9 +340,13 @@ Status SparsifyGatherInternal( TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint( weights_node.name(), ckpt_reader, (*shapes_and_slices)[weights_node.name()], &weight)); - // Add both both weight and identity node names. - removed_variable_names.push_back(weights_node.name()); - removed_variable_names.push_back(match.inputs[0].node.name()); + } + // Add both both weight and identity node names. + removed_node_names.push_back(weights_node.name()); + removed_node_names.push_back(match.inputs[0].node.name()); + for (auto input_node : match.inputs[0].node.input()) { + auto parsed_input = StringReplace(input_node, "^", "", true); + refs[parsed_input]--; } Tensor indices_tensor; Tensor values_tensor; @@ -362,15 +416,23 @@ Status SparsifyGatherInternal( // Connect nodes AddNodeInput(hashtable_node.name(), &init_table_node); + refs[hashtable_node.name()]++; AddNodeInput(indices_node.name(), &init_table_node); + refs[indices_node.name()]++; AddNodeInput(values_node.name(), &init_table_node); + refs[values_node.name()]++; AddNodeInput(hashtable_node.name(), &lookup_node); + refs[hashtable_node.name()]++; AddNodeInput(gather_node.input(1), &lookup_node); + refs[gather_node.input(1)]++; AddNodeInput(default_value_node.name(), &lookup_node); + refs[default_value_node.name()]++; AddNodeInput(lookup_node.name(), &expand_dims_node); + refs[lookup_node.name()]++; AddNodeInput(dim_idx_node.name(), &expand_dims_node); + refs[dim_idx_node.name()]++; // Copy 'ids' input of original 'Gather' new_nodes->push_back(match.inputs[1].node); @@ -404,47 +466,88 @@ Status SparsifyGatherInternal( for (const string& name : init_table_node_names) { // Add control dependence from init_table_node to group_deps_node AddNodeInput(StrCat("^", name), init_op); + refs[name]++; + } + + // Erase inputs and outputs as they are not considered for deletion. + for (const auto& output : context.output_names) { + refs.erase(output); } - // Remove all dependencies associated with removed variables. - while (!removed_variable_names.empty()) { - auto name = removed_variable_names.back(); - removed_variable_names.pop_back(); + for (const auto& input : context.input_names) { + refs.erase(input); + } + + // Add nodes with a reference count of 0 for deletion. + for (auto entry : refs) { + if (entry.second == 0) { + removed_node_names.push_back(entry.first); + } + } + + while (!removed_node_names.empty()) { + auto name = removed_node_names.back(); + removed_node_names.pop_back(); + int i = 0; while (i < replaced_graph_def.node_size()) { - if (!replaced_graph_def.node(i).input_size()) { - if (replaced_graph_def.node(i).name() == name) { - replaced_graph_def.mutable_node()->SwapElements( - i, replaced_graph_def.node_size() - 1); - replaced_graph_def.mutable_node()->RemoveLast(); - continue; + // Revisit this to see if we can safely remove RestoreV2 nodes. + if ((replaced_graph_def.node(i).name() == name) && + (replaced_graph_def.node(i).op() != "RestoreV2")) { + for (const auto& input : replaced_graph_def.node(i).input()) { + auto parsed_input = StringReplace(input, "^", "", true); + refs[parsed_input] -= 1; + if (refs[parsed_input] == 0) { + removed_node_names.push_back(parsed_input); + } } - i++; + TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i)); continue; } int j = 0; + bool deleted_inputs = false; while (j < replaced_graph_def.node(i).input_size()) { if (replaced_graph_def.node(i).input(j) == name || replaced_graph_def.node(i).input(j) == ("^" + name)) { - replaced_graph_def.mutable_node(i)->mutable_input()->SwapElements( - j, replaced_graph_def.node(i).input_size() - 1); - replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast(); + TF_RETURN_IF_ERROR( + RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j)); + deleted_inputs = true; continue; } j++; } - if ((replaced_graph_def.node(i).input_size() == 0) || - (replaced_graph_def.node(i).op() == "Assign" && - replaced_graph_def.node(i).input_size() == 1)) { - removed_variable_names.push_back(replaced_graph_def.node(i).name()); - if (replaced_graph_def.node(i).input_size() == 1) { - removed_variable_names.push_back( - replaced_graph_def.node(i).input(0)); + if (deleted_inputs) { + if (replaced_graph_def.node(i).op() == "ConcatV2") { + if (replaced_graph_def.node(i).input_size() > 2) { + SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1, + replaced_graph_def.mutable_node(i)); + } else if (replaced_graph_def.node(i).input_size() == 2) { + if (refs[replaced_graph_def.node(i).input(1)] != 1) { + return errors::Internal( + "Expect axis tensor of ConcatV2 node to only be referenced " + "once."); + } + refs[replaced_graph_def.node(i).input(1)] -= 1; + removed_node_names.push_back(replaced_graph_def.node(i).input(1)); + replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast(); + replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N"); + replaced_graph_def.mutable_node(i)->set_op("Identity"); + } else { + return errors::Internal( + "ConcatV2 should have at least two elements"); + } + } + if ((replaced_graph_def.node(i).op() == "Assign" || + replaced_graph_def.node(i).op() == "Reshape" || + replaced_graph_def.node(i).op() == "Equal" || + replaced_graph_def.node(i).op() == "Mean" || + replaced_graph_def.node(i).op() == "ScalarSummary") && + replaced_graph_def.node(i).input_size() == 1) { + removed_node_names.push_back(replaced_graph_def.node(i).name()); + } + if (!replaced_graph_def.node(i).input_size()) { + removed_node_names.push_back(replaced_graph_def.node(i).name()); } - replaced_graph_def.mutable_node()->SwapElements( - i, replaced_graph_def.node_size() - 1); - replaced_graph_def.mutable_node()->RemoveLast(); - continue; } i++; } @@ -485,17 +588,22 @@ Status SparsifyGather(const GraphDef& input_graph_def, }; // clang-format on + GraphDef cleaned_input_graph_def; + RemoveAttributes(input_graph_def, {"_output_shapes"}, + &cleaned_input_graph_def); + GraphDef temp_output; std::unique_ptr ckpt_reader; TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader)); std::unique_ptr > shapes_and_slices; - TF_RETURN_IF_ERROR(ObtainVariableInfo(input_graph_def, &shapes_and_slices)); + TF_RETURN_IF_ERROR( + ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices)); - TF_RETURN_IF_ERROR(SparsifyGatherInternal(input_graph_def, shapes_and_slices, - context, gather_pattern, - ckpt_reader, &temp_output)); + TF_RETURN_IF_ERROR(SparsifyGatherInternal( + cleaned_input_graph_def, shapes_and_slices, context, gather_pattern, + ckpt_reader, &temp_output)); TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices, context, gather_v2_pattern, diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc index 000568a0cc9aceffa927abb1dc56e6586030fea0..d41321c9a6df755eed099ec453f162e2132cfb57 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc @@ -71,7 +71,7 @@ class SparsifyGatherTest : public ::testing::Test { } void TestSinglePartition(bool gather_v2, bool include_shared_init, - bool test_variable, + bool test_variable, bool test_kept_concat, const string& shared_init_name = "group_deps") { GraphDef graph_def; @@ -80,6 +80,8 @@ class SparsifyGatherTest : public ::testing::Test { // Build the graph. NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def); NodeDef* w_node; + NodeDef* zeros_const; + NodeDef* zeros_shape; NodeDef* zeros_node; NodeDef* assign_node; @@ -92,19 +94,27 @@ class SparsifyGatherTest : public ::testing::Test { } else { w_node = CreateNode("w/part_1", "VariableV2", {}, &graph_def); - zeros_node = - CreateNode("w/part_1/Initializer/zeros", "Const", {}, &graph_def); + zeros_shape = CreateNode("w/part_1/Initializer/zeros/shape_as_tensor", + "Const", {}, &graph_def); + zeros_const = CreateNode("w/part_1/Initializer/zeros/Const", "Const", {}, + &graph_def); + zeros_node = CreateNode("w/part_1/Initializer/zeros", "Fill", + {zeros_shape, zeros_const}, &graph_def); assign_node = CreateNode("w/part_1/Assign", "Assign", {w_node, zeros_node}, &graph_def); NodeDef* save_const_node = CreateNode("save/Const", "Const", {}, &graph_def); + Tensor tensor_names_values(DT_STRING, TensorShape({1})); + test::FillValues(&tensor_names_values, {"w"}); NodeDef* tensor_names_node = CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def); + SetNodeTensorAttr("value", tensor_names_values, + tensor_names_node); + NodeDef* tensor_shapes_slices_node = CreateNode( "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); - Tensor shapes_slices_val(DT_STRING, TensorShape({1})); shapes_slices_val.flat()(0) = "4 1 0,4:0,1"; SetNodeTensorAttr("value", shapes_slices_val, @@ -133,6 +143,26 @@ class SparsifyGatherTest : public ::testing::Test { } } + NodeDef* concat_axis_node = + CreateNode("linear/concat/axis", "Const", {}, &graph_def); + NodeDef* concat_input_node = + CreateNode("concat/input/node", "Const", {}, &graph_def); + NodeDef* concat_node = nullptr; + if (!test_kept_concat) { + concat_node = CreateNode( + "concat/node", "ConcatV2", + {identity_node, concat_input_node, concat_axis_node}, &graph_def); + SetNodeAttr("N", 2, concat_node); + } else { + NodeDef* concat_input_node_2 = + CreateNode("concat/input/node_2", "Const", {}, &graph_def); + concat_node = CreateNode("concat/node", "ConcatV2", + {identity_node, concat_input_node, + concat_input_node_2, concat_axis_node}, + &graph_def); + SetNodeAttr("N", 3, concat_node); + } + // Run the op. GraphDef result; TransformFuncContext context; @@ -151,12 +181,32 @@ class SparsifyGatherTest : public ::testing::Test { MapNamesToNodes(result, &node_lookup); // Check nodes. + EXPECT_EQ(0, + node_lookup.count("w/part_1/Initializer/zeros/shape_as_tensor")); + EXPECT_EQ(0, node_lookup.count("w/part_1/Initializer/zeros/Const")); EXPECT_EQ(0, node_lookup.count("w/part_1/Initializer/zeros")); EXPECT_EQ(0, node_lookup.count("w/part_1/Assign")); EXPECT_EQ(1, node_lookup.count("ids")); EXPECT_EQ("Const", node_lookup.at("ids")->op()); + EXPECT_EQ(1, node_lookup.count("concat/node")); + + if (!test_kept_concat) { + EXPECT_EQ(0, node_lookup.count("linear/concat/axis")); + EXPECT_EQ("Identity", node_lookup.at("concat/node")->op()); + EXPECT_EQ(1, node_lookup.at("concat/node")->input_size()); + EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0)); + } else { + EXPECT_EQ(1, node_lookup.count("linear/concat/axis")); + EXPECT_EQ("ConcatV2", node_lookup.at("concat/node")->op()); + EXPECT_EQ(3, node_lookup.at("concat/node")->input_size()); + EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0)); + EXPECT_EQ("concat/input/node_2", node_lookup.at("concat/node")->input(1)); + EXPECT_EQ("linear/concat/axis", node_lookup.at("concat/node")->input(2)); + EXPECT_EQ(2, node_lookup.at("concat/node")->attr().at("N").i()); + } + EXPECT_EQ(1, node_lookup.count("w/part_1/indices")); EXPECT_EQ("Const", node_lookup.at("w/part_1/indices")->op()); Tensor expected_indices_tensor(DT_INT64, TensorShape({3})); @@ -247,7 +297,11 @@ class SparsifyGatherTest : public ::testing::Test { // Two partitions NodeDef* w_node1; NodeDef* w_node2; + NodeDef* zeros_const1; + NodeDef* zeros_shape1; NodeDef* zeros_node1; + NodeDef* zeros_const2; + NodeDef* zeros_shape2; NodeDef* zeros_node2; NodeDef* assign_node1; NodeDef* assign_node2; @@ -260,51 +314,53 @@ class SparsifyGatherTest : public ::testing::Test { SetNodeTensorAttr("value", weights, w_node1); SetNodeTensorAttr("value", weights, w_node2); } else { - w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def); - zeros_node1 = - CreateNode("w1/part_1/Initializer/zeros", "Const", {}, &graph_def); - assign_node1 = CreateNode("w1/part_1/Assign", "Assign", - {w_node1, zeros_node1}, &graph_def); - NodeDef* save_const_node = CreateNode("save/Const", "Const", {}, &graph_def); - NodeDef* tensor_names_node1 = + + NodeDef* tensor_names_node = CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def); - NodeDef* tensor_shapes_slices_node1 = CreateNode( - "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); + Tensor tensor_names_values(DT_STRING, TensorShape({2})); + test::FillValues(&tensor_names_values, {"w1", "w2"}); + SetNodeTensorAttr("value", tensor_names_values, + tensor_names_node); - Tensor shapes_slices_val1(DT_STRING, TensorShape({1})); - shapes_slices_val1.flat()(0) = "4 1 0,4:0,1"; - SetNodeTensorAttr("value", shapes_slices_val1, - tensor_shapes_slices_node1); + NodeDef* tensor_shapes_slices_node = CreateNode( + "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); + Tensor shapes_slices_val(DT_STRING, TensorShape({2})); + shapes_slices_val.flat()(0) = "4 1 0,4:0,1"; + shapes_slices_val.flat()(1) = "4 1 0,4:0,1"; + SetNodeTensorAttr("value", shapes_slices_val, + tensor_shapes_slices_node); - NodeDef* restore_node1 = CreateNode( + NodeDef* restore_node = CreateNode( "save/RestoreV2", "RestoreV2", - {save_const_node, tensor_names_node1, tensor_shapes_slices_node1}, + {save_const_node, tensor_names_node, tensor_shapes_slices_node}, &graph_def); - CreateNode("save/Assign", "Assign", {w_node1, restore_node1}, &graph_def); + + w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def); + + zeros_shape1 = CreateNode("w1/part_1/Initializer/zeros/shape_as_tensor", + "Const", {}, &graph_def); + zeros_const1 = CreateNode("w1/part_1/Initializer/zeros/Const", "Const", + {}, &graph_def); + zeros_node1 = CreateNode("w1/part_1/Initializer/zeros", "Fill", + {zeros_shape1, zeros_const1}, &graph_def); + assign_node1 = CreateNode("w1/part_1/Assign", "Assign", + {w_node1, zeros_node1}, &graph_def); + + CreateNode("save/Assign", "Assign", {w_node1, restore_node}, &graph_def); w_node2 = CreateNode("w2/part_1", "VariableV2", {}, &graph_def); - zeros_node2 = - CreateNode("w2/part_1/Initializer/zeros", "Const", {}, &graph_def); + zeros_shape2 = CreateNode("w2/part_1/Initializer/zeros/shape_as_tensor", + "Const", {}, &graph_def); + zeros_const2 = CreateNode("w2/part_1/Initializer/zeros/Const", "Const", + {}, &graph_def); + zeros_node2 = CreateNode("w2/part_1/Initializer/zeros", "Fill", + {zeros_shape2, zeros_const2}, &graph_def); assign_node2 = CreateNode("w2/part_1/Assign", "Assign", {w_node2, zeros_node2}, &graph_def); - NodeDef* tensor_names_node2 = - CreateNode("save/RestoreV2_1/tensor_names", "Const", {}, &graph_def); - NodeDef* tensor_shapes_slices_node2 = CreateNode( - "save/RestoreV2_1/shape_and_slices", "Const", {}, &graph_def); - - Tensor shapes_slices_val2(DT_STRING, TensorShape({1})); - shapes_slices_val2.flat()(0) = "4 1 0,4:0,1"; - SetNodeTensorAttr("value", shapes_slices_val2, - tensor_shapes_slices_node2); - - NodeDef* restore_node2 = CreateNode( - "save/RestoreV2_1", "RestoreV2", - {save_const_node, tensor_names_node2, tensor_shapes_slices_node2}, - &graph_def); - CreateNode("save/Assign_1", "Assign", {w_node2, restore_node2}, + CreateNode("save/Assign_1", "Assign", {w_node2, restore_node}, &graph_def); BundleWriter writer(Env::Default(), checkpoint_path); @@ -322,6 +378,13 @@ class SparsifyGatherTest : public ::testing::Test { MakeGather("gather1", gather_v2, identity_node1, input_node, &graph_def); MakeGather("gather2", gather_v2, identity_node2, input_node, &graph_def); + NodeDef* concat_axis_node = + CreateNode("linear/concat/axis", "Const", {}, &graph_def); + NodeDef* concat_node = CreateNode( + "concat/node", "ConcatV2", + {identity_node1, identity_node2, concat_axis_node}, &graph_def); + SetNodeAttr("N", 2, concat_node); + // Shared init node if (include_shared_init) { if (!test_variable) { @@ -350,8 +413,14 @@ class SparsifyGatherTest : public ::testing::Test { MapNamesToNodes(result, &node_lookup); // Check nodes. + EXPECT_EQ(0, + node_lookup.count("w1/part_1/Initializer/zeros/shape_as_tensor")); + EXPECT_EQ(0, node_lookup.count("w1/part_1/Initializer/zeros/Const")); EXPECT_EQ(0, node_lookup.count("w1/part_1/Initializer/zeros")); EXPECT_EQ(0, node_lookup.count("w1/part_1/Assign")); + EXPECT_EQ(0, + node_lookup.count("w2/part_1/Initializer/zeros/shape_as_tensor")); + EXPECT_EQ(0, node_lookup.count("w2/part_1/Initializer/zeros/Const")); EXPECT_EQ(0, node_lookup.count("w2/part_1/Initializer/zeros")); EXPECT_EQ(0, node_lookup.count("w2/part_1/Assign")); EXPECT_EQ(1, node_lookup.count("ids")); @@ -487,6 +556,9 @@ class SparsifyGatherTest : public ::testing::Test { node_lookup.at("gather2/LookupTableFind")->input(2)); EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0)); + EXPECT_EQ(0, node_lookup.count("linear/concat/axis")); + EXPECT_EQ(0, node_lookup.count("concat/node")); + // Check control deps. EXPECT_EQ(2, node_lookup.at(shared_init_name)->input_size()); EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(), @@ -522,18 +594,31 @@ class SparsifyGatherTest : public ::testing::Test { }; TEST_F(SparsifyGatherTest, TestSinglePartition) { - TestSinglePartition(false, false, false); - TestSinglePartition(false, true, false); - TestSinglePartition(true, false, false); - TestSinglePartition(true, true, false); - TestSinglePartition(false, false, true); - TestSinglePartition(false, true, true); - TestSinglePartition(true, false, true); - TestSinglePartition(true, true, true); - TestSinglePartition(false, true, false, "shared_inits"); - TestSinglePartition(true, true, false, "shared_inits"); - TestSinglePartition(false, true, true, "shared_inits"); - TestSinglePartition(true, true, true, "shared_inits"); + TestSinglePartition(false, false, false, false); + TestSinglePartition(false, true, false, false); + TestSinglePartition(true, false, false, false); + TestSinglePartition(true, true, false, false); + TestSinglePartition(false, false, true, false); + TestSinglePartition(false, true, true, false); + TestSinglePartition(true, false, true, false); + TestSinglePartition(true, true, true, false); + TestSinglePartition(false, true, false, false, "shared_inits"); + TestSinglePartition(true, true, false, false, "shared_inits"); + TestSinglePartition(false, true, true, false, "shared_inits"); + TestSinglePartition(true, true, true, false, "shared_inits"); + + TestSinglePartition(false, false, false, true); + TestSinglePartition(false, true, false, true); + TestSinglePartition(true, false, false, true); + TestSinglePartition(true, true, false, true); + TestSinglePartition(false, false, true, true); + TestSinglePartition(false, true, true, true); + TestSinglePartition(true, false, true, true); + TestSinglePartition(true, true, true, true); + TestSinglePartition(false, true, false, true, "shared_inits"); + TestSinglePartition(true, true, false, true, "shared_inits"); + TestSinglePartition(false, true, true, true, "shared_inits"); + TestSinglePartition(true, true, true, true, "shared_inits"); } TEST_F(SparsifyGatherTest, TestMultiPartition) { diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index dbc81599de8539ce58933f9d40bf99fcae8f8e67..614457e8996491a60d4a7df213180117bce321ad 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -99,6 +99,7 @@ genrule( "//third_party/hadoop:LICENSE.txt", "//third_party/eigen3:LICENSE", "//third_party/fft2d:LICENSE", + "@aws//:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", @@ -112,8 +113,10 @@ genrule( "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", "@libxsmm_archive//:LICENSE", + "@llvm//:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", + "@nasm//:LICENSE", "@nsync//:LICENSE", "@png_archive//:LICENSE", "@protobuf_archive//:LICENSE", @@ -134,6 +137,7 @@ genrule( "//third_party/hadoop:LICENSE.txt", "//third_party/eigen3:LICENSE", "//third_party/fft2d:LICENSE", + "@aws//:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", @@ -147,8 +151,10 @@ genrule( "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", "@libxsmm_archive//:LICENSE", + "@llvm//:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", + "@nasm//:LICENSE", "@nsync//:LICENSE", "@png_archive//:LICENSE", "@protobuf_archive//:LICENSE", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 598080ed2753b862056ebcc76c4c572ae45b46e6..fb6eaa4faa28b4f6b17e1774907c0c9ff58d6ada 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -11,6 +11,7 @@ load( ) load("//third_party/mkl:build_defs.bzl", "if_mkl") load("//tensorflow:tensorflow.bzl", "if_cuda") +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps") # This returns a list of headers of all public header libraries (e.g., @@ -70,7 +71,6 @@ py_binary( "//tensorflow/python/eager:eager_pip", "//tensorflow/contrib/summary:summary_test_util", # These targets don't build on Windows yet. Exclude them for now. - # "//tensorflow/contrib/ndlstm", # "//tensorflow/contrib/slim", # "//tensorflow/contrib/slim/python/slim/nets:nets_pip", # "//tensorflow/contrib/specs", @@ -88,13 +88,20 @@ filegroup( "//third_party/eigen3:LICENSE", "//third_party/fft2d:LICENSE", "//third_party/hadoop:LICENSE.txt", + "@absl_py//absl/flags:LICENSE", + "@arm_neon_2_x86_sse//:LICENSE", + "@astor_archive//:LICENSE", + "@aws//:LICENSE", "@boringssl//:LICENSE", + "@com_google_absl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", "@fft2d//:fft/readme.txt", + "@flatbuffers//:LICENSE.txt", + "@gast_archive//:PKG-INFO", "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", "@grpc//:LICENSE", @@ -105,11 +112,15 @@ filegroup( "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@grpc//third_party/nanopb:LICENSE.txt", + "@nasm//:LICENSE", "@nsync//:LICENSE", + "@pcre//:LICENCE", "@png_archive//:LICENSE", "@protobuf_archive//:LICENSE", "@six_archive//:LICENSE", "@snappy//:COPYING", + "@swig//:LICENSE", + "@termcolor_archive//:COPYING.txt", "@zlib_archive//:zlib.h", "@org_python_pypi_backports_weakref//:LICENSE", ] + if_mkl([ @@ -137,9 +148,9 @@ sh_binary( "//tensorflow/contrib/boosted_trees:boosted_trees_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:prefetching_py", + "//tensorflow/contrib/data/python/ops:contrib_op_loader", "//tensorflow/contrib/eager/python/examples:examples_pip", - "//tensorflow/contrib/eager/python:checkpointable", + "//tensorflow/contrib/eager/python:checkpointable_utils", "//tensorflow/contrib/eager/python:evaluator", "//tensorflow/contrib/gan:gan", "//tensorflow/contrib/graph_editor:graph_editor_pip", @@ -148,12 +159,12 @@ sh_binary( "//tensorflow/contrib/lite/toco:toco", "//tensorflow/contrib/lite/toco/python:toco_wrapper", "//tensorflow/contrib/lite/toco/python:toco_from_protos", - "//tensorflow/contrib/ndlstm:ndlstm", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/predictor:predictor_pip", - "//tensorflow/contrib/py2tf:py2tf_internal", + "//tensorflow/contrib/py2tf:py2tf", "//tensorflow/contrib/py2tf/converters:converters", "//tensorflow/contrib/py2tf/converters:test_lib", + "//tensorflow/contrib/py2tf/impl:impl", "//tensorflow/contrib/py2tf/pyct:pyct", "//tensorflow/contrib/py2tf/pyct/static_analysis:static_analysis", "//tensorflow/contrib/receptive_field:receptive_field_pip", @@ -181,7 +192,9 @@ sh_binary( "//tensorflow/python:test_ops", "//tensorflow/tools/dist_test/server:grpc_tensorflow_server", ], - }) + if_mkl(["//third_party/mkl:intel_binary_blob"]), + }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([ + "//tensorflow/contrib/tensorrt:init_py", + ]), ) # A genrule for generating a marker file for the pip package on Windows diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index 38a900738786e2413f5b1dd914caaebeafc92e21..73d759eb130633094b402c821cc32eb76c076a44 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -65,7 +65,6 @@ BLACKLIST = [ "//tensorflow/contrib/framework:checkpoint_ops_testdata", "//tensorflow/contrib/bayesflow:reinforce_simple_example", "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long - "//tensorflow/contrib/py2tf:py2tf_internal", "//tensorflow/contrib/timeseries/examples:predict", "//tensorflow/contrib/timeseries/examples:multivariate", "//tensorflow/contrib/timeseries/examples:known_anomaly", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 62df6453fb5d39728c2985a28a70a263d79804b1..4b6f123daa7b528173234a2bffd30ead2aa9fc0e 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -29,16 +29,17 @@ from setuptools.dist import Distribution # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '1.5.0-rc1' +_VERSION = '1.6.0-rc1' REQUIRED_PACKAGES = [ 'absl-py >= 0.1.6', 'astor >= 0.6.0', 'gast >= 0.2.0', - 'numpy >= 1.12.1', + 'grpcio >= 1.8.6', + 'numpy >= 1.13.3', 'six >= 1.10.0', 'protobuf >= 3.4.0', - 'tensorflow-tensorboard >= 0.4.0', + 'tensorboard >= 1.6.0, < 1.7.0', 'termcolor >= 1.1.0', ] @@ -61,7 +62,7 @@ else: if 'tf_nightly' in project_name: for i, pkg in enumerate(REQUIRED_PACKAGES): if 'tensorboard' in pkg: - REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.5.0a0, < 1.6.0a0' + REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.7.0a0, < 1.8.0a0' break # weakref.finalize and enum were introduced in Python 3.4 @@ -79,13 +80,13 @@ CONSOLE_SCRIPTS = [ # is now declared by the tensorboard pip package. If we remove the # TensorBoard command, pip will inappropriately remove it during install, # even though the command is not removed, just moved to a different wheel. - 'tensorboard = tensorboard.main:main', + 'tensorboard = tensorboard.main:run_main', ] # pylint: enable=line-too-long # remove the tensorboard console script if building tf_nightly if 'tf_nightly' in project_name: - CONSOLE_SCRIPTS.remove('tensorboard = tensorboard.main:main') + CONSOLE_SCRIPTS.remove('tensorboard = tensorboard.main:run_main') TEST_PACKAGES = [ 'scipy >= 0.15.1', @@ -180,9 +181,10 @@ def find_files(pattern, root): matches = ['../' + x for x in find_files('*', 'external') if '.py' not in x] -so_lib_paths = [i for i in os.listdir('.') - if os.path.isdir(i) - and fnmatch.fnmatch(i, '_solib_*')] +so_lib_paths = [ + i for i in os.listdir('.') + if os.path.isdir(i) and fnmatch.fnmatch(i, '_solib_*') +] for path in so_lib_paths: matches.extend( diff --git a/tensorflow/tools/test/file_name_test.py b/tensorflow/tools/test/file_name_test.py new file mode 100644 index 0000000000000000000000000000000000000000..16fb8a822d09ed136cf79dd2473fc202ca632d83 --- /dev/null +++ b/tensorflow/tools/test/file_name_test.py @@ -0,0 +1,48 @@ +#!/usr/bin/python +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Test that checks if we have any issues with case insensitive filesystems. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) +ERROR_MESSAGE = """ +Files with same name but different case detected in directory: {} +""" + + +def main(): + # Make sure BASE_DIR ends with tensorflow. If it doesn't, we probably + # computed the wrong directory. + if os.path.split(BASE_DIR)[-1] != 'tensorflow': + raise AssertionError( + "BASE_DIR = '%s' doesn't end with tensorflow" % BASE_DIR) + + for dirpath, dirnames, filenames in os.walk(BASE_DIR, followlinks=True): + lowercase_directories = [x.lower() for x in dirnames] + lowercase_files = [x.lower() for x in filenames] + + lowercase_dir_contents = lowercase_directories + lowercase_files + if len(lowercase_dir_contents) != len(set(lowercase_dir_contents)): + raise AssertionError(ERROR_MESSAGE.format(dirpath)) + + +if __name__ == '__main__': + main() diff --git a/tensorflow/tools/test/run_and_gather_logs_lib.py b/tensorflow/tools/test/run_and_gather_logs_lib.py index a953ed1b53d13504f92d2ffeb4c1ac6bcb0b8477..3b4921bb983a72223b092d99eb3fb59332fc6345 100644 --- a/tensorflow/tools/test/run_and_gather_logs_lib.py +++ b/tensorflow/tools/test/run_and_gather_logs_lib.py @@ -136,7 +136,7 @@ def run_and_gather_logs(name, test_name, test_args, gpu_config = gpu_info_lib.gather_gpu_devices() if gpu_config: gpu_name = gpu_config[0].model - gpu_short_name_match = re.search(r"Tesla (K40|K80|P100)", gpu_name) + gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)", gpu_name) if gpu_short_name_match: gpu_short_name = gpu_short_name_match.group(0) test_adjusted_name = name + "|" + gpu_short_name.replace(" ", "_") diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 9145d9e58a3df6c074d5ac44a665a33339c45cc6..2e84d83fe49b08a4c00baa50021759aa0c47c7e3 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -1,6 +1,7 @@ # TensorFlow external dependencies that can be loaded in WORKSPACE files. load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") +load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure") load("//third_party/mkl:build_defs.bzl", "mkl_repository") load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/py:python_configure.bzl", "python_configure") @@ -68,6 +69,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): check_bazel_version_at_least("0.5.4") clang6_configure(name="local_config_clang6") cuda_configure(name="local_config_cuda") + tensorrt_configure(name="local_config_tensorrt") git_configure(name="local_config_git") sycl_configure(name="local_config_sycl") python_configure(name="local_config_python") @@ -112,16 +114,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "5996380e3e8b981f55d1c8d58e709c00dbb4806ba367be75d0925a68cc2f6478", strip_prefix = "abseil-cpp-720c017e30339fd1786ce4aac68bc8559736e53f", + build_file = str(Label("//third_party:com_google_absl.BUILD")), ) tf_http_archive( name = "eigen_archive", urls = [ - "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/14e1418fcf12.tar.gz", - "https://bitbucket.org/eigen/eigen/get/14e1418fcf12.tar.gz", + "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/2355b229ea4c.tar.gz", + "https://bitbucket.org/eigen/eigen/get/2355b229ea4c.tar.gz", ], - sha256 = "2b526c6888639025323fd4f2600533c0f982d304ea48e4f1663e8066bd9f6368", - strip_prefix = "eigen-eigen-14e1418fcf12", + sha256 = "0cadb31a35b514bf2dfd6b5d38205da94ef326ec6908fc3fd7c269948467214f", + strip_prefix = "eigen-eigen-2355b229ea4c", build_file = str(Label("//third_party:eigen.BUILD")), ) @@ -176,11 +179,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "gemmlowp", urls = [ - "https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", - "https://github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", + "https://mirror.bazel.build/github.com/google/gemmlowp/archive/7c7c744640ddc3d0af18fb245b4d23228813a71b.zip", + "https://github.com/google/gemmlowp/archive/7c7c744640ddc3d0af18fb245b4d23228813a71b.zip", ], - sha256 = "dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d", - strip_prefix = "gemmlowp-010bb3e71a26ca1d0884a167081d092b43563996", + sha256 = "b852cc90259a7357c8a323f108f2cec6e85979fc3b18b5590b99e0130044b2cf", + strip_prefix = "gemmlowp-7c7c744640ddc3d0af18fb245b4d23228813a71b", ) tf_http_archive( @@ -350,16 +353,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "protobuf_archive", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", - "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", ], - sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", - strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", - # TODO: remove patching when tensorflow stops linking same protos into - # multiple shared libraries loaded in runtime by python. - # This patch fixes a runtime crash when tensorflow is compiled - # with clang -O2 on Linux (see https://github.com/tensorflow/tensorflow/issues/8394) - patch_file = str(Label("//third_party/protobuf:add_noinlines.patch")), + sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", + strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", ) # We need to import the protobuf library under the names com_google_protobuf @@ -368,21 +366,21 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "com_google_protobuf", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", - "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", ], - sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", - strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", + sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", + strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", ) tf_http_archive( name = "com_google_protobuf_cc", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", - "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", + "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", ], - sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", - strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", + sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", + strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", ) tf_http_archive( @@ -475,11 +473,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/11a2ca6eea8a7fe240a14c0c35fd2017341279be.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/11a2ca6eea8a7fe240a14c0c35fd2017341279be.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/11b0e47b5b79bab22d27b6b2952b1f7582848063.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/11b0e47b5b79bab22d27b6b2952b1f7582848063.tar.gz", ], - sha256 = "b5429ccf8d57273cb8489714f728c997cd720ec66fc2c0292422ab8f0e729ce0", - strip_prefix = "llvm-11a2ca6eea8a7fe240a14c0c35fd2017341279be", + sha256 = "b870b6f5df94c4c0cf7c6957046fca354c37d7641e838e905279a7509b0705e9", + strip_prefix = "llvm-11b0e47b5b79bab22d27b6b2952b1f7582848063", build_file = str(Label("//third_party/llvm:llvm.BUILD")), ) @@ -558,6 +556,18 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:nccl.BUILD")), ) + tf_http_archive( + name = "kafka", + urls = [ + "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz", + "https://github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz", + ], + sha256 = "dd035d57c8f19b0b612dd6eefe6e5eebad76f506e302cccb7c2066f25a83585e", + strip_prefix = "librdkafka-0.11.1", + build_file = str(Label("//third_party:kafka/BUILD")), + patch_file = str(Label("//third_party/kafka:config.patch")), + ) + tf_http_archive( name = "aws", urls = [ @@ -660,6 +670,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31", strip_prefix = "cub-1.7.4", build_file = str(Label("//third_party:cub.BUILD")), + # TODO: remove the patch when upstream fix is accepted and released. + # PR with a fix: https://github.com/NVlabs/cub/pull/125 + patch_file = str(Label("//third_party/cub:fix_compilation_in_clang.patch")), ) tf_http_archive(

  • Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
    tensorflow-1.5.0-rc1CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
    tensorflow_gpu-1.5.0-rc1GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
    tensorflow-1.6.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
    tensorflow_gpu-1.6.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
    tensorflow-1.6.0rc0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
    tensorflow_gpu-1.6.0rc0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
    tensorflow-1.5.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
    tensorflow_gpu-1.5.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
    tensorflow-1.4.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
    tensorflow_gpu-1.4.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.368
    tensorflow-1.3.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A